728x90
시계열에 좋은 LSTM! 이미지에 좋은 Conv!
이 둘을 짬뽕시킨 Convolutional LSTM을 알아보자.
LSTM에 대한 사전지식이 필요하다.
https://limitsinx.tistory.com/62
$i_t$는 입력 게이트, $f_t$는 망각 게이트(forget gate인데 이렇게 번역하는게 맞나...?)
$o_t$는 출력 게이트이다. $h_t$은 t시점에서의 hidden state, $c_t$는 cell state를 나타낸다.
$h_(t-1)$, $X_(t-1)$ 두 개가 concat하게 conv 연산에 들어가서 같은 크기로 4개가 나눠져 생기고 이게 각각 연산을 통해 input_gate, forget_gate, output_gate 다음 cell state, hidden state가 만들어지는 것이다.
1. input tensor와 hidden state를 체널별로 이어준다.
2. 채널별로 이어준 녀석을 컨볼루션 레이어를 통과시켜준다.
3. combined_conv를 채널 엑시스 기준으로 split해주면 4개로 나눠진다.
위 그림과 매칭 시켜보면 여기서 conv filter 종류를 4가지 사용했다고 보면된다.
4. cc_i에 sigmoid를 취하면 input_gate값이,
5. cc_f에 sigmoid를 취하면 forget_gate값이,
6. cc_o에 sigmoid를 취하면 output_gate값이,
7. cc_g에 tanh를 취하면 input gate와 element-wise하게 곱해질 값이 나타난다.
8. 다음 c state는 forget에 이전 c state를 곱한 것과 input gate의 출력에 g를 곱한 것을 더한 것이다.
9. 다음 hidden state는 output gate의 출력에 다음 c state에 tanh를 취한 값을 곱한 것이다.
코드로 살펴보자
class ConvLSTMCell(nn.Module):
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
super(ConvLSTMCell, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias
# 이 컨볼루션 레이어는 입력값으로 $X_(t-1)$, $H_(t-1)$를 받을 것이다.
# 위 그림과 같이 출력이 총4개 이고, 각각 hidden state와 shape이 같다.
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
out_channels=4 * self.hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias)
def forward(self, input_tensor, cur_state):
h_cur, c_cur = cur_state
# input tensor와 hidden state를 체널별로 이어준다.
combined = torch.cat([input_tensor, h_cur], dim=1)
# 채널별로 이어준 녀석을 컨볼루션 레이어를 통과시켜준다.
combined_conv = self.conv(combined)
# combined_conv를 채널 엑시스 기준으로 split해주면 4개로 나눠진다.
# 위 그림과 매칭 시켜보면 여기서 conv filter를 4가지 사용했다고 보면된다.
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
# cc_i에 sigmoid를 취하면 input_gate값이,
i = torch.sigmoid(cc_i)
# cc_f에 sigmoid를 취하면 forget_gate값이,
f = torch.sigmoid(cc_f)
# cc_o에 sigmoid를 취하면 output_gate값이,
o = torch.sigmoid(cc_o)
# cc_g에 tanh를 취하면 input gate와 element-wise하게 곱해질 값이 나타난다.
g = torch.tanh(cc_g)
# 다음 c state는 forget에 이전 c state를 곱한 것과 input gate의 출력에 g를 곱한 것을 더한 것이다.
c_next = f * c_cur + i * g
# 다음 hidden state는 output gate의 출력에 다음 c state에 tanh를 취한 값을 곱한 것이다.
h_next = o * torch.tanh(c_next)
return h_next, c_next
def init_hidden(self, batch_size, image_size):
height, width = image_size
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
https://arxiv.org/pdf/1506.04214.pdf
http://homepage.divms.uiowa.edu/~zhuoning/papers/p984-yuan.pdf