pytorch 实现初始化操作详细讲解 常用方案

摘要

初始化的作用就是是网络更新参数速度加快,一个好的初始化操作也是必备的,今天讲解一下常用的初始化操作

输出参数

def print_weight(m):
    if isinstance(m, nn.Linear):
        print("weight", m.weight.data)
        print("bias:", m.bias.data)
        print("next...")
def print_weight(m):
    if isinstance(m, nn.Conv2d):
        print("weight", m.weight.data)
        print("bias:", m.bias)
        print("next...")

model.apply(print_weight)

使用输出来查看每一层的权重值,

自定义初始化参数

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

model.apply(weights_init_normal)

这种是自己定义的写法,基本可以用于全部的网络,这里初始化参数是一种提分策略,比赛中可以多次调解使用找到效果最好的一种。

xavier的使用

def weights_init_normal2(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu'))
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


Xavier初始化的基本思想是保持输入和输出的方差一致,这样就避免了所有输出值都趋向于0。这是通用的方法,适用于任何激活函数。也可以使用 gain 参数来自定义初始化的标准差来匹配特定的激活函数,但是本身xavier的提出是针对tanh()函数,

xavier的均匀分布
nn.init.xavier_uniform_(w,gain=1)
torch.nn.init.xavier_uniform_(tensor, gain=1)
xavier初始化方法中服从均匀分布U(−a,a) ,分布的参数a = gain * sqrt(6/fan_in+fan_out),

xavier的正态分布
nn.init.xavier_normal_(b, gain=1)
torch.nn.init.xavier_normal_(tensor, gain=1)
xavier初始化方法中服从正态分布,
mean=0,std = gain * sqrt(2/fan_in + fan_out)

kaiming的使用

def weights_init_normal3(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.kaiming_uniform_(m.weight.data, mode='fan_in', nonlinearity='relu')
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

He initialization的思想是:在ReLU网络中,假定每一层有一半的神经元被激活,另一半为0。在ReLU网络中使用效果最好。

kaiming均匀分布
torch.nn.init.kaiming_uniform_(tensor, a=0, mode=‘fan_in’, nonlinearity=‘leaky_relu’)
此为均匀分布,U~(-bound, bound), bound = sqrt(6/(1+a^2)*fan_in)
mode- 可选为fan_in 或 fan_out, fan_in使正向传播时,方差一致; fan_out使反向传播时,方差一致
nonlinearity- 可选 relu 和 leaky_relu ,默认值为 。 leaky_relu

kaiming正态分布
torch.nn.init.kaiming_normal_(tensor, a=0, mode=‘fan_in’, nonlinearity=‘leaky_relu’)
此为0均值的正态分布,N~ (0,std),其中std = sqrt(2/(1+a^2)*fan_in)

稀疏初始化

torch.nn.init.sparse_(tensor, sparsity, std=0.01)
sparsity - 每列中需要被设置成零的元素比例
std - 用于生成非零值的正态分布的标准差
nn.init.sparse_(w, sparsity=0.1)

对不同模型使用

from efficientnet_pytorch import EfficientNet
net = EfficientNet.from_name('efficientnet-b0').cuda()
print(net)
def weights_init_normal4(m):
    classname = m.__class__.__name__
    if classname.find(" Conv2dStaticSamePadding") != -1:
        nn.init.kaiming_uniform_(m.weight.data, mode='fan_in', nonlinearity='relu')
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
net.apply(weights_init_normal4)

为什么需要这样的,因为在定义的时候起名不同,所以在最新的网络需要print看一下网络结构,然后将有conv2d的函数写在find中即可。不过大部分最好使用预训练权重快的不止一倍。

测试代码

import torch
import torch.nn as nn
import numpy as np
Layers = [3, 4, 6, 3]
class Block(nn.Module):
    def __init__(self, in_channels, filters, stride=1, is_1x1conv=False):
        super(Block, self).__init__()
        filter1, filter2, filter3 = filters
        self.is_1x1conv = is_1x1conv
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride,bias=False),
            nn.BatchNorm2d(filter1),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(filter1, filter2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(filter2),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(filter2, filter3, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(filter3),
        )
        if is_1x1conv:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(filter3)
            )
    def forward(self, x):
        x_shortcut = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        if self.is_1x1conv:
            x_shortcut = self.shortcut(x_shortcut)
        x = x + x_shortcut
        x = self.relu(x)
        return x


class Resnet50(nn.Module):

    def __init__(self):
        super(Resnet50,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = self._make_layer(64, (64, 64, 256), Layers[0])
        self.conv3 = self._make_layer(256, (128, 128, 512), Layers[1], 2)
        self.conv4 = self._make_layer(512, (256, 256, 1024), Layers[2], 2)
        self.conv5 = self._make_layer(1024, (512, 512, 2048), Layers[3], 2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(2048, 10)
        )
    def forward(self, input):
        x = self.conv1(input)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.avgpool(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x
    def _make_layer(self, in_channels, filters, blocks, stride=1):
        layers = []
        block_1 = Block(in_channels, filters, stride=stride, is_1x1conv=True)
        layers.append(block_1)
        for i in range(1, blocks):
            layers.append(Block(filters[2], filters, stride=1, is_1x1conv=False))

        return nn.Sequential(*layers)

def Resnet():
    return Resnet50()
def print_weight(m):
    if isinstance(m, nn.Conv2d):
        print("weight", m.weight.data)
        print("bias:", m.bias)
        print("next...")

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)



def weights_init_normal2(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.xavier_normal_(m.weight, gain=1)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

def weights_init_normal3(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.kaiming_uniform_(m.weight.data, mode='fan_in', nonlinearity='relu')
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
model = Resnet50()

model.apply(weights_init_normal2)


你可能感兴趣的:(pytorch 实现初始化操作详细讲解 常用方案)