VGG网络实现了更深层次的网络设计。
原文章:Very Deep Convolutional Networks for Large-Scale Image Recognition
import torch
import torch.nn as nn
import torch.nn.functional as F
class VGG(nn.Module):
def __init__(self, features, num_class=1000):
super(VGG,self).__init__()
self.features = features
# 全连接层
self.classify = nn.Sequential(
# 防止过拟合
nn.Dropout(0.5),
# 全连接层1
nn.Linear(512*7*7, 4096),
nn.ReLU(inplace=True),
# 全连接层2
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
# 全连接层3
nn.Linear(4096, num_class)
)
self.__initialize_weights()
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, start_dim=1)
x = self.classify(x)
return x
def __initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
cfgs = {
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}
def make_features(arglist, in_channels, batch_norm=False):
layer = []
input_channels = in_channels
for item in arglist:
if item == 'M':
layer += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
if batch_norm is False:
layer += [nn.Conv2d(in_channels=input_channels, out_channels=item, kernel_size=3, padding=1),
nn.ReLU(inplace=True)]
else:
layer += [nn.Conv2d(in_channels=input_channels, out_channels=item, kernel_size=3, padding=1),
nn.BatchNorm2d(item),
nn.ReLU(inplace=True)]
input_channels = item
return nn.Sequential(*layer)
def vgg16(model_name='vgg16', **kwargs):
try:
cfg = cfgs[model_name]
except:
print("Warning: model number {} not in cfgs dict!".format(model_name))
exit(-1)
model = VGG(make_features(cfg, in_channels=3, batch_norm=False), **kwargs)
return model
if __name__ == "__main__":
# vgg16-D
model = vgg16()
print(model)
# vgg16-C
layer14 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1)
layer21 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1)
layer28 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1)
model.features[14] = layer14
model.features[21] = layer21
model.features[28] = layer28
print(model)
A = torch.rand((8,3,224,224))
B = model(A)
print(B.shape)