본문 바로가기

Data-science/deep learning

[pytorch] one-hot encoding이 반드시 필요할까?

728x90

그렇지 않다.

F.cross_entropy 에선 one-hot 인코딩을 하지 않아도 cross entropy loss를 잘 계산해준다.

예를 들어 카테고리가 3개라고 하면 pred 값은 (B, 3) 형태일 것이다.

B는 배치 사이즈

이때 target의 형태는 [B] 이면 된다. 각 Batch 마다 하나의 값을 int 값으로 갖고 있으면 된단 말이다. pytorch 에선 .long()으로 케스팅해주면된다. [B, 3]처럼 one-hot encoding 해줄 필요 없음.

 

 

 

Is One-Hot Encoding required for using PyTorch's Cross Entropy Loss Function?

 

Is One-Hot Encoding required for using PyTorch's Cross Entropy Loss Function?

For example, if I want to solve the MNIST classification problem, we have 10 output classes. With PyTorch, I would like to use the torch.nn.CrossEntropyLoss function. Do I have to format the target...

stackoverflow.com