CNN卷积神经网络(三)——VGG网络PyTorch实现

CNN卷积神经网络(三)——VGG网络PyTorch实现

VGG网络实现了更深层次的网络设计。
原文章:Very Deep Convolutional Networks for Large-Scale Image Recognition
CNN卷积神经网络(三)——VGG网络PyTorch实现_第1张图片

VGG网络PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class VGG(nn.Module):
    def __init__(self, features, num_class=1000):
        super(VGG,self).__init__()
        self.features = features
        # 全连接层
        self.classify = nn.Sequential(
            # 防止过拟合
            nn.Dropout(0.5),
            # 全连接层1
            nn.Linear(512*7*7, 4096),
            nn.ReLU(inplace=True),
            #  全连接层2
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            # 全连接层3
            nn.Linear(4096, num_class)
        )
        self.__initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classify(x)
        return x

    def __initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)


cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}

def make_features(arglist, in_channels, batch_norm=False):
    layer = []
    input_channels = in_channels
    for item in arglist:
        if item == 'M':
            layer += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            if batch_norm is False:
                layer += [nn.Conv2d(in_channels=input_channels, out_channels=item, kernel_size=3, padding=1),
                          nn.ReLU(inplace=True)]
            else:
                layer += [nn.Conv2d(in_channels=input_channels, out_channels=item, kernel_size=3, padding=1),
                          nn.BatchNorm2d(item),
                          nn.ReLU(inplace=True)]
            input_channels = item
    return nn.Sequential(*layer)

def vgg16(model_name='vgg16', **kwargs):
    try:
        cfg = cfgs[model_name]
    except:
        print("Warning: model number {} not in cfgs dict!".format(model_name))
        exit(-1)
    model = VGG(make_features(cfg, in_channels=3, batch_norm=False), **kwargs)
    return model

if __name__ == "__main__":
    # vgg16-D
    model = vgg16()
    print(model)

    # vgg16-C
    layer14 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1)
    layer21 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1)
    layer28 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1)
    model.features[14] = layer14
    model.features[21] = layer21
    model.features[28] = layer28
    print(model)

    A = torch.rand((8,3,224,224))
    B = model(A)
    print(B.shape)

你可能感兴趣的:(pytorch,cnn,网络)