PyTorch实现:经典网络 块网络VGG

PyTorch实现:经典网络 VGG

AlexNet 证明深层神经网络卓有成效,但是AlexNet网络的设计没能给出一个通用的设计模板,来指导后续的研究人员设计网络框架。从AlexNet来看,卷积神经网络的基本组成部分包括:卷积层、激活函数层、汇聚层。因此,VGG提出了卷积块的概念,构建了一种基于卷积块设计的网络设计方法。

1. 块结构

VGGNet 是由块组成的,每个块都包含不同数量的卷积层、汇聚层等。这样,就可以在网络搭建过程中,复用这些块结构。

def vgg_block(num_convs, in_channels, out_channels):
	r'''块结构
	num_convs: 卷积层数
	in_channels: 输入通道数
	out_channels: 输出通道数
	'''
    layers = []
    for _ in range(num_convs):
        layers.append(nn.Conv2d(
                        in_channels,
                        out_channels,
                        kernel_size=3,
                        padding=1
                                ))
        layers.append(nn.ReLU())
        in_channels = out_channels
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*layers)

2. VGG 网络简单搭建

VGGNet模型给出了五种结构网络,深度分别是11,11,13,16,19。这里深度指的是包含参数的网络层,包括卷积层和全连接层。
因为VGGNet的网络相对于笔记本的性能来说,还是较大。因此,这里基于VGG的想法,搭建一个小型的VGG网络来说明问题。

class VGG(nn.Module):
    
    def __init__(self, in_channels, conv_arch):
        r"""
        parameters:
            conv_arch: tuple, ((num_convs, out_channel),...), 
        """
        super(VGG, self).__init__()
        self.in_channels = in_channels
        self._conv_blk = self.vgg_strc(self.in_channels, conv_arch)
        
    def forward(self, x):
        for layer in self._conv_blk:
            x = layer(x)
            
        return x
        
    def vgg_strc(self, in_channels, conv_arch):
        conv_blks = []
        for (num_convs, out_channels) in conv_arch:
            conv_blks.append(self.vgg_block(
                                                num_convs, 
                                                in_channels, 
                                                out_channels
                                                ))
            in_channels = out_channels
            
        conv_blks.extend([
            nn.Flatten(),
            nn.Linear(out_channels*7*7, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 10)
        ])
        return nn.Sequential(*conv_blks)

    def vgg_block(self, num_convs, in_channels, out_channels):

        layers = []
        for _ in range(num_convs):
            layers.append(nn.Conv2d(
                            in_channels,
                            out_channels,
                            kernel_size=3,
                            padding=1
                                    ))
            layers.append(nn.ReLU())
            in_channels = out_channels
        layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
        return nn.Sequential(*layers)

# 块结构深度定义
# (num_convs, out_channels)
conv_arch = [(1, 32), (1, 64), (2, 128), (2, 256), (2, 256)]

# 网络生成
net = VGG(3, small_conv_arch)
x = torch.randn(1, 3, 224, 224)
net(x).shape

输出:

Sequential output shape:  torch.Size([1, 32, 112, 112])
Sequential output shape:  torch.Size([1, 64, 56, 56])
Sequential output shape:  torch.Size([1, 128, 28, 28])
Sequential output shape:  torch.Size([1, 256, 14, 14])
Sequential output shape:  torch.Size([1, 256, 7, 7])
Flatten output shape:  torch.Size([1, 12544])
Linear output shape:  torch.Size([1, 4096])
ReLU output shape:  torch.Size([1, 4096])
Dropout output shape:  torch.Size([1, 4096])
Linear output shape:  torch.Size([1, 4096])
ReLU output shape:  torch.Size([1, 4096])
Dropout output shape:  torch.Size([1, 4096])
Linear output shape:  torch.Size([1, 10])

3. FashionMNIST数据集上

def load_datasets_FashionMNIST(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        transform = trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    train_data = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
    test_data = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
    
    print("FashionMNIST 下载完成...")
    return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True),
            torch.utils.data.DataLoader(test_data, batch_size, shuffle=False))

train_iter, test_iter = load_datasets_FashionMNIST(128, resize=224)

PyTorch实现:经典网络 块网络VGG_第1张图片
的简单训练

你可能感兴趣的:(经典深度模型,PyTorch使用,pytorch,深度学习,机器学习,神经网络,python)