ResNeXt

ResNeXt的PyTorch实现:https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/resnext.py

ResNeXt101_64x4d为例

(224, 224, 3)

→【self.features】→(7, 7, 2048)

→【self.avg_pool, ks=7, s=1】→(1, 1, 2048)

→【reshape】→(2048,)→【self.last_linear】→(1000,)

下面着重研究self.features的结构,源代码:https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/resnext_features/resnext101_64x4d_features.py

(224, 224, 3)

→【Conv, Cout=64, ks=7, s=2, p=3】→【bn, relu】→(112, 112, 64)

→【max pool, ks=3, s=2, p=1】→(56, 56, 64)

→【Sequence_1】→(56, 56, 256)

→【Sequence_2】→(28, 28, 512)

→【Sequence_3】→(14, 14, 1024)

→【Sequence_4】→()

研究第1个Sequence的结构:(56, 56, 64)→【Sequence_1】→(56, 56, 256)

(56, 56, 64)

→【Sequence_1-Block1】→(56, 56, 256)

→【Sequence_1-Block2】→(56, 56, 256)

→【Sequence_1-Block3】→(56, 56, 256)

第1个Block,(56, 56, 64)→【Sequence_1-Block1】→(56, 56, 256),不改变空间上的维度,将通道数从64增加到256,代码如下

nn.Sequential(#Sequential,
    LambdaMap(lambda x: x, #ConcatTable,
        nn.Sequential(#Sequential,
            nn.Sequential(#Sequential,
                nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 64, bias = False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
            ),
            nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False),
            nn.BatchNorm2d(256),
        ),
        nn.Sequential(#Sequential,
            nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False),
            nn.BatchNorm2d(256),
        ),
    ),
    LambdaReduce(lambda x, y: x + y), #CAddTable,
    nn.ReLU(),
)

上述代码中,LambdaMap产生分支,LambdaReduce合并分支

(56, 56, 64)

分支1:→【Conv, Cout=256, ks=1,】→(56, 56, 256)→【bn, relu】→(56, 56, 256)
→【Conv, Cout=256, ks=3, p=1,group=64】→(56, 56, 256)→【bn, relu】
→【Conv, Cout=256, ks=1】→(56, 56, 256)→【bn】→(56, 56, 256)

分支2(short cut):→【Conv, Cout=256, ks=1】→(56, 56, 256)→【bn】→(56, 56, 256)

合并add:→(56, 56, 256)→【relu】→(56, 56, 256)

第2个Block,(56, 56, 256)→【Sequence_1-Block2】→(56, 56, 256),相当于Identity Block,代码如下

nn.Sequential(#Sequential,
	LambdaMap(lambda x: x, #ConcatTable,
        nn.Sequential(#Sequential,
            nn.Sequential(#Sequential,
                nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 64, bias = False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
            ),
            nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias = False),
            nn.BatchNorm2d(256),
        ),
        Lambda(lambda x: x), #Identity,
    ),
    LambdaReduce(lambda x, y: x + y), #CAddTable,
    nn.ReLU(),
)
(56, 56, 256)

分支1:→【Conv, Cout=256, ks=1,】→(56, 56, 256)→【bn, relu】→(56, 56, 256)
→【Conv, Cout=256, ks=3, p=1,group=64】→(56, 56, 256)→【bn, relu】
→【Conv, Cout=256, ks=1】→(56, 56, 256)→【bn】→(56, 56, 256)

分支2(short cut):→【Identity】→(56, 56, 256)

合并add:→(56, 56, 256)→【relu】→(56, 56, 256)

【附】参数group
默认情况下groups=1,表示常规的卷积,例如:(7, 7, 6)→【Conv, Cout=12, ks=3, groups=1】→(5, 5, 12)weight(12, 3, 3, 6)

若指定groups=3(7, 7, 6)→【Conv, Cout=12, ks=3, groups=3】→(5, 5, 12),output的size不变,但是weight变为(12, 3, 3, 2),显然计算的方式发生了改变

groups=3表示input, output的通道数都被分成了3组,具体来说

input (7, 7, 6)被分为3组,每组包含6/3=2个通道,size(7, 7, 2)
input[:, :, 0:2], input[:, :, 2:4], input[:, :, 4:6]

12个filter被分为3组,每组包含12/3=4个filter,size(4, 3, 3, 2)(注意最后一维是2,与每一组input的通道数相等)
weight[0:4], weight[4:8], weight[8:12]

每组的input和filter分别进行卷积运算,得到3组output,每组output的size为(5, 5, 4)
将它们在通道维度上拼接起来,最终output的size为(5, 5, 12)

以下是一段验证代码

import torch
import torch.nn as nn
import torch.nn.functional as F


x = torch.rand(1, 6, 7, 7)
layer = nn.Conv2d(6, 12, kernel_size=3, groups=3, bias=False)
output = layer(x)


output2 = []
w = layer.weight

for i in range(layer.groups):
    c = layer.in_channels // layer.groups
    x_slice = x[:, i*c:(i+1)*c]
    
    c = layer.out_channels // layer.groups
    w_slice = w[i*c:(i+1)*c]

    output2.append( F.conv2d(x_slice, w_slice) )


output2 = torch.cat(output2, dim=1)
print( (output - output2).sum().item() )

你可能感兴趣的:(深度学习)