nn.moduleList 和Sequential由来、用法和实例
注意:本博文详细记述了nn.moduleList 和 nn.Sequential的区别以及使用,解释比较清晰。
问题:对于nn.Sequential使用时,直接调用解析:
实例:
class ASPP(nn.Module):
def __init__(self, in_channels, out_channels, paddings, dilations):
# todo depthwise separable conv
super(ASPP, self).__init__()
self.conv11 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False, ),
nn.BatchNorm2d(out_channels))
self.conv33 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3,
padding=paddings, dilation=dilations, bias=False, ),
nn.BatchNorm2d(out_channels))
self.conv_p = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False, ),
nn.BatchNorm2d(out_channels))
self.concate_conv = nn.Sequential(nn.Conv2d(out_channels * 3, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels))
# self.upsample = nn.Upsample(mode='bilinear', align_corners=True)
def forward(self, x):
conv11 = self.conv11(x) #上述self.conv11实现中并没有代用input参数,为何在此处可以直接调用,并且传入参数x,见下解释
conv33 = self.conv33(x)
# image pool and upsample
image_pool = nn.AvgPool2d(kernel_size=x.size()[2:])
image_pool = image_pool(x)
upsample = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=True)
upsample = upsample(image_pool)
upsample = self.conv_p(upsample)
# concate
concate = torch.cat([conv11, conv33, upsample], dim=1)
return self.concate_conv(concate)
nn.Sequential是继承自Module的类,同样nn.Conv2d等函数也是继承自Module的类,而Module中实现由__calller__函数来调用子类中的forward函数;
nn.Sequential内部封装有forward函数,forward函数带有input参数;
#source:https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential
class Sequential(Module):
def forward(self, input):
for module in self._modules.values():
input = module(input)
return input
#source:https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv2d
class Conv2d(_ConvNd):
def forward(self, input):
if self.padding_mode == 'circular':
expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
(self.padding[0] + 1) // 2, self.padding[0] // 2)
return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
self.weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
#source:https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#ReLU
class ReLU(Module):
@weak_script_method
def forward(self, input):
return F.relu(input, inplace=self.inplace)