stylegan2에 freezeD를 적용하길 원한다.
이에 앞서 stylegan1에 쓰인 freezeD 코드를 이해하고 분석해본다. 왜냐하면 pytorch 구현은 stylegan 1에 대한 freezeD만 있기 때문이다.
freezeD는 custom한 데이터셋? 혹은 다른 도메인의 데이터셋으로 학습할 때 Discriminator의 특정 layer를 freezing 시키는 방법이다. 이게 효과적인 이유는 gan에서 FFHQ를 통해 학습한 Discriminator는 훌륭한 진짜/가짜 판별자이고, 이를 transfer learning으로 그대로 이용할 수 있기 때문이다. 이 판별자의 처음 몇 layer를 고정시키고 후 단의 layer만을 학습시키는 방식이다. 이는 기존의 전체 layer를 미세 조정하는 finetuning 방법보다 훨씬 더 우수한 성능을 나타냈다!
stylegan1에서 freezeD 코드 이해 및 분석
Stylegan1 에서의 freezeD 코드를 보자. 명령줄 인수로는 layer 3개를 얼리는 것처럼 보인다.
CUDA_VISIBLE_DEVICES=1 python finetune.py --name DATASET_freezeD --mixing --loss r1 --sched --dataset DATASET --freeze_D --feature_loc 3
# Note that feature_loc = 7 - layer_num
parser.add_argument('--only_adain', action='store_true', help='only optimize AdaIN layers')
parser.add_argument('--supervised', action='store_true', help='use supervised loss instead of GAN loss')
parser.add_argument('--miner', action='store_true', help='use miner network instead of full generator')
parser.add_argument('--lambda_l2_G', type=float, default=0, help='weight for l2 loss for G')
parser.add_argument('--lambda_l2_D', type=float, default=0, help='weight for l2 loss for D')
parser.add_argument('--lambda_FM', type=float, default=0, help='weight for FM loss for D')
parser.add_argument('--feature_loc', type=int, default=3, help='feature location for discriminator (default: 3)')
parser.add_argument('--freeze_D', action='store_true', help='freeze layers of discriminator D')
# 초기 entry point에서의 코드
if args.freeze_D:
requires_grad(D_target, False) # freeze D
우선 D_target 전체 layer를 얼린다.
....
for i in pbar:
...
gen_in1, gen_in2 = sample_noise(len(real_image))
### update D ###
D_target.zero_grad()
requires_grad(G_target, False)
if args.freeze_D:
for loc in range(args.feature_loc):
requires_grad(D_target, True, target_layer=f'progression.{8 - loc}')
requires_grad(D_target, True, target_layer=f'linear')
else:
requires_grad(D_target, True)
D_loss_val, grad_loss_val = backward_D(args, G_target, D_target, real_image, gen_in1)
D_optimizer.step()
### update G ###
G_target.zero_grad()
if not args.miner:
requires_grad(G_target, True) # do not update G
if args.freeze_D:
for loc in range(args.feature_loc):
requires_grad(D_target, False, target_layer=f'progression.{8 - loc}')
requires_grad(D_target, False, target_layer=f'linear')
else:
requires_grad(D_target, False)
G_loss_val = backward_G(args, G_target, D_target, gen_in2)
G_optimizer.step()
accumulate(G_running_target, G_target.module)
freeze D의 자세한 방법은 위 코드에서 확인할 수 있다.
# override requires_grad function
def requires_grad(model, flag=True, target_layer=None):
for name, param in model.named_parameters():
if target_layer is None: # every layer
param.requires_grad = flag
elif target_layer in name: # target layer
param.requires_grad = flag
requires_grad는 target_layer 이름을 포함한 모든 layer를 freezing 시킬지 말지를 결정하는 함수다.
처음에 학습할 때 Discriminator만 학습하고, 이후에 Generator를 학습하는 건 같다. Discriminator에서 feature location이 3이면 progression layer중 8,7,6과 마지막 linear layer만을 학습하고 0,1,2,3,4,5는 freezing 시킨다. backward_D란 함수를 호출하는데 이 함수는 아래에 적어두었다.
Geneorator를 학습시킬 때는 Discriminator에서 progression layer중 8,7,6과 마지막 linear layer만을 freezing시킨다. 0, 1, 2, 3, 4, 5 progession layer는 학습한다는 의미일까? 아니다. 왜냐하면 애초에 entrypoint에서 Discriminator의 모든 layer를 freezing시킨 상태였기 때문에 나머지를 굳이 freezing 하지 않아도 되는 것이다.
backward_G 함수도 아래에 첨부하였다.
정리하면, 단순하다. Discriminator 학습시 끝 단 몇개만 지정해서 학습을 하고 초기 나머지 layer는 학습 시킨다. Generator 학습시에는 Discriminator는 학습하지 않는다.
def backward_D(args, G_target, D_target, real_image, gen_in):
### update D (GAN loss) ###
real_image = real_image.cuda()
real_image.requires_grad = True
real_predict = D_target(real_image, step=step, alpha=alpha) # before activation
D_loss_real = F.softplus(-real_predict).mean()
grad_real = grad(outputs=real_predict.sum(), inputs=real_image, create_graph=True)[0]
grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()
grad_penalty = 10 / 2 * grad_penalty
if args.miner:
gen_in = miner(gen_in)
fake_image_tgt = G_target(gen_in, step=step, alpha=alpha)
fake_predict = D_target(fake_image_tgt, step=step, alpha=alpha)
D_loss_fake = F.softplus(fake_predict).mean()
### update D (regularizer) ###
if args.lambda_FM > 0:
FM_loss = FM_reg(real_image, args.feature_loc) * args.lambda_FM
else:
FM_loss = 0
if args.lambda_l2_D > 0:
l2_D_loss = l2_reg(D_source, D_target) * args.lambda_l2_D
else:
l2_D_loss = 0
(D_loss_real + D_loss_fake + grad_penalty + FM_loss + l2_D_loss).backward()
D_loss_val = (D_loss_real + D_loss_fake).item()
grad_loss_val = grad_penalty.item() if grad_penalty > 0 else 0
return D_loss_val, grad_loss_val
def backward_G(args, G_target, D_target, gen_in):
### update G (GAN loss) ###
if args.miner:
gen_in = miner(gen_in)
fake_image_tgt = G_target(gen_in, step=step, alpha=alpha)
predict = D_target(fake_image_tgt, step=step, alpha=alpha)
gen_loss = F.softplus(-predict).mean()
### update G (regularizer) ###
if args.lambda_l2_G > 0:
l2_G_loss = l2_reg(G_source, G_target) * args.lambda_l2_G
else:
l2_G_loss = 0
(gen_loss + l2_G_loss).backward()
G_loss_val = gen_loss.item()
return G_loss_val
stylegan2에서는 freezeD를 어떻게 적용할까?
참고
github.com/sangwoomo/FreezeD/blob/master/stylegan/finetune.py
'Data-science > deep learning' 카테고리의 다른 글
g_ema?, EMA 구하는 공식 (0) | 2020.12.26 |
---|---|
[pytorch] stylegan2 pretrained model load error (0) | 2020.12.26 |
[pytorch] stylegan2 freezeD 적용 (0) | 2020.12.24 |
[pytorch] requires grad 확인. (0) | 2020.12.23 |
[pytorch] stylegan1 Vs stylegan2 Discriminator 차이 (0) | 2020.12.23 |