[mxnet]核心接口(一)

视频来源:https://www.bilibili.com/video/BV1EW411n7b8?p=2
根据视频整理的文档


(一)mxnet 的核心接口

接口 功效
Context 指定运行设备
NDArray python与C++交互数据对象
DataIter 为训练提供batch数据
Symbol 定义网络
LR Scheduler 定义学习率衰减策略
Optimizer 优化器
Executor 图的前向计算与反向梯度推导
Metric 查看模型训练过程指标
Callback 回调函数
KVStore 跨设备的键值储存
Module ALL in one 将前面的模块封装

一、Context

定义方式:

  1. CPU Context:通过mxnet.cpu(0)定义,这里的设备id为0,会默认使用所有CPU核心

  2. GPU Context:通过mxnet.gpu(0)定义,这里的设备id,决定使用那块GPU设备,代码前后所

    [mx.gpu(0),mx.gpu(1)] 表示并行

  3. 使用要求整个代码中的context,必须保持一致,且如果使用GPU

对比MXNet和Tensorflow:

  1. tf.device("/gpu:1")=mx.gpu(device_id=1)
  2. `tf.device("/cpu:0")=mx.cpu(device_id=0)

二、NDArray

简单理解:一个同时支持CPU与GPU的Numpy,两者可以进行转换

为什么mxnet需要ndarray?

  • ndarray 之于mxnet 相当于 numpy 之于 tensorflow
  • mxnet 执行网络,输入的数据,与获取的节点数据对象都是ndarray,是客户端与底层C++交互的数据接口
  • mxnet使用python定义Op,其中使用的面向过程的计算,都是基于ndaray
  • mxnet 新出的实现动态狼罗的 Gluon API ,大量使用了ndarray提供操作

nd中的与np中的ndarry运算基本相同,只是ndarray设计到了不同设备上的运算、复制、移动等。

2.1 ndarray基本运算相同

import numpy as np
from mxnet import nd
np_array=np.arange(10,dtype=np.float32).reshape((2,5))
mx_array=nd.arange(10,dtype=np.float32).reshape((2,5))
print('np_array:\n',np_array)
print('mx_array:\n',mx_array)

2.2 线性代数上相似

print("np_array inner product:\n",np_array.dot(np_array.T))
print("mx_array inner product:\n",nd.dot(mx_array,mx_array.T))

2.3 Numpy与NDarry数组的相互转化

nd.array()asnumpy()

# numpy-->mxnet
np_ones=np.ones(shape=(2,2))
from_np_array=nd.array(np_ones,dtype=np.float32)
print(from_np_array)

# mxnet-->numpy
mx_ones=nd.ones(shape=(2,2),dtype=np.float32)
from_mx_array=mx_ones.asnumpy()
print(from_mx_array)

2.4 mxnet ndarray 在CPU和GPU上的执行

import mxnet as mx
# init mxnet ndarray on cpu
mx_cpu=nd.ones(shape=(2,2),ctx=mx.cpu(0))
# init mxnet ndarray on gpu
mx_gpu=nd.ones(shape=(2,2),ctx=mx.gpu(0))
print(mx_cpu+mx_gpu)

提示:会报错,因为在不同设备上进行运算。

2.5 Host与Device上的NDarray的移动操作

# init mxnet nadarray on cpu
mx_cpu=nd.ones(shape=(2,2),ctx=mx.cpu(0))
print("mx_cpu_context:",mx_cpu.context)
host_to_gpu=mx_cpu.as_in_context(mx.gpu(0))
print("host_to_gpu context",host_to_gpu.context)

# init mxnet ndarray on cpu
mx_cpu=nd.ones(shape=(2,2),ctx=mx.cpu(0))

# init mxnet nadarray on gpu
mx_gpu=nd.ones(shape=(2,2),ctx=mx.gpu(0))
print(mx_cpu.as_in_context(mx.gpu(0)+mx_gpu))

2.6 Host与Device上的NDarray的复制操作

mx_gpu_0=nd.ones(shape=(2,2),ctx=mx.gpu(0))
mx_gpu_1=nd.zeros(shape=(2,2),ctx=mx.gpu(1))
print("before copy:\n",mx_gpu_1)
mx_gpu_0.copyto(mx_gpu_1)
print("after copy:\n",mx_gpu_1)

三、mx.io.DataIter用于提供输入数据的接口

功能

MXNet中的所有I / O均由此类的专业化处理。 MXNet中的数据迭代器类似于Python中的标准迭代器。 在每次调用“ next”时,它们都会返回一个“ DataBatch”,该数据代表下一批数据。 当没有更多数据要返回时,它将引发StopIteration异常。

3.0 初步认识

mx.io.NDArrayIter(mnist['train_data'],mnist['train_label'],batch_size,shuttle=True)

将numpy的数据对象传入改为DataBatch对象,具体的迭代数据对象类型为ndarray

3.1 工作原理

  • 数据迭代器一个Epoch结束后,会抛出StopIteration异常
  • epoch结束后的callback函数的调用,就是依靠数据迭代器的StopIteration

3.2 自定义一个DataIter

# mnist 是一个(60000,1,28,28)的对象

class DemoDataIter(mx.io.DataIter):
    def __init__(self,batch_size,mnist):
        super(DemoDataIter,self).__init__(batch_size)
        self.idx=0
        self.sample_size=60000
        self.batch_size=batch_size
        self.mnist_data=mnist['train_data']
        self.mnist_label=mnist["train_label"]
        self.provide_data=[("data",(batch_size,1,28,28))]
        self.provide_label=[("label",(batch_size,))]
        
    def next(self):
        if self.idx+self.batch_size>self.sample_size:
            raise StopIteration
        data =self.mnist_data[self.idx:self.idx+self.batch_size]
        label=self.mnist_label[self.idx:self.idx+self.batch_size]
        data_batch=mx.io.DataBatch(data=[data],label=[label])
        self.idx+=self.batch_size
        return data_batch
    
    def reset(self):
        self.idx=0

3.3 利用异常进行epoch

iter=DemoDataIter(2,mnist)
epoch=0
while True:
    try:
        iter.next()
        except StopIteration:
            iter.reset()
            epoch+=1
            print("epoch num:" epoch)
            if epoch >=10 :break

3.4 rec文件的读取

help

mxnet 的高性能数据读取文件格式:rec,类似tensorflow的tfrecords,caffe的lmdb

  1. 目的都是为了将零碎的图片文件等(不限于图片)格式化为一个连续存储的二进制序列文件,加速训练
  2. 官方提供了适用于多数情况下,图片到rec文件的转换工具,Im2rec.py,如果自己从源码编译mxnet可以得到一个C++版本的im2rec可执行文件,但是速度上python和C++版本的速度差异不打,主要速度瓶颈在IO(图片的大小)

写rec文件,对原始数据(图片等)的要求:

  1. 图片数据,不同个图片的尺寸可以任意大小
  2. 不仅仅使用于图片,可以支持一般类型的数据,如numpy数组
  3. 可以支持多label,如检测任务中不同Bounding Box的位置信息,可以当作不同label进行写入

3.5 mxnet提供的常用rec格式数据迭代器

  1. ImageRecordIter

    一般用于分类任务图片读取与运行期间的数据增广的rec迭代器

  2. ImageDetRecordIter

    用于检测任务图片的读取与运行期间的数据增广的rec迭代器

  3. BucketSwntenceIter

    用于不定长序列数据的迭代,常用于RNN

案例1:ImageRecordIter

mx.io.ImageRecordIter(
    # 需要迭代的rec文件路径
    path_image        ='data.rec',
    
    label_width       =1,
    
    # r,g.b的均值,也可以增加标准差std
    mean_r            =123.68,
    mean_g            =116.779,
    mean_b            =109.939,
    
    data_name         ='data',
    label_name        ='softmax_label',
    data_shape        =(3,224,224),
    batch_size        =128,
    rand_crop         =True,
    
    # 数据的在线增广
    min_random_scale  =1,
    max_random_scale  =1,
    pad               =0,
    fill_vale         =1,
    max_aspect_ratio  =0, #[0,1]
    random_h          =0,#[0,180]
    random_s          =0,#[0,255]
    random_l          =0,#[0,255]
    max_rotate_angle  =0, #[90,360]
    max_shear_ratio   =0,#[0,1]
    rand_mirror       =True,
    preprocess_threads=8,
    shuffle           =True,
    num_parts         =nworker,#kvstore
    
    part_index        =rank  #kvstore)

案例2:ImageDetRecordIter

mx.io.ImageDetRecordIter(
    # 需要迭代的rec文件路径
    path_image        ='data.rec',
    
    label_width       =-1,#varibale  label size
    label_pad_width   =350.
    label_pad_value   =-1,
    
    # r,g.b的均值,也可以增加标准差std
    mean_r            =123.68,
    mean_g            =116.779,
    mean_b            =109.939,
    
    data_shape        =(3,224,224),
    batch_size        =128,
    
    # 数据增广功能
    shuttle           =True,
    random_hue_prob   =0.5,
    max_random_hue    =18,
    random_saturation_prob=0.5,
    max_random_saturation =32,
    random_illumination_prob=0.5
    max_random_illumination=32,
    random_contrast_prob   =0.5,
    max_random_contrast    =0.5
    
     #...
)

四、symbol

symbol:用于符号式编程的接口

4.1 符号式的定义只是构建了图的结构,没有立即执行

# 命令式编程
data=nd.ones(shape=(1,2),dtype=np.float32)
weight=nd.random.normal(shape=(12,2))
bias=nd.random.normal(shape=(12,))
fc=nd.FullyConnected(data,weight=weight,bias=bias,num_hidden=12)
print("imperative:",fc)
# 符号式编程
data=mx.sym.Variable('data',shape=(1,2))
fc=mx.sym.FullyConnected(data=data,num_hidden=12)
print('symbolic:',fc)

out:

imperative: 
[[-2.5673738  -2.8626602  -1.3854902  -1.8239176   2.4033828  -3.8420382
  -4.0500565   0.43445504 -1.5064126   0.4659688  -0.5043662  -2.2939835 ]]

symbolic: 

提示

  1. mx.sym.FullyConnected(data=data,num_hidden)=12中隐含着自动定义了weight和bias,不用显示性的定义
  2. fc的打印结果为symbol,注意到并没有执行这个数据,同时data中也没有实际的数据

4.2 symbol的基本函数-获取symbol

  • symbol.infer_type:推导当前symbol所依赖的所有的symbol的数据类型
  • symbol.infer_shape:推导当前symbol所依赖的所有symbol的形状
  • symbol.list_arguments:列出当前symbol所用到的基本参量的名称
  • symbol.list_outputs:列出当前symbol的输出名称
  • symbol.list_auxiliary_states:列出当前symbol的辅助参量名称
  1. arguments=输入数据symbol+权值参数symbol
  2. auxiliary_states=辅助symbol,比如BN中的gamma和beta
  3. 我们只要拿到最终的一个symbol,就可以查看它所依赖的所有的symbol。参数需要初始化,数据需要迭代器放入
  4. 对于推导数据的类型,我们只需要告知输入数据的类型即可,如数据类型np.float32
  5. 对于推导数据的形状,我们只需要告知输入数据的形状即可,如形状data=(1,2)

4.3 symbol几个基本函数输出的关系

infer_type和infer_shape输出都是一个具有三个元素的元组,这三个元素分别对应基本参量、输出和辅助参量

# fc在前面
print("name:",(fc.list_arguments(),fc.list_outputs(),fc.list_auxiliary_states()))
print("type:",fc.infer_type(data=(np.float32,np.float32)))
print("shape:",fc.infer_shape(data=(1,2)))

out:

name: (['data', 'fullyconnected0_weight', 'fullyconnected0_bias'], ['fullyconnected0_output'], [])
type: ([, , ], [], [])
shape: ([(1, 2), (12, 2), (12,)], [(1, 12)], [])

至此我们可以获取图中每个symbol的相关信息,如相撞,数据类型等

name:

infer_type:推导类型

infer_shape:推导形状

以上三者一一对应。

4.4 查看结构,利用json文件

print(fc.tojson())
{
  "nodes": [
    {
      "op": "null", 
      "name": "data", 
      "attrs": {"__shape__": "(1, 2)"}, 
      "inputs": []
    }, 
    {
      "op": "null", 
      "name": "fullyconnected0_weight", 
      "attrs": {"num_hidden": "12"}, 
      "inputs": []
    }, 
    {
      "op": "null", 
      "name": "fullyconnected0_bias", 
      "attrs": {"num_hidden": "12"}, 
      "inputs": []
    }, 
    {
      "op": "FullyConnected", 
      "name": "fullyconnected0", 
      "attrs": {"num_hidden": "12"}, 
      "inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0]]
    }
  ], 
  "arg_nodes": [0, 1, 2], 
  "node_row_ptr": [0, 1, 2, 3, 4], 
  "heads": [[3, 0, 0]], 
  "attrs": {"mxnet_version": ["int", 10600]}
}

[mxnet]核心接口(一)_第1张图片

4.5 symbol如何执行,并获取节点输出值

data=mx.sym.Variable('data',shape=(1,2))
weight=mx.sym.Variable('weight',shape=(1,2))
bias=mx.sym.Variable('bias',shape=(1.2))
fc=mx.sym.FullyConnected(data=data,weight=weight,bias=bias,num_hidden=12)

executor=fc.bind(ctx=mx.cpu(),args={'data':mx.nd.ones([1,2]),
'weight':nd.random.normal(shape=(12,2),
'bias':nd.random.normal(shape=(12,)

executor.forward()
print(executor.outputs[0].asnumpy())

注意:

  • 在这里 weight,bias 需要显示的初始化,显示性的原因在于当下没有学习到自动初始化的高阶函数
  • 对于神经节点层 fc 来说,有三个输入:data,weight,bias
  • bind ,意味捆绑
  • fc.bind()是将流程图数据捆绑在一起
    ctx是指运行的设备环境,args是绑定的具体数据,注意到这里绑定的数据是根据名称去查找的,而不是根据变量名
  • 所有的输入数据的对象为ndarray
  • 在执行完forward后,调用outputs属性得出ndarray的结果,然后转化层numpy
  • outputs[0],显然在fc层有12个神经元,而它是第一个神经元输出的结果

4.6 symbol 如何获取中间节点

data
fc1
fc2

如果当下只有fc2这个变量,不存在fc1这个变量,那么如何获取中间变量呢?

data=mx.sym.Variable('data')
fc1=mx.sym.FullyConnected(data=data,num_hidden=12,name='fc1')
fc2=mx.sym.FullyConnected(data=fc1,num_hidden=12,name='fc2')

print(fc2.get_internals().list_outputs())
print(type(fc2.get_internals()['fc1_output'])) # also a  symbol class

out:

['data', 'fc1_weight', 'fc1_bias', 'fc1_output', 'fc2_weight', 'fc2_bias', 'fc2_output']

提示:
fc2.get_internals()获取fc2的所依赖的所有symbol变量
list_outputs()是列出当前symbol的所有变量名称,如上out所示
现如今要获取fc1的类型,即找到其变量名称[‘fc1_output’],再对这个变量进行type()
我们定义fc1的名称为fc1之后在查看这个节点时,程序自动加了后缀_output,变为fc1_output

通过此方式,可以获取程序中的任意一个节点,从而可以获取节点的相关信息

4.7图的拼接

可以在图的尾部补上额外的symbol节点,但是在原始图的头部替换输入节点比较傲困难
案例
图1:数据结构如下,并保存为json格式

data
fc1
fc2
softmax
data=mx.sym.Variable("data",shape=(1,2))
weight=mx.sym.Variable("weight",shape=(1,2))
bias=mx.sym.Variable("bias",shape=(1,2))

fc1=mx.sym.FullyConnected(data=data,weight=weight,bias=bias,
                          num_hidden=12,name='fc1')
fc2=mx.sym.FullyConnected(data=fc1,weight=weight,bias=bias,
                          num_hidden=12,name='fc2')
softmax=mx.sym.SoftmaxOutput(fc2,name='softmax')
softmax.save('model.symbol.json')

图2:数据结构如下

data
fc1
fc2_new
softmax_new

fc1来源与图1中的fc1节点,而fc2_new和softmax_new是重新定义的节点。算法如下:

symbol=mx.sym.load('model.symbol.json')
print(symbol.get_internals().list_outputs())
fc1=symbol.get_internals()['fc1_output']
fc2_new=mx.sym.FullyConnected(data=fc1,weight=weight,bias=bias
                              ,num_hidden=99,name='fc2_new')
softmax_new=mx.sym.SoftmaxOutput(fc2_new,name='softmax_new')
print(softmax_new.get_internals().list_outputs)

out:

['data', 'weight', 'bias', 'fc1_output', 'fc2_output', 'softmax_label', 'softmax_output']
>

5.Executor

用于执行图的接口
负责图的前向计算和方向梯度推导

5.1 executor的说明

  1. 基本Executor:mxnet.executor.Executor,可类比于tf.Sesseion,当symbol 绑定 了Executor后,当前executor对应的图就不能再做更改了,与其他静态图框架相同
  2. 用于数据并行的Executor:mxnet.executor_group.DataParallelExecutorGroup官方解释:A group of executors that lives on a group of devices.This is a helper class used to implement data parallelization .Each mini-batch will be split and run on the devices.简单理解就是包装了在不同设备上的Executor, 使其可以协作完成不同设备上的并行训练。

如4.5中所示的executor

5.2 executor执行栗子

hidden_num=4
data=mx.sym.Variable("data")
weight=mx.sym.Variable("weight")
bias=mx.sym.Variable("bias")

fc=mx.sym.FullyConnected(data=data,weight=weight,bias=bias,
                          num_hidden=hidden_num,name='fc')
softmax=mx.sym.SoftmaxOutput(fc,name='softmax')

executor=sotfmax.blind(ctx=mx.cpu(),
                      args={'data':mx.nd.ones([1,2]),
                            'softmax_label':mx.nd.ones((1,)),
                            'weight':nd.random.normal(shape=(hidden_num,2)),
                            'bias':nd.random.normal(shape=(hidden_num,))},
                      args_grad={'weight':mx.nd.zeros((hidden_num,2)),
                                'bias':mx.nd.zeros((hidden,))})

executor.forwards(is_train=True)
print('output:',executor.output_dict['softmax_output'].asnumpy())
executor.backward(
print('weight_grad:\n',executor.grad_dict['weight'].asnumpy()))
print('bias_grad:\n',executor.grad_dict['bias'].asnumpy())

注意:
1.executor.forwards(is_train=True)参数是会分配梯度空间以及保存推导过程,(存储中间值)
2. 整个途中,context都须保持一致

6.Metric

用于衡量模型效果的接口

6.1计算分类任务的正确率的例子

class Accuracy(mx.metric.EvalMetric):
    def __init__(self,axis=1,name='accuracy',output_names=None,label_names=None):
        super(Accuracy,self).__init__(
        name,axis=axis,output_names=output_names,label_names=label_names)
        self.axis=axis
        
    def updata(self,labels, preds):
        for label,pred_label in zip(labels,preds):
            if pred_label.shape!=label.shape:
                pred_label=mx.nd.argmax(pred_label,axis=self.axis)
            pred_label=pred_label.asnumpy().astype('int32')
            label=label.asnumpy().astype('int32')
            
            self.sum_metric+=(pred_label.flat==label.flat).sum()
            self.num_inst+=len(pred_label.flat)

假设最后一个节点为softmax节点,计算它的精确度(accuracy)的话。label是实际标签,pred是预测标签。但是有可能会存在多分类的情况。也就是labels和preds.

class mx.metric.EvalMetric(object):
    def get(self):
        if self.num_inst==0:
            return(self.name,float('nan'))
        else:
            return(self.name,self.sum_metric/self.num_inst)

6.2 Metric Hack分析

  1. 需要继承mx.metric.EvalMetric接口,重写update方法
  2. update传入的参数分析
    a. labels:list 类型,每个元素对应DataBatch中的label
    b. predicts:list类型,是Loss Symbol中的label外的输入,因此list的个数与网络上loss的个数有关
  3. update函数需要完成:
    a. 更新属性sum_metric和num_ints的值,mxnet会调用num_inst/sum_metric 来计算当前metric的输出值
    b. 与一个特殊的Callback类有关:Speedometer,Speedometer会打印所有的metric的值

7. Callback

用于模型训练过程中的回调接口
### 7.1 统计模型训练速度的callback例子
可以统计每秒钟处理的样本数量

class Speedmeter:
    def __init__(self, batch_size, frequent=50):
        self.bach_size=batch_size
        self.frequent=frequent
        self.init=False
        self.tic=0
        self.last_cont=0
        
    def __call__(self,param):
        cont=param.nbatch
        if self.last_cont>cont:
            self.init=False
        self.last_cont=cont
        if self.init:
            if cont% self.frequent==0:
                speed=self.frequent * self.batch_size/(time.time()-self.tic)
                logging.info('Iter[%d] Batch [%d] \tSpeed: %.2f samples/sec',
                            param.epoch,cont,speed)
        else:
            self.init=True\
            self.tic=time.time()

### 7.2 callback hack 分析

  1. 只要是callable对象即可,但通常采用实现了__call__方法的类 函数,实现更复杂的功能
  2. 分为两类:
    a. epoch结束后的回调函数,如 用于保存模型的回调函数
    b. 训练一个Batch后的回调函数,如 用于统计训练速度的回调函数
  3. mxnet给Callback函数传入什么可使用字段?
    [mxnet]核心接口(一)_第2张图片
    [mxnet]核心接口(一)_第3张图片

8. LR Scheduler

用于指定模型训练过程学习率衰减策略的接口

8.1 学习率阶梯式衰减策略例子

class FactorScheduler(mx.lr_scheduler.LRScheduler):
    def __init__ (self, step, factor=1, stop_factor_lr=1e-8):
        super(FactorScheduler, self).__init__()
        self.step = step
        self.factor = factor
        self.stop_factor_lr = stop_factor_lr
        self.count = 0
    def __call__(self, num_update):
        while num_update > self.count + self.step:
            self.count += self.step
            self.base_lr *= self.factor
            if self.base_lr < self.stop_factor_lr:
                self.base_lr = self.stop_factor_lr
                logging.info( "Update[%d]: learning rate arrived at %0.5e",
                             num_update,self.base_lr)
            else:
                logging.info( "Update[%d]: Change learning rate to 0.5e",
                         num_update, self.base_lr)
        return self.base_lr

8.2

[mxnet]核心接口(一)_第4张图片

9. KVStore

用于跨device的数据操作借口,可以理解为参数服务器
三个基本函数

  • init:在KVStore中初始化数据
  • pull:从KVStore中把数据拿出俩,可以向多个device分发数据
  • push:将数据更新到KVStore中去,可以从多个的车ive收集数据

[mxnet]核心接口(一)_第5张图片

10.Optimizer

Optimizer :用于使用梯度更新权值参数的接口
MXNet的Optimizer肩负什么责任:

  1. 调用LR Scheduler获取当前"基础”学习率,然后根据用户设置的不同层的学习率乘子( Ir _mutl )计算不同权值参数对应的最终学习率
  2. 进行正则化:根据初始化的正则化系数,以及用户设置的不同层的正则化乘子( wd_mult )计算最终每个权值参数的正则化系数
  3. 给梯度做rescale (因为MXNet的梯度没有除上batch size )

最后, Optimizer会被KVStore或Updater调用,传入前向计算出的参数值和反向计算出的梯度值等(NDArray对象),由Optimizer根据上述计算出的Ir , weight decayit算权值参数的更新梯度值,完成次参数迭代更新
以一眼蔽之,optimizer可以对梯度进行任意操作后更新梯度。给用户提供了很强的hack能力

11. Module

11.1 简介

  1. 是MXNet中集大成的接口 ,将几乎所有的模块封装成一个可以一 步完成训
    练和测试的接口,方便用户训练与测试模型
  2. 所有的Module都继承 了BaseModule,MXNet官方提供的两个常用实现
    mx.mod.Module ,mx.mod.BucketingModule
  3. 可以自定义Module ,如官方example RCNN中为了适应不同尺寸大小的输入,定义了MutableModule

11.2 工作

以mx.mod.Module为例,简要分析都完成了什么

  1. 将Symbol绑定Executor ,使当前图可以被执行。如果是多卡并行则会绑定DataParallelExecutorGroup自动切分数据成几份,分 别送入不同的device进行训练
  2. 初始化图中的权值参数 ,首先restore用户提供的模型参数,没有提供的参数则采取随机初始化
  3. 初始化Optimizer (创建Optimizer时会初始化LR Scheduler )
  4. 创建EvalMetric
  5. 开始读取数据进行训练 ,训练每个Batch后,执行Batch End Callback函数,一个Epoch结束后,执行Epoch End Callback函数
  6. 由于BucketingModule用于 变长的时序训练数据,因此BucketingModule会根据DataBatch中的提供的bucket key,去决定是否生成新的Module,维护bucket key与Module的一对应关系 ,并共享不同Module之间的参数,实现RNN模型的训练

你可能感兴趣的:(深度学习,#,mxnet)