在上面一篇文章中,我们对训练代码中的inputs和outputs获得做了简单分析。有了inputs和outpus后,就可以开始计算loss值了。这也是本文重点。
主要讲解下面这个代码。中文注释添加在里面。
def generate_images_pred(self, inputs, outputs):
"""Generate the warped (reprojected) color images for a minibatch.
Generated images are saved into the `outputs` dictionary.
"""
//outputs["disp"]直接输出的就是视差图,并且仍然多尺度[0,1,2,3]分布。
for scale in self.opt.scales:
disp = outputs[("disp", scale)]
if self.opt.v1_multiscale:
source_scale = scale
else:
disp = F.interpolate(
disp, [self.opt.height, self.opt.width], mode="bilinear", align_corners=False)
source_scale = 0
//将disp值映射到[0.01,10],并求倒数就能得到深度值
_, depth = disp_to_depth(disp, self.opt.min_depth, self.opt.max_depth)
//将深度值存放到outputs["depth"...]中
outputs[("depth", 0, scale)] = depth
//在stereo traning时, frame_id恒为"s"。
for i, frame_id in enumerate(self.opt.frame_ids[1:]):
if frame_id == "s":
T = inputs["stereo_T"]
else:
T = outputs[("cam_T_cam", 0, frame_id)]
# from the authors of https://arxiv.org/abs/1712.00175
if self.opt.pose_model_type == "posecnn":
axisangle = outputs[("axisangle", 0, frame_id)]
translation = outputs[("translation", 0, frame_id)]
inv_depth = 1 / depth
mean_inv_depth = inv_depth.mean(3, True).mean(2, True)
T = transformation_from_parameters(
axisangle[:, 0], translation[:, 0] * mean_inv_depth[:, 0], frame_id < 0)
//将深度图投影成3维点云
cam_points = self.backproject_depth[source_scale](
depth, inputs[("inv_K", source_scale)])
//将3维点云投影成二维图像
pix_coords = self.project_3d[source_scale](
cam_points, inputs[("K", source_scale)], T)
//将二维图像赋值给outputs[("sample"..)]
outputs[("sample", frame_id, scale)] = pix_coords
//outputs上某点(x,y)的三个通道像素值来自于inputs上的(x',y')
//而x'和y'则由outputs(x,y)的最低维[0]和[1]
outputs[("color", frame_id, scale)] = F.grid_sample(
inputs[("color", frame_id, source_scale)],
outputs[("sample", frame_id, scale)],
padding_mode="border")
if not self.opt.disable_automasking:
outputs[("color_identity", frame_id, scale)] = \
inputs[("color", frame_id, source_scale)]
loss值由下面这个函数来获取。
def compute_losses(self, inputs, outputs):
"""Compute the reprojection and smoothness losses for a minibatch
"""
losses = {}
total_loss = 0
//按尺度来计算loss
for scale in self.opt.scales:
loss = 0
reprojection_losses = []
if self.opt.v1_multiscale:
source_scale = scale
else:
source_scale = 0
//按尺度获得视差图
disp = outputs[("disp", scale)]
//按尺度获得原始输入图
color = inputs[("color", 0, scale)]
//0尺度的原始输入图
target = inputs[("color", 0, source_scale)]
//在stereo-training时,frame_id恒为“s”
for frame_id in self.opt.frame_ids[1:]:
//按尺度获得对应图像的预测图(即深度图转换到点云再转到二维图像最后采样得到的彩图
pred = outputs[("color", frame_id, scale)]
//根据pred多尺度图和0尺度
reprojection_losses.append(self.compute_reprojection_loss(pred, target))
reprojection_losses = torch.cat(reprojection_losses, 1)
//直接对inputs["color",0,0]和["color",s,0]计算identity loss
if not self.opt.disable_automasking:
identity_reprojection_losses = []
for frame_id in self.opt.frame_ids[1:]:
pred = inputs[("color", frame_id, source_scale)]
identity_reprojection_losses.append(
self.compute_reprojection_loss(pred, target))
identity_reprojection_losses = torch.cat(identity_reprojection_losses, 1)
if self.opt.avg_reprojection:
identity_reprojection_loss = identity_reprojection_losses.mean(1, keepdim=True)
else:
# save both images, and do min all at once below
identity_reprojection_loss = identity_reprojection_losses
elif self.opt.predictive_mask:
# use the predicted mask
mask = outputs["predictive_mask"]["disp", scale]
if not self.opt.v1_multiscale:
mask = F.interpolate(
mask, [self.opt.height, self.opt.width],
mode="bilinear", align_corners=False)
reprojection_losses *= mask
# add a loss pushing mask to 1 (using nn.BCELoss for stability)
weighting_loss = 0.2 * nn.BCELoss()(mask, torch.ones(mask.shape).cuda())
loss += weighting_loss.mean()
if self.opt.avg_reprojection:
reprojection_loss = reprojection_losses.mean(1, keepdim=True)
else:
reprojection_loss = reprojection_losses
if not self.opt.disable_automasking:
# add random numbers to break ties
identity_reprojection_loss += torch.randn(
identity_reprojection_loss.shape).cuda() * 0.00001
combined = torch.cat((identity_reprojection_loss, reprojection_loss), dim=1)
else:
combined = reprojection_loss
if combined.shape[1] == 1:
to_optimise = combined
else:
to_optimise, idxs = torch.min(combined, dim=1)
if not self.opt.disable_automasking:
outputs["identity_selection/{}".format(scale)] = (
idxs > identity_reprojection_loss.shape[1] - 1).float()
loss += to_optimise.mean()
mean_disp = disp.mean(2, True).mean(3, True)
norm_disp = disp / (mean_disp + 1e-7)
smooth_loss = get_smooth_loss(norm_disp, color)
loss += self.opt.disparity_smoothness * smooth_loss / (2 ** scale)
total_loss += loss
losses["loss/{}".format(scale)] = loss
total_loss /= self.num_scales
losses["loss"] = total_loss
return losses
关于图片预测和损失值计算,本文只是在代码里面做了一些简单的注释,其实还有很多细节没有深入研究,比如说SSIM值得计算等,留待后面文章的深入。