论文名称:基于改进 UNet 孪生网络的遥感影像矿区变化检测
论文地址:http://www.chinacaj.net/i,2,425089,0.html
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from torch.nn import functional as F
import torch
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch), #添加了BN层
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class Unet(nn.Module):
def __init__(self, in_ch, out_ch):
super(Unet, self).__init__()
self.pool = nn.MaxPool2d(2)
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
# 逆卷积,也可以使用上采样(保证k=stride,stride即上采样倍数)
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1536, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(768, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(384, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(192, 64)
self.conv10 = nn.Conv2d(64, out_ch, 1)
self.conv1_dilation = nn.Conv2d(2048, 256, 1, stride=1, padding=0, bias=False, dilation=1) # dilation就是空洞率,即间隔
self.conv2_dilation = nn.Conv2d(2048, 256, 2, stride=1, padding=2, bias=False, dilation=2) # dilation就是空洞率,即间隔
self.conv4_dilation = nn.Conv2d(2048, 256, 4, stride=1, padding=4, bias=False, dilation=4) # dilation就是空洞率,即间隔
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.upsample = nn.Upsample(scale_factor=7, mode='bicubic', align_corners=True)
self.conv_c = nn.Conv2d(2816, 1024, 1, stride=1, padding=0, bias=False, dilation=1) # dilation就是空洞率,即间隔
self.upsample1 = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True)
def forward(self, A2016, A2019):
A2016_resize = self.pool(A2016)
A2019_resize = self.pool(A2019)
A2016_crop = A2016[:,:,56:168,56:168]
A2019_crop = A2019[:,:,56:168,56:168]
c1_A2016_resize = self.conv1(A2016_resize) # [2, 64, 112, 112]
c1_A2019_resize = self.conv1(A2019_resize) # [2, 64, 112, 112]
c1_A2016_crop = self.conv1(A2016_crop) # [2, 64, 112, 112]
c1_A2019_crop = self.conv1(A2019_crop) # [2, 64, 112, 112]
c1_resize = torch.abs(torch.sub(c1_A2016_resize,c1_A2019_resize))
c1_crop = torch.abs(torch.sub(c1_A2016_crop,c1_A2019_crop))
c1 = torch.cat([c1_resize,c1_crop], dim = 1)
p1_A2016_resize = self.pool1(c1_A2016_resize) # [2, 64, 56, 56]
p1_A2019_resize = self.pool1(c1_A2019_resize) # [2, 64, 56, 56]
p1_A2016_crop = self.pool1(c1_A2016_crop) # [2, 64, 56, 56]
p1_A2019_crop = self.pool1(c1_A2019_crop) # [2, 64, 56, 56]
c2_A2016_resize = self.conv2(p1_A2016_resize) # [2, 128, 56, 56]
c2_A2019_resize = self.conv2(p1_A2019_resize) # [2, 128, 56, 56]
c2_A2016_crop = self.conv2(p1_A2016_crop) # [2, 128, 56, 56]
c2_A2019_crop = self.conv2(p1_A2019_crop) # [2, 128, 56, 56]
c2_resize = torch.abs(torch.sub(c2_A2016_resize,c2_A2019_resize))
c2_crop = torch.abs(torch.sub(c2_A2016_crop,c2_A2019_crop))
c2 = torch.cat([c2_resize,c2_crop], dim = 1)
p2_A2016_resize = self.pool2(c2_A2016_resize) # [2, 128, 28, 28]
p2_A2019_resize = self.pool2(c2_A2019_resize) # [2, 128, 28, 28]
p2_A2016_crop = self.pool2(c2_A2016_crop) # [2, 128, 28, 28]
p2_A2019_crop = self.pool2(c2_A2019_crop) # [2, 128, 28, 28]
c3_A2016_resize = self.conv3(p2_A2016_resize) # [2, 256, 28, 28]
c3_A2019_resize = self.conv3(p2_A2019_resize) # [2, 256, 28, 28]
c3_A2016_crop = self.conv3(p2_A2016_crop) # [2, 256, 28, 28]
c3_A2019_crop = self.conv3(p2_A2019_crop) # [2, 256, 28, 28]
c3_resize = torch.abs(torch.sub(c3_A2016_resize,c3_A2019_resize))
c3_crop = torch.abs(torch.sub(c3_A2016_crop,c3_A2019_crop))
c3 = torch.cat([c3_resize,c3_crop], dim = 1)
p3_A2016_resize = self.pool3(c3_A2016_resize) # [2, 256, 14, 14]
p3_A2019_resize = self.pool3(c3_A2019_resize) # [2, 256, 14, 14]
p3_A2016_crop = self.pool3(c3_A2016_crop) # [2, 256, 14, 14]
p3_A2019_crop = self.pool3(c3_A2019_crop) # [2, 256, 14, 14]
c4_A2016_resize = self.conv4(p3_A2016_resize) # [2, 512, 14, 14]
c4_A2019_resize = self.conv4(p3_A2019_resize) # [2, 512, 14, 14]
c4_A2016_crop = self.conv4(p3_A2016_crop) # [2, 512, 14, 14]
c4_A2019_crop = self.conv4(p3_A2019_crop) # [2, 512, 14, 14]
c4_resize = torch.abs(torch.sub(c4_A2016_resize,c4_A2019_resize))
c4_crop = torch.abs(torch.sub(c4_A2016_crop,c4_A2019_crop))
c4 = torch.cat([c4_resize,c4_crop], dim = 1)
p4_A2016_resize = self.pool4(c4_A2016_resize) # [2, 512, 7, 7]
p4_A2019_resize = self.pool4(c4_A2019_resize) # [2, 512, 7, 7]
p4_A2016_crop = self.pool4(c4_A2016_crop) # [2, 512, 7, 7]
p4_A2019_crop = self.pool4(c4_A2019_crop) # [2, 512, 7, 7]
c5_A2016_resize = self.conv5(p4_A2016_resize) # [2, 1024, 7, 7]
c5_A2019_resize = self.conv5(p4_A2019_resize) # [2, 1024, 7, 7]
c5_A2016_crop = self.conv5(p4_A2016_crop) # [2, 1024, 7, 7]
c5_A2019_crop = self.conv5(p4_A2019_crop) # [2, 1024, 7, 7]
c5_resize = torch.abs(torch.sub(c5_A2016_resize,c5_A2019_resize))
c5_crop = torch.abs(torch.sub(c5_A2016_crop,c5_A2019_crop))
c5 = torch.cat([c5_resize,c5_crop], dim = 1)
c5_1_dilation = self.conv1_dilation(c5)
c5_2_dilation = self.conv1_dilation(c5)
c5_4_dilation = self.conv1_dilation(c5)
c5_AVG = self.global_pool(c5)
c5 = self.upsample(c5_AVG)
c5 = torch.cat([c5_1_dilation,c5_2_dilation,c5_4_dilation,c5], dim = 1)
c5 = self.conv_c(c5)
up_6 = self.up6(c5) # [2, 512, 14, 14]
merge6 = torch.cat([up_6, c4], dim=1) # [2, 1024, 14, 14]
c6 = self.conv6(merge6) # [2, 512, 14, 14]
up_7 = self.up7(c6) # [2, 256, 28, 28]
merge7 = torch.cat([up_7, c3], dim=1) # [2, 512, 28, 28]
c7 = self.conv7(merge7) # [2, 256, 28, 28]
up_8 = self.up8(c7) # [2, 128, 56, 56]
merge8 = torch.cat([up_8, c2], dim=1) # [2, 256, 56, 56]
c8 = self.conv8(merge8) # [2, 128, 56, 56]
up_9 = self.up9(c8) # [2, 64, 112, 112]
merge9 = torch.cat([up_9, c1], dim=1) # [2, 128, 112, 112]
c9 = self.conv9(merge9) # [2, 64, 112, 112]
c10 = self.conv10(c9) # [2, 64, 112, 112]
c10 = self.upsample1(c10)
out = nn.Sigmoid()(c10) # [2, 64, 112, 112]
return out
if __name__ == "__main__":
A2016 = torch.randn(2, 3, 224, 224)
A2019 = torch.randn(2, 3, 224, 224)
UNet = Unet(3,3)
out_result = UNet(A2016,A2019)
print(out_result)
print(out_result.shape)