Pytorch官网对torch.nn.Parameter()
的解释:
torch.nn.Parameter
是继承自torch.Tensor
的子类,其主要作用是作为nn.Module
中的参数使用。
它与torch.Tensor的区别就是nn.Parameter会自动被认为是module的可训练参数,即加入到parameter()这个迭代器中去;而module中非nn.Parameter()的普通tensor是不在parameter中的。
注意到,nn.Parameter的对象的requires_grad属性的默认值是True,即是可被训练的,这与torh.Tensor对象的默认值相反。
在nn.Module类中,pytorch也是使用nn.Parameter来对每一个module的参数进行初始化的。
下面举一个实际应用的例子:
import torch
from torch import nn
from d2l import torch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
num_inputs, num_outputs, num_hiddens = 784, 10, 256
w1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
w2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))
params = [w1, b1, w2, b2]
print(params)
我们的输出结果为:
(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/15.动手学深度学习代码手撸/MLP.py"
[Parameter containing:
tensor([[ 0.0011, 0.0245, 0.0086, ..., 0.0178, 0.0028, 0.0007],
[-0.0109, 0.0161, -0.0044, ..., -0.0075, -0.0115, 0.0167],
[-0.0030, -0.0010, 0.0069, ..., -0.0103, 0.0200, -0.0165],
...,
[ 0.0026, 0.0117, -0.0006, ..., -0.0134, -0.0076, 0.0070],
[ 0.0089, -0.0054, 0.0077, ..., 0.0184, 0.0148, 0.0176],
[ 0.0067, 0.0064, -0.0108, ..., 0.0061, 0.0206, -0.0238]],
requires_grad=True), Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
requires_grad=True), Parameter containing:
tensor([[-8.4045e-03, -9.1918e-03, -8.5745e-04, ..., -9.0292e-05,
9.4313e-03, -3.8237e-03],
[-2.7639e-03, -4.3183e-03, 3.0515e-03, ..., 3.5514e-03,
8.0173e-03, 1.8906e-02],
[ 5.2295e-03, -9.2716e-04, 8.9173e-03, ..., 1.2121e-02,
1.2078e-02, -3.1057e-02],
...,
[-3.4958e-03, -1.0638e-02, -3.9941e-03, ..., -9.9821e-03,
-1.9388e-02, 2.4327e-03],
[ 2.8740e-03, -9.1803e-03, -1.0169e-02, ..., -2.1247e-03,
1.1494e-02, -6.7526e-03],
[-5.2698e-04, 6.6234e-03, 2.2721e-02, ..., 1.9378e-02,
-4.2416e-03, -6.5305e-03]], requires_grad=True), Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)]