【MXNet】(十七):模型参数的读取与存储

MXNet中可以使用save函数和load函数分别存储和读取NDArray。

下面创建一个NDArray变量并将其存储在文件中。

from mxnet import nd
from mxnet.gluon import nn

x = nd.ones(3)
nd.save('x', x)

然后将其从文件中读回内存。

x2 = nd.load('x')
x2

还可以存储一列NDArray并读回内存。

y = nd.zeros(4)
nd.save('xy', [x, y])
x2, y2 = nd.load('xy')
(x2, y2)

【MXNet】(十七):模型参数的读取与存储_第1张图片

还可以存储并读取一个从字符串映射到NDArray的字典。

mydict = {'x': x, 'y': y}
nd.save('mydict', mydict)
mydict2 = nd.load('mydict')
mydict2

【MXNet】(十七):模型参数的读取与存储_第2张图片

此外,还可以读写Gluon模型的参数。

先创建一个多层感知机,并将其初始化。

class MLP(nn.Block):
    def __init__(self, **kwargs):
        super(MLP, self).__init__(**kwargs)
        self.hidden = nn.Dense(256, activation='relu')
        self.output = nn.Dense(10)
        
    def forward(self, x):
        return self.output(self.hidden(x))
    
net = MLP()
net.initialize()
X = nd.random.uniform(shape=(2, 20))
Y = net(X)

将模型参数存入文件。

filename = 'mlp.params'
net.save_parameters(filename)

再次实例化一个多层感知机,将其参数初始化为文件里保存的参数。

net2 = MLP()
net2.load_parameters(filename)

判断一下两个实例的计算结果是否一致。

Y2 = net2(X)
Y2 == Y

你可能感兴趣的:(深度学习,MXNet)