trajGRU 코드
논문 원제 : Deep Learning for Precipitation Nowcasting: A Benchmark and A New Model
ConvLSTM-based models is location-invariant
while natural motion and transformation (e.g., rotation) are location-variant in general.
Specifically, we go beyond ConvLSTM and propose the Trajectory GRU (TrajGRU) model that can actively learn the location-variant structure for recurrent connections.
https://arxiv.org/pdf/1706.03458.pdf
코드 library : MXnet (아마존에서 개발한 AI library)
https://aws.amazon.com/ko/mxnet/
def train(args):
assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1
base_dir = args.save_dir
logging_config(folder=base_dir, name="training")
save_movingmnist_cfg(base_dir)
# custom dataset 에 동작하는 MovingMNISTAdvancedIterator같은 Iteartor 구현 필요
mnist_iter = MovingMNISTAdvancedIterator(
distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM,
initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER,
cfg.MOVINGMNIST.VELOCITY_UPPER),
rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER,
cfg.MOVINGMNIST.ROTATION_UPPER),
scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER,
cfg.MOVINGMNIST.SCALE_VARIATION_UPPER),
illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER,
cfg.MOVINGMNIST.ILLUMINATION_UPPER))
# MovingMNISTFactory 정의
mnist_rnn = MovingMNISTFactory(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),
in_seq_len=cfg.MODEL.IN_LEN,
out_seq_len=cfg.MODEL.OUT_LEN)
encoder_net, forecaster_net, loss_net = \
encoder_forecaster_build_networks(
factory=mnist_rnn,
context=args.ctx)
t_encoder_net, t_forecaster_net, t_loss_net = \
encoder_forecaster_build_networks(
factory=mnist_rnn,
context=args.ctx[0],
shared_encoder_net=encoder_net,
shared_forecaster_net=forecaster_net,
shared_loss_net=loss_net,
for_finetune=True)
encoder_net.summary()
forecaster_net.summary()
loss_net.summary()
# Begin to load the model if load_dir is not empty
if len(cfg.MODEL.LOAD_DIR) > 0:
load_mnist_params(load_dir=cfg.MODEL.LOAD_DIR, load_iter=cfg.MODEL.LOAD_ITER,
encoder_net=encoder_net, forecaster_net=forecaster_net)
states = EncoderForecasterStates(factory=mnist_rnn, ctx=args.ctx[0])
states.reset_all()
for info in mnist_rnn.init_encoder_state_info:
assert info["__layout__"].find('N') == 0, "Layout=%s is not supported!" %info["__layout__"]
iter_id = 0
while iter_id < cfg.MODEL.TRAIN.MAX_ITER:
frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN)
data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...], ctx=args.ctx[0]) / 255.0
target_nd = mx.nd.array(
frame_dat[cfg.MODEL.IN_LEN:(cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
ctx=args.ctx[0]) / 255.0
train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
encoder_net=encoder_net, forecaster_net=forecaster_net,
loss_net=loss_net, init_states=states,
data_nd=data_nd, gt_nd=target_nd, mask_nd=None,
iter_id=iter_id)
if (iter_id + 1) % 100 == 0:
new_frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN)
data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...], ctx=args.ctx[0]) / 255.0
target_nd = mx.nd.array(
frame_dat[cfg.MOVINGMNIST.IN_LEN:(cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],
ctx=args.ctx[0]) / 255.0
pred_nd = mnist_get_prediction(data_nd=data_nd, states=states,
encoder_net=encoder_net, forecaster_net=forecaster_net)
save_gif(pred_nd.asnumpy()[:, 0, 0, :, :], os.path.join(base_dir, "pred.gif"))
save_gif(data_nd.asnumpy()[:, 0, 0, :, :], os.path.join(base_dir, "in.gif"))
save_gif(target_nd.asnumpy()[:, 0, 0, :, :], os.path.join(base_dir, "gt.gif"))
# model save (encoder_net과 forecaster_net 2개인가 봄)
if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:
encoder_net.save_checkpoint(
prefix=os.path.join(base_dir, "encoder_net"),
epoch=iter_id)
forecaster_net.save_checkpoint(
prefix=os.path.join(base_dir, "forecaster_net"),
epoch=iter_id)
iter_id += 1
encoder_forecaster_build_networks
def encoder_forecaster_build_networks(factory, context,
shared_encoder_net=None,
shared_forecaster_net=None,
shared_loss_net=None,
for_finetune=False):
"""
Parameters
----------
factory : EncoderForecasterBaseFactory
context : list
shared_encoder_net : MyModule or None
shared_forecaster_net : MyModule or None
shared_loss_net : MyModule or None
for_finetune : bool
Returns
-------
"""
encoder_net = MyModule(factory.encoder_sym(),
data_names=[ele.name for ele in factory.encoder_data_desc()],
label_names=[],
context=context,
name="encoder_net")
encoder_net.bind(data_shapes=factory.encoder_data_desc(),
label_shapes=None,
inputs_need_grad=True,
shared_module=shared_encoder_net)
if shared_encoder_net is None:
encoder_net.init_params(mx.init.MSRAPrelu(slope=0.2))
init_optimizer_using_cfg(encoder_net, for_finetune=for_finetune)
forecaster_net = MyModule(factory.forecaster_sym(),
data_names=[ele.name for ele in
factory.forecaster_data_desc()],
label_names=[],
context=context,
name="forecaster_net")
forecaster_net.bind(data_shapes=factory.forecaster_data_desc(),
label_shapes=None,
inputs_need_grad=True,
shared_module=shared_forecaster_net)
if shared_forecaster_net is None:
forecaster_net.init_params(mx.init.MSRAPrelu(slope=0.2))
init_optimizer_using_cfg(forecaster_net, for_finetune=for_finetune)
loss_net = MyModule(factory.loss_sym(),
data_names=[ele.name for ele in
factory.loss_data_desc()],
label_names=[ele.name for ele in
factory.loss_label_desc()],
context=context,
name="loss_net")
loss_net.bind(data_shapes=factory.loss_data_desc(),
label_shapes=factory.loss_label_desc(),
inputs_need_grad=True,
shared_module=shared_loss_net)
if shared_loss_net is None:
loss_net.init_params()
return encoder_net, forecaster_net, loss_net
MovingMNISTFactory 클래스
class MovingMNISTFactory(EncoderForecasterBaseFactory):
def __init__(self,
batch_size,
in_seq_len,
out_seq_len):
super(MovingMNISTFactory, self).__init__(batch_size=batch_size,
in_seq_len=in_seq_len,
out_seq_len=out_seq_len,
height=cfg.MOVINGMNIST.IMG_SIZE,
width=cfg.MOVINGMNIST.IMG_SIZE)
def loss_sym(self):
self.reset_all()
pred = mx.sym.Variable('pred') # Shape: (out_seq_len, batch_size, 1, H, W)
target = mx.sym.Variable('target') # Shape: (out_seq_len, batch_size, 1, H, W)
avg_mse = mx.sym.mean(mx.sym.square(target - pred))
avg_mse = mx.sym.MakeLoss(avg_mse,
name="mse")
loss = mx.sym.Group([avg_mse])
return loss
→ loss 정의만 따로 해주면 된다.
→ 나머진 EncoderForecasterBaseFactory란 추상화 클래스의 함수에 다 정의돼 있음!
MovingMNISTFactory 추상화 클래스
class EncoderForecasterBaseFactory(PredictionBaseFactory):
def __init__(self,
batch_size,
in_seq_len,
out_seq_len,
height,
width,
ctx_num=1,
name="encoder_forecaster"):
super(EncoderForecasterBaseFactory, self).__init__(batch_size=batch_size,
in_seq_len=in_seq_len,
out_seq_len=out_seq_len,
height=height,
width=width,
name=name)
self._ctx_num = ctx_num
def _init_rnn(self):
self._encoder_rnn_blocks, self._forecaster_rnn_blocks, self._gan_rnn_blocks =\
get_encoder_forecaster_rnn_blocks(batch_size=self._batch_size)
return self._encoder_rnn_blocks + self._forecaster_rnn_blocks + self._gan_rnn_blocks
@property
def init_encoder_state_info(self):
init_state_info = []
for block in self._encoder_rnn_blocks:
for state in block.init_state_vars():
init_state_info.append({'name': state.name,
'shape': state.attr('__shape__'),
'__layout__': state.list_attr()['__layout__']})
return init_state_info
@property
def init_forecaster_state_info(self):
init_state_info = []
for block in self._forecaster_rnn_blocks:
for state in block.init_state_vars():
init_state_info.append({'name': state.name,
'shape': state.attr('__shape__'),
'__layout__': state.list_attr()['__layout__']})
return init_state_info
@property
def init_gan_state_info(self):
init_gan_state_info = []
for block in self._gan_rnn_blocks:
for state in block.init_state_vars():
init_gan_state_info.append({'name': state.name,
'shape': state.attr('__shape__'),
'__layout__': state.list_attr()['__layout__']})
return init_gan_state_info
def stack_rnn_encode(self, data):
CONFIG = cfg.MODEL.ENCODER_FORECASTER
pre_encoded_data = self._pre_encode_frame(frame_data=data, seqlen=self._in_seq_len)
reshape_data = mx.sym.Reshape(pre_encoded_data, shape=(-1, 0, 0, 0), reverse=True)
# Encoder Part
conv1 = conv2d_act(data=reshape_data,
num_filter=CONFIG.FIRST_CONV[0],
kernel=(CONFIG.FIRST_CONV[1], CONFIG.FIRST_CONV[1]),
stride=(CONFIG.FIRST_CONV[2], CONFIG.FIRST_CONV[2]),
pad=(CONFIG.FIRST_CONV[3], CONFIG.FIRST_CONV[3]),
act_type=cfg.MODEL.CNN_ACT_TYPE,
name="econv1")
rnn_block_num = len(CONFIG.RNN_BLOCKS.NUM_FILTER)
encoder_rnn_block_states = []
for i in range(rnn_block_num):
if i == 0:
inputs = conv1
else:
inputs = downsample
rnn_out, states = self._encoder_rnn_blocks[i].unroll(
length=self._in_seq_len,
inputs=inputs,
begin_states=None,
ret_mid=False)
encoder_rnn_block_states.append(states)
if i < rnn_block_num - 1:
downsample = downsample_module(data=rnn_out[-1],
num_filter=CONFIG.RNN_BLOCKS.NUM_FILTER[i + 1],
kernel=(CONFIG.DOWNSAMPLE[i][0],
CONFIG.DOWNSAMPLE[i][0]),
stride=(CONFIG.DOWNSAMPLE[i][1],
CONFIG.DOWNSAMPLE[i][1]),
pad=(CONFIG.DOWNSAMPLE[i][2],
CONFIG.DOWNSAMPLE[i][2]),
b_h_w=(self._batch_size,
CONFIG.FEATMAP_SIZE[i + 1],
CONFIG.FEATMAP_SIZE[i + 1]),
name="edown%d" %(i + 1))
return encoder_rnn_block_states
def stack_rnn_forecast(self, block_state_list, last_frame):
CONFIG = cfg.MODEL.ENCODER_FORECASTER
block_state_list = [self._forecaster_rnn_blocks[i].to_split(block_state_list[i])
for i in range(len(self._forecaster_rnn_blocks))]
rnn_block_num = len(CONFIG.RNN_BLOCKS.NUM_FILTER)
rnn_block_outputs = []
# RNN Forecaster Part
curr_inputs = None
for i in range(rnn_block_num - 1, -1, -1):
rnn_out, rnn_state = self._forecaster_rnn_blocks[i].unroll(
length=self._out_seq_len, inputs=curr_inputs,
begin_states=block_state_list[i][::-1], # Reverse the order of states for the forecaster
ret_mid=False)
rnn_block_outputs.append(rnn_out)
if i > 0:
upsample = upsample_module(data=rnn_out[-1],
num_filter=CONFIG.RNN_BLOCKS.NUM_FILTER[i],
kernel=(CONFIG.UPSAMPLE[i - 1][0],
CONFIG.UPSAMPLE[i - 1][0]),
stride=(CONFIG.UPSAMPLE[i - 1][1],
CONFIG.UPSAMPLE[i - 1][1]),
pad=(CONFIG.UPSAMPLE[i - 1][2],
CONFIG.UPSAMPLE[i - 1][2]),
b_h_w=(self._batch_size, CONFIG.FEATMAP_SIZE[i - 1]),
name="fup%d" %i)
curr_inputs = upsample
# Output
if cfg.MODEL.OUT_TYPE == "DFN":
concat_fbrnn1_out = mx.sym.concat(*rnn_out[-1], dim=0)
dynamic_filter = deconv2d(data=concat_fbrnn1_out,
num_filter=121,
kernel=(CONFIG.LAST_DECONV[1], CONFIG.LAST_DECONV[1]),
stride=(CONFIG.LAST_DECONV[2], CONFIG.LAST_DECONV[2]),
pad=(CONFIG.LAST_DECONV[3], CONFIG.LAST_DECONV[3]))
flow = dynamic_filter
dynamic_filter = mx.sym.SliceChannel(dynamic_filter, axis=0, num_outputs=self._out_seq_len)
prev_frame = last_frame
preds = []
for i in range(self._out_seq_len):
pred_ele = DFN(data=prev_frame, local_kernels=dynamic_filter[i], K=11, batch_size=self._batch_size)
preds.append(pred_ele)
prev_frame = pred_ele
pred = mx.sym.concat(*preds, dim=0)
elif cfg.MODEL.OUT_TYPE == "direct":
flow = None
deconv1 = deconv2d_act(data=mx.sym.concat(*rnn_out[-1], dim=0),
num_filter=CONFIG.LAST_DECONV[0],
kernel=(CONFIG.LAST_DECONV[1], CONFIG.LAST_DECONV[1]),
stride=(CONFIG.LAST_DECONV[2], CONFIG.LAST_DECONV[2]),
pad=(CONFIG.LAST_DECONV[3], CONFIG.LAST_DECONV[3]),
act_type=cfg.MODEL.CNN_ACT_TYPE,
name="fdeconv1")
conv_final = conv2d_act(data=deconv1,
num_filter=CONFIG.LAST_DECONV[0],
kernel=(3, 3), stride=(1, 1), pad=(1, 1),
act_type=cfg.MODEL.CNN_ACT_TYPE, name="conv_final")
pred = conv2d(data=conv_final,
num_filter=1, kernel=(1, 1), name="out")
else:
raise NotImplementedError
pred = mx.sym.Reshape(pred,
shape=(self._out_seq_len, self._batch_size,
1, self._height, self._width),
__layout__="TNCHW")
return pred, flow
def encoder_sym(self):
self.reset_all()
data = mx.sym.Variable('data') # Shape: (in_seq_len, batch_size, C, H, W)
block_state_list = self.stack_rnn_encode(data=data)
states = []
for i, rnn_block in enumerate(self._encoder_rnn_blocks):
states.extend(rnn_block.flatten_add_layout(block_state_list[i]))
return mx.sym.Group(states)
def encoder_data_desc(self):
ret = list()
ret.append(mx.io.DataDesc(name='data',
shape=(self._in_seq_len,
self._batch_size * self._ctx_num,
1,
self._height,
self._width),
layout="TNCHW"))
for info in self.init_encoder_state_info:
state_shape = safe_eval(info['shape'])
assert info['__layout__'].find('N') == 0,\
"Layout=%s is not supported!" %info["__layout__"]
state_shape = (state_shape[0] * self._ctx_num, ) + state_shape[1:]
ret.append(mx.io.DataDesc(name=info['name'],
shape=state_shape,
layout=info['__layout__']))
return ret
def forecaster_sym(self):
self.reset_all()
block_state_list = []
for block in self._forecaster_rnn_blocks:
block_state_list.append(block.init_state_vars())
if cfg.MODEL.OUT_TYPE == "direct":
pred, _ = self.stack_rnn_forecast(block_state_list=block_state_list,
last_frame=None)
return mx.sym.Group([pred])
else:
last_frame = mx.sym.Variable('last_frame') # Shape: (batch_size, C, H, W)
pred, flow = self.stack_rnn_forecast(block_state_list=block_state_list,
last_frame=last_frame)
return mx.sym.Group([pred, mx.sym.BlockGrad(flow)])
def forecaster_data_desc(self):
ret = list()
for info in self.init_forecaster_state_info:
state_shape = safe_eval(info['shape'])
assert info['__layout__'].find('N') == 0, \
"Layout=%s is not supported!" % info["__layout__"]
state_shape = (state_shape[0] * self._ctx_num,) + state_shape[1:]
ret.append(mx.io.DataDesc(name=info['name'],
shape=state_shape,
layout=info['__layout__']))
if cfg.MODEL.OUT_TYPE != "direct":
ret.append(mx.io.DataDesc(name="last_frame",
shape=(self._ctx_num * self._batch_size,
1, self._height, self._width),
layout="NCHW"))
return ret
def loss_sym(self):
raise NotImplementedError
def loss_data_desc(self):
ret = list()
ret.append(mx.io.DataDesc(name='pred',
shape=(self._out_seq_len,
self._ctx_num * self._batch_size,
1,
self._height,
self._width),
layout="TNCHW"))
return ret
def loss_label_desc(self):
ret = list()
ret.append(mx.io.DataDesc(name='target',
shape=(self._out_seq_len,
self._ctx_num * self._batch_size,
1,
self._height,
self._width),
layout="TNCHW"))
if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:
ret.append(mx.io.DataDesc(name='mask',
shape=(self._out_seq_len,
self._ctx_num * self._batch_size,
1,
self._height,
self._width),
layout="TNCHW"))
return ret
encoder 종류에 따른 rnn call (ConvGRU, trajGRU)
def get_encoder_forecaster_rnn_blocks(batch_size):
encoder_rnn_blocks = []
forecaster_rnn_blocks = []
gan_rnn_blocks = []
CONFIG = cfg.MODEL.ENCODER_FORECASTER.RNN_BLOCKS
for vec, block_prefix in [(encoder_rnn_blocks, "ebrnn"),
(forecaster_rnn_blocks, "fbrnn"),
(gan_rnn_blocks, "dbrnn")]:
for i in range(len(CONFIG.NUM_FILTER)):
name = "%s%d" % (block_prefix, i + 1)
if CONFIG.LAYER_TYPE[i] == "ConvGRU":
rnn_block = BaseStackRNN(base_rnn_class=ConvGRU,
stack_num=CONFIG.STACK_NUM[i],
name=name,
residual_connection=CONFIG.RES_CONNECTION,
num_filter=CONFIG.NUM_FILTER[i],
b_h_w=(batch_size,
cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i],
cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i]),
h2h_kernel=CONFIG.H2H_KERNEL[i],
h2h_dilate=CONFIG.H2H_DILATE[i],
i2h_kernel=CONFIG.I2H_KERNEL[i],
i2h_pad=CONFIG.I2H_PAD[i],
act_type=cfg.MODEL.RNN_ACT_TYPE)
elif CONFIG.LAYER_TYPE[i] == "TrajGRU":
rnn_block = BaseStackRNN(base_rnn_class=TrajGRU,
stack_num=CONFIG.STACK_NUM[i],
name=name,
L=CONFIG.L[i],
residual_connection=CONFIG.RES_CONNECTION,
num_filter=CONFIG.NUM_FILTER[i],
b_h_w=(batch_size,
cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i],
cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i]),
h2h_kernel=CONFIG.H2H_KERNEL[i],
h2h_dilate=CONFIG.H2H_DILATE[i],
i2h_kernel=CONFIG.I2H_KERNEL[i],
i2h_pad=CONFIG.I2H_PAD[i],
act_type=cfg.MODEL.RNN_ACT_TYPE)
else:
raise NotImplementedError
vec.append(rnn_block)
return encoder_rnn_blocks, forecaster_rnn_blocks, gan_rnn_blocks
train step 및 기타
Model(Encoder & Forecaster) States
class EncoderForecasterStates(object):
def __init__(self, factory, ctx):
self._factory = factory
self._ctx = ctx
self._encoder_state_info = factory.init_encoder_state_info
self._forecaster_state_info = factory.init_forecaster_state_info
self._states_nd = []
for info in self._encoder_state_info:
state_shape = safe_eval(info['shape'])
state_shape = (state_shape[0] * factory._ctx_num, ) + state_shape[1:]
self._states_nd.append(mx.nd.zeros(shape=state_shape, ctx=ctx))
def reset_all(self):
for ele, info in zip(self._states_nd, self._encoder_state_info):
ele[:] = 0
def reset_batch(self, batch_id):
for ele, info in zip(self._states_nd, self._encoder_state_info):
ele[batch_id][:] = 0
def update(self, states_nd):
for target, src in zip(self._states_nd, states_nd):
target[:] = src
def get_encoder_states(self):
return self._states_nd
def get_forecaster_state(self):
return self._states_nd
train step
def train_step(batch_size, encoder_net, forecaster_net,
loss_net, init_states,
data_nd, gt_nd, mask_nd, iter_id=None):
"""Finetune the encoder, forecaster and GAN for one step
Parameters
----------
batch_size : int
encoder_net : MyModule
forecaster_net : MyModule
loss_net : MyModule
init_states : EncoderForecasterStates
data_nd : mx.nd.ndarray
gt_nd : mx.nd.ndarray
mask_nd : mx.nd.ndarray
iter_id : int
Returns
-------
init_states: EncoderForecasterStates
loss_dict: dict
"""
# Forward Encoder
encoder_net.forward(is_train=True,
data_batch=mx.io.DataBatch(data=[data_nd] + init_states.get_encoder_states()))
encoder_states_nd = encoder_net.get_outputs()
init_states.update(encoder_states_nd)
# Forward Forecaster
if cfg.MODEL.OUT_TYPE == "direct":
forecaster_net.forward(is_train=True,
data_batch=mx.io.DataBatch(data=init_states.get_forecaster_state()))
else:
last_frame_nd = data_nd[data_nd.shape[0] - 1]
forecaster_net.forward(is_train=True,
data_batch=mx.io.DataBatch(data=init_states.get_forecaster_state() +
[last_frame_nd]))
forecaster_outputs = forecaster_net.get_outputs()
pred_nd = forecaster_outputs[0]
# Calculate the gradient of the loss functions
if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:
loss_net.forward_backward(data_batch=mx.io.DataBatch(data=[pred_nd],
label=[gt_nd, mask_nd]))
else:
loss_net.forward_backward(data_batch=mx.io.DataBatch(data=[pred_nd],
label=[gt_nd]))
pred_grad = loss_net.get_input_grads()[0]
loss_dict = loss_net.get_output_dict()
for k in loss_dict:
loss_dict[k] = nd.mean(loss_dict[k]).asscalar()
# Backward Forecaster
forecaster_net.backward(out_grads=[pred_grad])
if cfg.MODEL.OUT_TYPE == "direct":
encoder_states_grad_nd = forecaster_net.get_input_grads()
else:
encoder_states_grad_nd = forecaster_net.get_input_grads()[:-1]
# Backward Encoder
encoder_net.backward(encoder_states_grad_nd)
# Update forecaster and encoder
forecaster_grad_norm = forecaster_net.clip_by_global_norm(max_norm=cfg.MODEL.TRAIN.GRAD_CLIP)
encoder_grad_norm = encoder_net.clip_by_global_norm(max_norm=cfg.MODEL.TRAIN.GRAD_CLIP)
forecaster_net.update()
encoder_net.update()
loss_str = ", ".join(["%s=%g" %(k, v) for k, v in loss_dict.items()])
if iter_id is not None:
logging.info("Iter:%d, %s, e_gnorm=%g, f_gnorm=%g"
% (iter_id, loss_str, encoder_grad_norm, forecaster_grad_norm))
return init_states, loss_dict
load_encoder_forecaster_params
def load_encoder_forecaster_params(load_dir, load_iter, encoder_net, forecaster_net):
logging.info("Loading parameters from {}, Iter = {}"
.format(os.path.realpath(load_dir), load_iter))
encoder_arg_params, encoder_aux_params = load_params(prefix=os.path.join(load_dir,
"encoder_net"),
epoch=load_iter)
encoder_net.init_params(arg_params=encoder_arg_params, aux_params=encoder_aux_params,
allow_missing=False, force_init=True)
forecaster_arg_params, forecaster_aux_params = load_params(prefix=os.path.join(load_dir,
"forecaster_net"),
epoch=load_iter)
forecaster_net.init_params(arg_params=forecaster_arg_params,
aux_params=forecaster_aux_params,
allow_missing=False,
force_init=True)
logging.info("Loading Complete!")
MyModule
class MyModule(Module):
"""Some enhancement to the mx.mod.Module
"""
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
logger=logging, context=mx.context.gpu(), work_load_list=None,
fixed_param_names=None, state_names=None, name=None):
self._name = name
super(MyModule, self).__init__(symbol=symbol,
data_names=data_names,
label_names=label_names,
logger=logger,
context=context,
work_load_list=work_load_list,
fixed_param_names=fixed_param_names,
state_names=state_names)
self._tmp_grads = None
def clip_by_global_norm(self, max_norm=1.0):
"""Clips gradient norm.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
The method is first used in
`[ICML2013] On the difficulty of training recurrent neural networks`
Parameters
----------
max_norm : float or int
The maximum clipping threshold of the gradient norm.
Returns
-------
norm_val : float
The computed norm of the gradients.
Examples
--------
An example of using clip_grad_norm to clip the gradient before updating the parameters::
>>> #Get the gradient via back-propagation
>>> net.forward_backward(data_batch=data_batch)
>>> norm_val = net.clip_by_global_norm(max_norm=1.0)
>>> net.update()
"""
assert self.binded and self.params_initialized and self.optimizer_initialized
norm_val = self.global_grad_norm()
if norm_val > max_norm:
ratio = max_norm / float(norm_val)
for grads in self._exec_group.grad_arrays:
for grad in grads:
grad *= ratio
return norm_val
def global_grad_norm(self):
"""Calculate global gradient norm.
The L2 norm is computed over all gradients together, as if they were
concatenated into a single vector.
Could be used to debug the optimization process.
See http://videolectures.net/deeplearning2015_goodfellow_network_optimization/
Returns
-------
norm_val : float
The computed norm of the gradients.
Examples
--------
An example of using global_norm to calculate the gradient norm after back-propgation::
>>> #Get the gradient via back-propagation
>>> net.forward_backward(data_batch=data_batch)
>>> norm_val = net.global_grad_norm()
>>> print(norm_val)
"""
assert self.binded and self.params_initialized and self.optimizer_initialized
# The code in the following will cause the estimated norm to be different for multiple gpus
norm_val = 0.0
for exe in self._exec_group.execs:
norm_val += nd_global_norm(exe.grad_arrays).asscalar()
norm_val /= float(len(self._exec_group.execs))
norm_val *= self._optimizer.rescale_grad
return norm_val
def debug_norm_all(self, debug_gnorm=True):
if debug_gnorm:
for k, v, grad_v in zip(self._param_names, self._exec_group.param_arrays,
self._exec_group.grad_arrays):
logging.debug("%s: v-norm: %g, g-norm: %g"
%(k,
nd.norm(v[0]).asnumpy()[0],
nd.norm(grad_v[0]).asnumpy()[0]))
else:
for k, v in zip(self._param_names, self._exec_group.param_arrays):
logging.debug("%s: v-norm: %g"
%(k,
nd.norm(v[0]).asnumpy()[0]))
def summary(self, level=2):
"""Summarize the network parameters.
Parameters
----------
level : int, optional
Level of the summarization logs to print.
The log becomes more verbose with higher summary level.
- Level = 0
Print the total param number + aux param number
- Level = 1
Print the shape of all parameters + The total number of paremter numbers
- Level = 2
Print the shape of the data/state and other available information in Level 1
"""
self.logger.info("Summary of %s" %self._name)
assert self.binded and self.params_initialized
assert 0 <= level <= 2, \
"Level must be between 0 and 2, level=%d is not supported" % level
def _log_var(key, value, typ="param"):
if typ == "param":
if k in self._fixed_param_names:
self.logger.info(" %s: %s, %d, req = %s, fixed"
% (key,
str(value.shape),
np.prod(value.shape),
self._exec_group.grad_req[k]))
else:
self.logger.info(" %s: %s, %d, req = %s"
% (key,
str(value.shape),
np.prod(value.shape),
self._exec_group.grad_req[k]))
elif typ == "data" or typ == "aux":
self.logger.info(" %s: %s, %d"
% (key,
str(value.shape),
np.prod(value.shape)))
total_param_num = 0
total_fixed_param_num = 0
total_aux_param_num = 0
if level >= 2:
if len(self.data_names) == 0:
self.logger.info("Data: None")
else:
self.logger.info("Data:")
for k, v in zip(self.data_names, self.data_shapes):
_log_var(k, v, typ="data")
if len(self._state_names) == 0:
self.logger.info("State: None")
else:
self.logger.info("State:")
for k in self._state_names:
v = self._exec_group.execs[0].arg_dict[k]
_log_var(k, v, typ="data")
if level >= 1:
if len(self._param_names) == 0:
self.logger.info("Param: None")
else:
self.logger.info("Params:")
for k in self._param_names:
v = self._arg_params[k]
_log_var(k, v)
if k in self._fixed_param_names:
total_fixed_param_num += np.prod(v.shape)
else:
total_param_num += np.prod(v.shape)
if len(self._aux_names) == 0:
self.logger.info("Aux States: None")
else:
self.logger.info("Aux States: ")
for k in self._aux_names:
v = self._aux_params[k]
_log_var(k, v, typ="aux")
total_aux_param_num += np.prod(v.shape)
else:
for k in self._param_names:
v = self._arg_params[k]
total_param_num += np.prod(v.shape)
for k in self._aux_names:
v = self._aux_params[k]
total_aux_param_num += np.prod(v.shape)
self.logger.info("Total Param Num (exclude fixed ones): " + str(total_param_num))
self.logger.info("Total Fixed Param Num: " + str(total_fixed_param_num))
self.logger.info("Total Aux Param Num: " + str(total_aux_param_num))
def get_output_dict(self):
outputs = self.get_outputs()
return OrderedDict([(k, v) for k, v in zip(self._output_names, outputs)])
def clear_grad(self):
assert self.binded and self.params_initialized and self.optimizer_initialized
# clear the gradient
for grads in self._exec_group.grad_arrays:
for grad in grads:
grad[:] = 0
def save_tmp_grad(self):
if self._tmp_grads is None:
self._tmp_grads = []
for grads in self._exec_group.grad_arrays:
vec = []
for grad in grads:
vec.append(grad.copyto(grad.context))
self._tmp_grads.append(vec)
else:
for i, grads in enumerate(self._exec_group.grad_arrays):
for j, grad in enumerate(grads):
self._tmp_grads[i][j][:] = grad
def acc_grad_with_tmp(self):
assert self._tmp_grads is not None
for i, grads in enumerate(self._exec_group.grad_arrays):
for j, grad in enumerate(grads):
grad += self._tmp_grads[i][j]
def load_params_allow_missing(self, fname):
"""Loads model parameters from file.
Parameters
----------
fname : str
Path to input param file.
Examples
--------
>>> # An example of loading module parameters.
>>> mod.load_params('myfile')
"""
logging.info("Load Param From %s" %fname)
save_dict = mx.nd.load(fname)
arg_params = {}
aux_params = {}
for k, value in save_dict.items():
arg_type, name = k.split(':', 1)
if arg_type == 'arg':
if name in self._param_names:
logging.info("set %s" %name)
arg_params[name] = value
elif arg_type == 'aux':
if name in self._aux_names:
logging.info("set %s" % name)
aux_params[name] = value
else:
raise ValueError("Invalid param file " + fname)
self.set_params(arg_params, aux_params, allow_missing=True)
기타 참고하면 좋을 링크
ai.googleblog.com/2020/03/a-neural-weather-model-for-eight-hour.html
ai.googleblog.com/2020/01/using-machine-learning-to-nowcast.html
'Data-science > deep learning' 카테고리의 다른 글
tensorflow 설치시 cuda, cudnn 버전 호환성 (0) | 2020.10.28 |
---|---|
conda 가상환경 구축 (0) | 2020.10.28 |
RainNet v1.0 설명 (Spatiotemporal prediction) (0) | 2020.10.21 |
stylegan2 & stylegan2-ada 코드 (loss & network) (0) | 2020.10.17 |
[StyleGan2-ada 실습] AFHQ 데이터 셋 이용해서 stylegan2-ada 학습하기 2 (0) | 2020.10.16 |