文章目录
- 1 一个简单的网络
- 2 nn.Module.init_weight()
-
1 一个简单的网络
- 一个Pytorch模型应该以类的形式出现
- Pytorch训练模型应该是nn.Module的子类
- 一个训练模型最少包含init和forward(初始化和前向传播)两个过程。
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
2 nn.Module.init_weight()
- 这个代码是
SeNet
的代码,放在这里学习init_weight
import numpy as np
import torch
from torch import nn
from torch.nn import init
class SEAttention(nn.Module):
def __init__(self, channel=512, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def init_weights(self):
for m in self.modules():
print(m)
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7)
se = SEAttention(channel=512, reduction=8)
output = se(input)
print(output.shape)
2.1 kaiming 高斯初始化
- 使得每一个卷积层的输出方差都为1,权值的初始化方法如下:
![Pytorch框架 || torch.nn.modules.Module(nn.Module)_第1张图片](http://img.e-com-net.com/image/info8/8f54f62efe8d495c9b54655a00f2b636.jpg)
![Pytorch框架 || torch.nn.modules.Module(nn.Module)_第2张图片](http://img.e-com-net.com/image/info8/8ea7edd2d97d47e0acfb9e67e7d90d8e.jpg)
![Pytorch框架 || torch.nn.modules.Module(nn.Module)_第3张图片](http://img.e-com-net.com/image/info8/80be3425d11e4146bb4b915768546e26.jpg)
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')