본문 바로가기

Data-science/deep learning

[pytorch] stylegan1 Vs stylegan2 Discriminator 차이

728x90

stylegan1 Discriminator 뜯어보기

progression.0.conv1.0.conv.bias
progression.0.conv1.0.conv.weight_orig
progression.0.conv2.1.weight
progression.0.conv2.1.bias
progression.1.conv1.0.conv.bias
progression.1.conv1.0.conv.weight_orig
progression.1.conv2.1.weight
progression.1.conv2.1.bias
progression.2.conv1.0.conv.bias
progression.2.conv1.0.conv.weight_orig
progression.2.conv2.1.weight
progression.2.conv2.1.bias
progression.3.conv1.0.conv.bias
progression.3.conv1.0.conv.weight_orig
progression.3.conv2.1.weight
progression.3.conv2.1.bias
progression.4.conv1.0.conv.bias
progression.4.conv1.0.conv.weight_orig
progression.4.conv2.1.conv.bias
progression.4.conv2.1.conv.weight_orig
progression.5.conv1.0.conv.bias
progression.5.conv1.0.conv.weight_orig
progression.5.conv2.1.conv.bias
progression.5.conv2.1.conv.weight_orig
progression.6.conv1.0.conv.bias
progression.6.conv1.0.conv.weight_orig
progression.6.conv2.1.conv.bias
progression.6.conv2.1.conv.weight_orig
progression.7.conv1.0.conv.bias
progression.7.conv1.0.conv.weight_orig
progression.7.conv2.1.conv.bias
progression.7.conv2.1.conv.weight_orig
progression.8.conv1.0.conv.bias
progression.8.conv1.0.conv.weight_orig
progression.8.conv2.0.conv.bias
progression.8.conv2.0.conv.weight_orig
from_rgb.0.conv.bias
from_rgb.0.conv.weight_orig
from_rgb.1.conv.bias
from_rgb.1.conv.weight_orig
from_rgb.2.conv.bias
from_rgb.2.conv.weight_orig
from_rgb.3.conv.bias
from_rgb.3.conv.weight_orig
from_rgb.4.conv.bias
from_rgb.4.conv.weight_orig
from_rgb.5.conv.bias
from_rgb.5.conv.weight_orig
from_rgb.6.conv.bias
from_rgb.6.conv.weight_orig
from_rgb.7.conv.bias
from_rgb.7.conv.weight_orig
from_rgb.8.conv.bias
from_rgb.8.conv.weight_orig
linear.linear.bias
linear.linear.weight_orig

Stylegan2 Discriminator 뜯어보기

convs.0.0.weight
convs.0.1.bias
convs.1.conv1.0.weight
convs.1.conv1.1.bias
convs.1.conv2.1.weight
convs.1.conv2.2.bias
convs.1.skip.1.weight
convs.2.conv1.0.weight
convs.2.conv1.1.bias
convs.2.conv2.1.weight
convs.2.conv2.2.bias
convs.2.skip.1.weight
convs.3.conv1.0.weight
convs.3.conv1.1.bias
convs.3.conv2.1.weight
convs.3.conv2.2.bias
convs.3.skip.1.weight
convs.4.conv1.0.weight
convs.4.conv1.1.bias
convs.4.conv2.1.weight
convs.4.conv2.2.bias
convs.4.skip.1.weight
convs.5.conv1.0.weight
convs.5.conv1.1.bias
convs.5.conv2.1.weight
convs.5.conv2.2.bias
convs.5.skip.1.weight
convs.6.conv1.0.weight
convs.6.conv1.1.bias
convs.6.conv2.1.weight
convs.6.conv2.2.bias
convs.6.skip.1.weight
final_conv.0.weight
final_conv.1.bias
final_linear.0.weight
final_linear.0.bias
final_linear.1.weight
final_linear.1.bias

Discriminator의 구조에서 확연한 차이가 난다.

stylegan의 경우 Average pooling을 써서 DownSampling을 한다. 그리고 activation 함수로는 LeakyRelu를 이용한다. 

반면, stylegan 2의 경우 AveragePooling이 아닌, Conv layer에서 크기가 3x3인 filter를 Stride를 2로 가져가서 down sampling 해준다. 또한 activate 함수로 FusedLeakyRelu를 쓰고, Residual block으로 Skip connection을 이용, 두 layer를 합산해주는 연산을 한다.

정확히는, 

class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        self.conv1 = ConvLayer(in_channel, in_channel, 3)
        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)

        self.skip = ConvLayer(
            in_channel, out_channel, 1, downsample=True, activate=False, bias=False
        )

    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)

        skip = self.skip(input)
        out = (out + skip) / math.sqrt(2)

        return out

더해준뒤 root(2)로 나눠주는 연산이다.