CasMVSNet原理阐述及代码实战

文章目录

  • 前言
  • 一、论文先进之处
  • 二、网络结构
    • 1.引入特征金字塔fpn
    • 2.多阶段推断深度图
    • 3.多阶段刨析
    • 4.一些问题
  • 三、代码刨析
  • 四、网络实战
    • 1.训练
    • 2.测试


前言

Cascade Cost Volume for High-Resolution Multi-View Stereo and Stereo Matching

出自CVPR 2020,github链接casmvsnet
下图是我修正后的网络图,更好的帮助大家理解网络
CasMVSNet原理阐述及代码实战_第1张图片


一、论文先进之处

使用多阶段策略,由粗到细的推断深度图
DTU准确度提升了35%,GPU和运行时间降了50%

二、网络结构

1.引入特征金字塔fpn

本网络的输入为原图像与标定的相机位姿,其中位姿信息用于单应性变换原图像用于特征提取。特征提取网络包括两部分,特征金字塔和多尺度聚合模块。特征金字塔可视为一个编码解码结构,通过特征金字塔可以得到三个不同尺度的特征图,顶层特征图包含高层语义特征但缺乏底层的细节;高层的特征图虽然包含特征细节,但缺失足够的语义信息。因此从多个尺度上提取特征能够描述准确的图像特征。根据特征金字塔的上采样及下采样结构,输出的特征图有三个尺度,顶层的特征输出大小为W/4×H/4×32,其中W和H为原图像尺度;中间层尺度大小为W/2×H/2×16;最底层尺度为W×H×8。这三个尺度可成为3个阶段

2.多阶段推断深度图

每一阶段都类似MVSNet,不过第二三阶段估计的是残差深度图,第一阶段估计的是稀疏的深度图

3.多阶段刨析

通过特征提取以后得到了一张参考图像的特征和N-1张源图像的特征,再根据单应性变换,将每一张源图像投影到每一层深度上构成特征体,最后利用插值法使每张投影尺寸相同。理论上每一张参考图像会有N-1个对应的特征体,将这些特征体基于方差的形式构建一个代价体。由于在第一阶段生成的代价体尺度为W/4×H/4×32×48是稀疏的,因此使用的正则网络是3D CNN生成粗糙的深度图并且将场景的深度范围进行了一个估计。
第一阶段得到的深度图作为参考,进行第二阶段的深度图估计。使用中层特征图在剩余深度范围内进行采样深度平面构成剩余代价体,其尺度为W/2×H/2×16×32,使用3DCNN残差深度图,最终生成第二阶段深度图。值得一提的是,在构成剩余代价体时需要使用残差量的单应性变换,公式如下:
CasMVSNet原理阐述及代码实战_第2张图片
最终的深度图是依靠第二阶段估计的深度图用相同的方式构成剩余代价体,其尺度大小为W×H×8×8,生成最终的深度图。

4.一些问题

本网络获取残差深度图的方式为通过缩小剩余深度获取,这个范围是通过一个系数实现的
我在这里有详细阐述


三、代码刨析

class CascadeMVSNet(nn.Module):  #返回各个阶段的输出深度图与置信度图(附件一个优化深度图)  2022/6/21
    def __init__(self, refine=False, ndepths=[48, 32, 8], depth_interals_ratio=[4, 2, 1], share_cr=False,
                 grad_method="detach", arch_mode="fpn", cr_base_chs=[8, 8, 8]):
        super(CascadeMVSNet, self).__init__()
        self.refine = refine
        self.share_cr = share_cr
        self.ndepths = ndepths
        self.depth_interals_ratio = depth_interals_ratio
        self.grad_method = grad_method
        self.arch_mode = arch_mode
        self.cr_base_chs = cr_base_chs
        self.num_stage = len(ndepths)
        print("**********netphs:{}, depth_intervals_ratio:{},  grad:{}, chs:{}************".format(ndepths,
              depth_interals_ratio, self.grad_method, self.cr_base_chs))

        assert len(ndepths) == len(depth_interals_ratio)

        self.stage_infos = {
            "stage1":{
                "scale": 4.0,
            },
            "stage2": {
                "scale": 2.0,
            },
            "stage3": {
                "scale": 1.0,
            }
        }

        self.feature = FeatureNet(base_channels=8, stride=4, num_stage=self.num_stage, arch_mode=self.arch_mode)
        if self.share_cr:
            self.cost_regularization = CostRegNet(in_channels=self.feature.out_channels, base_channels=8)
        else:
            self.cost_regularization = nn.ModuleList([CostRegNet(in_channels=self.feature.out_channels[i],
                                                                 base_channels=self.cr_base_chs[i])
                                                      for i in range(self.num_stage)])
        if self.refine:
            self.refine_network = RefineNet()
        self.DepthNet = DepthNet()

    def forward(self, imgs, proj_matrices, depth_values):
        depth_min = float(depth_values[0, 0].cpu().numpy())
        depth_max = float(depth_values[0, -1].cpu().numpy())
        depth_interval = (depth_max - depth_min) / depth_values.size(1)

        # step 1. feature extraction
        features = []
        for nview_idx in range(imgs.size(1)):  #imgs shape (B, N, C, H, W)
            img = imgs[:, nview_idx]
            features.append(self.feature(img))    #提取参考源图像特征 金字塔结构  2022/6/20

        outputs = {}  #各个阶段的输出深度图与置信度图(附件一个优化深度图)  2022/6/21
        depth, cur_depth = None, None
        for stage_idx in range(self.num_stage):  #step 2. 分几个阶段,循环执行多阶段 2022/6/20
            # print("*********************stage{}*********************".format(stage_idx + 1))
            #stage feature, proj_mats, scales
            features_stage = [feat["stage{}".format(stage_idx + 1)] for feat in features]  #feat是一个包含特征金字塔三个阶段的特征图列表,故此步得到所有图片第i层的特征图 2022/6/20
            proj_matrices_stage = proj_matrices["stage{}".format(stage_idx + 1)]
            stage_scale = self.stage_infos["stage{}".format(stage_idx + 1)]["scale"]

            if depth is not None:  #除了第一阶段 ,都会走这  2022/6/21
                if self.grad_method == "detach":  
                    cur_depth = depth.detach()
                else:
                    cur_depth = depth
                cur_depth = F.interpolate(cur_depth.unsqueeze(1),
                                                [img.shape[2], img.shape[3]], mode='bilinear',
                                                align_corners=Align_Corners_Range).squeeze(1)
            else:
                cur_depth = depth_values
            depth_range_samples = get_depth_range_samples(cur_depth=cur_depth,
                                                        ndepth=self.ndepths[stage_idx],
                                                        depth_inteval_pixel=self.depth_interals_ratio[stage_idx] * depth_interval,
                                                        dtype=img[0].dtype,
                                                        device=img[0].device,
                                                        shape=[img.shape[0], img.shape[2], img.shape[3]],
                                                        max_depth=depth_max,
                                                        min_depth=depth_min)

            outputs_stage = self.DepthNet(features_stage, proj_matrices_stage,
                                          depth_values=F.interpolate(depth_range_samples.unsqueeze(1),  #imgs shape (B, N, C, H, W)
                                                                     [self.ndepths[stage_idx], img.shape[2]//int(stage_scale), img.shape[3]//int(stage_scale)], mode='trilinear',
                                                                     align_corners=Align_Corners_Range).squeeze(1),
                                          num_depth=self.ndepths[stage_idx],   #48,32,8  2022/6/221
                                          cost_regularization=self.cost_regularization if self.share_cr else self.cost_regularization[stage_idx])
            #DepthNet返回一个深度图和置信度图  2022/6/21
            depth = outputs_stage['depth']

            outputs["stage{}".format(stage_idx + 1)] = outputs_stage
            outputs.update(outputs_stage)

        # depth map refinement
        if self.refine:
            refined_depth = self.refine_network(torch.cat((imgs[:, 0], depth), 1))
            outputs["refined_depth"] = refined_depth

        return outputs
class DepthNet(nn.Module):   #该网络承接特征金字塔提取,深度间隔构成,进行正则化网络  2022/6/21
    def __init__(self):
        super(DepthNet, self).__init__()
                                              #depth_values为深度采样构成的一个张量(B, D, H, W)  2022/6/21
    def forward(self, features, proj_matrices, depth_values, num_depth, cost_regularization, prob_volume_init=None):
        proj_matrices = torch.unbind(proj_matrices, 1)
        assert len(features) == len(proj_matrices), "Different number of images and projection matrices"
        assert depth_values.shape[1] == num_depth, "depth_values.shape[1]:{}  num_depth:{}".format(depth_values.shapep[1], num_depth)
        num_views = len(features)

        # step 1. feature extraction
        # in: images; out: 32-channel feature maps
        ref_feature, src_features = features[0], features[1:]
        ref_proj, src_projs = proj_matrices[0], proj_matrices[1:]

        # step 2. differentiable homograph, build cost volume
        ref_volume = ref_feature.unsqueeze(2).repeat(1, 1, num_depth, 1, 1)
        volume_sum = ref_volume
        volume_sq_sum = ref_volume ** 2
        del ref_volume
        for src_fea, src_proj in zip(src_features, src_projs):
            #warpped features
            src_proj_new = src_proj[:, 0].clone()
            src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4])
            ref_proj_new = ref_proj[:, 0].clone()
            ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4])
            warped_volume = homo_warping(src_fea, src_proj_new, ref_proj_new, depth_values)  #特征体 [B, C, Ndepth, H, W]  2022/6/21
            # warped_volume = homo_warping(src_fea, src_proj[:, 2], ref_proj[:, 2], depth_values)

            if self.training:
                volume_sum = volume_sum + warped_volume
                volume_sq_sum = volume_sq_sum + warped_volume ** 2
            else:
                # TODO: this is only a temporal solution to save memory, better way?
                volume_sum += warped_volume
                volume_sq_sum += warped_volume.pow_(2)  # the memory of warped_volume has been modified
            del warped_volume
        # aggregate multiple feature volumes by variance
        volume_variance = volume_sq_sum.div_(num_views).sub_(volume_sum.div_(num_views).pow_(2))  #代价体 2022/6/21

        # step 3. cost volume regularization
        cost_reg = cost_regularization(volume_variance)
        # cost_reg = F.upsample(cost_reg, [num_depth * 4, img_height, img_width], mode='trilinear')
        prob_volume_pre = cost_reg.squeeze(1)   #概率体之前 2022/6/21

        if prob_volume_init is not None:
            prob_volume_pre += prob_volume_init

        prob_volume = F.softmax(prob_volume_pre, dim=1)   #概率体  2022/6/21
        depth = depth_regression(prob_volume, depth_values=depth_values)   #初始深度图  2022/6/21

        with torch.no_grad():
            # photometric confidence
            prob_volume_sum4 = 4 * F.avg_pool3d(F.pad(prob_volume.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)), (4, 1, 1), stride=1, padding=0).squeeze(1)
            depth_index = depth_regression(prob_volume, depth_values=torch.arange(num_depth, device=prob_volume.device, dtype=torch.float)).long()
            depth_index = depth_index.clamp(min=0, max=num_depth-1)
            photometric_confidence = torch.gather(prob_volume_sum4, 1, depth_index.unsqueeze(1)).squeeze(1)

        return {"depth": depth,  "photometric_confidence": photometric_confidence}

def depth_regression(p, depth_values):   #depth_values为深度采样构成的一个张量(B, D, H, W)  2022/6/21
    if depth_values.dim() <= 2:
        # print("regression dim <= 2")
        depth_values = depth_values.view(*depth_values.shape, 1, 1)
    depth = torch.sum(p * depth_values, 1)  #走这  2022/6/21

    return depth

def cas_mvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
    depth_loss_weights = kwargs.get("dlossw", None)

    total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)

    for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys() if "stage" in k]:
        depth_est = stage_inputs["depth"]
        depth_gt = depth_gt_ms[stage_key]
        mask = mask_ms[stage_key]
        mask = mask > 0.5

        depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean')

        if depth_loss_weights is not None:
            stage_idx = int(stage_key.replace("stage", "")) - 1
            total_loss += depth_loss_weights[stage_idx] * depth_loss
        else:
            total_loss += 1.0 * depth_loss

    return total_loss, depth_loss


def get_cur_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, shape, max_depth=192.0, min_depth=0.0):
    #shape, (B, H, W)
    #cur_depth: (B, H, W)
    #return depth_range_values: (B, D, H, W)
    cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel)  # (B, H, W)
    cur_depth_max = (cur_depth + ndepth / 2 * depth_inteval_pixel)
    # cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel).clamp(min=0.0)   #(B, H, W)
    # cur_depth_max = (cur_depth_min + (ndepth - 1) * depth_inteval_pixel).clamp(max=max_depth)

    assert cur_depth.shape == torch.Size(shape), "cur_depth:{}, input shape:{}".format(cur_depth.shape, shape)
    new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1)  # (B, H, W)

    depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=cur_depth.device,
                                                                  dtype=cur_depth.dtype,
                                                                  requires_grad=False).reshape(1, -1, 1,
                                                                                               1) * new_interval.unsqueeze(1))  #本质意思是得到残差深度图,然后加上上一阶段深度图的上采样.这个网络是d对源文改进一下,在深度间隔上直接构成一个新的区间,直接得到了新的深度图  2022/6/22

    return depth_range_samples


def get_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, device, dtype, shape,
                           max_depth=192.0, min_depth=0.0):
    #shape: (B, H, W)
    #cur_depth: (B, H, W) or (B, D)
    #return depth_range_samples: (B, D, H, W)
    if cur_depth.dim() == 2:  #dim()返回张量维度个数(B,D)返回2  2022/6/21
        cur_depth_min = cur_depth[:, 0]  # (B,)
        cur_depth_max = cur_depth[:, -1]
        new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1)  # (B, )

        depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=device, dtype=dtype,
                                                                       requires_grad=False).reshape(1, -1) * new_interval.unsqueeze(1)) #(B, D)

        depth_range_samples = depth_range_samples.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, shape[1], shape[2]) #(B, D, H, W)

    else:

        depth_range_samples = get_cur_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, shape, max_depth, min_depth)

    return depth_range_samples

四、网络实战

1.训练

首先下载dtu训练集,在我的介绍d2hc实战里有给出链接
下载额外的数据集Depths_raw,并将其放在dtu训练集文件下
CasMVSNet原理阐述及代码实战_第3张图片
训练集即可完成。
我在1080ti上是可以训练的占显存约为8g。每个epoch大约6小时

parser = argparse.ArgumentParser(description='A PyTorch Implementation of Cascade Cost Volume MVSNet')
parser.add_argument('--mode', default='train', help='train or test', choices=['train', 'test', 'profile'])
parser.add_argument('--model', default='mvsnet', help='select model')
parser.add_argument('--device', default='cuda', help='select model')

parser.add_argument('--dataset', default='dtu_yao', help='select dataset')
parser.add_argument('--trainpath', default='/dtu', help='train datapath')
parser.add_argument('--testpath', help='test datapath')
parser.add_argument('--trainlist',  default='lists/dtu/train.txt', help='train list')
parser.add_argument('--testlist', default='lists/dtu/test.txt', help='test list')

parser.add_argument('--epochs', type=int, default=16, help='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--lrepochs', type=str, default="10,12,14:2", help='epoch ids to downscale lr and the downscale rate')
parser.add_argument('--wd', type=float, default=0.0, help='weight decay')

parser.add_argument('--batch_size', type=int, default=1, help='train batch size')
parser.add_argument('--numdepth', type=int, default=192, help='the number of depth values')
parser.add_argument('--interval_scale', type=float, default=1.06, help='the number of depth values')

parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint')
parser.add_argument('--logdir', default='./checkpoints/debug', help='the directory to save checkpoints/logs')
parser.add_argument('--resume', action='store_true', help='continue to train the model')

parser.add_argument('--summary_freq', type=int, default=50, help='print and summary frequency')
parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency')
parser.add_argument('--eval_freq', type=int, default=3, help='eval freq')

parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed')
parser.add_argument('--pin_m', action='store_true', help='data loader pin memory')
parser.add_argument("--local_rank", type=int, default=0)

parser.add_argument('--share_cr', action='store_true', help='whether share the cost volume regularization')
parser.add_argument('--ndepths', type=str, default="48,32,8", help='ndepths')
parser.add_argument('--depth_inter_r', type=str, default="4,2,1", help='depth_intervals_ratio')
parser.add_argument('--dlossw', type=str, default="0.5,1.0,2.0", help='depth loss weight for different stage')
parser.add_argument('--cr_base_chs', type=str, default="8,8,8", help='cost regularization base channels')
parser.add_argument('--grad_method', type=str, default="detach", choices=["detach", "undetach"], help='grad method')

parser.add_argument('--using_apex', action='store_true', help='using apex, need to install apex')
parser.add_argument('--sync_bn', action='store_true',help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str, default="O0")
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)

2.测试

下载测试集即可,链接也在我之前写的d2hc里。
所有的dtu测试集运行时间大概为40分钟
这里我使用了自带的脚本融合,而不是Gipuma

parser = argparse.ArgumentParser(description='Predict depth, filter, and fuse')
parser.add_argument('--model', default='mvsnet', help='select model')

parser.add_argument('--dataset', default='general_eval', help='select dataset')
parser.add_argument('--testpath', default='./data', help='testing data dir for some scenes')
parser.add_argument('--testpath_single_scene', help='testing data path for single scene')
parser.add_argument('--testlist', default='./lists/dtu/test.txt', help='testing scene list')

parser.add_argument('--batch_size', type=int, default=1, help='testing batch size')
parser.add_argument('--numdepth', type=int, default=192, help='the number of depth values')

parser.add_argument('--loadckpt', default='E:/CasMVSNet/checkpoints/debug/model_000015.ckpt', help='load a specific checkpoint')
parser.add_argument('--outdir', default='./outputs_newm', help='output dir')
parser.add_argument('--display', action='store_true', help='display depth images and masks')

parser.add_argument('--share_cr', action='store_true', help='whether share the cost volume regularization')

parser.add_argument('--ndepths', type=str, default="48,32,8", help='ndepths')
parser.add_argument('--depth_inter_r', type=str, default="4,2,1", help='depth_intervals_ratio')
parser.add_argument('--cr_base_chs', type=str, default="8,8,8", help='cost regularization base channels')
parser.add_argument('--grad_method', type=str, default="detach", choices=["detach", "undetach"], help='grad method')

parser.add_argument('--interval_scale', type=float, default=1.06,help='the depth interval scale')  
parser.add_argument('--num_view', type=int, default=5, help='num of view')
parser.add_argument('--max_h', type=int, default=864, help='testing max h')  #864   1000
parser.add_argument('--max_w', type=int, default=1152, help='testing max w')  #1152   2000
parser.add_argument('--fix_res', action='store_true', help='scene all using same retorch.FloatTensors')

parser.add_argument('--num_worker', type=int, default=4, help='depth_filer worker')
parser.add_argument('--save_freq', type=int, default=20, help='save freq of local pcd')


parser.add_argument('--filter_method', type=str, default='normal', choices=["gipuma", "normal"], help="filter method")  #normal表示自有 2022/6/15

你可能感兴趣的:(计算机视觉,人工智能,深度学习,python,神经网络)