MXNet支持使用命令式和符号式两种混合编程。
一般大家写的基本都是命令式编程,代码很直观,易于调试,但相比较符号式编程在效率上要差点,符号式编程看起来稍微复杂点,类似于建房子的图纸,先把整个架构给画出来,然后将整个流程bind绑定到执行器,最后给定输入,执行器运行即可。命令式编程属于交互式的,而符号式编程是非交互式的。
import mxnet as mx
from mxnet import nd
A=nd.array([1,2,3])
B=nd.array([4,5,6])
C=nd.array([2,3,4])
Q=(A+B)*C
print(Q)
'''
[10. 21. 36.]
'''
声明变量为Symbol类型:
SA=mx.sym.Variable('A')
SB=mx.sym.Variable('B')
SC=mx.sym.Variable('C')
SQ=(SA+SB)*SC
EX=SQ.bind(ctx=mx.cpu(), args={'A':nd.array([1,2,3]), 'B':nd.array([4,5,6]),'C':nd.array([2,3,4])})
print(EX.forward())
#Variable可以简写为var,如SA=mx.sym.var('A')
'''
[
[10. 21. 36.]
]
'''
符号式编程用到的是MXNet框架中的Symbol接口,定义好计算图之后调用bind方法运行并返回,其中这个方法可以指定到CPU或GPU上进行计算。由于在命令式编程中,因为不确定变量后面是否还需要用到,所以会一直占用着内存,而符号式编程,首先就已经定义好了计算图,所以能够知道有些变量只是临时的,这样有些内存空间可以共享,这点也是它们最主要的区别。
import mxnet as mx
from mxnet import nd
data=mx.sym.Variable('data')#定义输入数据
conv=mx.sym.Convolution(data=data, num_filter=100, kernel=(3,3), pad=(1,1), name='conv1')#卷积层
bn=mx.sym.BatchNorm(data=conv, name='bn1')#批标准化层
relu=mx.sym.Activation(data=bn, act_type='relu', name='relu1')#激活函数relu的激活层
pool=mx.sym.Pooling(data=relu, kernel=(2,2), stride=(2,2), pool_type='max', name='pool1')#最大池化层
fc=mx.sym.FullyConnected(data=pool, num_hidden=3, name='fc1')#全连接层,分类为3
sym=mx.sym.SoftmaxOutput(data=fc, name='softmax')#最后得到损失函数输出层
#Symbol对象参数列表
print(sym.list_arguments())
#['data1', 'conv1_weight', 'conv1_bias', 'bn1_gamma', 'bn1_beta', 'fc1_weight', 'fc1_bias', 'softmax_label']
arg_shape,out_shape,aux_shape=sym.infer_shape(data=(1,3,10,10))
print('参数形状:',arg_shape)
print('输出形状:',out_shape)
print('辅助层的形状:',aux_shape)#如BN层
'''
参数形状: [(1, 3, 10, 10), (100, 3, 3, 3), (100,), (100,), (100,), (3, 2500), (3,), (1,)]
输出形状: [(1, 3)]
辅助层的形状: [(100,), (100,)]
'''
#截取输入到池化层的信息
sym1=sym.get_internals()['pool1_output']
print(sym1.list_arguments())#['data1', 'conv1_weight', 'conv1_bias', 'bn1_gamma', 'bn1_beta']
#在这个后面新增3个全连接层
fc1=mx.sym.FullyConnected (data=sym_mini, num_hidden=3, name='fc1')
sym2=mx.sym.SoftmaxOutput (data=fc1, name='softmax')
print(sym2.list_arguments())
#['data1', 'conv1_weight', 'conv1_bias', 'bn1_gamma', 'bn1_beta', 'fc1_weight', 'fc1_bias', 'softmax_label']
#训练模型
data=mx.nd.random.uniform(0,1,shape=(100,3,224,224))
label=mx.nd.round(mx.nd.random.uniform(0,1,shape=(100)))
train_data = mx.io.NDArrayIter(data={'data':data},label={'softmax_label':label},batch_size=8,shuffle=True)
print(train_data.provide_data)#[DataDesc[data,(8, 3, 224, 224),,NCHW]]
print(train_data.provide_label)#[DataDesc[softmax_label,(8,),,NCHW]]
import logging
mod=mx.mod.Module(symbol=sym,context=mx.cpu())
logger=logging.getLogger()
logger.setLevel(logging.INFO)
mod.fit(train_data=train_data, num_epoch=5)
'''
INFO:root:Epoch[0] Train-accuracy=0.471154
INFO:root:Epoch[0] Time cost=10.652
INFO:root:Epoch[1] Train-accuracy=0.701923
INFO:root:Epoch[1] Time cost=25.408
INFO:root:Epoch[2] Train-accuracy=0.884615
INFO:root:Epoch[2] Time cost=10.553
INFO:root:Epoch[3] Train-accuracy=0.903846
INFO:root:Epoch[3] Time cost=9.904
INFO:root:Epoch[4] Train-accuracy=1.000000
INFO:root:Epoch[4] Time cost=10.006
'''
将常用的训练操作都封装在了fit方法里面,使用起来很方便,可以看到默认优化器是随机梯度下降法(sgd),学习率0.01
fit(train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), eval_end_callback=None, eval_batch_end_callback=None, initializer=
, arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None)
从代码训练结果也可以看出fit()方法封装了bind、初始化参数、选择优化器、前向计算、反向传播、参数更新和计算评价指标等操作,以及显示训练结果操作,对于如何实现,有兴趣的可以继续往下看:
mod=mx.mod.Module(symbol=sym,context=mx.cpu())
mod.bind(data_shapes=train_data.provide_data,label_shapes=train_data.provide_label)
mod.init_params()
mod.init_optimizer()
eval_metric=mx.metric.create('acc')#评价方式:精度,也可以自定义评价函数
for epoch in range(5):
end_of_batch=False
eval_metric.reset()
data_iter=iter(train_data)
next_data_batch=next(data_iter)
while not end_of_batch:
data_batch=next_data_batch
mod.forward(data_batch)
mod.backward()
mod.update()
mod.update_metric(eval_metric, labels=data_batch.label)
try:
next_data_batch=next(data_iter)
mod.prepare(next_data_batch)
except StopIteration:
end_of_batch=True
eval_name_vals=eval_metric.get_name_value()
print("Epoch:{} Train_Acc:{:.4f}".format(epoch, eval_name_vals[0][1]))
arg_params, aux_params=mod.get_params()
mod.set_params(arg_params, aux_params)
train_data.reset()
'''
Epoch:0 Train_Acc:0.4327
Epoch:1 Train_Acc:0.7115
Epoch:2 Train_Acc:0.8654
Epoch:3 Train_Acc:1.0000
Epoch:4 Train_Acc:1.0000
'''