본문 바로가기

Data-science/deep learning

[논문 읽기] ConvLSTM 이해하기

728x90

시계열에 좋은 LSTM! 이미지에 좋은 Conv!

이 둘을 짬뽕시킨 Convolutional LSTM을 알아보자.

 

LSTM에 대한 사전지식이 필요하다.

 

LSTM을 잘 설명해놓은 그림.

https://limitsinx.tistory.com/62

 

[코드로 이해하는 딥러닝 2-11] - RNN(Recurrent Neural Network)/LSTM(Long-Short-Term-Memory)

[코드로 이해하는 딥러닝 0] - 글연재에 앞서 https://limitsinx.tistory.com/27 [코드로 이해하는 딥러닝 1] - Tensorflow 시작 https://limitsinx.tistory.com/28 [코드로 이해하는 딥러닝 2] - Tensorflow 변..

limitsinx.tistory.com

$i_t$는 입력 게이트, $f_t$는 망각 게이트(forget gate인데 이렇게 번역하는게 맞나...?)

$o_t$는 출력 게이트이다. $h_t$은 t시점에서의 hidden state, $c_t$는 cell state를 나타낸다.

W는 convolution 연산에서 나오는 필터의 weight를 의미하고 *는 convolution 연산을 의미한다.
위에서 동그라미 연산은 Hadamard product로 elementwise multiplication을 의미한다. 출처 : https://medium.com/linear-algebra/part-14-dot-and-hadamard-product-b7e0723b9133

$h_(t-1)$, $X_(t-1)$ 두 개가 concat하게 conv 연산에 들어가서 같은 크기로 4개가 나눠져 생기고 이게 각각 연산을 통해 input_gate, forget_gate, output_gate 다음 cell state, hidden state가 만들어지는 것이다.

 

1. input tensor와 hidden state를 체널별로 이어준다.


2. 채널별로 이어준 녀석을 컨볼루션 레이어를 통과시켜준다.


3. combined_conv를 채널 엑시스 기준으로 split해주면 4개로 나눠진다.


위 그림과 매칭 시켜보면 여기서 conv filter 종류를 4가지 사용했다고 보면된다. 


4. cc_i에 sigmoid를 취하면 input_gate값이,


5. cc_f에 sigmoid를 취하면 forget_gate값이,


6. cc_o에 sigmoid를 취하면 output_gate값이,


7. cc_g에 tanh를 취하면 input gate와 element-wise하게 곱해질 값이 나타난다.


8. 다음 c state는 forget에 이전 c state를 곱한 것과 input gate의 출력에 g를 곱한 것을 더한 것이다.


9. 다음 hidden state는 output gate의 출력에 다음 c state에 tanh를 취한 값을 곱한 것이다.

 

코드로 살펴보자

class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias
		
        # 이 컨볼루션 레이어는 입력값으로 $X_(t-1)$, $H_(t-1)$를 받을 것이다.
        # 위 그림과 같이 출력이 총4개 이고, 각각 hidden state와 shape이 같다.
        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
		# input tensor와 hidden state를 체널별로 이어준다.
        combined = torch.cat([input_tensor, h_cur], dim=1)  
		
        # 채널별로 이어준 녀석을 컨볼루션 레이어를 통과시켜준다.
        combined_conv = self.conv(combined)
        
        # combined_conv를 채널 엑시스 기준으로 split해주면 4개로 나눠진다.
        # 위 그림과 매칭 시켜보면 여기서 conv filter를 4가지 사용했다고 보면된다. 
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        
        # cc_i에 sigmoid를 취하면 input_gate값이,
        i = torch.sigmoid(cc_i)
        # cc_f에 sigmoid를 취하면 forget_gate값이,
        f = torch.sigmoid(cc_f)
        # cc_o에 sigmoid를 취하면 output_gate값이,
        o = torch.sigmoid(cc_o)
        # cc_g에 tanh를 취하면 input gate와 element-wise하게 곱해질 값이 나타난다.
        g = torch.tanh(cc_g)

		# 다음 c state는 forget에 이전 c state를 곱한 것과 input gate의 출력에 g를 곱한 것을 더한 것이다.
        c_next = f * c_cur + i * g
        # 다음 hidden state는 output gate의 출력에 다음 c state에 tanh를 취한 값을 곱한 것이다.
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))

 

 

 

 

https://arxiv.org/pdf/1506.04214.pdf

http://homepage.divms.uiowa.edu/~zhuoning/papers/p984-yuan.pdf