MXNet中读写NDArray和Gluon模型参数

        我们在MNIST数据集手写数字识别(二)训练MNIST这个数据集的时候,它们的权重参数保存在一个pkl的文件里面,我们使用的是pickle

import pickle

with open('sample_weight.pkl','rb') as f:
        print(pickle.load(f))

with open('f.pkl','wb') as f:
            pickle.dump(params,f)

打开pkl文件通过pickle.load方法来加载读取,通过pickle.dump方法来写入到f.pkl文件,更多关于这个pkl文件的存储与读取可以参看Python基础知识汇总

那么在MXNet框架中是如何保存和读取参数的,先看一段代码:

from mxnet import nd
from mxnet.gluon import nn

a=nd.zeros(shape=(3,4))
b=nd.ones(shape=(2,3))

nd.save('f_v',a)
nd.save('f_list',[a,b])
nd.save('f_dict',{'x':a,'y':b})

print('普通变量:%s \n' % nd.load('f_v'))
print('字典类型:%s \n'  % nd.load('f_dict'))
print('列表类型:%s \n'  % nd.load('f_list'))

可以看出是通过saveload分别存储和读取NDArray,将会在保存到硬盘为f_v、f_list、f_dict的文件,现在来读写Gluon模型的参数

class MLP(nn.Block):
    def __init__(self,**kwargs):
        super(MLP,self).__init__(**kwargs)
        self.hidden=nn.Dense(256,activation='relu')
        self.output=nn.Dense(10)
    
    def forward(self,x):
        return self.output(self.hidden(x))

net1=MLP()
net1.initialize()
X=nd.random.uniform(shape=(3,4))
Y=net1(X)

filename='mlp.params'
net1.save_params(filename)

net2=MLP()
net2.load_params(filename)

print(Y==net2(X))

'''
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]

'''

可以看出,使用了save_paramsload_params来保存与读取模型参数,由于net2使用的参数是来自net1保存的参数,所以这两个实例都有着同样的模型参数,那输出的结果肯定是一样的。

你可能感兴趣的:(深度学习框架(MXNet),mxnet,Gluon模型参数)