depthwise separable conv = depthwise conv + pointwise conv
Inception是性能最好的骨干网之一,用很少的参数可以学习到很丰富的特征。
Inceptionv3的基本模块结构如下:
filter在三维空间(2空间维度和1通道维度)进行学习,需要同时学习空间相关性和通道相关性(cross-channel correlations and spatial correlations.)。
Inception将空间相关性和通道相关性解耦,先使用1x1的卷积核学习通道上的相关性,再使用3x3卷积核学习通道和空间相关性。
depthwise separable conv可看做是Inception模块的特例。
depthwise separable conv
与图4的“extreme” version of an Inception module
略有不同
depthwise separable conv = depthwise conv + pointwise conv;
depthwise separable conv先是depthwise conv ,再是1x1的 pointwise conv;
pointwise conv
:1x1的卷积,学习通道间的相关性,同时将特征图由一个通道空间映射到另一个通道空间。
depthwise conv
:one filter per output channel,filter数与通道数相等,每个filter只与一个channel卷积 。学习空间上的相关性
注意:depthwise conv后没有激活函数
For deep feature spaces,the non-linearity is helpful, but for shallow ones,it becomes harmful,possibly due to a loss of information.
pytorch实现:
class SeparableConvBnRelu(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size=1, stride=1, padding=0, dilation=1,
has_relu=True, norm_layer=nn.BatchNorm2d):
super(SeparableConvBnRelu, self).__init__()
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride,
padding, dilation, groups=in_channels,
bias=False)
self.bn = norm_layer(in_channels)
self.point_wise_cbr = ConvBnRelu(in_channels, out_channels, 1, 1, 0,
has_bn=True, norm_layer=norm_layer,
has_relu=has_relu, has_bias=False)
def forward(self, x):
x = self.depthwise(x)
x = self.bn(x)
x = self.point_wise_cbr(x)
return x
class ConvBnRelu(nn.Module):
def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1,
groups=1, has_bn=True, norm_layer=nn.BatchNorm2d, bn_eps=1e-5,
has_relu=True, inplace=True, has_bias=False):
super(ConvBnRelu, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize,
stride=stride, padding=pad,
dilation=dilation, groups=groups, bias=has_bias)
self.has_bn = has_bn
if self.has_bn:
self.bn = norm_layer(out_planes, eps=bn_eps)
self.has_relu = has_relu
if self.has_relu:
self.relu = nn.ReLU(inplace=inplace)
def forward(self, x):
x = self.conv(x)
if self.has_bn:
x = self.bn(x)
if self.has_relu:
x = self.relu(x)
return x
简化版: