注意卷积操作后的[N,Cout,Hout,Wout]的计算
"""为后面做准备"""
def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation):
#注意 padding:这样可以使得dilation对Hout,Wout的大小没有影响
return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation = dilation, bias=False),
nn.BatchNorm2d(out_planes))
def convbn_3d(in_planes, out_planes, kernel_size, stride, pad):
return nn.Sequential(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride,bias=False),
nn.BatchNorm3d(out_planes))
"""方括号里的网络部分"""
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride, downsample, pad, dilation):
super(BasicBlock, self).__init__()
self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation),#kernel_size固定为3
nn.ReLU(inplace=True))
self.conv2 = convbn(planes, planes, 3, 1, pad, dilation)#kernel_size=3,stride=1
self.downsample = downsample
self.stride = stride
def forward(self, x):#定义网络的前向传播路径
out = self.conv1(x)
out = self.conv2(out)
if self.downsample is not None:
x = self.downsample(x)
out += x
return out
"""视差回归"""
class disparityregression(nn.Module):
def __init__(self, maxdisp):
super(disparityregression, self).__init__()
self.disp = torch.Tensor(np.reshape(np.array(range(maxdisp)),[1, maxdisp,1,1])).cuda()
def forward(self, x):
out = torch.sum(x*self.disp.data,1, keepdim=True)
return out
"""特征提取"""
class feature_extraction(nn.Module):
def __init__(self):
super(feature_extraction, self).__init__()
self.inplanes = 32
"""
conv0_1-3
nn.ReLU(inplace=True):True:将会改变输入的数据,否则不会改变原输入(输入会随输出变化)
"""
self.firstconv = nn.Sequential(convbn(3, 32, 3, 2, 1, 1),#(32,H/2,W/2)
nn.ReLU(inplace=True),
convbn(32, 32, 3, 1, 1, 1),#(32,H/2,W/2)
nn.ReLU(inplace=True),
convbn(32, 32, 3, 1, 1, 1),#(32,H/2,W/2)
nn.ReLU(inplace=True))
"""
conv1_x-conv4_x
_make_layer(block,planes,blocks_num,stride,pad,dilation)
"""
self.layer1 = self._make_layer(BasicBlock, 32, 3, 1,1,1)
self.layer2 = self._make_layer(BasicBlock, 64, 16, 2,1,1)
self.layer3 = self._make_layer(BasicBlock, 128, 3, 1,1,1)
self.layer4 = self._make_layer(BasicBlock, 128, 3, 1,1,2)
"""
SPP module:四个不同池化大小分支,输出大小一致,pad=0
"""
self.branch1 = nn.Sequential(nn.AvgPool2d((64, 64), stride=(64,64)),
convbn(128, 32, 1, 1, 0, 1),
nn.ReLU(inplace=True))
self.branch2 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32,32)),
convbn(128, 32, 1, 1, 0, 1),
nn.ReLU(inplace=True))
self.branch3 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16,16)),
convbn(128, 32, 1, 1, 0, 1),
nn.ReLU(inplace=True))
self.branch4 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8,8)),
convbn(128, 32, 1, 1, 0, 1),
nn.ReLU(inplace=True))
"""
fusion
"""
self.lastconv = nn.Sequential(convbn(320, 128, 3, 1, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 32, kernel_size=1, padding=0, stride = 1, bias=False))
"""
为conv1_x到conv4_x做准备
"""
def _make_layer(self, block, planes, blocks, stride, pad, dilation):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:#self.inplanes=32(conv1_x之后依旧为32)
downsample = nn.Sequential(#conv2_x为例,stride=2 , 32!=64*1,进行downsample(下采样)
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),#nn.Conv2d(in_channels=32,out_channels=64,kernel_size=1,stride=2)
nn.BatchNorm2d(planes * block.expansion),)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation))#第一层考虑是否要downsample
self.inplanes = planes * block.expansion#inplanes要与上一层的outplanes相等
for i in range(1, blocks):
layers.append(block(self.inplanes, planes,1,None,pad,dilation))#不用考虑downsample
return nn.Sequential(*layers)
def forward(self, x):
#x.size=(,3,H,W)[batch,channel,h,w]
output = self.firstconv(x)#(,32,H/2,W/2)
output = self.layer1(output)#(,32,H/2,W/2)
output_raw = self.layer2(output)#(,64,H/4,W/4)
output = self.layer3(output_raw)#(,128,H/4,W/4)
output_skip = self.layer4(output)#size=(,128,H/4,W/4)
output_branch1 = self.branch1(output_skip)
output_branch1 = F.upsample(output_branch1, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
output_branch2 = self.branch2(output_skip)
output_branch2 = F.upsample(output_branch2, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
output_branch3 = self.branch3(output_skip)
output_branch3 = F.upsample(output_branch3, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
output_branch4 = self.branch4(output_skip)
output_branch4 = F.upsample(output_branch4, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
output_feature = torch.cat((output_raw, output_skip, output_branch4, output_branch3, output_branch2, output_branch1), 1)#按维数1(列拼接)
output_feature = self.lastconv(output_feature)
return output_feature
class hourglass(nn.Module):
def __init__(self, inplanes):
super(hourglass, self).__init__()
#conv1-conv6分别为一个hourglass结构的从左到右layer层
self.conv1 = nn.Sequential(convbn_3d(inplanes, inplanes*2, kernel_size=3, stride=2, pad=1),
nn.ReLU(inplace=True))
self.conv2 = convbn_3d(inplanes*2, inplanes*2, kernel_size=3, stride=1, pad=1)
self.conv3 = nn.Sequential(convbn_3d(inplanes*2, inplanes*2, kernel_size=3, stride=2, pad=1),
nn.ReLU(inplace=True))
self.conv4 = nn.Sequential(convbn_3d(inplanes*2, inplanes*2, kernel_size=3, stride=1, pad=1),
nn.ReLU(inplace=True))
self.conv5 = nn.Sequential(nn.ConvTranspose3d(inplanes*2, inplanes*2, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
nn.BatchNorm3d(inplanes*2)) #+conv2
self.conv6 = nn.Sequential(nn.ConvTranspose3d(inplanes*2, inplanes, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
nn.BatchNorm3d(inplanes)) #+x
def forward(self, x ,presqu, postsqu):
out = self.conv1(x) #in:1/4 out:1/8
pre = self.conv2(out) #in:1/8 out:1/8,没有包含relu
#postsqu:上个hourglass的conv5,与下一个hourglass的conv2连接
if postsqu is not None:
pre = F.relu(pre + postsqu, inplace=True)
else:
pre = F.relu(pre, inplace=True)
out = self.conv3(pre) #in:1/8 out:1/16
out = self.conv4(out) #in:1/16 out:1/16
#presqu:第一个hourglass的conv2,要连接到每个hourglass的conv5
if presqu is not None:
post = F.relu(self.conv5(out)+presqu, inplace=True) #in:1/16 out:1/8
else:
post = F.relu(self.conv5(out)+pre, inplace=True)
out = self.conv6(post) #in:1/8 out:1/4
#每个hourglass有三个值要输出:conv6,conv5,conv2
return out, pre, post
.contiguous():PyTorch中的contiguous Pytorch中contiguous()函数理解
torch.squeeze(C,dim=0):移除指定维度为1的维度
class PSMNet(nn.Module):
def __init__(self, maxdisp):
super(PSMNet, self).__init__()
"""设定的最大视差"""
self.maxdisp = maxdisp
"""特征提取"""
self.feature_extraction = feature_extraction()
"""stacked hourglass:3Dconv0和3Dconv1"""
self.dres0 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1),
nn.ReLU(inplace=True),
convbn_3d(32, 32, 3, 1, 1),
nn.ReLU(inplace=True))
self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
nn.ReLU(inplace=True),
convbn_3d(32, 32, 3, 1, 1))
"""stacked hourglass:三个hourglass,输入通道32"""
self.dres2 = hourglass(32)
self.dres3 = hourglass(32)
self.dres4 = hourglass(32)
self.classif1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1,bias=False))
self.classif2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1,bias=False))
self.classif3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1,bias=False))
"""网络参数初始化?"""
for m in self.modules():
#判断m是否为某个类型(如nn.Conv2d)
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, left, right):
"""左右特征:(,32,H/4,W/4)"""
refimg_fea = self.feature_extraction(left)
targetimg_fea = self.feature_extraction(right)
#matching
"""构建cost volume:(,64,D/4,H/4,W/4)"""
#size获得形状大小
cost = Variable(torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1]*2, self.maxdisp//4, refimg_fea.size()[2], refimg_fea.size()[3]).zero_()).cuda()
#拼接左右特征
for i in range(self.maxdisp//4):
if i > 0 :
cost[:, :refimg_fea.size()[1], i, :,i:] = refimg_fea[:,:,:,i:]
cost[:, refimg_fea.size()[1]:, i, :,i:] = targetimg_fea[:,:,:,:-i]
else:
cost[:, :refimg_fea.size()[1], i, :,:] = refimg_fea
cost[:, refimg_fea.size()[1]:, i, :,:] = targetimg_fea
cost = cost.contiguous()
"""stacked hourglass部分的网络"""
cost0 = self.dres0(cost)
cost0 = self.dres1(cost0) + cost0
out1, pre1, post1 = self.dres2(cost0, None, None)
out1 = out1+cost0
out2, pre2, post2 = self.dres3(out1, pre1, post1)
out2 = out2+cost0
out3, pre3, post3 = self.dres4(out2, pre1, post2)
out3 = out3+cost0
#三个cost输出
cost1 = self.classif1(out1)
cost2 = self.classif2(out2) + cost1
cost3 = self.classif3(out3) + cost2
if self.training:
cost1 = F.upsample(cost1, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear')
cost2 = F.upsample(cost2, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear')
cost1 = torch.squeeze(cost1,1)
pred1 = F.softmax(cost1,dim=1)
pred1 = disparityregression(self.maxdisp)(pred1)
cost2 = torch.squeeze(cost2,1)
pred2 = F.softmax(cost2,dim=1)
pred2 = disparityregression(self.maxdisp)(pred2)
cost3 = F.upsample(cost3, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear')
cost3 = torch.squeeze(cost3,1)
pred3 = F.softmax(cost3,dim=1)
#For your information: This formulation 'softmax(c)' learned "similarity"
#while 'softmax(-c)' learned 'matching cost' as mentioned in the paper.
#However, 'c' or '-c' do not affect the performance because feature-based cost volume provided flexibility.
pred3 = disparityregression(self.maxdisp)(pred3)
if self.training:#训练时需要三个预测值
return pred1, pred2, pred3
else:#测试时只需要最后一个预测值
return pred3