Conv2d函数详解(Pytorch)

本文是基于Pytorch框架下的API :Conv2d()。该函数使用在二维输入,另外还有Conv1d()、Conv3d(),其输入分别是一维和三维。下面将介绍Conv2d()的参数。

一、参数介绍

 def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_2_t,
        stride: _size_2_t = 1,
        padding: _size_2_t = 0,
        dilation: _size_2_t = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = 'zeros'  # TODO: refine this type
    ):
  • in_channels:网络输入的通道数。
  • out_channels:网络输出的通道数。
  • kernel_size:卷积核的大小,如果该参数是一个整数q,那么卷积核的大小是qXq。
  • stride:步长。是卷积过程中移动的步长。默认情况下是1。一般卷积核在输入图像上的移动是自左至右,自上至下。如果参数是一个整数那么就默认在水平和垂直方向都是该整数。如果参数是stride=(2, 1),2代表着高(h)进行步长为2,1代表着宽(w)进行步长为1。
  • padding:填充,默认是0填充。
  • dilation:扩张。一般情况下,卷积核与输入图像对应的位置之间的计算是相同尺寸的,也就是说卷积核的大小是3X3,那么它在输入图像上每次作用的区域是3X3,这种情况下dilation=0。当dilation=1时,表示的是下图这种情况。
    Conv2d函数详解(Pytorch)_第1张图片
  • groups:分组。指的是对输入通道进行分组,如果groups=1,那么输入就一组,输出也为一组。如果groups=2,那么就将输入分为两组,那么相应的输出也是两组。另外需要注意的是in_channels和out_channels必须能整除groups。
  • bias:偏置参数,该参数是一个bool类型的,当bias=True时,表示在后向反馈中学习到的参数b被应用。
  • padding_mode:填充模式, padding_mode=‘zeros’表示的是0填充。

二、通过调整参数来感受这些参数
1、结果1

import torch
import torch.nn as nn

# 输入是一个N=20,C=16,H=50,W=100的向量
m = nn.Conv2d(16, 33, 3, stride=2)
input = torch.randn(20, 16, 50, 100)
output = m(input)

print(output.size())

  • 在nn.Conv2d()中第一个参数要和输入的通道数相同(16)。在nn.Conv2d()中第二个参数表示输出的通道数。输出中N=20不变,C=33。通过3X3的卷积核、步长为2,50X100的输入变成了24X49。
torch.Size([20, 33, 24, 49])

2、结果2

import torch
import torch.nn as nn

m = nn.Conv2d(16, 33, 3, stride=(1, 2))
input = torch.randn(20, 16, 50, 100)
output = m(input)

print(output.size())

-上一步中stride=2表示的是stride=(2,2), 这里添加了stride=(1, 2),表示向右步长为1,向下步长为2,输出结果如下:

torch.Size([20, 33, 48, 49])

3、结果3

import torch
import torch.nn as nn

m = nn.Conv2d(16, 33, (3, 5), stride=2, padding=(4, 2))
input = torch.randn(20, 16, 50, 100)
output = m(input)

print(output.size())

  • 这里添加了padding=(4,2),表示在左右方向上添加4圈0填充,在上下方向上添加2圈0,相当于输入是58X104(原来的输入是50X100)。
torch.Size([20, 33, 28, 50])

三、总结
Conv2d()是卷积神经网络的操作函数,了解函数中的参数是用好CNN的关键。

你可能感兴趣的:(深度学习,卷积,神经网络,深度学习,python)