目录
1--VGG16网络
2--代码实现
3--参考
具体原理参考 VGG 网络的原论文:VGG原论文
VGG 16 网络结构如下图所示:
VGG 16 网络的实现流程如下图所示:
from torch import nn
import torch
class Vgg16_net(nn.Module):
def __init__(self):
super(Vgg16_net, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), # 64 × 224 × 224
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), # 64 × 224 × 224
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2) # 64 × 112 × 112
)
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), # 128 × 112 × 112
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), # 128 × 112 × 112
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2) # 128 × 64 × 64
)
self.layer3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), # 256 × 56 × 56
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), # 256 × 56 × 56
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), # 256 × 56 × 56
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2) # 256 × 28 × 28
)
self.layer4 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), # 512 × 28 × 28
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), # 512 × 28 × 28
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), # 512 × 28 × 28
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2) # 512 × 14 × 14
)
self.layer5 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), # 512 × 14 × 14
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), # 512 × 14 × 14
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), # 512 × 14 × 14
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2) # 512 × 7 × 7
)
self.conv = nn.Sequential(
self.layer1,
self.layer2,
self.layer3,
self.layer4,
self.layer5
)
self.fc = nn.Sequential(
nn.Linear(512*7*7, 4096), # CHW -> 4096
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
)
self.linear = nn.Linear(4096, 120) # 120类
def forward(self, data):
B, C, H, W = data.shape
data = self.conv(data) # B C(512) H(7) W(7)
# flatten
data = data.reshape(B, -1) # B CHW
data = self.fc(data) # B 4096
# predict
output = self.linear(data) # B Class
# softmax
result = torch.softmax(output, dim = 1) # B Class
return result # B Class
if __name__ == "__main__":
B = 8 # 表示 batch_size
C = 3
H = 224
W = 224
# 随机生成测试图片
rgb_x = torch.rand((B, C, H, W)) # B × 3 × 224 × 224
# 测试
rgb_cnn = Vgg16_net()
result = rgb_cnn(rgb_x)
print('result.shape: ', result.shape)
print("All done!")
VGG16网络结构与代码
VGG16模型详解