使用pytorch实现Inception模块

在pytorch中没有找到Inception模块,自己写了一个,以供调用。
Inception模块的顺序为:
1. 输入 -> 1*1卷积 -> BatchNorm -> ReLU -> 1*5卷积 -> BatchNorm -> ReLU
2. 输入 -> 1*1卷积 -> BatchNorm -> ReLU -> 1*3卷积 -> BatchNorm -> ReLU
3. 输入 -> 池化 -> 1*1卷积 -> BatchNorm -> ReLU
4. 输入 -> 1*1卷积 -> BatchNorm -> ReLU
其中,1和2步骤可以重复多次。最后将所有结果串接起来。
pytorch中实现如下,应用例子见我的下一篇文章:openface(三):卷积网络。

import torch.nn as nn
class Inception(nn.Module):
    def __init__(self, inputSize, kernelSize, kernelStride, outputSize, reduceSize, pool):
         # inputSize:输入尺寸
         # kernelSize:第1步骤和第2步骤中第二个卷积核的尺寸,是一个列表
         # kernelStride:同上
         # outputSize:同上
         # reduceSize:1*1卷积中的输出尺寸,是一个列表
         # pool: 是一个池化层
        super(Inception, self).__init__()
        self.layers = {}
        poolFlag = True
        fname = 0
        for p in kernelSize, kernelStride, outputSize, reduceSize:
            if len(p) == 4:
                (_kernel, _stride, _output, _reduce) = p
                self.layers[str(fname)] = nn.Sequential(
                    # Convolution 1*1
                    nn.Conv2d(inputSize, _reduce, 1),
                    nn.BatchNorm2d(_reduce),
                    nn.ReLU(),
                    # Convolution kernel*kernel
                    nn.Conv2d(_reduce, _output, _kernel, _stride),
                    nn.BatchNorm2d(_output),
                    nn.ReLU())
            else:
                if poolFlag:
                    assert len(p) == 1
                    self.layers[str(fname)] = nn.Sequential(
                        # pool
                        pool, #这里的输出尺寸需要考虑一下
                        nn.Conv2d(inputSize, p, 1),
                        nn.BatchNorm2d(p),
                        nn.ReLU())
                    poolFlag = False
                else:
                    assert len(p) == 1
                    self.layers[str(fname)] = nn.Sequential(
                        # Convolution 1*1
                        nn.Conv2d(inputSize, p, 1),
                        nn.BatchNorm2d(p),
                        nn.ReLU())
            fname += 1

        if poolFlag:
            self.layers[str(fname)] = nn.Sequential(pool)
            poolFlag = False
    def forward(self, x):
        for key, layer in self.layers.items:
            if key == str(0):
                out = layer(x)
            else:
                out = torch.cat((out, layer(x)), 1) #因为有Batch,所以是在第1维方向串接。
        return out

你可能感兴趣的:(人脸识别)