mxnet——模型加载与保存

一、加载模型与pretrain模型network相同

# loading predict module
data_shape_G = 96
Batch = namedtuple('Batch',['data'])
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix=r"~/meh_cla",epoch=2)

train,val = get_iterators(batch_size=batch_size, data_shape=(3, 96, 96))
train = Multi_mnist_iterator(train)
val = Multi_mnist_iterator(val)

model = mx.mod.Module(      # load pre train model
    symbol=sym,
    context=device,
    data_names=['data'],
    label_names=['softmax1_label','softmax2_label','softmax3_label']  # network structure
)
model.bind(data_shapes=train.provide_data, label_shapes=train.provide_label)
model.set_params(arg_params, aux_params, allow_missing=True)
model.fit(train, val,
          optimizer_params={'learning_rate': lr, 'momentum': 0.9},
          num_epoch=num_epochs,
          eval_metric=MAE_zz(name="mae"),
          batch_end_callback=mx.callback.Speedometer(batch_size, 2),
          epoch_end_callback=checkpoint
          )

二、加载模型与pretrain模型network不同
mxnet——模型加载与保存_第1张图片
三、模型的保存

# 使用 checkpoint callback 在每个 epoch 之后保存一次参数。
# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)
mod = mx.mod.Module(symbol=net)
mod.fit(train_iter, num_epoch=5, epoch_end_callback=checkpoint)

# 先看下fit部分的代码
# sync aux params across devices
arg_params, aux_params = self.get_params()
self.set_params(arg_params, aux_params)
if epoch_end_callback is not None:
    for callback in _as_list(epoch_end_callback):
        callback(epoch, self.symbol, arg_params, aux_params)

参考博文

https://blog.csdn.net/u012436149/article/details/78174260?utm_source=blogxgwz7

你可能感兴趣的:(mxnet)