5.vgg16网络模块的实现

代码如下:

from torch import nn
import torch

class VGG16(nn.Module):
    def __init__(self,num_classes=1000):
        super(VGG16,self).__init__()

        layers=[]
        indim=3
        outdim=64

        #构造卷积结构,一共有13层
        for i in range(13):
            layers+=[nn.Conv2d(indim,outdim,3,1,1),nn.ReLU(inplace=True)]
            indim=outdim

            #在第2,4,7,10,13层后加池化层
            if i==1 or i==3 or i==6 or i==9 or i==12:
                layers+=[nn.MaxPool2d(2,2)]
                #第10层后的卷积层通道数相同
                if i!=9:
                    outdim*=2
        self.features=nn.Sequential(*layers)

        #下面构建3个全连接层
        self.classifiers=nn.Sequential(
            #第一层
            nn.Linear(512*7*7,4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),

            #第二层
            nn.Linear(4096,4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),

            #第三层
            nn.Linear(4096,num_classes),

        )

    def forward(self,x):

        x=self.features(x)

        #将特征图的维度从[1,512,7,7]变为[1,512*7*7]
        x=x.view(x.size(0),-1)

        x=self.classifiers(x)

        return x


if __name__ == '__main__':

    print("...........................................")

    vgg=VGG16(21)

    input=torch.randn(1,3,224,224)

    scores=vgg(input)
    print(scores.shape)
    print(scores)

你可能感兴趣的:(python,开发语言)