定义一个resnet-18网络
import gluonbook as gb
from mxnet.gluon import Trainer,data as gdata, nn
from mxnet import init, nd
import os
import sys
class Residual(nn.Block): # 本类已保存在 gluonbook 包中方便以后使用。
def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
super(Residual, self).__init__(**kwargs)
self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1,
strides=strides)
self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2D(num_channels, kernel_size=1,
strides=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm()
self.bn2 = nn.BatchNorm()
def forward(self, X):
Y = nd.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
return nd.relu(Y + X)
def resnet_block(num_channels, num_residuals, first_block =False):
blk = nn.Sequential()
for i in range(num_residuals):
if i==0 and not first_block:
blk.add(Residual(num_channels,use_1x1conv=True,strides=2))
else:
blk.add(Residual(num_channels))
return blk
net = nn.Sequential()
net.add(nn.Conv2D(64,kernel_size=11,padding=3,strides=2),
nn.BatchNorm(),
nn.Activation('relu'),
nn.MaxPool2D(pool_size=3,strides=2,padding=1),
resnet_block(64,4,first_block=True),
resnet_block(128,4),
resnet_block(256,4),
resnet_block(512,4),
nn.GlobalAvgPool2D(),
nn.Dense(1024),
nn.BatchNorm(),
nn.Activation('relu'),
nn.Dropout(0.4),
nn.Dense(10))
随机初始化网络并且保存数据,使用block类自带的save_parameters() 成员函数:
net.initialize(force_reinit=True,ctx=ctx,init=init.Xavier())
new_filename = 'tmp.params'
net.save_parameters(new_filename)
然后就是读进来数据,分析保存的结构
#load params for analyzation
params = nd.load('tmp.params')
#params is a dict
print(isinstance(params,dict))
#print dict members'names
for key in params:
print(key)
nd.load的结果是一个字典,字典的keys的打印结果如下:
从字典的key可以看出,这个key的组成
_reg_params中以字典的形式保存layer对应的参数,如conv2d的_reg_params为:
conv2d._reg_params={'weight':NDArray,'bias':NDArray}
下面看block.save_parameters()这个函数如何把block对象的参数保存成上面的样子
#block的成员函数,用递归的方式收集block对象所有的参数
#block对象可能是多层定义的,因此这里使用了基于DFS的搜索方法
def _collect_params_with_prefix(self, prefix=''):
if prefix:
prefix += '.'
#添加该block的参数
ret = {prefix + key : val for key, val in self._reg_params.items()}
#添加该block下的子block的参数
for name, child in self._children.items():
#递归,输入传前缀,前缀是当前block的输入前缀+该子block的key
#如果是sequence,key是一个数字,代表该block在sequence中的位置0,1,2,3……
#如果是自定义的block,则是自定义的名称
#字典的update操作等于拼接两个字典
ret.update(child._collect_params_with_prefix(prefix + name))
return ret
def save_parameters(self, filename):
#得到列出所有参数的字典
params = self._collect_params_with_prefix()
#这一步应该是转化成cpu下的NDArray
arg_dict = {key : val._reduce() for key, val in params.items()}
#ndarray类的保存字典函数
ndarray.save(filename, arg_dict)