본문 바로가기

Data-science/deep learning

trajactory GRU 코드

728x90

trajGRU 코드

논문 원제 : Deep Learning for Precipitation Nowcasting: A Benchmark and A New Model

 

ConvLSTM-based models is location-invariant

while natural motion and transformation (e.g., rotation) are location-variant in general.

Specifically, we go beyond ConvLSTM and propose the Trajectory GRU (TrajGRU) model that can actively learn the location-variant structure for recurrent connections.

https://arxiv.org/pdf/1706.03458.pdf

github : github.com/scholltan/Deep-Learning-for-Precipitation-Nowcasting-A-Benchmark-and-A-New-Model./blob/master/experiments/movingmnist/mnist_rnn_main.py

 

scholltan/Deep-Learning-for-Precipitation-Nowcasting-A-Benchmark-and-A-New-Model.

Source code of paper "[NIPS2017] Deep Learning for Precipitation Nowcasting: A Benchmark and A New Model" - scholltan/Deep-Learning-for-Precipitation-Nowcasting-A-Benchmark-and-A-New-Model.

github.com

코드 library : MXnet (아마존에서 개발한 AI library)

https://aws.amazon.com/ko/mxnet/

 

AWS 기반 Apache MXnet

유연성 및 선택권 MXNet은 C++, JavaScript, Python, R, Matlab, Julia, Scala, Clojure 및 Perl를 비롯한 다양한 프로그래밍 언어를 지원하므로 이미 익숙한 언어로 시작할 수 있습니다. 하지만 모델을 구축할 때

aws.amazon.com

 

 

def train(args):

    assert cfg.MODEL.FRAME_STACK == 1 and cfg.MODEL.FRAME_SKIP == 1

    base_dir = args.save_dir

    logging_config(folder=base_dir, name="training")

    save_movingmnist_cfg(base_dir)

    # custom dataset 에 동작하는 MovingMNISTAdvancedIterator같은 Iteartor 구현 필요

    mnist_iter = MovingMNISTAdvancedIterator(

        distractor_num=cfg.MOVINGMNIST.DISTRACTOR_NUM,

        initial_velocity_range=(cfg.MOVINGMNIST.VELOCITY_LOWER,

                                cfg.MOVINGMNIST.VELOCITY_UPPER),

        rotation_angle_range=(cfg.MOVINGMNIST.ROTATION_LOWER,

                              cfg.MOVINGMNIST.ROTATION_UPPER),

        scale_variation_range=(cfg.MOVINGMNIST.SCALE_VARIATION_LOWER,

                               cfg.MOVINGMNIST.SCALE_VARIATION_UPPER),

        illumination_factor_range=(cfg.MOVINGMNIST.ILLUMINATION_LOWER,

                                   cfg.MOVINGMNIST.ILLUMINATION_UPPER))

    # MovingMNISTFactory 정의

    mnist_rnn = MovingMNISTFactory(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE // len(args.ctx),

                                   in_seq_len=cfg.MODEL.IN_LEN,

                                   out_seq_len=cfg.MODEL.OUT_LEN)



    encoder_net, forecaster_net, loss_net = \

        encoder_forecaster_build_networks(

            factory=mnist_rnn,

            context=args.ctx)

    t_encoder_net, t_forecaster_net, t_loss_net = \

        encoder_forecaster_build_networks(

            factory=mnist_rnn,

            context=args.ctx[0],

            shared_encoder_net=encoder_net,

            shared_forecaster_net=forecaster_net,

            shared_loss_net=loss_net,

            for_finetune=True)

    encoder_net.summary()

    forecaster_net.summary()

    loss_net.summary()

    # Begin to load the model if load_dir is not empty

    if len(cfg.MODEL.LOAD_DIR) > 0:

        load_mnist_params(load_dir=cfg.MODEL.LOAD_DIR, load_iter=cfg.MODEL.LOAD_ITER,

                          encoder_net=encoder_net, forecaster_net=forecaster_net)

    states = EncoderForecasterStates(factory=mnist_rnn, ctx=args.ctx[0])

    states.reset_all()

    for info in mnist_rnn.init_encoder_state_info:

        assert info["__layout__"].find('N') == 0, "Layout=%s is not supported!" %info["__layout__"]

    iter_id = 0

    while iter_id < cfg.MODEL.TRAIN.MAX_ITER:

        frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,

                                         seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN)

        data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...], ctx=args.ctx[0]) / 255.0

        target_nd = mx.nd.array(

            frame_dat[cfg.MODEL.IN_LEN:(cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],

            ctx=args.ctx[0]) / 255.0

        train_step(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,

                   encoder_net=encoder_net, forecaster_net=forecaster_net,

                   loss_net=loss_net, init_states=states,

                   data_nd=data_nd, gt_nd=target_nd, mask_nd=None,

                   iter_id=iter_id)

        if (iter_id + 1) % 100 == 0:

            new_frame_dat, _ = mnist_iter.sample(batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,

                                         seqlen=cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN)

            data_nd = mx.nd.array(frame_dat[0:cfg.MOVINGMNIST.IN_LEN, ...], ctx=args.ctx[0]) / 255.0

            target_nd = mx.nd.array(

                frame_dat[cfg.MOVINGMNIST.IN_LEN:(cfg.MOVINGMNIST.IN_LEN + cfg.MOVINGMNIST.OUT_LEN), ...],

                ctx=args.ctx[0]) / 255.0

            pred_nd = mnist_get_prediction(data_nd=data_nd, states=states,

                                           encoder_net=encoder_net, forecaster_net=forecaster_net)

            save_gif(pred_nd.asnumpy()[:, 0, 0, :, :], os.path.join(base_dir, "pred.gif"))

            save_gif(data_nd.asnumpy()[:, 0, 0, :, :], os.path.join(base_dir, "in.gif"))

            save_gif(target_nd.asnumpy()[:, 0, 0, :, :], os.path.join(base_dir, "gt.gif"))

        # model save (encoder_net과 forecaster_net 2개인가 봄)

        if (iter_id + 1) % cfg.MODEL.SAVE_ITER == 0:

            encoder_net.save_checkpoint(

                prefix=os.path.join(base_dir, "encoder_net"),

                epoch=iter_id)

            forecaster_net.save_checkpoint(

                prefix=os.path.join(base_dir, "forecaster_net"),

                epoch=iter_id)

        iter_id += 1

encoder_forecaster_build_networks

 

def encoder_forecaster_build_networks(factory, context,

                                      shared_encoder_net=None,

                                      shared_forecaster_net=None,

                                      shared_loss_net=None,

                                      for_finetune=False):

    """

    

    Parameters

    ----------

    factory : EncoderForecasterBaseFactory

    context : list

    shared_encoder_net : MyModule or None

    shared_forecaster_net : MyModule or None

    shared_loss_net : MyModule or None

    for_finetune : bool



    Returns

    -------



    """

    encoder_net = MyModule(factory.encoder_sym(),

                           data_names=[ele.name for ele in factory.encoder_data_desc()],

                           label_names=[],

                           context=context,

                           name="encoder_net")

    encoder_net.bind(data_shapes=factory.encoder_data_desc(),

                     label_shapes=None,

                     inputs_need_grad=True,

                     shared_module=shared_encoder_net)

    if shared_encoder_net is None:

        encoder_net.init_params(mx.init.MSRAPrelu(slope=0.2))

        init_optimizer_using_cfg(encoder_net, for_finetune=for_finetune)

    forecaster_net = MyModule(factory.forecaster_sym(),

                                   data_names=[ele.name for ele in

                                               factory.forecaster_data_desc()],

                                   label_names=[],

                                   context=context,

                                   name="forecaster_net")

    forecaster_net.bind(data_shapes=factory.forecaster_data_desc(),

                        label_shapes=None,

                        inputs_need_grad=True,

                        shared_module=shared_forecaster_net)

    if shared_forecaster_net is None:

        forecaster_net.init_params(mx.init.MSRAPrelu(slope=0.2))

        init_optimizer_using_cfg(forecaster_net, for_finetune=for_finetune)



    loss_net = MyModule(factory.loss_sym(),

                        data_names=[ele.name for ele in

                                    factory.loss_data_desc()],

                        label_names=[ele.name for ele in

                                     factory.loss_label_desc()],

                        context=context,

                        name="loss_net")

    loss_net.bind(data_shapes=factory.loss_data_desc(),

                  label_shapes=factory.loss_label_desc(),

                  inputs_need_grad=True,

                  shared_module=shared_loss_net)

    if shared_loss_net is None:

        loss_net.init_params()

    return encoder_net, forecaster_net, loss_net

MovingMNISTFactory 클래스

class MovingMNISTFactory(EncoderForecasterBaseFactory):

    def __init__(self,

                 batch_size,

                 in_seq_len,

                 out_seq_len):

        super(MovingMNISTFactory, self).__init__(batch_size=batch_size,

                                                 in_seq_len=in_seq_len,

                                                 out_seq_len=out_seq_len,

                                                 height=cfg.MOVINGMNIST.IMG_SIZE,

                                                 width=cfg.MOVINGMNIST.IMG_SIZE)



    def loss_sym(self):

        self.reset_all()

        pred = mx.sym.Variable('pred')  # Shape: (out_seq_len, batch_size, 1, H, W)

        target = mx.sym.Variable('target')  # Shape: (out_seq_len, batch_size, 1, H, W)

        avg_mse = mx.sym.mean(mx.sym.square(target - pred))

        avg_mse = mx.sym.MakeLoss(avg_mse,

                                  name="mse")

        loss = mx.sym.Group([avg_mse])

        return loss

→ loss 정의만 따로 해주면 된다. 

→ 나머진 EncoderForecasterBaseFactory란 추상화 클래스의 함수에 다 정의돼 있음!

MovingMNISTFactory 추상화 클래스

 

class EncoderForecasterBaseFactory(PredictionBaseFactory):

    def __init__(self,

                 batch_size,

                 in_seq_len,

                 out_seq_len,

                 height,

                 width,

                 ctx_num=1,

                 name="encoder_forecaster"):

        super(EncoderForecasterBaseFactory, self).__init__(batch_size=batch_size,

                                                           in_seq_len=in_seq_len,

                                                           out_seq_len=out_seq_len,

                                                           height=height,

                                                           width=width,

                                                           name=name)

        self._ctx_num = ctx_num



    def _init_rnn(self):

        self._encoder_rnn_blocks, self._forecaster_rnn_blocks, self._gan_rnn_blocks =\

            get_encoder_forecaster_rnn_blocks(batch_size=self._batch_size)

        return self._encoder_rnn_blocks + self._forecaster_rnn_blocks + self._gan_rnn_blocks



    @property

    def init_encoder_state_info(self):

        init_state_info = []

        for block in self._encoder_rnn_blocks:

            for state in block.init_state_vars():

                init_state_info.append({'name': state.name,

                                        'shape': state.attr('__shape__'),

                                        '__layout__': state.list_attr()['__layout__']})

        return init_state_info



    @property

    def init_forecaster_state_info(self):

        init_state_info = []

        for block in self._forecaster_rnn_blocks:

            for state in block.init_state_vars():

                init_state_info.append({'name': state.name,

                                        'shape': state.attr('__shape__'),

                                        '__layout__': state.list_attr()['__layout__']})

        return init_state_info



    @property

    def init_gan_state_info(self):

        init_gan_state_info = []

        for block in self._gan_rnn_blocks:

            for state in block.init_state_vars():

                init_gan_state_info.append({'name': state.name,

                                            'shape': state.attr('__shape__'),

                                            '__layout__': state.list_attr()['__layout__']})

        return init_gan_state_info



    def stack_rnn_encode(self, data):

        CONFIG = cfg.MODEL.ENCODER_FORECASTER

        pre_encoded_data = self._pre_encode_frame(frame_data=data, seqlen=self._in_seq_len)

        reshape_data = mx.sym.Reshape(pre_encoded_data, shape=(-1, 0, 0, 0), reverse=True)



        # Encoder Part

        conv1 = conv2d_act(data=reshape_data,

                           num_filter=CONFIG.FIRST_CONV[0],

                           kernel=(CONFIG.FIRST_CONV[1], CONFIG.FIRST_CONV[1]),

                           stride=(CONFIG.FIRST_CONV[2], CONFIG.FIRST_CONV[2]),

                           pad=(CONFIG.FIRST_CONV[3], CONFIG.FIRST_CONV[3]),

                           act_type=cfg.MODEL.CNN_ACT_TYPE,

                           name="econv1")

        rnn_block_num = len(CONFIG.RNN_BLOCKS.NUM_FILTER)

        encoder_rnn_block_states = []

        for i in range(rnn_block_num):

            if i == 0:

                inputs = conv1

            else:

                inputs = downsample

            rnn_out, states = self._encoder_rnn_blocks[i].unroll(

                length=self._in_seq_len,

                inputs=inputs,

                begin_states=None,

                ret_mid=False)

            encoder_rnn_block_states.append(states)

            if i < rnn_block_num - 1:

                downsample = downsample_module(data=rnn_out[-1],

                                               num_filter=CONFIG.RNN_BLOCKS.NUM_FILTER[i + 1],

                                               kernel=(CONFIG.DOWNSAMPLE[i][0],

                                                       CONFIG.DOWNSAMPLE[i][0]),

                                               stride=(CONFIG.DOWNSAMPLE[i][1],

                                                       CONFIG.DOWNSAMPLE[i][1]),

                                               pad=(CONFIG.DOWNSAMPLE[i][2],

                                                    CONFIG.DOWNSAMPLE[i][2]),

                                               b_h_w=(self._batch_size,

                                                      CONFIG.FEATMAP_SIZE[i + 1],

                                                      CONFIG.FEATMAP_SIZE[i + 1]),

                                               name="edown%d" %(i + 1))

        return encoder_rnn_block_states



    def stack_rnn_forecast(self, block_state_list, last_frame):

        CONFIG = cfg.MODEL.ENCODER_FORECASTER

        block_state_list = [self._forecaster_rnn_blocks[i].to_split(block_state_list[i])

                            for i in range(len(self._forecaster_rnn_blocks))]

        rnn_block_num = len(CONFIG.RNN_BLOCKS.NUM_FILTER)

        rnn_block_outputs = []

        # RNN Forecaster Part

        curr_inputs = None

        for i in range(rnn_block_num - 1, -1, -1):

            rnn_out, rnn_state = self._forecaster_rnn_blocks[i].unroll(

                length=self._out_seq_len, inputs=curr_inputs,

                begin_states=block_state_list[i][::-1],  # Reverse the order of states for the forecaster

                ret_mid=False)

            rnn_block_outputs.append(rnn_out)

            if i > 0:

                upsample = upsample_module(data=rnn_out[-1],

                                           num_filter=CONFIG.RNN_BLOCKS.NUM_FILTER[i],

                                           kernel=(CONFIG.UPSAMPLE[i - 1][0],

                                                   CONFIG.UPSAMPLE[i - 1][0]),

                                           stride=(CONFIG.UPSAMPLE[i - 1][1],

                                                   CONFIG.UPSAMPLE[i - 1][1]),

                                           pad=(CONFIG.UPSAMPLE[i - 1][2],

                                                CONFIG.UPSAMPLE[i - 1][2]),

                                           b_h_w=(self._batch_size, CONFIG.FEATMAP_SIZE[i - 1]),

                                           name="fup%d" %i)

                curr_inputs = upsample

        # Output

        if cfg.MODEL.OUT_TYPE == "DFN":

            concat_fbrnn1_out = mx.sym.concat(*rnn_out[-1], dim=0)

            dynamic_filter = deconv2d(data=concat_fbrnn1_out,

                                      num_filter=121,

                                      kernel=(CONFIG.LAST_DECONV[1], CONFIG.LAST_DECONV[1]),

                                      stride=(CONFIG.LAST_DECONV[2], CONFIG.LAST_DECONV[2]),

                                      pad=(CONFIG.LAST_DECONV[3], CONFIG.LAST_DECONV[3]))

            flow = dynamic_filter

            dynamic_filter = mx.sym.SliceChannel(dynamic_filter, axis=0, num_outputs=self._out_seq_len)

            prev_frame = last_frame

            preds = []

            for i in range(self._out_seq_len):

                pred_ele = DFN(data=prev_frame, local_kernels=dynamic_filter[i], K=11, batch_size=self._batch_size)

                preds.append(pred_ele)

                prev_frame = pred_ele

            pred = mx.sym.concat(*preds, dim=0)

        elif cfg.MODEL.OUT_TYPE == "direct":

            flow = None

            deconv1 = deconv2d_act(data=mx.sym.concat(*rnn_out[-1], dim=0),

                                   num_filter=CONFIG.LAST_DECONV[0],

                                   kernel=(CONFIG.LAST_DECONV[1], CONFIG.LAST_DECONV[1]),

                                   stride=(CONFIG.LAST_DECONV[2], CONFIG.LAST_DECONV[2]),

                                   pad=(CONFIG.LAST_DECONV[3], CONFIG.LAST_DECONV[3]),

                                   act_type=cfg.MODEL.CNN_ACT_TYPE,

                                   name="fdeconv1")

            conv_final = conv2d_act(data=deconv1,

                                    num_filter=CONFIG.LAST_DECONV[0],

                                    kernel=(3, 3), stride=(1, 1), pad=(1, 1),

                                    act_type=cfg.MODEL.CNN_ACT_TYPE, name="conv_final")

            pred = conv2d(data=conv_final,

                          num_filter=1, kernel=(1, 1), name="out")

        else:

            raise NotImplementedError

        pred = mx.sym.Reshape(pred,

                              shape=(self._out_seq_len, self._batch_size,

                                     1, self._height, self._width),

                              __layout__="TNCHW")

        return pred, flow



    def encoder_sym(self):

        self.reset_all()

        data = mx.sym.Variable('data')  # Shape: (in_seq_len, batch_size, C, H, W)

        block_state_list = self.stack_rnn_encode(data=data)

        states = []

        for i, rnn_block in enumerate(self._encoder_rnn_blocks):

            states.extend(rnn_block.flatten_add_layout(block_state_list[i]))

        return mx.sym.Group(states)



    def encoder_data_desc(self):

        ret = list()

        ret.append(mx.io.DataDesc(name='data',

                                  shape=(self._in_seq_len,

                                         self._batch_size * self._ctx_num,

                                         1,

                                         self._height,

                                         self._width),

                                  layout="TNCHW"))

        for info in self.init_encoder_state_info:

            state_shape = safe_eval(info['shape'])

            assert info['__layout__'].find('N') == 0,\

                "Layout=%s is not supported!" %info["__layout__"]

            state_shape = (state_shape[0] * self._ctx_num, ) + state_shape[1:]

            ret.append(mx.io.DataDesc(name=info['name'],

                                      shape=state_shape,

                                      layout=info['__layout__']))

        return ret



    def forecaster_sym(self):

        self.reset_all()

        block_state_list = []

        for block in self._forecaster_rnn_blocks:

            block_state_list.append(block.init_state_vars())



        if cfg.MODEL.OUT_TYPE == "direct":

            pred, _ = self.stack_rnn_forecast(block_state_list=block_state_list,

                                              last_frame=None)

            return mx.sym.Group([pred])

        else:

            last_frame = mx.sym.Variable('last_frame')  # Shape: (batch_size, C, H, W)

            pred, flow = self.stack_rnn_forecast(block_state_list=block_state_list,

                                                 last_frame=last_frame)

            return mx.sym.Group([pred, mx.sym.BlockGrad(flow)])



    def forecaster_data_desc(self):

        ret = list()

        for info in self.init_forecaster_state_info:

            state_shape = safe_eval(info['shape'])

            assert info['__layout__'].find('N') == 0, \

                "Layout=%s is not supported!" % info["__layout__"]

            state_shape = (state_shape[0] * self._ctx_num,) + state_shape[1:]

            ret.append(mx.io.DataDesc(name=info['name'],

                                      shape=state_shape,

                                      layout=info['__layout__']))

        if cfg.MODEL.OUT_TYPE != "direct":

            ret.append(mx.io.DataDesc(name="last_frame",

                                      shape=(self._ctx_num * self._batch_size,

                                             1, self._height, self._width),

                                      layout="NCHW"))

        return ret



    def loss_sym(self):

        raise NotImplementedError



    def loss_data_desc(self):

        ret = list()

        ret.append(mx.io.DataDesc(name='pred',

                                  shape=(self._out_seq_len,

                                         self._ctx_num * self._batch_size,

                                         1,

                                         self._height,

                                         self._width),

                                  layout="TNCHW"))

        return ret



    def loss_label_desc(self):

        ret = list()

        ret.append(mx.io.DataDesc(name='target',

                                  shape=(self._out_seq_len,

                                         self._ctx_num * self._batch_size,

                                         1,

                                         self._height,

                                         self._width),

                                  layout="TNCHW"))

        if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:

            ret.append(mx.io.DataDesc(name='mask',

                                      shape=(self._out_seq_len,

                                             self._ctx_num * self._batch_size,

                                             1,

                                             self._height,

                                             self._width),

                                      layout="TNCHW"))

        return ret

 

encoder 종류에 따른 rnn call (ConvGRU, trajGRU)

def get_encoder_forecaster_rnn_blocks(batch_size):

    encoder_rnn_blocks = []

    forecaster_rnn_blocks = []

    gan_rnn_blocks = []

    CONFIG = cfg.MODEL.ENCODER_FORECASTER.RNN_BLOCKS

    for vec, block_prefix in [(encoder_rnn_blocks, "ebrnn"),

                              (forecaster_rnn_blocks, "fbrnn"),

                              (gan_rnn_blocks, "dbrnn")]:

        for i in range(len(CONFIG.NUM_FILTER)):

            name = "%s%d" % (block_prefix, i + 1)

            if CONFIG.LAYER_TYPE[i] == "ConvGRU":

                rnn_block = BaseStackRNN(base_rnn_class=ConvGRU,

                                         stack_num=CONFIG.STACK_NUM[i],

                                         name=name,

                                         residual_connection=CONFIG.RES_CONNECTION,

                                         num_filter=CONFIG.NUM_FILTER[i],

                                         b_h_w=(batch_size,

                                                cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i],

                                                cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i]),

                                         h2h_kernel=CONFIG.H2H_KERNEL[i],

                                         h2h_dilate=CONFIG.H2H_DILATE[i],

                                         i2h_kernel=CONFIG.I2H_KERNEL[i],

                                         i2h_pad=CONFIG.I2H_PAD[i],

                                         act_type=cfg.MODEL.RNN_ACT_TYPE)

            elif CONFIG.LAYER_TYPE[i] == "TrajGRU":

                rnn_block = BaseStackRNN(base_rnn_class=TrajGRU,

                                         stack_num=CONFIG.STACK_NUM[i],

                                         name=name,

                                         L=CONFIG.L[i],

                                         residual_connection=CONFIG.RES_CONNECTION,

                                         num_filter=CONFIG.NUM_FILTER[i],

                                         b_h_w=(batch_size,

                                                cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i],

                                                cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i]),

                                         h2h_kernel=CONFIG.H2H_KERNEL[i],

                                         h2h_dilate=CONFIG.H2H_DILATE[i],

                                         i2h_kernel=CONFIG.I2H_KERNEL[i],

                                         i2h_pad=CONFIG.I2H_PAD[i],

                                         act_type=cfg.MODEL.RNN_ACT_TYPE)

            else:

                raise NotImplementedError

            vec.append(rnn_block)

    return encoder_rnn_blocks, forecaster_rnn_blocks, gan_rnn_blocks

train step 및 기타

 

Model(Encoder & Forecaster) States

class EncoderForecasterStates(object):

    def __init__(self, factory, ctx):

        self._factory = factory

        self._ctx = ctx

        self._encoder_state_info = factory.init_encoder_state_info

        self._forecaster_state_info = factory.init_forecaster_state_info

        self._states_nd = []

        for info in self._encoder_state_info:

            state_shape = safe_eval(info['shape'])

            state_shape = (state_shape[0] * factory._ctx_num, ) + state_shape[1:]

            self._states_nd.append(mx.nd.zeros(shape=state_shape, ctx=ctx))



    def reset_all(self):

        for ele, info in zip(self._states_nd, self._encoder_state_info):

            ele[:] = 0



    def reset_batch(self, batch_id):

        for ele, info in zip(self._states_nd, self._encoder_state_info):

            ele[batch_id][:] = 0



    def update(self, states_nd):

        for target, src in zip(self._states_nd, states_nd):

            target[:] = src



    def get_encoder_states(self):

        return self._states_nd



    def get_forecaster_state(self):

        return self._states_nd

 

train step

def train_step(batch_size, encoder_net, forecaster_net,

               loss_net, init_states,

               data_nd, gt_nd, mask_nd, iter_id=None):

    """Finetune the encoder, forecaster and GAN for one step



    Parameters

    ----------

    batch_size : int

    encoder_net : MyModule

    forecaster_net : MyModule

    loss_net : MyModule

    init_states : EncoderForecasterStates

    data_nd : mx.nd.ndarray

    gt_nd : mx.nd.ndarray

    mask_nd : mx.nd.ndarray

    iter_id : int



    Returns

    -------

    init_states: EncoderForecasterStates

    loss_dict: dict

    """

    # Forward Encoder

    encoder_net.forward(is_train=True,

                        data_batch=mx.io.DataBatch(data=[data_nd] + init_states.get_encoder_states()))

    encoder_states_nd = encoder_net.get_outputs()

    init_states.update(encoder_states_nd)

    # Forward Forecaster

    if cfg.MODEL.OUT_TYPE == "direct":

        forecaster_net.forward(is_train=True,

                               data_batch=mx.io.DataBatch(data=init_states.get_forecaster_state()))

    else:

        last_frame_nd = data_nd[data_nd.shape[0] - 1]

        forecaster_net.forward(is_train=True,

                               data_batch=mx.io.DataBatch(data=init_states.get_forecaster_state() +

                                                                    [last_frame_nd]))

    forecaster_outputs = forecaster_net.get_outputs()

    pred_nd = forecaster_outputs[0]



    # Calculate the gradient of the loss functions

    if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK:

        loss_net.forward_backward(data_batch=mx.io.DataBatch(data=[pred_nd],

                                                             label=[gt_nd, mask_nd]))

    else:

        loss_net.forward_backward(data_batch=mx.io.DataBatch(data=[pred_nd],

                                                             label=[gt_nd]))

    pred_grad = loss_net.get_input_grads()[0]

    loss_dict = loss_net.get_output_dict()

    for k in loss_dict:

        loss_dict[k] = nd.mean(loss_dict[k]).asscalar()

    # Backward Forecaster

    forecaster_net.backward(out_grads=[pred_grad])

    if cfg.MODEL.OUT_TYPE == "direct":

        encoder_states_grad_nd = forecaster_net.get_input_grads()

    else:

        encoder_states_grad_nd = forecaster_net.get_input_grads()[:-1]

    # Backward Encoder

    encoder_net.backward(encoder_states_grad_nd)

    # Update forecaster and encoder

    forecaster_grad_norm = forecaster_net.clip_by_global_norm(max_norm=cfg.MODEL.TRAIN.GRAD_CLIP)

    encoder_grad_norm = encoder_net.clip_by_global_norm(max_norm=cfg.MODEL.TRAIN.GRAD_CLIP)

    forecaster_net.update()

    encoder_net.update()

    loss_str = ", ".join(["%s=%g" %(k, v) for k, v in loss_dict.items()])

    if iter_id is not None:

        logging.info("Iter:%d, %s, e_gnorm=%g, f_gnorm=%g"

                     % (iter_id, loss_str, encoder_grad_norm, forecaster_grad_norm))

    return init_states, loss_dict

load_encoder_forecaster_params

def load_encoder_forecaster_params(load_dir, load_iter, encoder_net, forecaster_net):

    logging.info("Loading parameters from {}, Iter = {}"

                 .format(os.path.realpath(load_dir), load_iter))

    encoder_arg_params, encoder_aux_params = load_params(prefix=os.path.join(load_dir,

                                                                             "encoder_net"),

                                                         epoch=load_iter)

    encoder_net.init_params(arg_params=encoder_arg_params, aux_params=encoder_aux_params,

                            allow_missing=False, force_init=True)

    forecaster_arg_params, forecaster_aux_params = load_params(prefix=os.path.join(load_dir,

                                                                             "forecaster_net"),

                                                               epoch=load_iter)

    forecaster_net.init_params(arg_params=forecaster_arg_params,

                               aux_params=forecaster_aux_params,

                               allow_missing=False,

                               force_init=True)

    logging.info("Loading Complete!")

 

MyModule

class MyModule(Module):

    """Some enhancement to the mx.mod.Module



    """



    def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),

                 logger=logging, context=mx.context.gpu(), work_load_list=None,

                 fixed_param_names=None, state_names=None, name=None):

        self._name = name

        super(MyModule, self).__init__(symbol=symbol,

                                       data_names=data_names,

                                       label_names=label_names,

                                       logger=logger,

                                       context=context,

                                       work_load_list=work_load_list,

                                       fixed_param_names=fixed_param_names,

                                       state_names=state_names)

        self._tmp_grads = None



    def clip_by_global_norm(self, max_norm=1.0):

        """Clips gradient norm.

        The norm is computed over all gradients together, as if they were

         concatenated into a single vector. Gradients are modified in-place.

        The method is first used in

         `[ICML2013] On the difficulty of training recurrent neural networks`

        Parameters

        ----------

        max_norm : float or int

            The maximum clipping threshold of the gradient norm.

        Returns

        -------

        norm_val : float

            The computed norm of the gradients.

        Examples

        --------

        An example of using clip_grad_norm to clip the gradient before updating the parameters::

            >>> #Get the gradient via back-propagation

            >>> net.forward_backward(data_batch=data_batch)

            >>> norm_val = net.clip_by_global_norm(max_norm=1.0)

            >>> net.update()

        """

        assert self.binded and self.params_initialized and self.optimizer_initialized

        norm_val = self.global_grad_norm()

        if norm_val > max_norm:

            ratio = max_norm / float(norm_val)

            for grads in self._exec_group.grad_arrays:

                for grad in grads:

                    grad *= ratio

        return norm_val



    def global_grad_norm(self):

        """Calculate global gradient norm.

        The L2 norm is computed over all gradients together, as if they were

         concatenated into a single vector.

        Could be used to debug the optimization process.

         See http://videolectures.net/deeplearning2015_goodfellow_network_optimization/

        Returns

        -------

        norm_val : float

            The computed norm of the gradients.

        Examples

        --------

        An example of using global_norm to calculate the gradient norm after back-propgation::

            >>> #Get the gradient via back-propagation

            >>> net.forward_backward(data_batch=data_batch)

            >>> norm_val = net.global_grad_norm()

            >>> print(norm_val)

        """

        assert self.binded and self.params_initialized and self.optimizer_initialized

        # The code in the following will cause the estimated norm to be different for multiple gpus

        norm_val = 0.0

        for exe in self._exec_group.execs:

            norm_val += nd_global_norm(exe.grad_arrays).asscalar()

        norm_val /= float(len(self._exec_group.execs))

        norm_val *= self._optimizer.rescale_grad

        return norm_val



    def debug_norm_all(self, debug_gnorm=True):

        if debug_gnorm:

            for k, v, grad_v in zip(self._param_names, self._exec_group.param_arrays,

                                    self._exec_group.grad_arrays):

                logging.debug("%s: v-norm: %g, g-norm: %g"

                              %(k,

                                nd.norm(v[0]).asnumpy()[0],

                                nd.norm(grad_v[0]).asnumpy()[0]))

        else:

            for k, v in zip(self._param_names, self._exec_group.param_arrays):

                logging.debug("%s: v-norm: %g"

                              %(k,

                                nd.norm(v[0]).asnumpy()[0]))



    def summary(self, level=2):

        """Summarize the network parameters.



        Parameters

        ----------

        level : int, optional

            Level of the summarization logs to print.

            The log becomes more verbose with higher summary level.

            - Level = 0

                Print the total param number + aux param number

            - Level = 1

                Print the shape of all parameters + The total number of paremter numbers

            - Level = 2

                Print the shape of the data/state and other available information in Level 1

        """

        self.logger.info("Summary of %s" %self._name)

        assert self.binded and self.params_initialized

        assert 0 <= level <= 2, \

            "Level must be between 0 and 2, level=%d is not supported" % level



        def _log_var(key, value, typ="param"):

            if typ == "param":

                if k in self._fixed_param_names:

                    self.logger.info("   %s: %s, %d, req = %s, fixed"

                                     % (key,

                                        str(value.shape),

                                        np.prod(value.shape),

                                        self._exec_group.grad_req[k]))

                else:

                    self.logger.info("   %s: %s, %d, req = %s"

                                     % (key,

                                        str(value.shape),

                                        np.prod(value.shape),

                                        self._exec_group.grad_req[k]))

            elif typ == "data" or typ == "aux":

                self.logger.info("   %s: %s, %d"

                                 % (key,

                                    str(value.shape),

                                    np.prod(value.shape)))



        total_param_num = 0

        total_fixed_param_num = 0

        total_aux_param_num = 0

        if level >= 2:

            if len(self.data_names) == 0:

                self.logger.info("Data: None")

            else:

                self.logger.info("Data:")

                for k, v in zip(self.data_names, self.data_shapes):

                    _log_var(k, v, typ="data")

            if len(self._state_names) == 0:

                self.logger.info("State: None")

            else:

                self.logger.info("State:")

                for k in self._state_names:

                    v = self._exec_group.execs[0].arg_dict[k]

                    _log_var(k, v, typ="data")

        if level >= 1:

            if len(self._param_names) == 0:

                self.logger.info("Param: None")

            else:

                self.logger.info("Params:")

                for k in self._param_names:

                    v = self._arg_params[k]

                    _log_var(k, v)

                    if k in self._fixed_param_names:

                        total_fixed_param_num += np.prod(v.shape)

                    else:

                        total_param_num += np.prod(v.shape)

            if len(self._aux_names) == 0:

                self.logger.info("Aux States: None")

            else:

                self.logger.info("Aux States: ")

                for k in self._aux_names:

                    v = self._aux_params[k]

                    _log_var(k, v, typ="aux")

                    total_aux_param_num += np.prod(v.shape)

        else:

            for k in self._param_names:

                v = self._arg_params[k]

                total_param_num += np.prod(v.shape)

            for k in self._aux_names:

                v = self._aux_params[k]

                total_aux_param_num += np.prod(v.shape)

        self.logger.info("Total Param Num (exclude fixed ones): " + str(total_param_num))

        self.logger.info("Total Fixed Param Num: " + str(total_fixed_param_num))

        self.logger.info("Total Aux Param Num: " + str(total_aux_param_num))



    def get_output_dict(self):

        outputs = self.get_outputs()

        return OrderedDict([(k, v) for k, v in zip(self._output_names, outputs)])



    def clear_grad(self):

        assert self.binded and self.params_initialized and self.optimizer_initialized

        # clear the gradient

        for grads in self._exec_group.grad_arrays:

            for grad in grads:

                grad[:] = 0



    def save_tmp_grad(self):

        if self._tmp_grads is None:

            self._tmp_grads = []

            for grads in self._exec_group.grad_arrays:

                vec = []

                for grad in grads:

                    vec.append(grad.copyto(grad.context))

                self._tmp_grads.append(vec)

        else:

            for i, grads in enumerate(self._exec_group.grad_arrays):

                for j, grad in enumerate(grads):

                    self._tmp_grads[i][j][:] = grad



    def acc_grad_with_tmp(self):

        assert self._tmp_grads is not None

        for i, grads in enumerate(self._exec_group.grad_arrays):

            for j, grad in enumerate(grads):

                grad += self._tmp_grads[i][j]





    def load_params_allow_missing(self, fname):

        """Loads model parameters from file.



        Parameters

        ----------

        fname : str

            Path to input param file.



        Examples

        --------

        >>> # An example of loading module parameters.

        >>> mod.load_params('myfile')

        """

        logging.info("Load Param From %s" %fname)

        save_dict = mx.nd.load(fname)

        arg_params = {}

        aux_params = {}

        for k, value in save_dict.items():

            arg_type, name = k.split(':', 1)

            if arg_type == 'arg':

                if name in self._param_names:

                    logging.info("set %s" %name)

                    arg_params[name] = value

            elif arg_type == 'aux':

                if name in self._aux_names:

                    logging.info("set %s" % name)

                    aux_params[name] = value

            else:

                raise ValueError("Invalid param file " + fname)

        self.set_params(arg_params, aux_params, allow_missing=True)

 

기타 참고하면 좋을 링크

ai.googleblog.com/2020/03/a-neural-weather-model-for-eight-hour.html

 

A Neural Weather Model for Eight-Hour Precipitation Forecasting

Posted by Nal Kalchbrenner and Casper Sønderby, Research Scientists, Google Research, Amsterdam Predicting weather from minutes to weeks...

ai.googleblog.com

ai.googleblog.com/2020/01/using-machine-learning-to-nowcast.html

 

Using Machine Learning to “Nowcast” Precipitation in High Resolution

Posted by Jason Hickey, Senior Software Engineer, Google Research The weather can affect a person’s daily routine in both mundane and se...

ai.googleblog.com