728x90
This happens because the opt_D.step() modifies the parameters of your discriminator inplace. But these parameters are required to compute the gradient for the generator. Hence the error.
We fixed the inplace detection for the optimizers in 1.5, this is why it works in 1.4.
You should re-organize your code to only do the steps() after all the gradients have been computed or make sure you don't modify parameters that are required.
Something like that should work.
위 글을 보니 제너레이터를 먼저 학습 시키고 그 다음 디스크리미네이터를 학습시키는 방향으로 변환하니 코드가 잘 돌아감을 알 수 있다.
중요한 건 detach가 들어간 부분!
Generator가 이미 생성한 이미지에 대해서는 backpropagation을 진행하면 안된다. 따라서 나는 아래와 같이 코드를 변경해주었다.
변경 이전 코드
for i, batch in enumerate(train_batch_iter):
if with_charid:
font_ids, char_ids, batch_images = batch
else:
font_ids, batch_images = batch
embedding_ids = font_ids
if self.GPU:
batch_images = batch_images.cuda()
if flip_labels:
np.random.shuffle(embedding_ids)
# target / source images
real_target = batch_images[:, 0, :, :]
real_target = real_target.view([self.batch_size, 1, self.img_size, self.img_size])
real_source = batch_images[:, 1, :, :]
real_source = real_source.view([self.batch_size, 1, self.img_size, self.img_size])
# centering
for idx, (image_S, image_T) in enumerate(zip(real_source, real_target)):
image_S = image_S.cpu().detach().numpy().reshape(self.img_size, self.img_size)
image_S = centering_image(image_S, resize_fix=90)
real_source[idx] = torch.tensor(image_S).view([1, self.img_size, self.img_size])
image_T = image_T.cpu().detach().numpy().reshape(self.img_size, self.img_size)
image_T = centering_image(image_T, resize_fix=resize_fix)
real_target[idx] = torch.tensor(image_T).view([1, self.img_size, self.img_size])
# generate fake image form source image
fake_target, encoded_source, _ = Generator(real_source, En, De, \
self.embeddings, embedding_ids, \
GPU=self.GPU, encode_layers=True)
real_TS = torch.cat([real_source, real_target], dim=1)
fake_TS = torch.cat([real_source, fake_target], dim=1)
# Scoring with Discriminator
real_score, real_score_logit, real_cat_logit = D(real_TS)
fake_score, fake_score_logit, fake_cat_logit = D(fake_TS)
# Get encoded fake image to calculate constant loss
encoded_fake = En(fake_target)[0]
const_loss = Lconst_penalty * mse_criterion(encoded_source, encoded_fake)
# category loss
real_category = torch.from_numpy(np.eye(self.fonts_num)[embedding_ids]).float()
if self.GPU:
real_category = real_category.cuda()
real_category_loss = bce_criterion(real_cat_logit, real_category)
fake_category_loss = bce_criterion(fake_cat_logit, real_category)
category_loss = 0.5 * (real_category_loss + fake_category_loss)
# labels
if self.GPU:
one_labels = torch.ones([self.batch_size, 1]).cuda()
zero_labels = torch.zeros([self.batch_size, 1]).cuda()
else:
one_labels = torch.ones([self.batch_size, 1])
zero_labels = torch.zeros([self.batch_size, 1])
# binary loss - T/F
real_binary_loss = bce_criterion(real_score_logit, one_labels)
fake_binary_loss = bce_criterion(fake_score_logit, zero_labels)
binary_loss = real_binary_loss + fake_binary_loss
# L1 loss between real and fake images
l1_loss = L1_penalty * l1_criterion(real_target, fake_target)
# cheat loss for generator to fool discriminator
cheat_loss = bce_criterion(fake_score_logit, one_labels)
# g_loss, d_loss
g_loss = cheat_loss + l1_loss + fake_category_loss + const_loss
d_loss = binary_loss + category_loss
# train Discriminator
D.zero_grad()
d_loss.backward(retain_graph=True)
d_optimizer.step()
# train Generator
En.zero_grad()
De.zero_grad()
g_loss.backward(retain_graph=True)
g_optimizer.step()
변경 이후 코드
for i, batch in enumerate(train_batch_iter):
labels, batch_images = batch
embedding_ids = labels
if GPU:
batch_images = batch_images.cuda()
if flip_labels:
np.random.shuffle(embedding_ids)
# target / source images
real_target = batch_images[:, 0, :, :].view([BATCH_SIZE, 1, IMG_SIZE, IMG_SIZE])
real_source = batch_images[:, 1, :, :].view([BATCH_SIZE, 1, IMG_SIZE, IMG_SIZE])
# generate fake image form source image
fake_target, encoded_source, encoder_layers = Generator(real_source, En, De, embeddings, embedding_ids, GPU=GPU)
real_TS = torch.cat([real_source, real_target], dim=1)
fake_TS = torch.cat([real_source, fake_target], dim=1)
# Scoring with Discriminator
fake_score, fake_score_logit, fake_cat_logit = D(fake_TS)
# Get encoded fake image to calculate constant loss
encoded_fake = En(fake_target)[0]
const_loss = Lconst_penalty * mse_criterion(encoded_source, encoded_fake)
# category loss
real_category = torch.from_numpy(np.eye(FONTS_NUM)[embedding_ids]).float()
if GPU:
real_category = real_category.cuda()
fake_category_loss = bce_criterion(fake_cat_logit, real_category)
# labels
if GPU:
one_labels = torch.ones([BATCH_SIZE, 1]).cuda()
zero_labels = torch.zeros([BATCH_SIZE, 1]).cuda()
else:
one_labels = torch.ones([BATCH_SIZE, 1])
zero_labels = torch.zeros([BATCH_SIZE, 1])
# L1 loss between real and fake images
l1_loss = L1_penalty * l1_criterion(real_target, fake_target)
# cheat loss for generator to fool discriminator
cheat_loss = bce_criterion(fake_score_logit, one_labels)
# g_loss, d_loss
g_loss = cheat_loss+ l1_loss + fake_category_loss + const_loss
# train Generator
En.zero_grad()
De.zero_grad()
g_loss.backward(retain_graph=True)
g_optimizer.step()
fake_TS = torch.cat([real_source, fake_target.detach().clone()], dim=1)
real_score, real_score_logit, real_cat_logit = D(real_TS)
fake_score, fake_score_logit, fake_cat_logit = D(fake_TS.detach().clone())
# binary loss for discriminator
real_binary_loss = bce_criterion(real_score_logit, one_labels)
fake_binary_loss = bce_criterion(fake_score_logit, zero_labels)
binary_loss = real_binary_loss + fake_binary_loss
#category loss for discriminator
fake_category_loss = bce_criterion(fake_cat_logit, real_category)
real_category_loss = bce_criterion(real_cat_logit, real_category)
category_loss = 0.5 * (real_category_loss + fake_category_loss)
d_loss = binary_loss + category_loss
# train Discriminator
D.zero_grad()
d_loss.backward(retain_graph=True)
https://jeinalog.tistory.com/15
https://github.com/pytorch/pytorch/issues/39141