获取smbol的内部节点(获取部分网络结构)

首先定义一个网络

data = mx.sym.Variable('data')
fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=1000)
act = mx.sym.Activation(data=fc1, act_type='relu',name='act')
fc2 = mx.sym.FullyConnected(data=act, name='fc2', num_hidden=10)
net = mx.sym.SoftmaxOutput(fc2,name="softmax")
net.save('model.symbol.json')

将该网络通过save序列化保存为json文件。 # mx.sym.Symbol 类型自带save函数

载入json,获取内部节点

net = mx.sym.load('model.symbol.json')
net.get_internals  #  获取所有内部节点,将其group起来
<Symbol group [data, fc1_weight, fc1_bias, fc1, act, fc2_weight, fc2_bias, fc2, softmax_label, softmax]>
net.get_internals().list_outputs()  #  获取所有内部节点的输出
['data',
 'fc1_weight',
 'fc1_bias',
 'fc1_output',
 'act_output',
 'fc2_weight',
 'fc2_bias',
 'fc2_output',
 'softmax_label',
 'softmax_output']
new_net = net.get_internals()['act_output'] #  通过key value方式,获取网络的内部节点数据,后续可以接着对new_net进行操作,构建新的网络

对于var,其list_outputs()list_inputs()都是他自己,对于其他op,比如+,-,relu等,其输出要在该op的name后面加上_output后缀

你可能感兴趣的:(mxnet-symbol)