import torch
import torch.nn as nn
import torchvision
class UNetFactory(nn.Module):
"""
本质上就是一个U型的网络,先encode,后decode,中间可能有架bridge。
其中encoder需要输出skip到decode那边做concatenate,使得decode阶段能补充信息。
bridge不能存在下采样和上采样的操作。
"""
def __init__(self, encoder_blocks, decoder_blocks, bridge=None):
super(UNetFactory, self).__init__()
self.encoder = UNetEncoder(encoder_blocks)
self.bridge = bridge
self.decoder = UNetDecoder(decoder_blocks)
def forward(self, x):
res = self.encoder(x)
out, skips = res[0], res[1:]
if self.bridge is not None:
out = self.bridge(out)
out = self.decoder(out, skips)
return out
class UNetEncoder(nn.Module):
"""
encoder会有多次下采样,下采样前的feature map要作为skip缓存起来将来送到decoder用。
这里约定,以下采样为界线,将encoder分成多个block,其中第一个block无下采样操作,后面的每个block内都
含有一次下采样操作。
"""
def __init__(self, blocks):
super(UNetEncoder, self).__init__()
assert len(blocks) > 0
self.blocks = nn.ModuleList(blocks)
'''
断言函数是对表达式布尔值的判断,要求表达式计算值必须为真。可用于自动调试。
如果表达式为假,触发异常;如果表达式为真,不执行任何操作。
深度学习框架中的应用:断言深度学习模型的模块数不为0。
assert len(blocks) > 0
详解PyTorch中的ModuleList和Sequential: https://zhuanlan.zhihu.com/p/75206669
'''
def forward(self, x):
skips = []
for i in range(len(self.blocks) - 1):
x = self.blocks[i](x)
skips.append(x)
res = [self.blocks[i+1](x)]
res += skips
return res
class UNetDecoder(nn.Module):
"""
decoder会有多次上采样,每次上采样后,要跟相应的skip做concatenate。
这里约定,以上采样为界线,将decoder分成多个block,其中最后一个block无上采样操作,其他block内
都含有一次上采样。如此一来,除第一个block以外,其他block都先做concatenate。
"""
def __init__(self, blocks):
super(UNetDecoder, self).__init__()
assert len(blocks) > 1
self.blocks = nn.ModuleList(blocks)
def _center_crop(self, skip, x):
"""
skip和x,谁比较大,就裁剪谁
"""
_, _, h1, w1 = skip.shape
_, _, h2, w2 = x.shape
ht, wt = min(h1, h2), min(w1, w2)
dh1 = (h1 - ht) // 2 if h1 > ht else 0
dw1 = (w1 - wt) // 2 if w1 > wt else 0
dh2 = (h2 - ht) // 2 if h2 > ht else 0
dw2 = (w2 - wt) // 2 if w2 > wt else 0
return skip[:, :, dh1: (dh1 + ht), dw1: (dw1 + wt)],
x[:, :, dh2: (dh2 + ht), dw2: (dw2 + wt)]
'''
此处skip高宽计算:H:(dh1+ht)-dh1 = ht, W:(dw1 + wt)-dw1= wt 此时h,w均取得最小值
'''
def forward(self, x, skips, reverse_skips=True):
assert len(skips) == len(self.blocks) - 1
if reverse_skips:
skips = skips[::-1]
x = self.blocks[0](x)
for i in range(1, len(self.blocks)):
skip, x = self._center_crop(skips[i-1], x)
x = torch.cat([skip, x], dim=1)
x = self.blocks[i](x)
return x
def unet_convs(in_channels, out_channels, padding=1):
"""
unet论文里出现次数最多的2个conv3x3(non-padding)的结构
"""
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def unet(in_channels, out_channels):
"""
构造跟论文一致的unet网络
https://arxiv.org/abs/1505.04597
"""
encoder_blocks = [
unet_convs(in_channels, 64),
nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
unet_convs(64, 128)
),
nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
unet_convs(128, 256)
),
nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
unet_convs(256, 512)
),
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
]
bridge = nn.Sequential(
unet_convs(512, 1024)
)
decoder_blocks = [
nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
nn.Sequential(
unet_convs(1024, 512),
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
),
nn.Sequential(
unet_convs(512, 256),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
),
nn.Sequential(
unet_convs(256, 128),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
),
nn.Sequential(
unet_convs(128, 64),
nn.Conv2d(64, out_channels, kernel_size=1)
)
]
return UNetFactory(encoder_blocks, decoder_blocks, bridge)
def unet_resnet(resnet_type, in_channels, out_channels, pretrained=False):
"""
利用resnet作为encoder,相应地,decoder也做一些改动,使得输出的尺寸跟原始的一致
"""
if resnet_type == 'resnet18':
resnet = torchvision.models.resnet.resnet18(pretrained)
encoder_out_channels = [in_channels, 64, 64, 128, 256, 512]
elif resnet_type == 'resnet34':
resnet = torchvision.models.resnet.resnet34(pretrained)
encoder_out_channels = [in_channels, 64, 64, 128, 256, 512]
elif resnet_type == 'resnet50':
resnet = torchvision.models.resnet.resnet50(pretrained)
encoder_out_channels = [in_channels, 64, 256, 512, 1024, 2048]
elif resnet_type == 'resnet101':
resnet = torchvision.models.resnet.resnet101(pretrained)
encoder_out_channels = [in_channels, 64, 256, 512, 1024, 2048]
elif resnet_type == 'resnet152':
resnet = torchvision.models.resnet.resnet152(pretrained)
encoder_out_channels = [in_channels, 64, 256, 512, 1024, 2048]
elif resnet_type == 'resnext50_32x4d':
resnet = torchvision.models.resnet.resnext50_32x4d(pretrained)
encoder_out_channels = [in_channels, 64, 256, 512, 1024, 2048]
else:
raise ValueError("unexpected resnet_type")
encoder_blocks = [
nn.Sequential(),
nn.Sequential(
resnet.conv1,
resnet.bn1,
resnet.relu
),
nn.Sequential(
resnet.maxpool,
resnet.layer1
),
resnet.layer2,
resnet.layer3,
resnet.layer4
]
bridge = None
decoder_blocks = []
in_ch = encoder_out_channels[-1]
out_ch = in_ch // 2
decoder_blocks.append(nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2))
for i in range(1, len(encoder_blocks)-1):
in_ch = encoder_out_channels[-i-1] + out_ch
decoder_blocks.append(nn.Sequential(
unet_convs(in_ch, out_ch, padding=1),
nn.ConvTranspose2d(out_ch, out_ch//2, kernel_size=2, stride=2),
))
out_ch = out_ch // 2
in_ch = encoder_out_channels[0] + out_ch
decoder_blocks.append(nn.Sequential(
unet_convs(in_ch, out_ch, padding=1),
nn.Conv2d(out_ch, out_channels, kernel_size=1)
))
return UNetFactory(encoder_blocks, decoder_blocks, bridge)
'''if __name__ == "__main__":
from torchsummary import summary
from PIL import Image
import numpy as np
net = unet(3, 3).cuda()
# net = unet_resnet('resnet101', 3, 3, False).cuda()
# net = unet_resnet('resnet50', 3, 3, True).cuda()
# net = unet_resnet('resnext50_32x4d', 3, 3, True).cuda()
# summary(net, (3, 224, 224), device='cuda')
img = Image.open('/home/zss/lane-detection-2019-howard/1.jpg')
img = np.array(img)
x = torch.from_numpy(img.astype(np.float32)).permute(2, 0, 1)
x = x.unsqueeze(0).cuda()
net.eval()
out = net(x)[0].cpu()
out = out.permute(1, 2, 0).detach().numpy()
out = out * 255 / np.max(out)
out = np.maximum(out.astype(np.uint8), 0)
print(out.shape)
img1 = Image.fromarray(out)
img1.save('robot2.jpg')
# x = torch.randn((1, 3, 256,256)).cuda()
# print(y)
'''
if __name__ == "__main__":
net = unet(3, 8)
print(net)