U-Net是医学图像处理领域最常用的一种网络结构,很多医学图像处理的网络结构都由U-Net改进而来。U-Net可以被看作是基于FCN和SegNet的一种改进方法,采用了FCN的全卷积、反卷积上采样、越级连接的方法,采用了SegNet的Encoder-Decoder结构。原文链接:U-Net: Convolutional Networks for Biomedical Image Segmentation
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import models
class unetDown(nn.Module):
def __init__(self, in_channels, out_channels):
super(unetDown, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0),
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0),
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class unetUp(nn.Module):
def __init__(self, in_channels, out_channels):
super(unetUp, self).__init__()
self.upConv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0),
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0),
def forward(self, x_cur, x_prev):
x = self.upConv(x_cur)
x = torch.cat([F.upsample_bilinear(x_prev, size=x.size()[2:]), x], 1)
x = self.conv1(x)
x = self.conv2(x)
return x
def cross_entropy2d(input, target, weight=None, size_average=True):
n, c, h, w = input.size()
nt, ht, wt = target.size()
# Handle inconsistent size between input and target
if h > ht and w > wt: # upsample labels
target = target.unsequeeze(1)
target = F.upsample(target, size=(h, w), mode="nearest")
target = target.sequeeze(1)
elif h < ht and w < wt: # upsample images
input = F.upsample(input, size=(ht, wt), mode="bilinear")
elif h != ht and w != wt:
raise Exception("Only support upsampling")
input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
target = target.view(-1)
loss = F.cross_entropy(input, target, weight=weight, size_average=size_average, ignore_index=250)
return loss
class unet(nn.Module):
def __init__(self, n_classes=21, pretrained=False):
super(unet, self).__init__()
self.down1 = unetDown(in_channels=3, out_channels=64)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.down2 = unetDown(in_channels=64, out_channels=128)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.down3 = unetDown(in_channels=128, out_channels=256)
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.down4 = unetDown(in_channels=256, out_channels=512)
self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.center = unetDown(in_channels=512, out_channels=1024)
self.up4 = unetUp(in_channels=1024, out_channels=512)
self.up3 = unetUp(in_channels=512, out_channels=256)
self.up2 = unetUp(in_channels=256, out_channels=128)
self.up1 = unetUp(in_channels=128, out_channels=64)
self.classifier = nn.Conv2d(in_channels=64, out_channels=n_classes, kernel_size=1)
def forward(self, x):
out_size = x.size()[2:]
down1_x = self.down1(x)
maxpool1_x = self.maxpool1(down1_x)
# print('maxpool1_x.data.size():', maxpool1_x.data.size())
down2_x = self.down2(maxpool1_x)
maxpool2_x = self.maxpool2(down2_x)
# print('maxpool2_x.data.size():', maxpool2_x.data.size())
down3_x = self.down3(maxpool2_x)
maxpool3_x = self.maxpool3(down3_x)
# print('maxpool3_x.data.size():', maxpool3_x.data.size())
down4_x = self.down4(maxpool3_x)
maxpool4_x = self.maxpool1(down4_x)
# print('maxpool4_x.data.size():', maxpool4_x.data.size())
center_x = self.center(maxpool4_x)
# print('center_x.data.size():', center_x.data.size())
up4_x = self.up4(center_x, down4_x)
# print('up4_x.data.size():', up4_x.data.size())
up3_x = self.up3(up4_x, down3_x)
# print('up3_x.data.size():', up3_x.data.size())
up2_x = self.up2(up3_x, down2_x)
# print('up2_x.data.size():', up2_x.data.size())
up1_x = self.up1(up2_x, down1_x)
# print('up1_x.data.size():', up1_x.data.size())
x = self.classifier(up1_x)
# 最后将模型上采样到原始分辨率
x = F.upsample_bilinear(x, out_size)
return x
if __name__ == '__main__':
n_classes = 21
image_width = 480
image_height = 360
model = unet(n_classes=n_classes, pretrained=False)
# model.init_vgg16()
x = Variable(torch.randn(1, 3, image_height, image_width))
y = Variable(torch.LongTensor(np.ones((1, image_height, image_width), dtype=np.int)))
# print(x.shape)
# ---------------------------unet模型运行时间-----------------------
start = time.time()
pred = model(x)
end = time.time()
loss = cross_entropy2d(pred, y)
V-Net是基于U-Net从2D到3D的改进,提出的出发点是因为很多临床的医学影像数据为3D数据(3D volumes)。原文链接:V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image
import torch
import torch.nn as nn
import torch.nn.functional as F
def passthrough(x, **kwargs):
return x
def ELUCons(elu, nchan): #elu为标记符,使用何种激活函数
if elu:
return nn.ELU(inplace=True)
return nn.PReLU(nchan)
class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
def _check_input_dim(self, input): #确认张量维度
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
super(ContBatchNorm3d, self)._check_input_dim(input)
def forward(self, input):
self._check_input_dim(input) #维度无误后进行正则化
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
True, self.momentum, self.eps)
class LUConv(nn.Module): #进行3D卷积,3D卷积后为正则化,最后进行激活 nchan为通道数
def __init__(self, nchan, elu):
super(LUConv, self).__init__()
self.relu1 = ELUCons(elu, nchan)
self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
self.bn1 = ContBatchNorm3d(nchan)
def forward(self, x):
out = self.relu1(self.bn1(self.conv1(x)))
return out
def _make_nConv(nchan, depth, elu): #根据depth进行多次卷积
layers = []
for _ in range(depth):
layers.append(LUConv(nchan, elu))
return nn.Sequential(*layers) #列表前面加星号作用是将列表解开成多个独立的参数,传入函数
class InputTransition(nn.Module):
def __init__(self, outChans, elu):
super(InputTransition, self).__init__()
self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2)
self.bn1 = ContBatchNorm3d(16)
self.relu1 = ELUCons(elu, 16)
def forward(self, x):
# do we want a PRELU here as well?
out = self.bn1(self.conv1(x))
# split input in to 16 channels
x16 = torch.cat((x, x, x, x, x, x, x, x, #把输入重复16次,变换为16通道
x, x, x, x, x, x, x, x), 0)
out = self.relu1(torch.add(out, x16))
return out
class DownTransition(nn.Module):
def __init__(self, inChans, nConvs, elu, dropout=False): #dropout为标记符 是否进行Dropout操作 默认为不进行dropout nConvs为卷积重复次数
super(DownTransition, self).__init__()
outChans = 2*inChans
self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
self.bn1 = ContBatchNorm3d(outChans)
self.do1 = passthrough
self.relu1 = ELUCons(elu, outChans)
self.relu2 = ELUCons(elu, outChans)
if dropout:
self.do1 = nn.Dropout3d()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x):
down = self.relu1(self.bn1(self.down_conv(x)))
out = self.do1(down)
out = self.ops(out)
out = self.relu2(torch.add(out, down))
return out
class UpTransition(nn.Module):
def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
super(UpTransition, self).__init__()
self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2) #//取整除,所得商向下取整
self.bn1 = ContBatchNorm3d(outChans // 2)
self.do1 = passthrough
self.do2 = nn.Dropout3d()
self.relu1 = ELUCons(elu, outChans // 2)
self.relu2 = ELUCons(elu, outChans)
if dropout:
self.do1 = nn.Dropout3d()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x, skipx):
out = self.do1(x)
skipxdo = self.do2(skipx)
out = self.relu1(self.bn1(self.up_conv(out)))
xcat = torch.cat((out, skipxdo), 1)
out = self.ops(xcat)
out = self.relu2(torch.add(out, xcat))
return out
class OutputTransition(nn.Module):
def __init__(self, inChans, elu, nll):
super(OutputTransition, self).__init__()
self.conv1 = nn.Conv3d(inChans, 2, kernel_size=5, padding=2)
self.bn1 = ContBatchNorm3d(2)
self.conv2 = nn.Conv3d(2, 2, kernel_size=1)
self.relu1 = ELUCons(elu, 2)
if nll:
self.softmax = F.log_softmax
self.softmax = F.softmax
def forward(self, x):
# convolve 32 down to 2 channels
out = self.relu1(self.bn1(self.conv1(x)))
out = self.conv2(out)
# make channels the last axis
out = out.permute(0, 2, 3, 4, 1).contiguous()
# flatten
out = out.view(out.numel() // 2, 2)
out = self.softmax(out)
# treat channel 0 as the predicted output
return out
class VNet(nn.Module):
# the number of convolutions in each layer corresponds
# to what is in the actual prototxt, not the intent
def __init__(self, elu=True, nll=False):
super(VNet, self).__init__()
self.in_tr = InputTransition(16, elu)
self.down_tr32 = DownTransition(16, 1, elu)
self.down_tr64 = DownTransition(32, 2, elu)
self.down_tr128 = DownTransition(64, 3, elu, dropout=True)
self.down_tr256 = DownTransition(128, 2, elu, dropout=True)
self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=True)
self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=True)
self.up_tr64 = UpTransition(128, 64, 1, elu)
self.up_tr32 = UpTransition(64, 32, 1, elu)
self.out_tr = OutputTransition(32, elu, nll)
def forward(self, x):
out16 = self.in_tr(x)
out32 = self.down_tr32(out16)
out64 = self.down_tr64(out32)
out128 = self.down_tr128(out64)
out256 = self.down_tr256(out128)
out = self.up_tr256(out256, out128)
out = self.up_tr128(out, out64)
out = self.up_tr64(out, out32)
out = self.up_tr32(out, out16)
out = self.out_tr(out)
return out