代码如下:
from torch import nn
import torch
class VGG16(nn.Module):
def __init__(self,num_classes=1000):
super(VGG16,self).__init__()
layers=[]
indim=3
outdim=64
#构造卷积结构,一共有13层
for i in range(13):
layers+=[nn.Conv2d(indim,outdim,3,1,1),nn.ReLU(inplace=True)]
indim=outdim
#在第2,4,7,10,13层后加池化层
if i==1 or i==3 or i==6 or i==9 or i==12:
layers+=[nn.MaxPool2d(2,2)]
#第10层后的卷积层通道数相同
if i!=9:
outdim*=2
self.features=nn.Sequential(*layers)
#下面构建3个全连接层
self.classifiers=nn.Sequential(
#第一层
nn.Linear(512*7*7,4096),
nn.ReLU(inplace=True),
nn.Dropout(),
#第二层
nn.Linear(4096,4096),
nn.ReLU(inplace=True),
nn.Dropout(),
#第三层
nn.Linear(4096,num_classes),
)
def forward(self,x):
x=self.features(x)
#将特征图的维度从[1,512,7,7]变为[1,512*7*7]
x=x.view(x.size(0),-1)
x=self.classifiers(x)
return x
if __name__ == '__main__':
print("...........................................")
vgg=VGG16(21)
input=torch.randn(1,3,224,224)
scores=vgg(input)
print(scores.shape)
print(scores)