详解pytorch CNN操作

一维卷积

一维卷积主要用作降维或者升维。以下所有例子都以语音/NLP的场景讲述,输入的矩阵为batch x T x d。T为一个batch中语音的最长时间(短的语音会加padding),d为特征的维度。一维卷积的作用是对特征进行降维/升维。

torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

in_channels: 输入的通道数
out_channels:输出的通道数
kernel_size:卷积核的大小
stride:每次移动卷积核的间距
padding:每个边padding的元素数
dilation:空洞卷积的参数
groups:分组卷积的组数

最普通的卷积操作

nn.Conv1d(in_channels=256, out_channels=100, kernel_size=2)

# 输出的特征
input = torch.randn(32, 35, 256)
# 加padding是为了让卷积之后的时间维度不变
padding = nn.ConstantPad2d((1, 0, 0, 0,), 0)
# 由于一维卷积是对最后一维操作的,所以得转置后面两维
input = input.transpose(2, 1)
# 先转置,后padding,这样padding就加在了时间维度了
input = padding(input)
# 卷积之后还得恢复到原来的维度
output = conv1(input).transpose(2, 1)
print(output.shape)   # torch.Size([32, 35, 100])

从结果可以发现,我只把最后的维度从256变成了100。一维卷积的卷积核实际上是[in_channels , kernel_size]

一维卷积示意图

分组卷积

空洞卷积

二维卷积

一维卷积是只在某一个维度上操作,而二维卷积是在两个维度上操作。
一维卷积的输入是三维的,而二维卷积的输入是四维的。

class Net(nn.Module):
    def __init__(self, idim, odim):
        super(Net, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, odim, 4, 2),  # 1为输入通道数,odim为输出通道数,kernel为(4,4),stride为(2,2)
            nn.ReLU(),
            nn.Conv2d(odim, odim, 3, 2), # kernel为(3,3),stride为(2,2)
            nn.ReLU()
        )
        # print((((idim - 1) // 2 - 1) // 2))
        self.out = nn.Sequential(
            nn.Linear(odim * ((((idim - 4) // 2 + 1) - 3) // 2 + 1), odim)
        )

    def forward(self, x):
        # 输入通道为1
        x = x.unsqueeze(1)
        print('x:', x.shape)
        y = self.conv(x)
        print(y.shape)
        b, c, t, f = y.size()
        # 将输出的通道数和时间维度交换,通过全连接转换特征维度
        o = self.out(y.transpose(1, 2).contiguous().view(b, t, c * f))
        print(o.shape)

x = torch.ones(32, 100, 320)
model = Net(320, 512)
print(model)
model(x)

# 输出为 torch.Size([32, 24, 512])

通过二维卷积操作,可以将时间长度减少了4倍,特征维度从320映射到了512

输出维度的计算方法:
输入的矩阵: T x d
Filter: F x F
步长: S x S
padding: P x P
则输出的矩阵大小为 ((T-F+2P)/S+1) x ((d-F+2P)/S)+1

二维卷积。每个filter在时间和特征维度上都要移动

分组卷积

空洞卷积

你可能感兴趣的:(详解pytorch CNN操作)