AlexNet:有5层卷积和2层全连接隐藏层,以及1个全连接输出层; 使⽤ReLU激活函数、
Dropout正则化(类似集成学习的思想,在训练过程中按照⼀定⽐例随机丢弃⼀些神经
元)、图像增强。
# 导入包
from torchvision import models
import torch
from torch import nn
from torch.nn import functional as F
class AlexNet(nn.Module):
def __init__(self):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
# torch.Size([32, 3, 224, 224])
# torch.Size([32, 64, 27, 27])
nn.Conv2d(in_channels=3,
out_channels=64,
kernel_size=(11, 11),
stride=(4, 4),
padding=(2, 2)),
nn.ReLU(), # 这里的inplace=False是指不会改变原始数值,相当于直接copy了一份数据然后在那之上进行修改
nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
# torch.Size([32, 3, 224, 224])
# torch.Size([32, 192, 13, 13])
nn.Conv2d(in_channels=64,
out_channels=192,
kernel_size=(5, 5),
stride=(1, 1),
padding=(2, 2)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
# torch.Size([32, 3, 224, 224])
# torch.Size([32, 384, 13, 13])
nn.Conv2d(in_channels=192,
out_channels=384,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=384,
out_channels=256,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
# torch.Size([32, 3, 224, 224])
# torch.Size([32, 256, 6, 6])
nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(6, 6)) # 转化(规范)大小,对口型
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
nn.Flatten(), # 展平层
nn.Linear(in_features=9216, out_features=4096),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(in_features=4096, out_features=4096),
nn.ReLU(),
nn.Linear(in_features=4096, out_features=10))
def forward(self, x):
# 提取特征
h = self.features(x)
# 规范大小
h = self.avgpool(h)
# 做分类
o = self.classifier(h)
return o
alexnet = AlexNet()
x = torch.randn(32,3,224,224)
y = alexnet(x)
y.shape
# output:torch.Size([32, 10])
自适应平均池化可以将任意大小的矩阵转化为要求矩阵
# 自适应平均池化
pool = nn.AdaptiveAvgPool2d(output_size=(6, 6))
x = torch.randn(8, 3, 128, 128)
x1 = torch.randn(8, 3, 1, 1)
pool(x).shape # output:torch.Size([8, 3, 6, 6])
pool(x1).shape # output: torch.Size([8, 3, 6, 6])