import torch
from torch import nn
class MLP(nn.Module):
# 声明带有模型参数的层,这里声明了两个全连接层
def __init__(self, **kwargs):
# 调用MLP父类Block的构造函数来进行必要的初始化,这样再构造实例时还可以指定其他函数
super(MLP, self).__init__(**kwargs)
self.hidden = nn.Linear(784, 256) # 输入数据进行线性变换得到隐藏层
self.act = nn.ReLU() # 然后经过relu变化
self.output = nn.Linear(256, 10) # 最后再经过线性变化输出结果
# 定义模型的向前计算,即如何根据x计算返回所需要的模型输出
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
以上的MLP类中无须定义反向传播函数。系统将通过自动求梯度而自动生成反向传播所需的backward函数。
我们可以实例化MLP类得到模型变量net。下面的代码初始化net并传入输入数据X做一次前向计算。其中,net(X)会调用MLP继承自Module类的__call__函数,这个函数将调用MLP类定义的forward函数来完成前向计算。
x = torch.rand(2,784)
net = MLP()
print(net)
net(x)
MLP(
(hidden): Linear(in_features=784, out_features=256, bias=True)
(act): ReLU()
(output): Linear(in_features=256, out_features=10, bias=True)
)
tensor([[ 0.0840, -0.0177, 0.0372, 0.2128, -0.1802, 0.2334, 0.0339, -0.1636,
0.0673, 0.0214],
[ 0.1027, 0.1027, -0.0777, 0.1907, -0.1190, 0.1237, -0.0201, -0.1508,
0.1732, 0.0766]], grad_fn=)
注意,这里并没有将Module类命名为Layer(层)或者Model(模型)之类的名字,这是因为该类是一个可供自由组建的部件。它的子类既可以是一个层(如PyTorch提供的Linear类),又可以是一个模型(如这里定义的MLP类),或者是模型的一个部分。我们下面通过两个例子来展示它的灵活性
下面实现一个与Sequential类有相同功能的MySequential类,让我们更加清晰地理解Sequential类的工作机制
class MySequential(nn.Module):
from collections import OrderedDict
def __init__(self,*args):
super(MySequential,self).__init__()
if len(args)==1 and isinstance(args[0],OrderedDict()): #判断传入的args长度为1并且是OrderedDict类型
for key,module in args[0].item():
self.add_module(key,module) #add_module方法会将module添加进self._module(一个OrderedDict)
else: # 传入的是一些module
for idx,module in enumerate(args):
self.add_module(str(idx),module)
def forward (self,input):
#self._modules返回一个OrderedDict,保证会按照成员添加时的顺序遍历
for module in self._modules.values():
input = module(input)
return input
我们用MySequential类来实现前面描述的MLP类,并使用随机初始化的模型做一次前向计算。
net = MySequential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10))
print(net)
net(x)
MySequential(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=10, bias=True)
)
tensor([[ 9.6760e-02, 6.1526e-01, -2.1479e-01, -2.4838e-01, -1.8044e-01,
1.0633e-01, -4.3294e-02, 1.5417e-01, 2.6871e-01, -2.3465e-01,
-1.7879e-01, 5.9953e-02, 1.7104e-01, 5.0568e-01, -2.8405e-01,
-3.0190e-01, -3.2859e-01, -3.7216e-01, 3.6683e-01, 1.3583e-01,
-1.4334e-01, -5.9070e-01, -2.3902e-01, 8.6598e-02, -3.2121e-01,
-1.4389e-01, 4.8549e-01, 1.8579e-01, -5.2170e-02, -1.5570e-02,
1.1181e-01, -1.9714e-01, 2.5036e-01, -5.8424e-02, -1.3689e-02,
3.6068e-01, -4.7498e-01, 4.4805e-01, 2.6055e-01, -1.0720e-01,
2.8770e-01, -1.0109e-01, 4.5815e-01, -4.0124e-01, 2.5889e-01,
-3.3336e-02, 6.0167e-01, 1.8990e-01, -5.8802e-02, 6.8722e-01,
-3.3514e-01, -9.3957e-01, -2.8368e-01, 3.1755e-01, 7.4712e-02,
5.6128e-01, -2.1928e-02, 6.3363e-03, -7.1676e-01, -5.8925e-01,
-7.0805e-02, -3.2069e-01, -3.3287e-01, -1.5045e-01, -6.0647e-01,
-9.4337e-02, 4.6517e-01, 2.5196e-01, 2.9131e-01, 6.4817e-01,
-1.7585e-01, 4.3723e-01, 1.0636e-01, 7.5438e-01, 5.5950e-01,
-2.4501e-01, -5.7343e-01, -3.1766e-01, 6.8117e-02, -6.5833e-01,
-6.0973e-01, 2.7250e-01, -1.4511e-01, 2.7163e-01, -5.4788e-01,
-2.9311e-02, 4.2360e-01, -4.0402e-01, -3.1681e-01, -3.5965e-01,
3.2306e-01, -7.2167e-02, 6.1791e-01, 8.5589e-02, -6.8449e-02,
9.2522e-02, 1.7102e-01, 3.8423e-01, -1.0283e-01, 4.5553e-01,
4.5290e-01, -1.6429e-01, 2.6184e-02, -3.3518e-01, -1.4742e-01,
-1.0841e-01, 2.6545e-01, 1.9037e-01, -1.7716e-01, 1.4499e-01,
2.0683e-01, -1.6765e-02, -4.0886e-01, -1.4158e-03, -2.4529e-01,
4.4866e-01, 1.9679e-01, 1.5016e-01, 5.2420e-01, 9.7283e-01,
1.3504e-01, -1.7518e-01, 5.9695e-01, -2.2792e-01, -3.1842e-01,
-2.8035e-01, 2.2754e-01, -2.5845e-01, -3.7324e-02, -2.8803e-02,
-3.4054e-01, -4.8280e-01, 2.2013e-01, 4.2498e-01, 1.3680e-01,
-4.6749e-01, -1.3055e-01, -5.7328e-01, -4.9055e-01, -2.2279e-01,
4.7637e-02, -5.7239e-01, 7.9942e-02, -3.0113e-01, -4.4272e-01,
2.4327e-01, -2.0071e-01, 2.6980e-01, 1.7690e-01, -1.4942e-01,
1.0565e-01, 5.8500e-02, -4.6605e-03, 7.9855e-01, -4.6251e-02,
2.8216e-01, -3.4840e-01, -3.5109e-01, -1.5388e-01, 4.0279e-01,
-1.0993e-03, 1.0501e-02, -9.4468e-02, 1.5324e-01, 4.7143e-02,
-5.6766e-01, -2.3478e-01, 1.1706e-01, -3.9058e-01, 6.1886e-03,
8.6564e-02, 3.1441e-02, -1.4241e-01, 2.0486e-01, 6.1886e-03,
1.7367e-01, -8.8623e-02, -2.5678e-01, 3.1567e-01, -8.5611e-01,
-2.4832e-01, 3.3563e-01, 1.5656e-02, -3.5225e-01, 1.7295e-02,
-9.2993e-02, 6.5053e-01, 8.0615e-02, -1.5109e-01, 2.0775e-01,
2.5568e-01, -2.0957e-01, -3.3211e-01, -6.3330e-01, 2.8433e-01,
-1.5971e-01, 7.5023e-03, 4.4780e-01, -3.9100e-01, -4.1025e-01,
-5.7477e-02, -8.6253e-02, 3.8070e-01, -2.8796e-02, -1.8143e-01,
6.0662e-01, -3.5896e-01, 2.8587e-01, -5.6008e-01, 3.0484e-01,
4.5223e-01, 4.1040e-02, 2.4319e-02, -3.3379e-02, 2.9209e-01,
-2.8169e-01, 1.4224e-02, 7.1399e-02, -3.3856e-02, 1.1149e-01,
-2.1588e-01, 1.2100e-02, -7.3755e-02, 1.8370e-01, -2.8106e-01,
3.1099e-01, 3.9101e-01, -1.6609e-01, -1.1564e-01, -3.2101e-01,
-2.5801e-01, -3.4583e-01, -2.2928e-01, 5.4415e-01, 1.7075e-01,
4.7242e-01, -1.3852e-01, -2.6206e-01, -2.8216e-01, 4.5729e-02,
5.3571e-01, 2.0759e-01, -4.7854e-02, 4.7861e-01, 9.9984e-02,
1.3719e-01, 6.1629e-01, 5.1949e-01, -1.6346e-01, 5.5077e-01,
-2.0594e-01, -6.3007e-02, -4.7834e-01, -1.7159e-01, -3.5195e-02,
-3.3245e-01],
[ 2.4010e-01, 3.0471e-01, -5.2543e-01, -2.6730e-01, -3.3221e-02,
1.9604e-01, 1.3777e-01, 3.4451e-01, 3.4591e-01, -2.6930e-01,
1.3081e-02, 5.6063e-01, 2.5105e-01, 2.2183e-01, -5.9729e-02,
-3.0842e-01, -3.8761e-01, -3.7497e-01, 3.1502e-01, 2.3410e-02,
-3.5446e-02, -1.4785e-01, -6.3958e-01, 3.0649e-01, -1.6719e-01,
-3.5391e-01, 4.7776e-01, 6.4315e-01, 1.3822e-01, -3.4816e-01,
3.9006e-02, -4.2801e-01, 2.4051e-02, -9.2707e-02, 6.2816e-02,
4.1986e-01, -4.2527e-01, 5.2899e-01, 5.9827e-01, -1.5147e-01,
4.1955e-01, -3.5571e-02, 1.9278e-01, -8.0548e-02, 2.4398e-01,
-2.8606e-01, 7.6591e-01, 4.3425e-02, 2.5114e-01, 3.8419e-01,
-3.3522e-01, -1.2938e+00, -3.2376e-01, 3.5667e-01, -5.2941e-02,
6.0979e-01, 1.5105e-02, 3.0921e-02, -2.8994e-01, -4.2888e-01,
8.4356e-02, -5.8704e-01, -2.1208e-01, -4.2752e-02, -4.8370e-01,
-1.1967e-01, 6.2766e-01, 3.0056e-02, 2.4747e-01, 3.9887e-01,
-1.5808e-01, 2.9959e-01, 6.6251e-01, 7.0864e-01, 5.6218e-01,
-2.1177e-02, -1.9826e-01, -1.9828e-01, 2.1884e-01, -5.1883e-01,
-3.9410e-02, 4.5899e-01, 2.0575e-01, -1.2067e-01, -3.3744e-01,
-2.9476e-01, 5.1977e-02, -1.7410e-01, -9.9014e-03, -4.7142e-01,
2.6167e-01, -3.6820e-01, 1.1605e-01, -1.8416e-01, -2.2039e-02,
2.8389e-01, 7.8144e-02, 7.1102e-02, -1.0808e-01, 2.3506e-01,
3.4647e-01, -2.2646e-01, 1.2892e-02, 5.4415e-02, 1.9749e-01,
-2.6919e-01, 2.0519e-01, 5.1301e-01, -2.5568e-01, 6.9206e-02,
1.1940e-01, -9.6597e-02, -4.3718e-01, -3.2612e-01, -2.9065e-01,
1.1352e-01, 1.0605e-01, 4.2826e-01, 4.5905e-01, 6.6613e-01,
-8.4032e-02, -2.5239e-01, 5.4654e-01, -3.7485e-01, -9.8932e-02,
-1.6790e-01, 2.7925e-01, -1.2701e-01, 1.0140e-01, 2.7005e-01,
-2.7344e-01, -6.0573e-01, 2.3084e-01, 3.3079e-01, 1.8777e-01,
-5.1973e-01, -2.6797e-01, -1.0291e+00, -3.6658e-01, -2.9196e-01,
3.2473e-02, -6.0835e-01, -8.1029e-02, -3.4202e-01, 1.9162e-02,
1.7709e-01, -2.7436e-01, 5.4803e-01, 2.4869e-01, -1.4216e-01,
3.1039e-01, 3.1186e-01, -4.0879e-01, 5.9306e-01, -1.8724e-01,
2.2229e-01, -8.1658e-02, -3.7945e-01, 6.7537e-02, 1.2900e-01,
1.1815e-01, -1.8873e-01, 2.1726e-02, 5.6741e-02, 4.9844e-02,
-7.6093e-01, 1.6821e-01, 1.5666e-01, 1.6226e-02, -2.1582e-01,
1.0466e-01, -7.5733e-02, 1.0996e-01, 3.5633e-01, -6.2467e-03,
3.3784e-01, -2.3040e-01, -1.2743e-01, 7.6373e-01, -7.0520e-01,
-3.3275e-01, 5.3384e-02, -9.7886e-02, -1.2004e-01, 6.4745e-03,
1.7047e-01, 7.6440e-01, 3.0851e-02, -1.8241e-01, 2.4145e-01,
3.4846e-01, -4.5113e-02, 4.5672e-02, -5.1739e-01, 8.7621e-02,
-3.0733e-01, 2.9596e-01, 3.2907e-01, -3.4796e-01, 5.4672e-02,
3.9778e-01, 2.7795e-02, 2.0549e-01, 9.3713e-02, -1.2914e-03,
7.0675e-01, -6.9141e-01, -2.3190e-01, -6.7074e-01, 2.2619e-01,
4.5718e-01, 2.1936e-01, 1.9920e-02, -7.0701e-02, 4.2070e-01,
-1.4190e-01, 2.0451e-01, -1.3693e-01, 3.6872e-01, -3.2106e-01,
-1.3463e-01, 1.1542e-01, -1.0738e-01, 2.8464e-01, -6.0397e-02,
4.8289e-01, 3.0147e-01, -1.5671e-01, -4.1534e-01, -5.3990e-01,
-1.8717e-01, -1.6201e-01, -1.2952e-01, 4.2304e-01, -2.3148e-01,
1.2543e-01, -4.5570e-01, 2.8618e-02, -2.8430e-01, -2.4973e-01,
3.7866e-01, 3.4822e-01, -1.2130e-01, 6.9779e-01, 4.8406e-03,
1.5225e-02, 5.1282e-01, 2.1857e-01, -5.3451e-02, 6.3913e-01,
-2.0348e-01, 8.1404e-03, -1.1061e-01, 1.7438e-01, -2.4503e-01,
2.5902e-01]], grad_fn=)
net = nn.ModuleList([nn.Linear(784,256),nn.ReLU()])
net.append(nn.Linear(256,10)) # 类似list的append操作
print(net[-1]) # 类似list的索引访问
print(net)
Linear(in_features=256, out_features=10, bias=True)
ModuleList(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=10, bias=True)
)
net = nn.ModuleDict({
'linear': nn.Linear(784, 256),
'act': nn.ReLU(),
})
net['output'] = nn.Linear(156,10 ) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
Linear(in_features=784, out_features=256, bias=True)
Linear(in_features=156, out_features=10, bias=True)
ModuleDict(
(act): ReLU()
(linear): Linear(in_features=784, out_features=256, bias=True)
(output): Linear(in_features=156, out_features=10, bias=True)
)
class FancyMLP(nn.Module):
def __init__(self,**kwargs):
super(FancyMLP,self).__init__(**kwargs)
self.rand_weight = torch.rand((20,20),requires_grad = False) # 不可训练参数(常数参数)
self.linear = nn.Linear(20,20)
def forward(self,x):
x = self.linear(x)
# 使用创建的常数参数,以及nn.functional中的relu函数和mm函数
x = nn.functional.relu(torch.mm(x,self.rand_weight.data)+1 )
# 复用全连接层,等价于两个全连接层共享参数
x = self.linear(x)
# 控制流,这里我们需要调用item函数返回标量进行比较
while x.norm().item()>1 :
x /=2
if x.norm().item()<0.8:
x *=10
return x.sum()
在这个FancyMLP模型中,我们使用了常数权重rand_weight(它不是可训练的模型参数),做了矩阵乘法操作(torch.mm),并重复使用了相同的Linear层,下面来测试下改模型的向前计算
x = torch.rand(2,20)
net = FancyMLP()
print(net)
net(x)
FancyMLP(
(linear): Linear(in_features=20, out_features=20, bias=True)
)
tensor(-9.4793, grad_fn=)
因为FancyMLP和Sequential类都是Module类的子类,所以我们可以嵌套调用它们。
class NestMLP(nn.Module):
def __init__(self, **kwargs):
super(NestMLP, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(40, 30), nn.ReLU())
def forward(self, x):
return self.net(x)
net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())
X = torch.rand(2, 40)
print(net)
net(X)
Sequential(
(0): NestMLP(
(net): Sequential(
(0): Linear(in_features=40, out_features=30, bias=True)
(1): ReLU()
)
)
(1): Linear(in_features=30, out_features=20, bias=True)
(2): FancyMLP(
(linear): Linear(in_features=20, out_features=20, bias=True)
)
)
tensor(-0.0132, grad_fn=)
import torch
from torch import nn
from torch.nn import init
net = nn.Sequential(nn.Linear(4,3),nn.ReLU(),nn.Linear(3,1)) # pytorch已进行默认初始化
print(net)
x = torch.rand(2,4)
y = net(x).sum()
Sequential(
(0): Linear(in_features=4, out_features=3, bias=True)
(1): ReLU()
(2): Linear(in_features=3, out_features=1, bias=True)
)
print(type(net.named_parameters()))
for name,param in net.named_parameters():
print(name,param.size())
0.weight torch.Size([3, 4])
0.bias torch.Size([3])
2.weight torch.Size([1, 3])
2.bias torch.Size([1])
可见返回的名字自动加上了层数的索引作为前缀。 我们再来访问net中单层的参数。对于使用Sequential类构造的神经网络,我们可以通过方括号[]来访问网络的任一层。索引0表示隐藏层为Sequential实例最先添加的层。
for name, param in net[0].named_parameters():
print(name, param.size(), type(param))
weight torch.Size([3, 4])
bias torch.Size([3])
因为这里是单层的所以没有了层数索引的前缀。另外返回的param的类型为torch.nn.parameter.Parameter,其实这是Tensor的子类,和Tensor不同的是如果一个Tensor是Parameter,那么它会自动被添加到模型的参数列表里,来看下面这个例子。
class MyModule(nn.Module):
def __init__(self,**kwargs):
super(MyModule,self).__init__(**kwargs)
self.weight1 = nn.Parameter(torch.rand(20,20))
self.weight2 = torch.rand(20,20 )
def forward(self,x):
pass
n = MyModule()
for name,param in n.named_parameters():
print(name)
weight1
上面的代码中weight1在参数列表中但是weight2却没在参数列表中。
因为Parameter是Tensor,即Tensor拥有的属性它都有,比如可以根据data来访问参数数值,用grad来访问参数梯度。
weight_0 = list(net[0].parameters())[0]
print(weight_0.data)
print(weight_0.grad) # 反向传播前梯度为None
y.backward()
print(weight_0.grad)
tensor([[ 0.0184, -0.0354, -0.0146, 0.0747],
[ 0.3217, 0.4394, 0.3784, -0.0991],
[ 0.4981, 0.1648, 0.1335, -0.1365]])
None
tensor([[ 0.0000, 0.0000, 0.0000, 0.0000],
[-0.6390, -0.8518, -0.7378, -0.3493],
[-0.0444, -0.0592, -0.0513, -0.0243]])
for name,param in net.named_parameters():
if 'weight' in name:
init.normal_(param,mean=0,std=0.01 )
print(name,param.data)
0.weight tensor([[ 0.0055, 0.0014, -0.0135, -0.0004],
[ 0.0099, -0.0193, -0.0063, -0.0028],
[-0.0098, -0.0035, -0.0005, 0.0108]])
2.weight tensor([[-0.0016, 0.0036, -0.0049]])
下面使用常数初始化权重参数
for name,param in net.named_parameters():
if 'bias' in name:
init.constant_(param,val=0)
print(name,param.data)
0.bias tensor([0., 0., 0.])
2.bias tensor([0.])
有时候我们需要的初始化方法并没有在init模块中提供。这时,可以实现一个初始化方法,从而能够像使用其他初始化方法那样使用它。在这之前我们先来看看PyTorch是怎么实现这些初始化方法的,例如torch.nn.init.normal_:
def normal_(tensor,mean = 0,std=1):
with torch.no_grad():
return tensor.normal_(mean,std)
可以看到这就是一个inplace改变Tensor值的函数,而且这个过程是不记录梯度的。 类似的我们来实现一个自定义的初始化方法。在下面的例子里,我们令权重有一半概率初始化为0,有另一半概率初始化为 [ − 10 , − 5 ] [-10,-5] [−10,−5]和 [ 5 , 10 ] [5,10] [5,10]两个区间里均匀分布的随机数。
def init_weight(tensor):
with torch.no_grad():
tensor.uniform_(-10,10 )
tensor *=(tensor.abs()>=5).float()
for name,param in net.named_parameters():
if 'weight' in name:
init_weight(param)
print(name,param.data)
0.weight tensor([[ 7.7090, 6.8918, -9.0180, -0.0000],
[-8.0763, -0.0000, -7.9557, 0.0000],
[-0.0000, 6.7024, 0.0000, 0.0000]])
2.weight tensor([[-0.0000, -5.7995, 8.2611]])
此外,我们还可以通过改变这些参数的data来改写模型参数值同时不会影响梯度:
for name,param in net.named_parameters():
if 'bias' in name:
param.data+=1
print(name,param.data)
0.bias tensor([1., 1., 1.])
2.bias tensor([1.])
linear = nn.Linear(1, 1, bias=False)
net = nn.Sequential(linear, linear)
print(net)
for name, param in net.named_parameters():
init.constant_(param, val=3)
print(name, param.data)
Sequential(
(0): Linear(in_features=1, out_features=1, bias=False)
(1): Linear(in_features=1, out_features=1, bias=False)
)
0.weight tensor([[3.]])
在内存中,这两个线性层其实一个对象:
print(id(net[0]) == id(net[1]))
print(id(net[0].weight) == id(net[1].weight))
True
True
因为模型参数里包含了梯度,所以在反向传播计算时,这些共享的参数的梯度是累加的:
x = torch.ones(1, 1)
y = net(x).sum()
y.backward()
print(net[0].weight.grad) # 单次梯度是3,所以两次就是6
tensor([[6.]])
深度学习的一个魅力在于神经网络中各式各样的层,例如全连接层和后面章节中将要介绍的卷积层、池化层与循环层。虽然PyTorch提供了大量常用的层,但有时候我们依然希望自定义层。本节将介绍如何使用Module来自定义层,从而可以被重复调用。
我们先介绍如何定义一个不含模型参数的自定义层。事实上,这和4.1节(模型构造)中介绍的使用Module类构造模型类似。下面的CenteredLayer类通过继承Module类自定义了一个将输入减掉均值后输出的层,并将层的计算定义在了forward函数里。这个层里不含模型参数。
import torch
from torch import nn
class CenteredLayer(nn.Module):
def __init__(self, **kwargs):
super(CenteredLayer, self).__init__(**kwargs)
def forward(self, x):
return x - x.mean()
我们对其进行实例化然后向前计算
layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))
tensor([-2., -1., 0., 1., 2.])
我们也可以用它来构造更复杂的模型。
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
下面打印自定义层各个输出的均值。因为均值是浮点数,所以它的值是一个很接近0的数。
y = net(torch.rand(4, 8))
y.mean().item()
-1.862645149230957e-09
我们还可以自定义含模型参数的自定义层。其中的模型参数可以通过训练学出。
在4.2节(模型参数的访问、初始化和共享)中介绍了Parameter类其实是Tensor的子类,如果一个Tensor是Parameter,那么它会自动被添加到模型的参数列表里。所以在自定义含模型参数的层时,我们应该将参数定义成Parameter,除了像4.2.1节那样直接定义成Parameter类外,还可以使用ParameterList和ParameterDict分别定义参数的列表和字典。
ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用append和extend在列表后面新增参数。
class MyDense(nn.Module):
def __init__(self):
super(MyDense, self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
self.params.append(nn.Parameter(torch.randn(4, 1)))
def forward(self, x):
for i in range(len(self.params)):
x = torch.mm(x, self.params[i])
return x
net = MyDense()
print(net)
MyDense(
(params): ParameterList(
(0): Parameter containing: [torch.FloatTensor of size 4x4]
(1): Parameter containing: [torch.FloatTensor of size 4x4]
(2): Parameter containing: [torch.FloatTensor of size 4x4]
(3): Parameter containing: [torch.FloatTensor of size 4x1]
)
)
ParameterDict接收一个Parameter实例的字典作为输入然后得到一个参数字典,然后可以按照字典的规则使用了。例如使用update()新增参数,使用keys()返回所有键值,使用items()返回所有键值对等等
class MyDictDense(nn.Module):
def __init__(self):
super(MyDictDense, self).__init__()
self.params = nn.ParameterDict({
'linear1': nn.Parameter(torch.randn(4, 4)),
'linear2': nn.Parameter(torch.randn(4, 1))
})
self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增
def forward(self, x, choice='linear1'):
return torch.mm(x, self.params[choice])
net = MyDictDense()
print(net)
MyDictDense(
(params): ParameterDict(
(linear1): Parameter containing: [torch.FloatTensor of size 4x4]
(linear2): Parameter containing: [torch.FloatTensor of size 4x1]
(linear3): Parameter containing: [torch.FloatTensor of size 4x2]
)
)
x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))
tensor([[ 3.7930, -0.7466, -0.0307, -0.0046]], grad_fn=)
tensor([[2.3354]], grad_fn=)
tensor([[-1.0751, 1.1118]], grad_fn=)
我们也可以使用自定义层构造模型。它和PyTorch的其他层在使用上很类似。
class MyListDense(nn.Module):
def __init__(self):
super(MyListDense, self).__init__()
self.params = nn.ParameterList([
nn.Parameter(torch.randn(4, 4)),
nn.Parameter(torch.randn(4, 4)),
nn.Parameter(torch.randn(4, 4)),
nn.Parameter(torch.randn(4, 1))
])
def forward(self, x):
for i in range(len(self.params)):
x = torch.mm(x, self.params[i])
return x
net = nn.Sequential(
MyDictDense(),
MyListDense(),
)
print(net)
print(net(x))
Sequential(
(0): MyDictDense(
(params): ParameterDict(
(linear1): Parameter containing: [torch.FloatTensor of size 4x4]
(linear2): Parameter containing: [torch.FloatTensor of size 4x1]
(linear3): Parameter containing: [torch.FloatTensor of size 4x2]
)
)
(1): MyListDense(
(params): ParameterList(
(0): Parameter containing: [torch.FloatTensor of size 4x4]
(1): Parameter containing: [torch.FloatTensor of size 4x4]
(2): Parameter containing: [torch.FloatTensor of size 4x4]
(3): Parameter containing: [torch.FloatTensor of size 4x1]
)
)
)
tensor([[21.9177]], grad_fn=)
到目前为止,我们介绍了如何处理数据以及如何构建、训练和测试深度学习模型。然而在实际中,我们有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。
下面的例子创建了Tensor变量x,并将其存在文件名同为x.pt的文件里。
import torch
from torch import nn
x = torch.ones(3)
torch.save(x, 'x.pt')
x2 = torch.load('x.pt')
x2
tensor([1., 1., 1.])
在PyTorch中,Module的可学习参数(即权重和偏差),模块模型包含在参数中(通过model.parameters()访问)。state_dict是一个从参数名称隐射到参数Tesnor的字典对象。
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
net.state_dict()
OrderedDict([('hidden.weight', tensor([[ 0.3454, 0.4994, 0.4346],
[ 0.2998, 0.1045, -0.4370]])),
('hidden.bias', tensor([ 0.0977, -0.3321])),
('output.weight', tensor([[-0.1467, 0.3071]])),
('output.bias', tensor([0.3328]))])
注意,只有具有可学习参数的层(卷积层、线性层等)才有state_dict中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()
{'state': {},
'param_groups': [{'lr': 0.001,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'params': [1700416806776, 1700416805048, 1700416806344, 1700416803176]}]}
PyTorch中保存和加载训练模型有两种常见的方法:
保存和加载state_dict(推荐方式)
# 保存
torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth
# 加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
保存和加载整个模型
# 保存
torch.save(model, PATH)
# 加载
model = torch.load(PATH)
我们采用推荐的方法来实验一下
X = torch.randn(2, 3)
Y = net(X)
PATH = "./net.pt"
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y
tensor([[True],
[True]])
因为这net和net2都有同样的模型参数,那么对同一个输入X的计算结果将会是一样的。上面的输出也验证了这一点。
此外,还有一些其他使用场景,例如GPU与CPU之间的模型保存与读取、使用多块GPU的模型的存储等等