본문 바로가기

Data-science/deep learning

[imaginaire] coco-funit, mode collapse 문제 해결 삽질

반응형

coco-funit에서 발생한 문제 

github.com/NVlabs/imaginaire/issues/43

 

mode collapse · Issue #43 · NVlabs/imaginaire

hi, when I trained with coco_funit, In the first few epochs, the results are normal, but mode collapse appears from the 59th epoch. Is this normal? Will it also appear during your training?

github.com

Mode collapse...

학습이 특정 epoch이후 안 된다... mode collapse에 빠짐

해결 방안

1. learning rate 줄이기

- discriminator learning rate : 1e-4 -> 1e-5

- generator learning rate : 1e-4 -> 1e-5

2. reconstruction loss 비중 높이기

- train.py를 살펴보면 loss에 대한 항이 추상화되어 trainer안에 다 감싸져 있음을 알 수 있다. trainer를 살펴봐야 한다. 어떤 부분을 살펴봐야 하나면 trainer.dis_update(data), trainer.gen_update(data) 부분이다.

즉, dis_update / gen_update 부분을 확인해보자.

    # Start training.
    for epoch in range(current_epoch, cfg.max_epoch):
        print('Epoch {} ...'.format(epoch))
        if not args.single_gpu:
            train_data_loader.sampler.set_epoch(current_epoch)
        trainer.start_of_epoch(current_epoch)
        for it, data in enumerate(train_data_loader):
            data = trainer.start_of_iteration(data, current_iteration)

            for _ in range(cfg.trainer.dis_step):
                trainer.dis_update(data)
            for _ in range(cfg.trainer.gen_step):
                trainer.gen_update(data)

            current_iteration += 1
            trainer.end_of_iteration(data, current_epoch, current_iteration)
            if current_iteration >= cfg.max_iter:
                print('Done with training!!!')
                return

dis_update / gen_update의 경우 imaginaire/trainers/base.py에서 확인이 가능하다. 그중 gen_update는 아래와 같다.

    def gen_update(self, data):
        r"""Update the generator.
        Args:
            data (dict): Data used for the current iteration.
        """
        self.opt_G.zero_grad()

        # Set requires_grad flags.
        requires_grad(self.net_G_module, True)
        requires_grad(self.net_D, False)

        # Compute the loss.
        self._time_before_forward()
        total_loss = self.gen_forward(data)
        if total_loss is None:
            return

        # Backpropagate the loss.
        self._time_before_backward()
        with amp.scale_loss(total_loss, self.opt_G, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        # Optionally clip gradient norm.
        if hasattr(self.cfg.gen_opt, 'clip_grad_norm'):
            nn.utils.clip_grad_norm_(
                amp.master_params(self.opt_G), self.cfg.gen_opt.clip_grad_norm)

        # Perform an optimizer step.
        self._time_before_step()
        self.opt_G.step()

        # Update model average.
        self._time_before_model_avg()
        if self.cfg.trainer.model_average:
            self.net_G.module.update_average()

        self._detach_losses()
        self._time_before_leave_gen()

    def gen_forward(self, data):
        r"""Every trainer should implement its own generator forward."""
        raise NotImplementedError

total_loss를 계산하는 부분이 중요한데, 이 부분은 self.gen_foward로 계산된다. 

gen_foward를 살펴보니, 이를 상속받은 클래스에서 구현해야만 하는 함수이다. 즉 base.py의 BaseTrainer를 상속받는 클래스를 살펴보자.

- coco-funit의 경우 trainer는 imaginaire/trainers/funit.py에 있다. 여기 부분에서 gen_foward가 보인다.

loss가 들어간 부분이 있다.

    def gen_forward(self, data):
        r"""Compute the loss for FUNIT generator.
        Args:
            data (dict): Training data at the current iteration.
        """

        net_G_output = self.net_G(data)
        net_D_output = self.net_D(data, net_G_output)

        self._time_before_loss()

        # GAN loss
        # We use both the translation and reconstruction streams.
        self.gen_losses['gan'] = 0.5 * (
            self.criteria['gan'](
                net_D_output['fake_out_trans'], True, dis_update=False) +
            self.criteria['gan'](
                net_D_output['fake_out_recon'], True, dis_update=False))

        # Image reconstruction loss
        self.gen_losses['image_recon'] = \
            self.criteria['image_recon'](net_G_output['images_recon'],
                                         data['images_content'])

        # Feature matching loss
        self.gen_losses['feature_matching'] = \
            self.criteria['feature_matching'](
                net_D_output['fake_features_trans'],
                net_D_output['real_features_style'])

        # Compute total loss
        total_loss = self._get_total_loss(gen_forward=True)
        return total_loss

'gan' , 'image_recon', 'feature_matching' 이렇게 3가지 loss를 self._get_total_loss에서 합산하는 것으로 보인다.

이제 self._get_total_loss를 보자. 이는 다시 base.py에 있는 BaseTrainer에서 확인할 수 있다.

    def _get_total_loss(self, gen_forward):
        r"""Return the total loss to be backpropagated.
        Args:
            gen_forward (bool): If ``True``, backpropagates the generator loss,
                otherwise the discriminator loss.
        """
        losses = self.gen_losses if gen_forward else self.dis_losses
        total_loss = torch.tensor(0., device=torch.device('cuda'))
        # Iterates over all possible losses.
        for loss_name in self.weights:
            # If it is for the current model (gen/dis).
            if loss_name in losses:
                # Multiply it with the corresponding weight
                # and add it to the total loss.
                total_loss += losses[loss_name] * self.weights[loss_name]
        losses['total'] = total_loss  # logging purpose
        return total_loss

gen_forawrd 인자가 True였으므로 losses는 위 self.gen_forward에서 정의한 self.gen_losses로 계산된다. 

total_loss는 각각의 loss항에 self.weights를 곱한 뒤 합산된다.

가장 중요한 부분이 각 loss의 가중치인 self.weights이다.

    def _init_loss(self, cfg):
        r"""Initialize loss terms. In FUNIT, we have several loss terms
        including the GAN loss, the image reconstruction loss, the feature
        matching loss, and the gradient penalty loss.
        Args:
            cfg (obj): Global configuration.
        """
        self.criteria['gan'] = GANLoss(cfg.trainer.gan_mode)
        self.criteria['image_recon'] = nn.L1Loss()
        self.criteria['feature_matching'] = nn.L1Loss()

        for loss_name, loss_weight in cfg.trainer.loss_weight.__dict__.items():
            if loss_weight > 0:
                self.weights[loss_name] = loss_weight

이는 초기 imaginaire/trainser/funit.py의 Trainer에서 self._init_loss를 살펴보면 된다. cfg.trainer.loss_weight가 dictionary 형태로 정의되어 있다.

cfg는 초기 config 파일이다. yaml파일인데 imaginaire/configs/ 안에 있다. 이를 살펴보면, 아래와 같다.

pretrained_weight: 1Wf0BhcIpVJgHQunipdt8r-KtQ9mRvKxt
inference_args:
    keep_original_size: True

image_save_iter: 1000
snapshot_save_iter: 5000
max_iter: 150000
logging_iter: 50
trainer:
    type: imaginaire.trainers.funit
    model_average: True
    model_average_beta: 0.999
    amp: O1
    gan_mode: hinge
    loss_weight:
        gan: 1
        feature_matching: 1
        image_recon: 0.1
    init:
        type: none
        gain: 1
gen_opt:
    type: adam
    lr: 0.00001
    adam_beta1: 0.
    adam_beta2: 0.999
    lr_policy:
        iteration_mode: True
        type: step
        step_size: 100000
        gamma: 0.5

trainer 부분에 loss_weight가 분명히 있다. 저 부분에서 image_recon이 0.1인데 이를 좀 더 올려보면 될듯하다.

config 파일에 없는 줄 알고 코드 분석을 했다.... 그런데 config 파일에 있었다. 허탈하지만 공부가 됐다.

 

반응형