728x90
stylegan2에 freezeD를 적용하길 원한다.
study-grow.tistory.com/entry/pytorch-stylegan1-freezeD-%EC%BD%94%EB%93%9C-%EB%B6%84%EC%84%9D
그에 앞서 1에 쓰인 freezeD를 분석했고, 이제 2에 적용할 차례다.
끝 단의 5, 6 그리고 final layer만 학습해보기로 한다.
for idx in pbar:
i = idx + args.start_iter
if i > args.iter:
print("Done!")
break
real_img = next(loader)
real_img = real_img.to(device)
requires_grad(generator, False)
#requires_grad(discriminator, True)
'''dhkim add'''
if args.freezeD:
for i in range(args.training_features):
requires_grad(discriminator, True, target_layer=f'convs.{6-i}')
requires_grad(discriminator, True, target_layer=f'final_')
else:
requires_grad(discriminator, True)
''''''
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
fake_img, _ = generator(noise)
if args.augment:
real_img_aug, _ = augment(real_img, ada_aug_p)
fake_img, _ = augment(fake_img, ada_aug_p)
else:
real_img_aug = real_img
fake_pred = discriminator(fake_img)
real_pred = discriminator(real_img_aug)
d_loss = d_logistic_loss(real_pred, fake_pred)
loss_dict["d"] = d_loss
loss_dict["real_score"] = real_pred.mean()
loss_dict["fake_score"] = fake_pred.mean()
discriminator.zero_grad()
d_loss.backward()
d_optim.step()
if args.augment and args.augment_p == 0:
ada_aug_p = ada_augment.tune(real_pred)
r_t_stat = ada_augment.r_t_stat
d_regularize = i % args.d_reg_every == 0
if d_regularize:
real_img.requires_grad = True
real_pred = discriminator(real_img)
r1_loss = d_r1_loss(real_pred, real_img)
discriminator.zero_grad()
(args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
d_optim.step()
loss_dict["r1"] = r1_loss
requires_grad(generator, True)
# requires_grad(discriminator, False)
'''dhkim add'''
if args.freezeD:
for i in range(args.training_features):
requires_grad(discriminator, False, target_layer=f'convs.{6-i}')
requires_grad(discriminator, False, target_layer=f'final_')
else:
requires_grad(discriminator, False)
''''''
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
fake_img, _ = generator(noise)
참고
github.com/sangwoomo/FreezeD/blob/master/stylegan/finetune.py
'Data-science > deep learning' 카테고리의 다른 글
[pytorch] stylegan2 pretrained model load error (0) | 2020.12.26 |
---|---|
[pytorch] stylegan1 freezeD 이해 및 코드 분석 (0) | 2020.12.25 |
[pytorch] requires grad 확인. (0) | 2020.12.23 |
[pytorch] stylegan1 Vs stylegan2 Discriminator 차이 (0) | 2020.12.23 |
FusedLeakyRelu? (0) | 2020.12.23 |