본문 바로가기

Data-science/deep learning

[pytorch] stylegan1 freezeD 이해 및 코드 분석

728x90

stylegan2에 freezeD를 적용하길 원한다.

이에 앞서 stylegan1에 쓰인 freezeD 코드를 이해하고 분석해본다. 왜냐하면 pytorch 구현은 stylegan 1에 대한 freezeD만 있기 때문이다.

freezeD 초록
freezeD의 효과

freezeD는 custom한 데이터셋? 혹은 다른 도메인의 데이터셋으로 학습할 때 Discriminator의 특정 layer를 freezing 시키는 방법이다. 이게 효과적인 이유는 gan에서 FFHQ를 통해 학습한 Discriminator는 훌륭한 진짜/가짜 판별자이고, 이를 transfer learning으로 그대로 이용할 수 있기 때문이다. 이 판별자의 처음 몇 layer를 고정시키고 후 단의 layer만을 학습시키는 방식이다. 이는 기존의 전체 layer를 미세 조정하는 finetuning 방법보다 훨씬 더 우수한 성능을 나타냈다! 

stylegan1에서 freezeD 코드 이해 및 분석

Stylegan1 에서의 freezeD 코드를 보자. 명령줄 인수로는 layer 3개를 얼리는 것처럼 보인다.

CUDA_VISIBLE_DEVICES=1 python finetune.py --name DATASET_freezeD --mixing --loss r1 --sched --dataset DATASET --freeze_D --feature_loc 3
# Note that feature_loc = 7 - layer_num
parser.add_argument('--only_adain', action='store_true', help='only optimize AdaIN layers')
parser.add_argument('--supervised', action='store_true', help='use supervised loss instead of GAN loss')
parser.add_argument('--miner', action='store_true', help='use miner network instead of full generator')
parser.add_argument('--lambda_l2_G', type=float, default=0, help='weight for l2 loss for G')
parser.add_argument('--lambda_l2_D', type=float, default=0, help='weight for l2 loss for D')
parser.add_argument('--lambda_FM', type=float, default=0, help='weight for FM loss for D')
parser.add_argument('--feature_loc', type=int, default=3, help='feature location for discriminator (default: 3)')
parser.add_argument('--freeze_D', action='store_true', help='freeze layers of discriminator D')
# 초기 entry point에서의 코드

if args.freeze_D:
	requires_grad(D_target, False)  # freeze D
우선 D_target 전체 layer를 얼린다.

....
for i in pbar:
	
    ...
    
    gen_in1, gen_in2 = sample_noise(len(real_image))

    ### update D ###

    D_target.zero_grad()

    requires_grad(G_target, False)
    if args.freeze_D:
        for loc in range(args.feature_loc):
            requires_grad(D_target, True, target_layer=f'progression.{8 - loc}')
        requires_grad(D_target, True, target_layer=f'linear')
    else:
        requires_grad(D_target, True)

    D_loss_val, grad_loss_val = backward_D(args, G_target, D_target, real_image, gen_in1)

    D_optimizer.step()

    ### update G ###

    G_target.zero_grad()

    if not args.miner:
        requires_grad(G_target, True)  # do not update G
    if args.freeze_D:
        for loc in range(args.feature_loc):
            requires_grad(D_target, False, target_layer=f'progression.{8 - loc}')
        requires_grad(D_target, False, target_layer=f'linear')
    else:
        requires_grad(D_target, False)

    G_loss_val = backward_G(args, G_target, D_target, gen_in2)

    G_optimizer.step()
    accumulate(G_running_target, G_target.module)

freeze D의 자세한 방법은 위 코드에서 확인할 수 있다.

# override requires_grad function
def requires_grad(model, flag=True, target_layer=None):
    for name, param in model.named_parameters():
        if target_layer is None:  # every layer
            param.requires_grad = flag
        elif target_layer in name:  # target layer
            param.requires_grad = flag

requires_grad는 target_layer 이름을 포함한 모든 layer를 freezing 시킬지 말지를 결정하는 함수다.

처음에 학습할 때 Discriminator만 학습하고, 이후에 Generator를 학습하는 건 같다. Discriminator에서 feature location이 3이면 progression layer중 8,7,6과 마지막 linear layer만을 학습하고 0,1,2,3,4,5는 freezing 시킨다. backward_D란 함수를 호출하는데 이 함수는 아래에 적어두었다.

Geneorator를 학습시킬 때는 Discriminator에서 progression layer중 8,7,6과 마지막 linear layer만을 freezing시킨다. 0, 1, 2, 3, 4, 5 progession layer는 학습한다는 의미일까? 아니다. 왜냐하면 애초에 entrypoint에서 Discriminator의 모든 layer를 freezing시킨 상태였기 때문에 나머지를 굳이 freezing 하지 않아도 되는 것이다. 

backward_G 함수도 아래에 첨부하였다.

정리하면, 단순하다. Discriminator 학습시 끝 단 몇개만 지정해서 학습을 하고 초기 나머지 layer는 학습 시킨다. Generator 학습시에는 Discriminator는 학습하지 않는다.
def backward_D(args, G_target, D_target, real_image, gen_in):
    ### update D (GAN loss) ###

    real_image = real_image.cuda()

    real_image.requires_grad = True
    real_predict = D_target(real_image, step=step, alpha=alpha)  # before activation
    D_loss_real = F.softplus(-real_predict).mean()

    grad_real = grad(outputs=real_predict.sum(), inputs=real_image, create_graph=True)[0]
    grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()
    grad_penalty = 10 / 2 * grad_penalty

    if args.miner:
        gen_in = miner(gen_in)

    fake_image_tgt = G_target(gen_in, step=step, alpha=alpha)
    fake_predict = D_target(fake_image_tgt, step=step, alpha=alpha)
    D_loss_fake = F.softplus(fake_predict).mean()

    ### update D (regularizer) ###

    if args.lambda_FM > 0:
        FM_loss = FM_reg(real_image, args.feature_loc) * args.lambda_FM
    else:
        FM_loss = 0

    if args.lambda_l2_D > 0:
        l2_D_loss = l2_reg(D_source, D_target) * args.lambda_l2_D
    else:
        l2_D_loss = 0

    (D_loss_real + D_loss_fake + grad_penalty + FM_loss + l2_D_loss).backward()

    D_loss_val = (D_loss_real + D_loss_fake).item()
    grad_loss_val = grad_penalty.item() if grad_penalty > 0 else 0

    return D_loss_val, grad_loss_val
def backward_G(args, G_target, D_target, gen_in):
	### update G (GAN loss) ###

	if args.miner:
		gen_in = miner(gen_in)

	fake_image_tgt = G_target(gen_in, step=step, alpha=alpha)
	predict = D_target(fake_image_tgt, step=step, alpha=alpha)
	gen_loss = F.softplus(-predict).mean()

	### update G (regularizer) ###

	if args.lambda_l2_G > 0:
		l2_G_loss = l2_reg(G_source, G_target) * args.lambda_l2_G
	else:
		l2_G_loss = 0

	(gen_loss + l2_G_loss).backward()

	G_loss_val = gen_loss.item()

	return G_loss_val

stylegan2에서는 freezeD를 어떻게 적용할까?

 

참고

arxiv.org/pdf/1812.04948.pdf

arxiv.org/pdf/2002.10964.pdf

github.com/sangwoomo/FreezeD/blob/master/stylegan/finetune.py

 

sangwoomo/FreezeD

Freeze the Discriminator: a Simple Baseline for Fine-Tuning GANs (CVPRW 2020) - sangwoomo/FreezeD

github.com