[pytorch error] GAN학습시 에러, RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation 에러 발생시 대처법
This happens because the opt_D.step() modifies the parameters of your discriminator inplace. But these parameters are required to compute the gradient for the generator. Hence the error.
We fixed the inplace detection for the optimizers in 1.5, this is why it works in 1.4.
You should re-organize your code to only do the steps() after all the gradients have been computed or make sure you don't modify parameters that are required.
Something like that should work.
위 글을 보니 제너레이터를 먼저 학습 시키고 그 다음 디스크리미네이터를 학습시키는 방향으로 변환하니 코드가 잘 돌아감을 알 수 있다.
중요한 건 detach가 들어간 부분!
Generator가 이미 생성한 이미지에 대해서는 backpropagation을 진행하면 안된다. 따라서 나는 아래와 같이 코드를 변경해주었다.
변경 이전 코드
for i, batch in enumerate(train_batch_iter):
if with_charid:
font_ids, char_ids, batch_images = batch
else:
font_ids, batch_images = batch
embedding_ids = font_ids
if self.GPU:
batch_images = batch_images.cuda()
if flip_labels:
np.random.shuffle(embedding_ids)
# target / source images
real_target = batch_images[:, 0, :, :]
real_target = real_target.view([self.batch_size, 1, self.img_size, self.img_size])
real_source = batch_images[:, 1, :, :]
real_source = real_source.view([self.batch_size, 1, self.img_size, self.img_size])
# centering
for idx, (image_S, image_T) in enumerate(zip(real_source, real_target)):
image_S = image_S.cpu().detach().numpy().reshape(self.img_size, self.img_size)
image_S = centering_image(image_S, resize_fix=90)
real_source[idx] = torch.tensor(image_S).view([1, self.img_size, self.img_size])
image_T = image_T.cpu().detach().numpy().reshape(self.img_size, self.img_size)
image_T = centering_image(image_T, resize_fix=resize_fix)
real_target[idx] = torch.tensor(image_T).view([1, self.img_size, self.img_size])
# generate fake image form source image
fake_target, encoded_source, _ = Generator(real_source, En, De, \
self.embeddings, embedding_ids, \
GPU=self.GPU, encode_layers=True)
real_TS = torch.cat([real_source, real_target], dim=1)
fake_TS = torch.cat([real_source, fake_target], dim=1)
# Scoring with Discriminator
real_score, real_score_logit, real_cat_logit = D(real_TS)
fake_score, fake_score_logit, fake_cat_logit = D(fake_TS)
# Get encoded fake image to calculate constant loss
encoded_fake = En(fake_target)[0]
const_loss = Lconst_penalty * mse_criterion(encoded_source, encoded_fake)
# category loss
real_category = torch.from_numpy(np.eye(self.fonts_num)[embedding_ids]).float()
if self.GPU:
real_category = real_category.cuda()
real_category_loss = bce_criterion(real_cat_logit, real_category)
fake_category_loss = bce_criterion(fake_cat_logit, real_category)
category_loss = 0.5 * (real_category_loss + fake_category_loss)
# labels
if self.GPU:
one_labels = torch.ones([self.batch_size, 1]).cuda()
zero_labels = torch.zeros([self.batch_size, 1]).cuda()
else:
one_labels = torch.ones([self.batch_size, 1])
zero_labels = torch.zeros([self.batch_size, 1])
# binary loss - T/F
real_binary_loss = bce_criterion(real_score_logit, one_labels)
fake_binary_loss = bce_criterion(fake_score_logit, zero_labels)
binary_loss = real_binary_loss + fake_binary_loss
# L1 loss between real and fake images
l1_loss = L1_penalty * l1_criterion(real_target, fake_target)
# cheat loss for generator to fool discriminator
cheat_loss = bce_criterion(fake_score_logit, one_labels)
# g_loss, d_loss
g_loss = cheat_loss + l1_loss + fake_category_loss + const_loss
d_loss = binary_loss + category_loss
# train Discriminator
D.zero_grad()
d_loss.backward(retain_graph=True)
d_optimizer.step()
# train Generator
En.zero_grad()
De.zero_grad()
g_loss.backward(retain_graph=True)
g_optimizer.step()
변경 이후 코드
for i, batch in enumerate(train_batch_iter):
labels, batch_images = batch
embedding_ids = labels
if GPU:
batch_images = batch_images.cuda()
if flip_labels:
np.random.shuffle(embedding_ids)
# target / source images
real_target = batch_images[:, 0, :, :].view([BATCH_SIZE, 1, IMG_SIZE, IMG_SIZE])
real_source = batch_images[:, 1, :, :].view([BATCH_SIZE, 1, IMG_SIZE, IMG_SIZE])
# generate fake image form source image
fake_target, encoded_source, encoder_layers = Generator(real_source, En, De, embeddings, embedding_ids, GPU=GPU)
real_TS = torch.cat([real_source, real_target], dim=1)
fake_TS = torch.cat([real_source, fake_target], dim=1)
# Scoring with Discriminator
fake_score, fake_score_logit, fake_cat_logit = D(fake_TS)
# Get encoded fake image to calculate constant loss
encoded_fake = En(fake_target)[0]
const_loss = Lconst_penalty * mse_criterion(encoded_source, encoded_fake)
# category loss
real_category = torch.from_numpy(np.eye(FONTS_NUM)[embedding_ids]).float()
if GPU:
real_category = real_category.cuda()
fake_category_loss = bce_criterion(fake_cat_logit, real_category)
# labels
if GPU:
one_labels = torch.ones([BATCH_SIZE, 1]).cuda()
zero_labels = torch.zeros([BATCH_SIZE, 1]).cuda()
else:
one_labels = torch.ones([BATCH_SIZE, 1])
zero_labels = torch.zeros([BATCH_SIZE, 1])
# L1 loss between real and fake images
l1_loss = L1_penalty * l1_criterion(real_target, fake_target)
# cheat loss for generator to fool discriminator
cheat_loss = bce_criterion(fake_score_logit, one_labels)
# g_loss, d_loss
g_loss = cheat_loss+ l1_loss + fake_category_loss + const_loss
# train Generator
En.zero_grad()
De.zero_grad()
g_loss.backward(retain_graph=True)
g_optimizer.step()
fake_TS = torch.cat([real_source, fake_target.detach().clone()], dim=1)
real_score, real_score_logit, real_cat_logit = D(real_TS)
fake_score, fake_score_logit, fake_cat_logit = D(fake_TS.detach().clone())
# binary loss for discriminator
real_binary_loss = bce_criterion(real_score_logit, one_labels)
fake_binary_loss = bce_criterion(fake_score_logit, zero_labels)
binary_loss = real_binary_loss + fake_binary_loss
#category loss for discriminator
fake_category_loss = bce_criterion(fake_cat_logit, real_category)
real_category_loss = bce_criterion(real_cat_logit, real_category)
category_loss = 0.5 * (real_category_loss + fake_category_loss)
d_loss = binary_loss + category_loss
# train Discriminator
D.zero_grad()
d_loss.backward(retain_graph=True)
https://jeinalog.tistory.com/15
내 손글씨를 따라쓰는 인공지능 요정, Wrinie (1) - 이론
#GAN #UNet #Generative Model #CNN #Deep Learning #딥러닝 #딥러닝 프로젝트 #생성모델 #인공지능 본 글은 개인 프로젝트로 진행된 프로젝트에 대한 소개글입니다. 내 손글씨를 따라쓰는 인공지능 요정, Wrinie
jeinalog.tistory.com
https://github.com/pytorch/pytorch/issues/39141
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation? · Issue #39141 ·
I am using pytorch-1.5 to do some gan test. My code is very simple gan code which just fit the sin(x) function: import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt...
github.com