Pytorch apply函数

apply 函数是nn.Module 中实现的, 递归地调用self.children() 去处理自己以及子模块

我们知道pytorch的任何网络net,都是torch.nn.Module的子类,都算是module, 也就是模块。

pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。经常用于初始化init_weights的操作

from torch import nn

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        print(m.weight)
        m.weight.fill_(1.0)
        print(m.weight)

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
a = net[0];
net.apply(init_weights)
b = net[0];

'''
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[0.3706, 0.5069],
        [0.3494, 0.2712]], requires_grad=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[ 0.3507,  0.2883],
        [-0.2437, -0.4020]], requires_grad=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
'''

你可能感兴趣的:(动手学深度学习,pytorch,深度学习,神经网络)