배경
간단한 knowledge distillation을 구현하고 싶은데 그러면 segmentation loss에 대해 빠삭하게 파악하고 있어야 한다. loss는 분류 문제와 동일하다하고 자세한 설명은 없는 글들이 대부분이다. 이에 도움이 되고자 글을 작성한다.
segmentation의 경우 각 object에 대한 mask데이터로 정답(ground truth)을 표현한다.
아래 그림처럼 object가 있는 부분만 흰색이고 없는 경우 검은 배경이다. 과제가 조금 달라 loss를 이해하는 방식에도 조금 차이가 있다.
Classification loss
우선 분류 문제 부터 살펴보자.
binary classification 에서 softmax후 binary_crossentropy를 적용한다고 생각해보면
softmax의 출력은 입력과 같은 차원이므로 출력 값도 노드가 2개인 (2,)이런 shape을 가진다. 또한 softmax 값과 비교하기 위한 실제 정답은 one_hot vector로 표현되어야 한다. 0번 클래스의 경우 [1,0],
1번 클래스의 경우 [0,1]에 해당한다.
pred_x : 예측 값, y : 라벨링 데이터. 이에대한 loss는 아래와 같다.
Segmentation loss
segmentation은 이미지 내의 pixel 별로 위 작업을 시행한다.
그러기 위해 우선 예측 값의 형태부터 알아보자.
sementation에서는 출력 되는 결과의 형태는 이미지이다 (가령 100x100이미지에 클래스가 2개라고 하자.) 그러면 output은 100x100x2가 나오고. 여기서 100x100이미지 각 pixel별로 softmax를 적용한다. 그럼 확률값을 가진 100x100x2 형태의 softmax_output이 나온다.
이제 y는 어떤 형태일까?
위 loss 계산식을 그대로 쓰기 위해선…
위와 같은 마스크 데이터를 위 예측 값 100x100x2와 비교할 수 있는 자료형으로 표현해 줘야한다.
분류 문제 같은 경우엔 정답 데이터를 [0,0,1] 혹은 [0,1,0], [1,0,0]으로 표현할 수 있다. 그런데 segmentation labeling 은 이미지 전체 각 pixel 별로 이 작업을 거친다고 보면 된다.
클래스가 2개인 경우를 살펴보자. y는 100x100x2의 형태를 가지고 마지막 2는 채널 수(=클래스 수)에 해당한다.
y(ground_truth) 의 0번 채널은 onehot vector의 0번 클래스에 해당하므로, 해당 클래스 부분의 데이터를 모두 1로 표현하고,
y의 1번 채널은 onehot vector의 1번 클래스에 해당한 값들만 1로 가지고 나머지 클래스 데이터들은 모두 0으로 표현한다.
위 마스크 데이터에서 흰색 부분을 1번 클래스에 해당한다고 하면,
- y의 0번 채널은 흰색 부분을 제외한 모든 부분은 1의 값을 가지고 흰색부분은 0을 가지게 된다.
- y의 1번 채널은 흰색 부분만 1의 값을 가지고, 나머지 부분은 모두 0을 가지게 된다.
이를 mask데이터를 onehot형식으로 표현한다고 한다.
관련 코드는 다음과 같다.
def get_one_hot_encoded_mask(self, mask_img):
y_img = np.squeeze(mask_img, axis=2)
one_hot_mask = np.zeros((self.image_height, self.image_width, self.n_classes))
back = (y_img == 0)
object = (y_img > 0)
one_hot_mask[:, :, 0] = np.where(back, 1, 0)
one_hot_mask[:, :, 1] = np.where(object, 1, 0)
return one_hot_mask
back은 background, 0번 클래스를 나타낸다. object는 말 그대로 object이고, 1번 클래스를 나타낸다.
- mask_img에서 0번 클래스에 해당하는 값들(값이 0)에 대한 mask 데이터를 뽑는다.
- mask_img에서 1번 클래스에 해당하는 값들(값이 1)에 대한 mask 데이터를 뽑는다.
- one_hot_mask의 0번 채널(0번 클래스)에서 0번 클래스에 해당하는 값을 모두 1로, 1번 클래스에 해당하는 값은 모두 0으로 한다. (1에서 획득한 mask 데이터를 이용)
- one_hot_mask의 1번 채널(1번 클래스)에서 1번 클래스에 해당하는 값을 모두 1로, 0번 클래스에 해당하는 값은 모두 0으로 한다. (2에서 획득한 mask 데이터를 이용)
그리고 pixel 별로 다시 아래 식을 계산한다.
그리고 그 전체 픽셀에 각 값에 대한 평균을 계산하는 것으로 loss를 계산한다.
이 loss를 최소화 한다는 건, 아래 그림처럼 같은 클래스 이미지 사이에 유사도를 최대화 하는 것으로도 이해할 수 있다.
참고 : https://papers.nips.cc/paper/9291-region-mutual-information-loss-for-semantic-segmentation.pdf