Pytorch关于卷积核(Conv2d)的简单操作与模型修剪

在使用Pytorch搭建深度学习算法时,torch.nn.Conv2d是用得最多的函数之一。Conv2d函数主要是对输入数据做卷积运算。输入参数如下图:

Pytorch关于卷积核(Conv2d)的简单操作与模型修剪_第1张图片

torch.nn.Conv2d函数所生成的卷积核主要包括weights与bias,及权重与偏置。在深度学习模型训练过程中,模型主要更新的就是卷积核的weights与bias,对于基于大数据集训练的模型,随着训练的迭代卷积核的层数越来越深,不免会有一些卷积核通道是冗余的。可以通过对卷积核通道的操作丢弃一些卷积层与达到精简模型的效果。举个简单例子:

import torch

conv = torch.nn.Conv2d(2,4,3)

Weight = conv.weight
Bias = conv.bias

print(Weight)  
print(Bias)

weight_select = torch.cat((conv.weight[:, 0:1], conv.weight[:, 1 + 1:]), dim=1)
print(conv.weight[:, 0:1].shape)
print(Weight.shape,"     ",weight_select.shape  )

以上代码生成4个3x3x2的卷积核,可以直接索weight和bias查看参数,第七行代码则是对每一个卷积核进行通道的选择然后拼接起来,原始卷积核和进行通道选择后的卷积核的维度打印出来后是不一样的,分别是:

torch.Size([4, 4, 3, 3])       torch.Size([4, 3, 3, 3])

每一个卷积核少了一个通道。需要注意的是,如果是在模型训练过程中对卷积核通道进行挑选的话,单个卷积核的通道数改变的话相对应的输入层的通道数也需改变。例如本来当前卷积核的维度是torch.Size([4, 4, 3, 3]),变成 torch.Size([4, 3, 3, 3])后当前对应的输入数据通道数也要改变否则下一代训练会报错。

你可能感兴趣的:(pytorch,深度学习,人工智能)