window 学习pytorch unet代码之self.inc = DoubleConv(n_channels, 64)

self.inc = DoubleConv(n_channels, 64)

可以猜测,channels是输入channel,64是要输出的channel

看看DoubleConv函数的具体实现

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

可以看到DoubleConv也是nn.Module的子类windows pytorch Unet网络实现二,网络学习的前两层卷积,就在这里实现了。第一~二句mid_channels=none因此mid_channels就是out_channels=64.

第三句将这两层卷积打包进nn.Sequential(有点像tensorflow中的operation可以将各个tensor放进去一样)。既可以放卷阶层,也可以以dic的形式存放。

第四~九句构建了两层网络,第一层输入3channel,输出64channe第二层输入64channel输出64channel,其中激活函数选择了ReLU,比较流行的激活函数。在conv2d和ReLU中间是nn.BatchNorm2d从字面意思理解是归一化操作。

 

你可能感兴趣的:(unet,pytorch)