pytorch 参数初始化函数

自定义了一个函数,用于网络初始化

def init_params(net, init_type='kn'):
    print('use init scheme: %s' % init_type)
    for m in net.modules():
        if isinstance(m, (nn.Conv2d, nn.Conv3d)):
            if init_type == 'kn':
                init.kaiming_normal_(m.weight, mode='fan_out')
            elif init_type == 'ku':
                init.kaiming_uniform_(m.weight, mode='fan_out')
            elif init_type == 'xn':
                init.xavier_normal_(m.weight)
            elif init_type == 'xu':
                init.xavier_uniform_(m.weight)
            if m.bias is not None:
                init.constant_(m.bias, 0)
        elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm3d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d)):
            init.constant_(m.weight, 1)
            if m.bias is not None:
                init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal_(m.weight, std=1e-3)
            if m.bias is not None:
                init.constant_(m.bias, 0)

你可能感兴趣的:(pytorch,python,人工智能)