论文题目:FULLY CONVOLUTIONAL SIAMESE NETWORKS FOR CHANGE DETECTION
发表于:CVPR2018
论文地址:https://arxiv.org/abs/1810.08462v1
论文代码:https://github.com/rcdaudt/fully_convolutional_change_detection
change detection(CD)是对地观测图像分析领域的主要问题之一。
变化检测系统的目标是根据在不同时间拍摄的给定区域的一对或序列的已配准图像,为每个像素分配一个二进制标签(binary label)。positive label表示与该像素对应的区域在两次采集间发生了变化。变化检测是一个定义良好的分类问题。
变化检测缺乏大量带注释的数据集,限制了模型的复杂度。但是,仍有一些可用的像素级标注变化检测数据集,可用来训练图像对变化的有监督机器学习系统。
变化检测,首先使用手工设计技术直接分析像素;后来结合使用描述符与机器学习技术;近年来深度学习主导了图像分析领域的大多问题。
由于可用数据量有限,这些方法大多使用基于迁移学习的各种技术,以在更大的数据集上训练过的不同问题的网络作为起点。这在多方面是有限制的,因为它假设这些数据集和相关的变化检测数据之间存在相似性。
这些迁移学习方法避免了端到端训练,因此也限制了系统的性能。
CNNs是一类特别适合图像分析的算法,已应用于不同环境中比较图像对。
FCNNs也被提出适用于密集预测问题,即像素级预测。
灵感来源:Bertinetto等人提出过全卷积孪生网络用于解决视频中的目标跟踪问题。
论文提出的FCNN体系结构能够仅从变化检测数据集学习执行变化检测,无需从其他数据集进行任何类型的预训练或迁移学习。能实现端到端的训练。
全卷积体系结构演变于《Urban change detection for multispectral earth observation using convolutional neural networks》(论文 | 笔记)。没有使用patch-based方法,以提高预测阶段的速度和精度。可以处理任何大小的输入,只要足够的内存可用。
上述论文比较了两种CNN架构:Early Fusion(EF)和Siamese(Siam)。EF结构将两个patch在输入网络之前拼接起来,视为不同的颜色通道。Siamese架构使用两个结构和权值共享的网络分支分别处理两个patch,然后在网络的卷积层输出后合并这两个分支。
本论文扩展了这些思想,并使用了构建U-Net的跳接概念,以进行语义分割。跳接的意义在于:使用网络早期出现的空间细节补充编码信息中方更抽象和更少的局部信息,从而在输出图像中产生具有更精确边界的精确类别预测。
第一种结构是 Fully Convolutional Early Fusion (FC-EF), 全卷积早期融合。基于U-Net模型。
FC-EF只包括4个最大池化层和4个上采样层,不同于U-Net中的5个。
和基于patch的EF模型一样,网络的输入是将一对图像中的两张图像在通道维度上进行拼接。
另外两种都是孪生结构,只是skip connections 的方式不同。网络的编码层有2条结构完全相同的分支,且权值共享。
第二种更直观的方法,Siamese-Concatenation(FC-Siam-conc,图1(b))。全卷积孪生拼接。在解码步骤中,利用两个skip connections将两个输入的同尺寸的图像特征分别拼接在生成特征两边 ,两个skip connections 来自左右两个不同的编码分支。有图可以看出在decoder的过程中,每次是三个feature map 拼接到一起,注意其拼接的位置。因为decoder有4层,所以利用skip connections进行了4个尺寸特征的拼接。先拼接,再卷积减少通道数,再上采样进行更大尺寸的拼接。最后达到原图尺寸后,使用全连接生成两种类别输出,生成output。
第三种方法,Fully Convolutional Siamese-Difference (FC-Siam-diff,图1(c))。全卷积孪生差异。旨在试图检测两张图像间的差异。只concat两张输入图像在每个尺寸绝对差值。
两个公开的变化检测数据集:OSCD和AC。
OSCD:包含多光谱卫星图像,网络仅使用RGB层进行测试。
AC:包含RGB航空图像。
数据扩充:翻转、旋转
使用Dropout防止训练时过拟合。
性能指标:精度、召回率、F1值、可用全局精度。
结果表明,
实验表明,FC-Siam-diff结构最适合进行变化检测,然后是FC-EF。有3个主要因素:
- 开发全卷积网络目的在于处理密集预测问题;
- 孪生结构为系统增添了两幅图像的明确比较;
- 跳接的使用也明确指导网络比较两幅图像间的差异。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1)
class BasicBlock_ss(nn.Module):
def __init__(self, inplanes, planes = None, subsamp=1):
super(BasicBlock_ss, self).__init__()
if planes == None:
planes = inplanes * subsamp
self.conv1 = conv3x3(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.subsamp = subsamp
self.doit = planes != inplanes
if self.doit:
self.couple = nn.Conv2d(inplanes, planes, kernel_size=1)
self.bnc = nn.BatchNorm2d(planes)
def forward(self, x):
if self.doit:
residual = self.couple(x)
residual = self.bnc(residual)
else:
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
if self.subsamp > 1:
out = F.max_pool2d(out, kernel_size=self.subsamp, stride=self.subsamp)
residual = F.max_pool2d(residual, kernel_size=self.subsamp, stride=self.subsamp)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu(out)
return out
class BasicBlock_us(nn.Module):
def __init__(self, inplanes, upsamp=1):
super(BasicBlock_us, self).__init__()
planes = int(inplanes / upsamp) # assumes integer result, fix later
self.conv1 = nn.ConvTranspose2d(inplanes, planes, kernel_size=3, padding=1, stride=upsamp, output_padding=1)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.upsamp = upsamp
self.couple = nn.ConvTranspose2d(inplanes, planes, kernel_size=3, padding=1, stride=upsamp, output_padding=1)
self.bnc = nn.BatchNorm2d(planes)
def forward(self, x):
residual = self.couple(x)
residual = self.bnc(residual)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu(out)
return out
class FresUNet(nn.Module):
"""FresUNet segmentation network."""
def __init__(self, input_nbr, label_nbr):
"""Init FresUNet fields."""
super(FresUNet, self).__init__()
self.input_nbr = input_nbr
cur_depth = input_nbr
base_depth = 8
# Encoding stage 1
self.encres1_1 = BasicBlock_ss(cur_depth, planes = base_depth)
cur_depth = base_depth
d1 = base_depth
self.encres1_2 = BasicBlock_ss(cur_depth, subsamp=2)
cur_depth *= 2
# Encoding stage 2
self.encres2_1 = BasicBlock_ss(cur_depth)
d2 = cur_depth
self.encres2_2 = BasicBlock_ss(cur_depth, subsamp=2)
cur_depth *= 2
# Encoding stage 3
self.encres3_1 = BasicBlock_ss(cur_depth)
d3 = cur_depth
self.encres3_2 = BasicBlock_ss(cur_depth, subsamp=2)
cur_depth *= 2
# Encoding stage 4
self.encres4_1 = BasicBlock_ss(cur_depth)
d4 = cur_depth
self.encres4_2 = BasicBlock_ss(cur_depth, subsamp=2)
cur_depth *= 2
# Decoding stage 4
self.decres4_1 = BasicBlock_ss(cur_depth)
self.decres4_2 = BasicBlock_us(cur_depth, upsamp=2)
cur_depth = int(cur_depth/2)
# Decoding stage 3
self.decres3_1 = BasicBlock_ss(cur_depth + d4, planes = cur_depth)
self.decres3_2 = BasicBlock_us(cur_depth, upsamp=2)
cur_depth = int(cur_depth/2)
# Decoding stage 2
self.decres2_1 = BasicBlock_ss(cur_depth + d3, planes = cur_depth)
self.decres2_2 = BasicBlock_us(cur_depth, upsamp=2)
cur_depth = int(cur_depth/2)
# Decoding stage 1
self.decres1_1 = BasicBlock_ss(cur_depth + d2, planes = cur_depth)
self.decres1_2 = BasicBlock_us(cur_depth, upsamp=2)
cur_depth = int(cur_depth/2)
# Output
self.coupling = nn.Conv2d(cur_depth + d1, label_nbr, kernel_size=1)
self.sm = nn.LogSoftmax(dim=1)
def forward(self, x1, x2):
x = torch.cat((x1, x2), 1)
# pad5 = ReplicationPad2d((0, x53.size(3) - x5d.size(3), 0, x53.size(2) - x5d.size(2)))
s1_1 = x.size()
x1 = self.encres1_1(x)
x = self.encres1_2(x1)
s2_1 = x.size()
x2 = self.encres2_1(x)
x = self.encres2_2(x2)
s3_1 = x.size()
x3 = self.encres3_1(x)
x = self.encres3_2(x3)
s4_1 = x.size()
x4 = self.encres4_1(x)
x = self.encres4_2(x4)
x = self.decres4_1(x)
x = self.decres4_2(x)
s4_2 = x.size()
pad4 = ReplicationPad2d((0, s4_1[3] - s4_2[3], 0, s4_1[2] - s4_2[2]))
x = pad4(x)
# x = self.decres3_1(x)
x = self.decres3_1(torch.cat((x, x4), 1))
x = self.decres3_2(x)
s3_2 = x.size()
pad3 = ReplicationPad2d((0, s3_1[3] - s3_2[3], 0, s3_1[2] - s3_2[2]))
x = pad3(x)
x = self.decres2_1(torch.cat((x, x3), 1))
x = self.decres2_2(x)
s2_2 = x.size()
pad2 = ReplicationPad2d((0, s2_1[3] - s2_2[3], 0, s2_1[2] - s2_2[2]))
x = pad2(x)
x = self.decres1_1(torch.cat((x, x2), 1))
x = self.decres1_2(x)
s1_2 = x.size()
pad1 = ReplicationPad2d((0, s1_1[3] - s1_2[3], 0, s1_1[2] - s1_2[2]))
x = pad1(x)
x = self.coupling(torch.cat((x, x1), 1))
x = self.sm(x)
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d
class SiamUnet_conc(nn.Module):
"""SiamUnet_conc segmentation network."""
def __init__(self, input_nbr, label_nbr):
super(SiamUnet_conc, self).__init__()
self.input_nbr = input_nbr
self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
self.bn11 = nn.BatchNorm2d(16)
self.do11 = nn.Dropout2d(p=0.2)
self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
self.bn12 = nn.BatchNorm2d(16)
self.do12 = nn.Dropout2d(p=0.2)
self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.bn21 = nn.BatchNorm2d(32)
self.do21 = nn.Dropout2d(p=0.2)
self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
self.bn22 = nn.BatchNorm2d(32)
self.do22 = nn.Dropout2d(p=0.2)
self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn31 = nn.BatchNorm2d(64)
self.do31 = nn.Dropout2d(p=0.2)
self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.bn32 = nn.BatchNorm2d(64)
self.do32 = nn.Dropout2d(p=0.2)
self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.bn33 = nn.BatchNorm2d(64)
self.do33 = nn.Dropout2d(p=0.2)
self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn41 = nn.BatchNorm2d(128)
self.do41 = nn.Dropout2d(p=0.2)
self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.bn42 = nn.BatchNorm2d(128)
self.do42 = nn.Dropout2d(p=0.2)
self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.bn43 = nn.BatchNorm2d(128)
self.do43 = nn.Dropout2d(p=0.2)
self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)
self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1)
self.bn43d = nn.BatchNorm2d(128)
self.do43d = nn.Dropout2d(p=0.2)
self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
self.bn42d = nn.BatchNorm2d(128)
self.do42d = nn.Dropout2d(p=0.2)
self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
self.bn41d = nn.BatchNorm2d(64)
self.do41d = nn.Dropout2d(p=0.2)
self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1)
self.bn33d = nn.BatchNorm2d(64)
self.do33d = nn.Dropout2d(p=0.2)
self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
self.bn32d = nn.BatchNorm2d(64)
self.do32d = nn.Dropout2d(p=0.2)
self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
self.bn31d = nn.BatchNorm2d(32)
self.do31d = nn.Dropout2d(p=0.2)
self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1)
self.bn22d = nn.BatchNorm2d(32)
self.do22d = nn.Dropout2d(p=0.2)
self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
self.bn21d = nn.BatchNorm2d(16)
self.do21d = nn.Dropout2d(p=0.2)
self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1)
self.bn12d = nn.BatchNorm2d(16)
self.do12d = nn.Dropout2d(p=0.2)
self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
self.sm = nn.LogSoftmax(dim=1)
def forward(self, x1, x2):
"""Forward method."""
# Stage 1
x11 = self.do11(F.relu(self.bn11(self.conv11(x1))))
x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)
# Stage 2
x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)
# Stage 3
x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)
# Stage 4
x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)
####################################################
# Stage 1
x11 = self.do11(F.relu(self.bn11(self.conv11(x2))))
x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)
# Stage 2
x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)
# Stage 3
x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)
# Stage 4
x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)
####################################################
# Stage 4d
x4d = self.upconv4(x4p)
pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1)
x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))
# Stage 3d
x3d = self.upconv3(x41d)
pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1)
x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))
# Stage 2d
x2d = self.upconv2(x31d)
pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1)
x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))
# Stage 1d
x1d = self.upconv1(x21d)
pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1)
x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
x11d = self.conv11d(x12d)
return self.sm(x11d)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d
class SiamUnet_diff(nn.Module):
"""SiamUnet_diff segmentation network."""
def __init__(self, input_nbr, label_nbr):
super(SiamUnet_diff, self).__init__()
self.input_nbr = input_nbr
self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
self.bn11 = nn.BatchNorm2d(16)
self.do11 = nn.Dropout2d(p=0.2)
self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
self.bn12 = nn.BatchNorm2d(16)
self.do12 = nn.Dropout2d(p=0.2)
self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.bn21 = nn.BatchNorm2d(32)
self.do21 = nn.Dropout2d(p=0.2)
self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
self.bn22 = nn.BatchNorm2d(32)
self.do22 = nn.Dropout2d(p=0.2)
self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn31 = nn.BatchNorm2d(64)
self.do31 = nn.Dropout2d(p=0.2)
self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.bn32 = nn.BatchNorm2d(64)
self.do32 = nn.Dropout2d(p=0.2)
self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.bn33 = nn.BatchNorm2d(64)
self.do33 = nn.Dropout2d(p=0.2)
self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn41 = nn.BatchNorm2d(128)
self.do41 = nn.Dropout2d(p=0.2)
self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.bn42 = nn.BatchNorm2d(128)
self.do42 = nn.Dropout2d(p=0.2)
self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.bn43 = nn.BatchNorm2d(128)
self.do43 = nn.Dropout2d(p=0.2)
self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)
self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
self.bn43d = nn.BatchNorm2d(128)
self.do43d = nn.Dropout2d(p=0.2)
self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
self.bn42d = nn.BatchNorm2d(128)
self.do42d = nn.Dropout2d(p=0.2)
self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
self.bn41d = nn.BatchNorm2d(64)
self.do41d = nn.Dropout2d(p=0.2)
self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
self.bn33d = nn.BatchNorm2d(64)
self.do33d = nn.Dropout2d(p=0.2)
self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
self.bn32d = nn.BatchNorm2d(64)
self.do32d = nn.Dropout2d(p=0.2)
self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
self.bn31d = nn.BatchNorm2d(32)
self.do31d = nn.Dropout2d(p=0.2)
self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
self.bn22d = nn.BatchNorm2d(32)
self.do22d = nn.Dropout2d(p=0.2)
self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
self.bn21d = nn.BatchNorm2d(16)
self.do21d = nn.Dropout2d(p=0.2)
self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
self.bn12d = nn.BatchNorm2d(16)
self.do12d = nn.Dropout2d(p=0.2)
self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
self.sm = nn.LogSoftmax(dim=1)
def forward(self, x1, x2):
"""Forward method."""
# Stage 1
x11 = self.do11(F.relu(self.bn11(self.conv11(x1))))
x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)
# Stage 2
x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)
# Stage 3
x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)
# Stage 4
x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)
####################################################
# Stage 1
x11 = self.do11(F.relu(self.bn11(self.conv11(x2))))
x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)
# Stage 2
x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)
# Stage 3
x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)
# Stage 4
x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)
# Stage 4d
x4d = self.upconv4(x4p)
pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1)
x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))
# Stage 3d
x3d = self.upconv3(x41d)
pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1)
x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))
# Stage 2d
x2d = self.upconv2(x31d)
pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1)
x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))
# Stage 1d
x1d = self.upconv1(x21d)
pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1)
x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
x11d = self.conv11d(x12d)
return self.sm(x11d)
Fully convolutional siamese networks for change detection_likyoo的博客-CSDN博客