Unet解析

import torch
import torch.nn as nn
import torchvision

class UNetFactory(nn.Module):
    """
    本质上就是一个U型的网络,先encode,后decode,中间可能有架bridge。
    其中encoder需要输出skip到decode那边做concatenate,使得decode阶段能补充信息。
    bridge不能存在下采样和上采样的操作。
    """
    def __init__(self, encoder_blocks, decoder_blocks, bridge=None):
        super(UNetFactory, self).__init__()
        self.encoder = UNetEncoder(encoder_blocks) # 返回List
        self.bridge = bridge  # 
        self.decoder = UNetDecoder(decoder_blocks)

    def forward(self, x):
        res = self.encoder(x)
        out, skips = res[0], res[1:] #将encoder输出与skips分开
        if self.bridge is not None:
            out = self.bridge(out)   #
        out = self.decoder(out, skips)#
        return out

class UNetEncoder(nn.Module):
    """
    encoder会有多次下采样,下采样前的feature map要作为skip缓存起来将来送到decoder用。
    这里约定,以下采样为界线,将encoder分成多个block,其中第一个block无下采样操作,后面的每个block内都
    含有一次下采样操作。
    """
    def __init__(self, blocks): #bolcks = encoder_block
        super(UNetEncoder, self).__init__()
        assert len(blocks) > 0  #len = 5 [0.1.2.3.4]
        self.blocks = nn.ModuleList(blocks)
        #assert/nn.ModuleList  :
        '''
        断言函数是对表达式布尔值的判断,要求表达式计算值必须为真。可用于自动调试。
        如果表达式为假,触发异常;如果表达式为真,不执行任何操作。   
        深度学习框架中的应用:断言深度学习模型的模块数不为0。
        assert len(blocks) > 0 
        详解PyTorch中的ModuleList和Sequential: https://zhuanlan.zhihu.com/p/75206669
        '''
    def forward(self, x):
        skips = []
        for i in range(len(self.blocks) - 1): #range(4) [0.1.2.3]
            x = self.blocks[i](x)
            skips.append(x)
        res = [self.blocks[i+1](x)]  #block[4]_output, e:List
        res += skips
        return res # 只能以这种方式返回多个tensor

class UNetDecoder(nn.Module):
    """
    decoder会有多次上采样,每次上采样后,要跟相应的skip做concatenate。
    这里约定,以上采样为界线,将decoder分成多个block,其中最后一个block无上采样操作,其他block内
    都含有一次上采样。如此一来,除第一个block以外,其他block都先做concatenate。
    """
    def __init__(self, blocks):
        super(UNetDecoder, self).__init__()
        assert len(blocks) > 1
        self.blocks = nn.ModuleList(blocks)
    
    def _center_crop(self, skip, x):
        """
        skip和x,谁比较大,就裁剪谁
        """
        _, _, h1, w1 = skip.shape
        _, _, h2, w2 = x.shape
        ht, wt = min(h1, h2), min(w1, w2)
        dh1 = (h1 - ht) // 2 if h1 > ht else 0  # //向下取整
        dw1 = (w1 - wt) // 2 if w1 > wt else 0
        dh2 = (h2 - ht) // 2 if h2 > ht else 0
        dw2 = (w2 - wt) // 2 if w2 > wt else 0
        return skip[:, :, dh1: (dh1 + ht), dw1: (dw1 + wt)], 
                x[:, :, dh2: (dh2 + ht), dw2: (dw2 + wt)]
'''
此处skip高宽计算:H:(dh1+ht)-dh1 = ht, W:(dw1 + wt)-dw1= wt  此时h,w均取得最小值
'''
    def forward(self, x, skips, reverse_skips=True):
        assert len(skips) == len(self.blocks) - 1
        if reverse_skips:
            skips = skips[::-1]  # 反转skips,方便下一步由下向上取出进行concat
        x = self.blocks[0](x)
        for i in range(1, len(self.blocks)):
            skip, x = self._center_crop(skips[i-1], x)
            x = torch.cat([skip, x], dim=1)
            x = self.blocks[i](x)
        return x

#def unet_convs(in_channels, out_channels, padding=0):# in_3. out_64
def unet_convs(in_channels, out_channels, padding=1):  # in_3. out_64
    """
    unet论文里出现次数最多的2个conv3x3(non-padding)的结构
    """
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )


def unet(in_channels, out_channels):
    """
    构造跟论文一致的unet网络
    https://arxiv.org/abs/1505.04597
    """
    # encoder
    encoder_blocks = [
        # two conv3x3
        unet_convs(in_channels, 64),
        # max pool 2x2, two conv3x3
        nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
            unet_convs(64, 128)
        ),
        # max pool 2x2, two conv3x3
        nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
            unet_convs(128, 256)
        ),
        # max pool 2x2, two conv3x3
        nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
            unet_convs(256, 512)
        ),
        # max pool 2x2
        nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
    ]

    # bridge
    bridge = nn.Sequential(
        # two conv3x3
        unet_convs(512, 1024)
    )
    # decoder
    decoder_blocks = [
        # up-conv2x2
        nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
        # two conv3x3, up-conv2x2
        nn.Sequential(
            unet_convs(1024, 512),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
        ),
        # two conv3x3, up-conv2x2
        nn.Sequential(
            unet_convs(512, 256),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
        ),
        # two conv3x3, up-conv2x2
        nn.Sequential(
            unet_convs(256, 128),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
        ),
        # two conv3x3, conv1x1
        nn.Sequential(
            unet_convs(128, 64),
            nn.Conv2d(64, out_channels, kernel_size=1)
        )
    ]
    return UNetFactory(encoder_blocks, decoder_blocks, bridge)

def unet_resnet(resnet_type, in_channels, out_channels, pretrained=False):
    """
    利用resnet作为encoder,相应地,decoder也做一些改动,使得输出的尺寸跟原始的一致
    """
    if resnet_type == 'resnet18':
        resnet = torchvision.models.resnet.resnet18(pretrained)
        encoder_out_channels = [in_channels, 64, 64, 128, 256, 512]  # encoder各个block的输出channel
    elif resnet_type == 'resnet34':
        resnet = torchvision.models.resnet.resnet34(pretrained)
        encoder_out_channels = [in_channels, 64, 64, 128, 256, 512]
    elif resnet_type == 'resnet50':
        resnet = torchvision.models.resnet.resnet50(pretrained)
        encoder_out_channels = [in_channels, 64, 256, 512, 1024, 2048]
    elif resnet_type == 'resnet101':
        resnet = torchvision.models.resnet.resnet101(pretrained)
        encoder_out_channels = [in_channels, 64, 256, 512, 1024, 2048]
    elif resnet_type == 'resnet152':
        resnet = torchvision.models.resnet.resnet152(pretrained)
        encoder_out_channels = [in_channels, 64, 256, 512, 1024, 2048]
    elif resnet_type == 'resnext50_32x4d':
        resnet = torchvision.models.resnet.resnext50_32x4d(pretrained)
        encoder_out_channels = [in_channels, 64, 256, 512, 1024, 2048]
    else:
        raise ValueError("unexpected resnet_type")

    # encoder
    encoder_blocks = [
        # org input
        nn.Sequential(),
        # conv1
        nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu
        ),
        # conv2_x
        nn.Sequential(
            resnet.maxpool,
            resnet.layer1
        ),
        # conv3_x
        resnet.layer2,
        # conv4_x
        resnet.layer3,
        # conv5_x
        resnet.layer4
    ]
    # bridge
    bridge = None  # 感觉并无必要
    # decoder
    decoder_blocks = []
    in_ch = encoder_out_channels[-1]
    out_ch = in_ch // 2
    decoder_blocks.append(nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)) # up-conv2x2
    for i in range(1, len(encoder_blocks)-1):
        in_ch = encoder_out_channels[-i-1] + out_ch  # cat
        decoder_blocks.append(nn.Sequential(  # two conv3x3, up-conv2x2
            unet_convs(in_ch, out_ch, padding=1),
            nn.ConvTranspose2d(out_ch, out_ch//2, kernel_size=2, stride=2),
        ))
        out_ch = out_ch // 2
    in_ch = encoder_out_channels[0] + out_ch  # cat
    decoder_blocks.append(nn.Sequential(  # two conv3x3, conv1x1
        unet_convs(in_ch, out_ch, padding=1),
        nn.Conv2d(out_ch, out_channels, kernel_size=1)
    ))

    return UNetFactory(encoder_blocks, decoder_blocks, bridge)


'''if __name__ == "__main__":
    from torchsummary import summary
    from PIL import Image
    import numpy as np
    net = unet(3, 3).cuda()
    # net = unet_resnet('resnet101', 3, 3, False).cuda()
    # net = unet_resnet('resnet50', 3, 3, True).cuda()
    # net = unet_resnet('resnext50_32x4d', 3, 3, True).cuda()
    # summary(net, (3, 224, 224), device='cuda')

    img = Image.open('/home/zss/lane-detection-2019-howard/1.jpg')
    img = np.array(img)
    x = torch.from_numpy(img.astype(np.float32)).permute(2, 0, 1)
    x = x.unsqueeze(0).cuda()
    net.eval()
    out = net(x)[0].cpu()
    out = out.permute(1, 2, 0).detach().numpy()
    out = out * 255 / np.max(out)
    out = np.maximum(out.astype(np.uint8), 0)
    print(out.shape)
    img1 = Image.fromarray(out)
    img1.save('robot2.jpg')
    # x = torch.randn((1, 3, 256,256)).cuda()
    # print(y)
    '''
if __name__ == "__main__":

    net = unet(3, 8)

    print(net)

你可能感兴趣的:(深度学习,自然语言处理,pytorch)