一个网络的权重初始化方法

class TestNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3,16,3)
        self.hidden = nn.Sequential(
            nn.Linear(100,100),
            nn.ReLU(),
            nn.Linear(100,50),
            nn.ReLU()
        )

        self.cla = nn.Linear(50,10)

    def forward(self,x):
        x = self.conv1(x)
        x = x.view(x.shape[0],-1)
        x = self.hidden(x)
        output = self.cla(x)
        return output
testnet = TestNet()
print(testnet)
def init_weights(m):
    if type(m) == nn.Conv2d:
        #如果是卷积层
        nn.init.normal(m.weight,mean=0,std=0.5)
        nn.init.constant(m.bias,val=0)

    if type(m) == nn.Linear:
        nn.init.uniform(m.weight,a=-0.1,b=0.1)
        m.bias.data.fill_(0.01) # fill_  带下划线表示原地操作,在当前内存地址进行操作

# apply方法,使当前网络绑定自定义初始化方法
# 设置随机数种子,确保可以实验重现
torch.manual_seed(13)
testnet.apply(init_weights)

你可能感兴趣的:(Pytorch,python,深度学习,人工智能)