深度学习—— 模型参数的访问、初始化和共享(pytorch)

模型参数的访问

  • 通过Module类的parameters()或者named_parameters()方法来访问所有参数(以迭代器的形式返回),后者除了返回参数Tensor外还会返回其名字。
  • 对于使用Sequential类构造的神经网络,我们可以通过方括号[]来访问网络的任一层。
  • param的类型为torch.nn.parameter.Parameter,其实这是Tensor的子类,和Tensor不同的是如果一个TensorParameter,那么它会自动被添加到模型的参数列表里

初始化模型参数

PyTorch中nn.Module的模块参数都采取了较为合理的初始化策略,PyTorch的init模块里提供了多种预设的初始化方法。也可以自定义初始化方法

共享模型参数

  1. Module类的forward函数里多次调用同一个层。
  2. 如果我们传入Sequential的模块是同一个Module实例的话参数也是共享的
import torch
from torch import nn
from torch.nn import init


class MyModel(nn.Module):
    def __init__(self, **kwargs):
        super(MyModel, self).__init__(**kwargs)
        self.weight1 = nn.Parameter(torch.rand(20, 20))
        self.weight2 = torch.rand(20, 20)

    def forward(self, x):
        pass
# 自定义参数初始化函数
def init_weight_(tensor):
    with torch.no_grad():
        tensor.uniform_(-10, 10)
        tensor *= (tensor.abs() >= 5).float()

if __name__ == '__main__':
    net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1))  # pytorch已进行默认初始化

    print(net)
    X = torch.rand(2, 4)  #
    Y = net(X).sum()

    # 访问模型参数
    for name, param in net.named_parameters():
        print(name, param.size())                   # y = xA^T + b 来解释param[0].weight.size = torch.Size([3, 4])

    # 通过方括号索引任意一层
    for name, param in net[0].named_parameters():
        print(name, param.size(), type(param))

    # 自动添加参数到模型中
    n = MyModel()
    for name, param in n.named_parameters():
        print(name, param.size())

    # 根据data来访问参数数值,用grad来访问参数梯度
    weight_0 = list(net[0].parameters())[0]
    print(weight_0.data)
    print(weight_0.grad)  # 反向传播前梯度为None
    Y.backward()
    print(weight_0.grad)

    # 初始化模型参数
    for name, param in net.named_parameters():
        if 'weight' in name:
            init.normal_(param, mean=0, std=0.01)
            print(name, param.data)

    # 用常数初始化模型参数
    for name, param in net.named_parameters():
        if 'bias' in name:
            init.constant_(param, val=0)
            print(name, param.data)

    # 自定义初始化模型
    for name, param in net.named_parameters():
        if 'weight' in name:
            init_weight_(param)
            print(name, param.data)

    for name, param in net.named_parameters():
        if 'bias' in name:
            param.data += 1
            print(name, param.data)

    # 共享参数
    linear = nn.Linear(1, 1, bias=False)
    net = nn.Sequential(linear, linear)
    print(net)
    for name, param in net.named_parameters():
        init.constant_(param, val=3)
        print(name, param.data)

    print(id(net[0]) == id(net[1]))
    print(id(net[0].weight) == id(net[1].weight))

    x = torch.ones(1, 1)
    y = net(x).sum()
    print(y)
    y.backward()
    print(net[0].weight.grad)  # 单次梯度是3,两次所以就是6
# 输出
Sequential(
  (0): Linear(in_features=4, out_features=3, bias=True)
  (1): ReLU()
  (2): Linear(in_features=3, out_features=1, bias=True)
)
0.weight torch.Size([3, 4])
0.bias torch.Size([3])
2.weight torch.Size([1, 3])
2.bias torch.Size([1])
weight torch.Size([3, 4]) 
bias torch.Size([3]) 
weight1 torch.Size([20, 20])
tensor([[-0.2916, -0.1304,  0.2056,  0.2799],
        [-0.0017, -0.0641,  0.3796, -0.0258],
        [-0.4972,  0.2411,  0.3270,  0.2545]])
None
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [-0.1125, -0.1632, -0.1011, -0.0985],
        [ 0.4213,  0.6114,  0.3785,  0.3688]])
0.weight tensor([[-0.0132,  0.0176,  0.0088, -0.0067],
        [ 0.0055, -0.0027,  0.0016, -0.0141],
        [ 0.0030,  0.0038,  0.0025,  0.0113]])
2.weight tensor([[0.0045, 0.0230, 0.0127]])
0.bias tensor([0., 0., 0.])
2.bias tensor([0.])
0.weight tensor([[ 0.0000, -7.3592, -6.9281, -0.0000],
        [-0.0000, -8.8139,  0.0000,  8.1443],
        [-0.0000, -0.0000,  6.1967, -8.2414]])
2.weight tensor([[-0., -0., -0.]])
0.bias tensor([1., 1., 1.])
2.bias tensor([1.])
Sequential(
  (0): Linear(in_features=1, out_features=1, bias=False)
  (1): Linear(in_features=1, out_features=1, bias=False)
)
0.weight tensor([[3.]])
True
True
tensor(9., grad_fn=)
tensor([[6.]])

Process finished with exit code 0

 

你可能感兴趣的:(pytorch,深度学习,pytorch)