본문 바로가기

Data-science/deep learning

[pytorch] nn.module의 zero_grad VS optimizer의 zero_grad의 차이

728x90

 

tutorials.pytorch.kr/beginner/pytorch_with_examples.html

 

예제로 배우는 파이토치(PyTorch) — PyTorch Tutorials 1.6.0 documentation

예제로 배우는 파이토치(PyTorch) Author: Justin Johnson번역: 박정환 이 튜토리얼에서는 PyTorch 의 핵심적인 개념을 예제를 통해 소개합니다. 본질적으로, PyTorch에는 2가지 주요한 특징이 있습니다: NumPy

tutorials.pytorch.kr

optimizer 객체에서 zero_grad()를 call하는 것. 이건 역전파를 하기 전에, 모든 변화도를 0으로 만들어준 거다.

그런데 nn.module 과 같이 모델 객체에도 zero_grad()를 똑같이 call하더라. 이게 왜 그런지 궁금했다. 

모델 객체에서도 변화도를 0으로 만드는 건 같다. 그런데 다음과 같은 경우가 있을 때 이게 좀 유용하단다.

gen_in1, gen_in2 = sample_noise(len(real_image))

### update D ###

D_target.zero_grad()

requires_grad(G_target, False)
if args.freeze_D:
    for loc in range(args.feature_loc):
        requires_grad(D_target, True, target_layer=f'progression.{8 - loc}')
    requires_grad(D_target, True, target_layer=f'linear')
else:
    requires_grad(D_target, True)

D_loss_val, grad_loss_val = backward_D(args, G_target, D_target, real_image, gen_in1)

D_optimizer.step()

### update G ###

G_target.zero_grad()

if not args.miner:
    requires_grad(G_target, True)  # do not update G
if args.freeze_D:
    for loc in range(args.feature_loc):
        requires_grad(D_target, False, target_layer=f'progression.{8 - loc}')
    requires_grad(D_target, False, target_layer=f'linear')
else:
    requires_grad(D_target, False)

G_loss_val = backward_G(args, G_target, D_target, gen_in2)

G_optimizer.step()
accumulate(G_running_target, G_target.module)

위 코드 처럼.  D_target과 G_target은 모델 객체이다. D_optimizer와 G_optimizer는 optimizer 객체이다.

효과는 같다. optimizer는 하나인데, 모델은 여러개 일때 쓸 수 있다, 혹은 그냥 구분하기 쉬워라고 쓴다.

예를 들어) gan network의 optimizer는 하나인데, Generator 모델, Discriminator 모델 이렇게 2개가 있을 수 있다. 그런데 학습 방법에서 처음에 Discriminator의 역전파만 고려하고 싶을 거다. 그럴 때 Discriminator 객체에서만 zero_grad를 호출하면 되는 것이다.

D_optimizer = optim.Adam(D_target.parameters(), lr=args.lr, betas=(0.0, 0.99))
...

왜 D_optimzer.zero_grad()를 안 하고 D_target.zero_grad()를 한 이유는 잘 모르겠다. 

예제 코드 출처:

github.com/sangwoomo/FreezeD/blob/master/stylegan/finetune.py

 

sangwoomo/FreezeD

Freeze the Discriminator: a Simple Baseline for Fine-Tuning GANs (CVPRW 2020) - sangwoomo/FreezeD

github.com

해답 출처:

discuss.pytorch.org/t/whats-the-difference-between-optimizer-zero-grad-vs-nn-module-zero-grad/59233/2

 

Whats the difference between Optimizer.zero_grad() vs nn.Module.zero_grad()

The nn.Module.zero_grad() also sets the gradients to 0 for all parameters. If you ceated your optimizer like opt = optim.SGD(model.paremeters(), xxx), then opt.zero_grad() and model.zero_grad() will have the same effect. The distinction is useful for peopl

discuss.pytorch.org