본문 바로가기

Data-science/deep learning

[pytorch] stylegan2 freezeD 적용

728x90

stylegan2에 freezeD를 적용하길 원한다.

study-grow.tistory.com/entry/pytorch-stylegan1-freezeD-%EC%BD%94%EB%93%9C-%EB%B6%84%EC%84%9D

 

[pytorch] stylegan1 freezeD 이해 및 코드 분석

stylegan2에 freezeD를 적용하길 원한다. 이에 앞서 stylegan1에 쓰인 freezeD 코드를 이해하고 분석해본다. 왜냐하면 pytorch 구현은 stylegan 1에 대한 freezeD만 있기 때문이다. freezeD는 custom한 데이터셋..

study-grow.tistory.com

그에 앞서 1에 쓰인 freezeD를 분석했고, 이제 2에 적용할 차례다.

끝 단의 5, 6 그리고 final layer만 학습해보기로 한다.

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        #requires_grad(discriminator, True)
        
        '''dhkim add'''
        if args.freezeD:
            for i in range(args.training_features):
                requires_grad(discriminator, True, target_layer=f'convs.{6-i}')
            requires_grad(discriminator, True, target_layer=f'final_')
        else:
            requires_grad(discriminator, True)
        ''''''
        
        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
#         requires_grad(discriminator, False)

        '''dhkim add'''
        if args.freezeD:
            for i in range(args.training_features):
                requires_grad(discriminator, False, target_layer=f'convs.{6-i}')
            requires_grad(discriminator, False, target_layer=f'final_')
        else:
            requires_grad(discriminator, False)
        ''''''

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

 

참고

arxiv.org/pdf/1812.04948.pdf

arxiv.org/pdf/2002.10964.pdf

github.com/sangwoomo/FreezeD/blob/master/stylegan/finetune.py

 

sangwoomo/FreezeD

Freeze the Discriminator: a Simple Baseline for Fine-Tuning GANs (CVPRW 2020) - sangwoomo/FreezeD

github.com