初始化的作用就是是网络更新参数速度加快,一个好的初始化操作也是必备的,今天讲解一下常用的初始化操作
def print_weight(m):
if isinstance(m, nn.Linear):
print("weight", m.weight.data)
print("bias:", m.bias.data)
print("next...")
def print_weight(m):
if isinstance(m, nn.Conv2d):
print("weight", m.weight.data)
print("bias:", m.bias)
print("next...")
model.apply(print_weight)
使用输出来查看每一层的权重值,
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
model.apply(weights_init_normal)
这种是自己定义的写法,基本可以用于全部的网络,这里初始化参数是一种提分策略,比赛中可以多次调解使用找到效果最好的一种。
def weights_init_normal2(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu'))
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
Xavier初始化的基本思想是保持输入和输出的方差一致,这样就避免了所有输出值都趋向于0。这是通用的方法,适用于任何激活函数。也可以使用 gain 参数来自定义初始化的标准差来匹配特定的激活函数,但是本身xavier的提出是针对tanh()函数,
xavier的均匀分布
nn.init.xavier_uniform_(w,gain=1)
torch.nn.init.xavier_uniform_(tensor, gain=1)
xavier初始化方法中服从均匀分布U(−a,a) ,分布的参数a = gain * sqrt(6/fan_in+fan_out),
xavier的正态分布
nn.init.xavier_normal_(b, gain=1)
torch.nn.init.xavier_normal_(tensor, gain=1)
xavier初始化方法中服从正态分布,
mean=0,std = gain * sqrt(2/fan_in + fan_out)
def weights_init_normal3(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.kaiming_uniform_(m.weight.data, mode='fan_in', nonlinearity='relu')
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
He initialization的思想是:在ReLU网络中,假定每一层有一半的神经元被激活,另一半为0。在ReLU网络中使用效果最好。
kaiming均匀分布
torch.nn.init.kaiming_uniform_(tensor, a=0, mode=‘fan_in’, nonlinearity=‘leaky_relu’)
此为均匀分布,U~(-bound, bound), bound = sqrt(6/(1+a^2)*fan_in)
mode- 可选为fan_in 或 fan_out, fan_in使正向传播时,方差一致; fan_out使反向传播时,方差一致
nonlinearity- 可选 relu 和 leaky_relu ,默认值为 。 leaky_relu
kaiming正态分布
torch.nn.init.kaiming_normal_(tensor, a=0, mode=‘fan_in’, nonlinearity=‘leaky_relu’)
此为0均值的正态分布,N~ (0,std),其中std = sqrt(2/(1+a^2)*fan_in)
torch.nn.init.sparse_(tensor, sparsity, std=0.01)
sparsity - 每列中需要被设置成零的元素比例
std - 用于生成非零值的正态分布的标准差
nn.init.sparse_(w, sparsity=0.1)
from efficientnet_pytorch import EfficientNet
net = EfficientNet.from_name('efficientnet-b0').cuda()
print(net)
def weights_init_normal4(m):
classname = m.__class__.__name__
if classname.find(" Conv2dStaticSamePadding") != -1:
nn.init.kaiming_uniform_(m.weight.data, mode='fan_in', nonlinearity='relu')
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
net.apply(weights_init_normal4)
为什么需要这样的,因为在定义的时候起名不同,所以在最新的网络需要print看一下网络结构,然后将有conv2d的函数写在find中即可。不过大部分最好使用预训练权重快的不止一倍。
import torch
import torch.nn as nn
import numpy as np
Layers = [3, 4, 6, 3]
class Block(nn.Module):
def __init__(self, in_channels, filters, stride=1, is_1x1conv=False):
super(Block, self).__init__()
filter1, filter2, filter3 = filters
self.is_1x1conv = is_1x1conv
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride,bias=False),
nn.BatchNorm2d(filter1),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(filter1, filter2, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(filter2),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.Conv2d(filter2, filter3, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(filter3),
)
if is_1x1conv:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(filter3)
)
def forward(self, x):
x_shortcut = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
if self.is_1x1conv:
x_shortcut = self.shortcut(x_shortcut)
x = x + x_shortcut
x = self.relu(x)
return x
class Resnet50(nn.Module):
def __init__(self):
super(Resnet50,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.conv2 = self._make_layer(64, (64, 64, 256), Layers[0])
self.conv3 = self._make_layer(256, (128, 128, 512), Layers[1], 2)
self.conv4 = self._make_layer(512, (256, 256, 1024), Layers[2], 2)
self.conv5 = self._make_layer(1024, (512, 512, 2048), Layers[3], 2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Sequential(
nn.Linear(2048, 10)
)
def forward(self, input):
x = self.conv1(input)
x = self.maxpool(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.avgpool(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
def _make_layer(self, in_channels, filters, blocks, stride=1):
layers = []
block_1 = Block(in_channels, filters, stride=stride, is_1x1conv=True)
layers.append(block_1)
for i in range(1, blocks):
layers.append(Block(filters[2], filters, stride=1, is_1x1conv=False))
return nn.Sequential(*layers)
def Resnet():
return Resnet50()
def print_weight(m):
if isinstance(m, nn.Conv2d):
print("weight", m.weight.data)
print("bias:", m.bias)
print("next...")
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
def weights_init_normal2(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.xavier_normal_(m.weight, gain=1)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
def weights_init_normal3(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.kaiming_uniform_(m.weight.data, mode='fan_in', nonlinearity='relu')
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
model = Resnet50()
model.apply(weights_init_normal2)