Pytorch apply() 函数

apply 函数是nn.Module 中实现的, 递归地调用self.children() 去处理自己以及子模块

我们知道pytorch的任何网络net,都是torch.nn.Module的子类,都算是module, 也就是模块。

pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。经常用于初始化init_weights的操作

from torch import nn

def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(0)

model = nn.Sequential(
            nn.Linear(2, 2), 
            nn.Linear(2, 2)
        )
model.apply(init_weights)

Pytorch apply() 函数_第1张图片

你可能感兴趣的:(Pytorch)