본문 바로가기

Data-science/deep learning

[pytorch] stylegan2 학습, freezeD + freezeG 동시 적용

728x90

22k 학습시
30k 학습시

뭔가 학습할수록 이미지가 더 일그러진다.... generaotr학습을 제대로 하지 못한다고 생각했다. 15k 기준으로 이미지가 점점 깨지기 시작한다. 그래서 기존의 generator에서 미세조정하면 되지 않을까란 생각이 들었다. 구체적으로는 14k 정도에서의 weight를 조금만 조정해보면 어떨까? 하는 생각!

freezeD만 해선 생성할 때 많이 일그러지는 것 같다. generator를 고정시켜보자! Generator도 일부는 고정하고 싶단 생각을 했다. 그러던 차에 누가 이미 실험을 해봤다. freezeG라는 github에 구현한 코드도 있었다. 감사! 감사

우선 이해를 해야했기에, 노트북에 정리해보았다. 어떤 layer를 학습하는지.

Generator 구조중 학습하는 부분

stylegan2에서는 skip generator를 이용한다. 여기서 마지막 tRGB layer와, 마지막 resoultion에서 이루어지는 2개의 conv layer를 fine tuning하기로 했다.

 

코드는 이러하다.

#         requires_grad(generator, True)
#         requires_grad(discriminator, False)

        '''growth-kim add'''
        if args.freezeD:
            for loc in range(args.training_features):
                requires_grad(discriminator, False, target_layer=f'convs.{6-loc}')
            requires_grad(discriminator, False, target_layer=f'final_')
        else:
            requires_grad(discriminator, False)
            
        if args.freezeG:
            for loc in range(args.training_features_G):#size=256 -> num_layers = (self.log_size - 2) * 2 + 1 :  6*2 + 1 = 13
                requires_grad(generator, True, target_layer=f'convs.{generator.num_layers-2-2*loc}') # 11
                requires_grad(generator, True, target_layer=f'convs.{generator.num_layers-3-2*loc}') # 10
                requires_grad(generator, True, target_layer=f'to_rgbs.{generator.log_size-3-loc}')   # 5
        else:
            requires_grad(generator, True)
        ''''''

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

초기 엔트리 포인트에서 arguments는 새로 만들었다.

쭉~ 내려가서 학습할 때... 둘 다 기본적으로 False로 세팅후 train에 들어가야 위 코드가 먹힌다.

이렇게하면 안 돌아가는데

'RuntimeError: One of the differentiated Variables does not require grad'

이런 에러가 나온다.

geneartor의 regularizer를 사용하지 못하도록 설정해야 한다. freezeG를 쓸 경우엔 g_regularize를 하지 않도록 한다.

 if not args.freezeG and g_regularize:
     path_batch_size = max(1, args.batch // args.path_batch_shrink)
     noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
     fake_img, latents = generator(noise, return_latents=True)

학습 결과는 어떠할까? 시험해봐야 한다.

 

bryandlee/FreezeG

Freezing generator for pseudo image translation. Contribute to bryandlee/FreezeG development by creating an account on GitHub.

github.com