常用:conv+bn+relu组合
#conv
nn.init.kaiming_normal_(conv.weight, mode = 'fan_in')
nn.init.constant_(conv.bias, 0.) #如果conv后面有bn, bias=False
#bn
nn.init.normal_(bn.weight, mean = 1., std = 0.02)
nn.init.constant_(bn.bias, 0.)
#fc
nn.init.kaiming_normal(fc.weight, mode = 'fan_out')
nn.init.constant_(fc.bias, 0.)
resnet:
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
inception:
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
X = stats.truncnorm(-2, 2, scale=stddev)
values = torch.Tensor(X.rvs(m.weight.data.numel()))
values = values.view(m.weight.data.size())
m.weight.data.copy_(values)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
vgg:
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
其他:
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
weight_shape = list(m.weight.data.size())
fan_in = np.prod(weight_shape[1:4])
fan_out = np.prod(weight_shape[2:4]) * weight_shape[0]
w_bound = np.sqrt(6. / (fan_in + fan_out))
m.weight.data.uniform_(-w_bound, w_bound)
if m.bias is not None:
m.bias.data.fill_(0)
elif classname.find('Linear') != -1:
weight_shape = list(m.weight.data.size())
fan_in = weight_shape[1]
fan_out = weight_shape[0]
w_bound = np.sqrt(6. / (fan_in + fan_out))
m.weight.data.uniform_(-w_bound, w_bound)
if m.bias is not None:
m.bias.data.fill_(0)
elif classname.find('LSTM') != -1:
for name, param in m.named_parameters():
if 'bias' in name:
torch.nn.init.constant(param, 0.0)
elif 'weight' in name:
torch.nn.init.orthogonal(param)
# Initialize biases for LSTM’s forget gate to 1 to remember more by default. Similarly, initialize biases for GRU’s reset gate to -1.
for names in m._all_weights:
for name in filter(lambda n: "bias" in n, names):
bias = getattr(m, name)
n = bias.size(0)
start, end = n // 4, n // 2
bias.data[start:end].fill_(1.)
elif classname.find('GRU') != -1:
for name, param in m.named_parameters():
if 'bias' in name:
torch.nn.init.constant(param, 0.0)
elif 'weight' in name:
torch.nn.init.orthogonal(param)
def initial_model_weight(layers):
for layer in layers:
if list(layer.children()) == []:
weights_init(layer)
# print('weight initial finished!')
else:
for sub_layer in list(layer.children()):
initial_model_weight([sub_layer])