728x90
nn.CrossEntropyLoss() 혹은 F.cross_entropy 를 사용했을 때 나타나는 에러일 것이다.
nn.CrossEntropyLoss()(pred, target) 이렇게 계산이 되는데
가령 pred의 shape의 [ B, C]라면 C는 클래스 갯수 B는 배치 사이즈
target의 shape은 [B] 가 되어야 하는데 [B, 1]이렇게 돼서 문제가 발생하는 거다.
그래서 문제를 해결하려면 target의 shape를 축소해주자.
nn.CrossEntropyLoss()(pred, target.squeeze(dim=-1))
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
https://discuss.pytorch.org/t/runtimeerror-multi-target-not-supported-newbie/10216/24