import torch
import torch.nn as nn
class pub(nn.Module):
def __init__(self, in_channel, out_channel, batch_norm=True, keep_size=False):
super(pub, self).__init__()
pad = 1 if keep_size else 0
Layer = [
nn.Conv2d(in_channel, out_channel, 3, padding=pad),
nn.ReLU(True),
nn.Conv2d(out_channel, out_channel, 3, padding=pad),
nn.ReLU(True)
]
if batch_norm:
Layer.insert(1, nn.BatchNorm2d(out_channel))
Layer.insert(len(Layer) - 1, nn.BatchNorm2d(out_channel))
self.pub_con = nn.Sequential(*Layer)
def forward(self, x):
return self.pub_con(x)
class unet_down(nn.Module):
def __init__(self, in_channel, out_channel, batch_norm=True, keep_size=False):
super(unet_down, self).__init__()
self.pub = pub(in_channel, out_channel, batch_norm, keep_size)
self.pool = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.pool(x)
x = self.pub(x)
return x
class unet_up(nn.Module):
def __init__(self, in_channel, out_channel, batch_norm=True, upsample=True, keep_size=False):
super(unet_up, self).__init__()
layers = []
if upsample:
layers += [nn.Conv2d(out_channel*2, out_channel, 1)]
layers += [nn.Upsample(scale_factor=2, mode='nearest')]
else:
layers += [nn.ConvTranspose2d(out_channel*2, out_channel, 2, stride=2)]
self.upsample = nn.Sequential(*layers)
self.pub = pub(in_channel, out_channel, batch_norm, keep_size)
self.orignal_size = keep_size
def forward(self, x1, x2):
x2 = self.upsample(x2)
c = (x1.size(2) - x2.size(2)) // 2
x1 = x1[:, :, c:-c, c:-c]
x = torch.cat((x1, x2), 1)
x = self.pub(x)
return x
class Unet(nn.Module):
def __init__(self, channels, class_nums, layers=5, upsample=True, batch_norm=True, keep_size=False):
super(Unet, self).__init__()
self.layers = layers
down = []
up = []
down.append(pub(channels, 64, batch_norm, keep_size))
for layer in range(layers-1):
down.append(unet_down(64*(2**layer), 128*(2**layer), batch_norm, keep_size))
up.append(unet_up(128*(2**(3-layer)), 64*(2**(3-layer)), upsample, batch_norm, keep_size))
up.append(nn.Conv2d(64, class_nums, 1))
self.down = nn.ModuleList(down)
self.up = nn.ModuleList(up)
self._initialize_weights()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
down = []
for i in range(self.layers):
x = self.down[i](x)
down.append(x)
x = down[self.layers-1]
for j in range(self.layers-1):
x = self.up[j](down[self.layers-j-2], x)
x = self.up[4](x)
return self.sigmoid(x)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()