pytorch weight init

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
model.apply(weight_init)

 

你可能感兴趣的:(pytorch weight init)