본문 바로가기

Data-science/deep learning

[pytorch error] GAN학습시 에러, RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation 에러 발생시 대처법

반응형

This happens because the opt_D.step() modifies the parameters of your discriminator inplace. But these parameters are required to compute the gradient for the generator. Hence the error.
We fixed the inplace detection for the optimizers in 1.5, this is why it works in 1.4.

You should re-organize your code to only do the steps() after all the gradients have been computed or make sure you don't modify parameters that are required.
Something like that should work.

 

위 글을 보니 제너레이터를 먼저 학습 시키고 그 다음 디스크리미네이터를 학습시키는 방향으로 변환하니 코드가 잘 돌아감을 알 수 있다.

중요한 건 detach가 들어간 부분!

Generator가 이미 생성한 이미지에 대해서는 backpropagation을 진행하면 안된다. 따라서 나는 아래와 같이 코드를 변경해주었다.

변경 이전 코드

for i, batch in enumerate(train_batch_iter):
    if with_charid:
        font_ids, char_ids, batch_images = batch
    else:
        font_ids, batch_images = batch
    embedding_ids = font_ids
    if self.GPU:
        batch_images = batch_images.cuda()
    if flip_labels:
        np.random.shuffle(embedding_ids)

    # target / source images
    real_target = batch_images[:, 0, :, :]
    real_target = real_target.view([self.batch_size, 1, self.img_size, self.img_size])
    real_source = batch_images[:, 1, :, :]
    real_source = real_source.view([self.batch_size, 1, self.img_size, self.img_size])

    # centering
    for idx, (image_S, image_T) in enumerate(zip(real_source, real_target)):
        image_S = image_S.cpu().detach().numpy().reshape(self.img_size, self.img_size)
        image_S = centering_image(image_S, resize_fix=90)
        real_source[idx] = torch.tensor(image_S).view([1, self.img_size, self.img_size])
        image_T = image_T.cpu().detach().numpy().reshape(self.img_size, self.img_size)
        image_T = centering_image(image_T, resize_fix=resize_fix)
        real_target[idx] = torch.tensor(image_T).view([1, self.img_size, self.img_size])

    # generate fake image form source image
    fake_target, encoded_source, _ = Generator(real_source, En, De, \
                                               self.embeddings, embedding_ids, \
                                               GPU=self.GPU, encode_layers=True)

    real_TS = torch.cat([real_source, real_target], dim=1)
    fake_TS = torch.cat([real_source, fake_target], dim=1)

    # Scoring with Discriminator
    real_score, real_score_logit, real_cat_logit = D(real_TS)
    fake_score, fake_score_logit, fake_cat_logit = D(fake_TS)

    # Get encoded fake image to calculate constant loss
    encoded_fake = En(fake_target)[0]
    const_loss = Lconst_penalty * mse_criterion(encoded_source, encoded_fake)

    # category loss
    real_category = torch.from_numpy(np.eye(self.fonts_num)[embedding_ids]).float()
    if self.GPU:
        real_category = real_category.cuda()
    real_category_loss = bce_criterion(real_cat_logit, real_category)
    fake_category_loss = bce_criterion(fake_cat_logit, real_category)
    category_loss = 0.5 * (real_category_loss + fake_category_loss)

    # labels
    if self.GPU:
        one_labels = torch.ones([self.batch_size, 1]).cuda()
        zero_labels = torch.zeros([self.batch_size, 1]).cuda()
    else:
        one_labels = torch.ones([self.batch_size, 1])
        zero_labels = torch.zeros([self.batch_size, 1])

    # binary loss - T/F
    real_binary_loss = bce_criterion(real_score_logit, one_labels)
    fake_binary_loss = bce_criterion(fake_score_logit, zero_labels)
    binary_loss = real_binary_loss + fake_binary_loss

    # L1 loss between real and fake images
    l1_loss = L1_penalty * l1_criterion(real_target, fake_target)

    # cheat loss for generator to fool discriminator
    cheat_loss = bce_criterion(fake_score_logit, one_labels)

    # g_loss, d_loss
    g_loss = cheat_loss + l1_loss + fake_category_loss + const_loss
    d_loss = binary_loss + category_loss

    # train Discriminator
    D.zero_grad()
    d_loss.backward(retain_graph=True)
    d_optimizer.step()

    # train Generator
    En.zero_grad()
    De.zero_grad()
    g_loss.backward(retain_graph=True)
    g_optimizer.step()

 

변경 이후 코드

for i, batch in enumerate(train_batch_iter):
    labels, batch_images = batch
    embedding_ids = labels
    if GPU:
        batch_images = batch_images.cuda()
    if flip_labels:
        np.random.shuffle(embedding_ids)

    # target / source images
    real_target = batch_images[:, 0, :, :].view([BATCH_SIZE, 1, IMG_SIZE, IMG_SIZE])
    real_source = batch_images[:, 1, :, :].view([BATCH_SIZE, 1, IMG_SIZE, IMG_SIZE])

    # generate fake image form source image
    fake_target, encoded_source, encoder_layers = Generator(real_source, En, De, embeddings, embedding_ids, GPU=GPU)

    real_TS = torch.cat([real_source, real_target], dim=1)
    fake_TS = torch.cat([real_source, fake_target], dim=1)

    # Scoring with Discriminator
    fake_score, fake_score_logit, fake_cat_logit = D(fake_TS)

    # Get encoded fake image to calculate constant loss
    encoded_fake = En(fake_target)[0]
    const_loss = Lconst_penalty * mse_criterion(encoded_source, encoded_fake)

    # category loss
    real_category = torch.from_numpy(np.eye(FONTS_NUM)[embedding_ids]).float()
    if GPU:
        real_category = real_category.cuda()
    fake_category_loss = bce_criterion(fake_cat_logit, real_category)

    # labels
    if GPU:
        one_labels = torch.ones([BATCH_SIZE, 1]).cuda()
        zero_labels = torch.zeros([BATCH_SIZE, 1]).cuda()
    else:
        one_labels = torch.ones([BATCH_SIZE, 1])
        zero_labels = torch.zeros([BATCH_SIZE, 1])

    # L1 loss between real and fake images
    l1_loss = L1_penalty * l1_criterion(real_target, fake_target)

    # cheat loss for generator to fool discriminator
    cheat_loss = bce_criterion(fake_score_logit, one_labels)

    # g_loss, d_loss
    g_loss = cheat_loss+ l1_loss + fake_category_loss + const_loss

    # train Generator
    En.zero_grad()
    De.zero_grad()

    g_loss.backward(retain_graph=True)
    g_optimizer.step()  

    fake_TS = torch.cat([real_source, fake_target.detach().clone()], dim=1)
    real_score, real_score_logit, real_cat_logit = D(real_TS)
    fake_score, fake_score_logit, fake_cat_logit = D(fake_TS.detach().clone())
    # binary loss for discriminator
    real_binary_loss = bce_criterion(real_score_logit, one_labels)
    fake_binary_loss = bce_criterion(fake_score_logit, zero_labels)
    binary_loss = real_binary_loss + fake_binary_loss
    #category loss for discriminator
    fake_category_loss = bce_criterion(fake_cat_logit, real_category)
    real_category_loss = bce_criterion(real_cat_logit, real_category)
    category_loss = 0.5 * (real_category_loss + fake_category_loss)

    d_loss = binary_loss + category_loss
    # train Discriminator
    D.zero_grad()
    d_loss.backward(retain_graph=True)

 

https://jeinalog.tistory.com/15

 

내 손글씨를 따라쓰는 인공지능 요정, Wrinie (1) - 이론

#GAN #UNet #Generative Model #CNN #Deep Learning #딥러닝 #딥러닝 프로젝트 #생성모델 #인공지능 본 글은 개인 프로젝트로 진행된 프로젝트에 대한 소개글입니다. 내 손글씨를 따라쓰는 인공지능 요정, Wrinie

jeinalog.tistory.com

 

https://github.com/pytorch/pytorch/issues/39141

 

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation? · Issue #39141 ·

I am using pytorch-1.5 to do some gan test. My code is very simple gan code which just fit the sin(x) function: import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt...

github.com

 

반응형