U-Net的pytorch实现

简介

查看U-Net论文请点击此处。U-Net最初是用于细胞识别的,针对少量的训练数据,作者通过数据增强等方式,实现了很好的效果,获得了比赛冠军,U-Net经过修改也可以用于其他用途。
U-Net论文作者提供了caffe的版本,github上也已经有人提供了pytorch的版本,但是经过了修改,本文提供的实现忠于论文的描述,没有研究过作者提供的版本,所以不保证本实现和作者的实现足够接近,仅是个人对论文的理解,如有错误,还望指正,欢迎留言,另外,由于没有对应的训练数据,所以没有严格的验证。

正文

代码如下, 代码后面有说明

import torch
import torch.nn as nn
from torchsummary import summary

class MyUnetDown(nn.Module):
    def __init__(self, in_channels):
        super(MyUnetDown, self).__init__()
        
        self.down_and_conv = nn.Sequential(
            nn.MaxPool2d(
                kernel_size = 2,
                stride=2,
            ),
            nn.Conv2d(
                in_channels = in_channels,
                out_channels = in_channels*2,
                kernel_size = 3,
                stride=1,
                padding=0, # 论文中说unpadded convolutions
                bias=False,
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels = in_channels*2,
                out_channels = in_channels*2,
                kernel_size = 3,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.ReLU(),
        ) # 这个地方如果不小心写了个逗号,就玩完了!!!
        
    def forward(self, x):
        return self.down_and_conv(x)
    

class MyUnetUp(nn.Module):
    def __init__(self, in_channels, cropsize):
        super(MyUnetUp, self).__init__()
        
        # 和最大池化对应,如果遇到奇数,没有padding最大池化会丢弃多余的,所以这里out_padding=0
        # 最大池化大小(形状)上相当于卷积kernel=2, stride=2, 这里是大小(形状)上的逆过程
        # 参数和对应的Conv2d相同, stride=2 卷积回去有歧义, out_padding区分歧义,这里设置为0,对应偶数
        # 详情参见ConvTranspose2d文档以及文档中给的ConvTranspose2d的图示链接
        #
        # ConvTranspose2d根据参数 dilation * (kernel_size - 1) - padding计算真实的padding
        # stride如果不是1还会在内部对应的填充,类似于padding,然后真实的卷积stride是1,再做卷积运算,
        # 这样就对应回去了,如果stride不是1,存在分歧,通过out_padding区分分歧,即先单面加out_padding,
        # 正常的padding是双面的,加完之后再卷积回去,这样得到的结果就在大小(形状)上和Conv2d之前相同了,
        # 相当于大小(形状)逆运算, shape -> Conv2d -> ConvTranspose2d -> shape
        self.up = nn.ConvTranspose2d(
            in_channels = in_channels,
            out_channels = in_channels//2,
            kernel_size=2,
            stride=2,
        )
        
        # 经过测试,发现这样做可以训练,但是不能保存模型!!!
        # # 根据论文中的图,左侧裁剪,然后和右侧拼接在一起,
        # def _crop_and_copy(left, right):
        #     #print("input", left.shape, right.shape)
        #     _, left, _ = torch.tensor_split(left, (cropsize, cropsize+right.shape[2]), dim = 2)
        #     _, left, _ = torch.tensor_split(left, (cropsize, cropsize+right.shape[3]), dim = 3)
        #     #print(cropsize,left.shape, right.shape, torch.cat([left, right], dim=1).shape)
        #     return torch.cat([left, right], dim=1)
        
        #self.crop_and_copy = _crop_and_copy
        self.cropsize = cropsize
        
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels = in_channels,
                out_channels = in_channels//2,
                kernel_size = 3,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels = in_channels//2,
                out_channels = in_channels//2,
                kernel_size = 3,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.ReLU()
        )

    def forward(self, x, left):
        x = self.up(x)
        # 谁有更好的办法?
        #x = self.crop_and_copy(left = left, right = x)
        _, left, _ = torch.tensor_split(left, (self.cropsize, self.cropsize+x.shape[2]), dim = 2)
        _, left, _ = torch.tensor_split(left, (self.cropsize, self.cropsize+x.shape[3]), dim = 3)
            #print(cropsize,left.shape, right.shape, torch.cat([left, right], dim=1).shape)
        x = torch.cat([left, x], dim=1)
        
        return self.conv(x)

    
class MyUnet(nn.Module):
    def __init__(self, depth=4, num_features=64):
        super(MyUnet,self).__init__()

        #depth = 4
        #num_features = 64
        _num_features_dbg = num_features
                
        # 下采样之前的第一步操作,为了便于在上采样过程中使用下采样过程的输出结果
        # 把 下采样+两个卷积ReLU 定义成一个整体, 上采样+两个卷积ReLU 定义成一个整体
        # 这样第一次下采样之前就多出来两个卷积ReLu, 就是in_conv
        # 
        # 如果定义 两个卷积ReLU+下采样 为一个整体, 由于需要提供给
        # 上采样部分的结果是 两个卷积ReLU 的输出, 而不是最后的输出
        # 处理起来比较麻烦
        self.in_conv = nn.Sequential(
            nn.Conv2d(
                in_channels = 1,#论文中是1, 对应灰度图
                out_channels = num_features,
                kernel_size = 3,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels = num_features,
                out_channels = num_features,
                kernel_size = 3,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.ReLU(),
        )
        
        up = [None] * depth
        down = [None] * depth
        
        for i in range(depth):
            down[i] = MyUnetDown(num_features)
            num_features *= 2
        
        cropsize = 4
        for i in range(depth):
            up[i] = MyUnetUp(in_channels=num_features, cropsize=cropsize)
            num_features //= 2
            cropsize = (cropsize + 4) * 2
        
        # 所有 下采样模块
        self.down_list = nn.ModuleList(down)
        # 所有 上采样模块
        self.up_list = nn.ModuleList(up)
        
        # 最后一步操作
        assert(num_features == _num_features_dbg)
        self.out_conv = nn.Conv2d(
            in_channels=num_features,
            out_channels=2,
            kernel_size=1,
        )


    def forward(self, x):
        
        # 保存下采样过程的输出结果,给上采样用, in_conv也算是下采样过程的一部分
        left_results = []

        x = self.in_conv(x)
        
        for down in self.down_list:
            left_results.append(x)
            x = down(x)
            
        for up in self.up_list:
            x = up(x, left = left_results.pop())
        
        assert(len(left_results) == 0)
            
        return self.out_conv(x)

    
unet = MyUnet()
print(unet)
summary(unet, (1,572,572),batch_size=32, device="cpu")

# hiddenlayer 不支持tensor_split, 不知道是否还有其他更好的办法, 谁知道还望告诉我。
# import hiddenlayer as hl 
# graph = hl.build_graph(unet, torch.zeros([1,1,572,572]))
# graph.theme = hl.graph.THEMES["blue"].copy()
# graph.save("/tmp/unet.png", format="png")


from torchviz import make_dot
x = torch.randn(1,1, 572,572).requires_grad_(True)
y = unet(x)
vis = make_dot(y, params=dict(list(unet.named_parameters()) + [('x', x)]))
vis.format = "png"
vis.directory = "/tmp"
vis.view()

运行上面代码可以看到网络的打印结果,还有生成的图片。
U-Net的pytorch实现_第1张图片

关于U-Net网络的结构,全部都在论文中的那张图上,不清楚授权问题,到论文中看吧,这里就不粘贴了,不得不说,作者这张图质量极高,看似简单,但仔看发现,描述得十分清晰。上面的代码也是主要照着这样图片写和检查的。

上面的代码主要将论文图中的节点分为四部分,in_conv,4个down,4个up,out_conv ,down是 2x2最大池化+两个卷积ReLU组合,up是 转置卷积+copy_and_crop_and_cat+两个卷积ReLU组合,out_conv是输出分类数,每个channel一个类,down和up单独定义了一个类,然后通过参数批量生成四个,最后通过nn.ModuleList保存在MyUnet中,可以正常打印出来。copy_and_crop_and_cat这个过程麻烦一点,需要保存down中的结果,在up中使用,并且还要裁剪,然后再按照通道数拼接,这部分不清楚怎样处理比较好,代码中是我个人的直观理解。

因为U-Net中没有使用padding,所以卷积过程中尺寸变小了,up和down是对称的尺寸增加2倍和减少两倍(通道数相反),但由于卷积损失,两边尺寸不同,需要对down中的裁剪,这样和up中的size就相同了,仔细看论文中的图,就很清晰了,最终的输出结果也是,大小比原始图片小,所以在训练和验证的时候,要做特殊处理,见后文。作者这样做的好处是可以无缝拼接和裁剪输入图像,原文说是overlap-tile,应该是指有重叠部分吧,个人观点是有重叠应该会比没重叠效果好,虽然没有padding,但是3x3的卷积,在最内部,考虑一个维度,每个点贡献三次,但是最边缘的点只贡献了一次,次边缘的点贡献两次,所以边缘和内部还是有差别的。

训练的时候会由于边缘部分浪费,如果训练图像比较小,浪费的部分(四边各裁掉92)可能比有用的部分还要多,作者采用的方式是输入超大的图像,减少batch,这样节约显存。关于输入和输出大小,论文图中举例是输入512x512x1,输出388x388x2,但因为没有全链接层,并且固定的对称结构,左右拼接的时候左边(论文图中是down的一侧)裁剪的大小是可以确定的,和输入无关。所以输入大小不一定是512x512x1,可变,输出跟随输入一起变,变化是可计算的,计算方法参考代码注释中的cropsize和channel的计算,另外打印网络信息(summary),输出图形,比如上面的图,也可以看到对应的大小。关于输入大小的选择,也有限制,因为如果宽或高的大小为奇数,2x2最大池化将会丢弃掉最后的那一个余数部分,为了保证每一步最大池化都遇到偶数,推荐从最下面往上推,这样得到的输入尺寸就没问题了。

验证

如果没有和论文作者相同的数据,可能需要修改,输入图像如果是彩色的,应该把输入通道(1)改成3,输出如果是多分类,应该把输出通道(2)改成对应分类数。

多数数据集的输入图像和标签图像应该是大小相同的吧,由于U-Net输出的分割图比输入的小,一般前景分割的部分不应该贴近边缘,否则可能会受边缘影响,并且还要保证输入大小的图像中的数据的分割部分恰好全部在标签图像中。我的处理方法是,先设定一个大小,作为标签图像的大小,然后对于非标签图像,在标签图像的大小上进行padding,padding的时候为了避免0的影响,使用’edge’的方式填充边缘像素,只要前景分割部分不出现在边缘,这样填充之后的也不应该是前景,所以个人认为可以这样处理而对结果影响比较小。个人认为U-Net更适合连续的可无缝拼接在一起的图,这样可以最大的减少边界效应,如果是不同的图片拼接到一起,交界的地方总是奇怪的突变部分,用这部分数据训练应该会对正常的图片中的内容的训练或多或少有影响吧。

pytorch中提供了resize和padding等常见操作,比如torchvision中的transforms.Padtransforms.Resize,可以查找文档。

网络初始化,论文中给出的方法是参考Delving deep into rectifiers: Surpassing human-
level performance on imagenet classification做的初始化,pytorch中默认的也是,但是似乎初始化参数不太相同。根据论文中的描述,我觉得应该是按照下面的参数初始化3x3卷积

@torch.no_grad()
def init_weights(layer):
    print("layer type is ", type(layer))
    if type(layer) == nn.Conv2d:
        nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
        if layer.bias is not None:
            print("bias not none")
            nn.init.constant_(layer.bias, 0)


unet.apply(init_weights)

转置卷积的初始化还不清楚如何设置比较好。

最终的损失函数使用交叉熵(结合softmax),pytorch中直接使用nn.CrossEntropyLoss()就可以了,输入的参数的shape(batchsize, classize, h, w)nn.CrossEntropyLoss()支持这种shape,根据class那个维度来计算,但是论文中使用了不同的权重,增加分割边界部分的权重,也包括类数量不同的对应权重,不同类的权重函数参数可以直接支持,对于图像中不同位置的权重需要自己预先制作一个权重图,论文中给出了公式,如果没有这个权重图,直接计算就可以了,即所有像素权重相同,如果使用权重图,把参数中的reduction设置成'none',这样得到的结果的shape为(batchsize, h, w),和权重图对应位置进行相乘,然后再求和或求平均。

这个网络挺占空间的,batchsize,如果设置过大可能现存不够,或许可以考虑使用将32位浮点数变成16位浮点数的优化技术。

根据《PyTorch深度学习入门与实战》书中fcn网络训练的例子和对应的数据集(很小),对U-Net进行了训练,效果并不好(没有采用对数据集扩充的手段),当然书中的fcn的训练效果也不好。

你可能感兴趣的:(pytorch,计算机视觉,pytorch)