KinectFusion实战笔记

KinectFusion算法实战笔记

    • 引言
    • 阅读材料
    • 重建数据集
    • 算法主要步骤
      • 重要概念
    • 实战代码解读
      • 一些先验知识
      • 数据预处理
      • 算法主函数
    • 数据集的处理
    • 体素的融合
    • 相机的追踪
    • 本文内容总结

引言

KinectFusion是RGBD SLAM重建技术的开山鼻祖,是三维重建过程中的经典算法。本文是对KinectFusion实践过程中的一次梳理,希望能够督促我快点完成代码的应用。水平不高,但求学习,通过记录的方式让自己静下心来梳理。

阅读材料

  • 代码参考:

    python版本:https://github.com/JingwenWang95/KinectFusion
    C++版本:https://github.com/Nerei/kinfu_remake

  • 先验知识:

    RGBD图像,网格重建,四元数,相机模型

    opencv关于相机校准和重建的文档:https://docs.opencv.org/2.4/modules/calib3d/doc/camera_calibration_and_3d_reconstruction.html#decomposeprojectionmatrix

重建数据集

常用的RGBD SLAM重建数据集是TUM的RGBD数据集: https://vision.in.tum.de/data/datasets/rgbd-dataset/download

该数据集中包含PNG格式的RBG彩色图和Depth灰度图对。其中,深度图按照5000倍的比例缩放,深度图中像素值5000代表距离相机1米,10000代表距离相机2米。像素值为0则表示缺失值,即此处没有数据。rgb.txt和depth.txt记录的每张图片的采集时间和图片名称(图片名称也是以时间来命名的)。由于RGB相机和深度传感器的差异,几乎没有对应的RGB图像和深度图是同一个时刻采集的。在处理图像时,需要一步预处理,找到rgb和depth图像的一一对应关系。

除了重建必要的深度图和彩色图之外,TUM数据集也包含相机轨迹的groundtruth,即在一个固定坐标系下,每一帧图像拍摄时相机的位置和方向坐标。在groundtruth.txt中,timestamp代表unix纪元下的时间, t x , t y , t z t_x, t_y,t_z tx,ty,tz代表了相机的光学中心相对于运动捕捉系统所定义的世界原点的位置。 q x , q y , q z , q w q_x, q_y, q_z, q_w qx,qy,qz,qw是以四元数的形式给出的相机的光学中心相对于世界原点的方向。

在实战过程中,相机焦距 f x f_x fx, f y f_y fy和相机光学中心 c x c_x cx, c y c_y cy在关联深度和RGB图中至关重要。相机参数矩阵为:
[ f x 0 c x 0 f y c y 0 0 1 ] \begin{bmatrix} f_x & 0&c_x \\ 0 & f_y&c_y\\0&0&1 \end{bmatrix} \quad fx000fy0cxcy1
TUM数据集中device共有三组,分别对应 f r 1 , f r 2 , f r 3 fr1, fr2, fr3 fr1,fr2,fr3三组相机参数。

对于初学者(比如我)而言,要了解程序的各个步骤是个万丈高楼平地起的过程。具体应用过程会出现这样那样的不理解,看CSDN和博客园是很好的方式。关于TUM数据集,推荐半闲居士: 一起做RGB-D SLAM。

算法主要步骤

重要概念

TSDF: TSDF可以译作截断符号距离函数,其是一种网格式的地图先选定要建模的三维空间,比如3×3×3 m^3 那么大,按照一定分辨率,将这个空间分成许多小块,存储每个小块内部的信息。TSDF地图整个存储在显存当中而不是内存中,由于每个体素的计算互相并不干扰,所以可以利用GPU的并行特性,并行地对对每个体素进行计算和更新(此处参考https://blog.csdn.net/qinqinxiansheng/article/details/119449196)。

图像金字塔:

RGBD SLAM算法的核心是使用多帧RGB-D图像从相机拍摄的图像中重建出三维场景。

原文中的算法流程图如下所示:
KinectFusion实战笔记_第1张图片

  1. 深度图转换(Depth Map Conversion)

深度图本质上是2.5D的信息,由彩色图(像素点由u = (x, y)表示)和深度图(包含深度信息D)组成。采集深度图的过程,是世界坐标到相机坐标的一个转换过程。算法根据采得的深度信息,将RGBD图转为点云,随后根据点云信息计算出顶点图和法线图。

Depth Map 深度图
|— Vertex Map 顶点图
|— Normal Map 法线图

  1. 相机位姿估计

    相机姿态估计是重建过程中至关重要的一步。本文中使用ICP算法用于不同帧图像间的注册,通过寻找不同帧彩色图点与点之间的相关性估计相机位姿 T g , k T_{g,k} Tg,k

  2. 体块集成 (Volumetric Integration)

    体块集成的过程中用到了TSDF(volumetric truncated signed distance)技术。作者先设置一个很大的三维空间用于存放重建场景,场景中的内容是不同帧获取的空间切片的叠加,体块内的元素随着深度图的读取逐步更新。

  3. 射线投射 (Raycasting)
    用于重建物体表面的绘制。

实战代码解读

KinectFusion的运行过程需要GPU加速,Jingwen版本的KinectFusion运行使用了pytorch框架而不是直接使用CUDA。

一些先验知识

  1. 深度图和RGB图像关联
    在实战过程中,相机焦距 f x f_x fx, f y f_y fy和相机光学中心 c x c_x cx, c y c_y cy在关联深度和RGB图中至关重要。
    相机参数矩阵为:
    [ f x 0 c x 0 f y c y 0 0 1 ] \begin{bmatrix} f_x & 0&c_x \\ 0 & f_y&c_y\\0&0&1 \end{bmatrix} \quad fx000fy0cxcy1
  2. 估计相机轨迹
    在估计相机轨迹的过程中有两种衡量误差的方式。一种是绝对轨迹误差ATE,此类方法适合估计视觉SLAM系统的性能。另一种是相对位姿误差RPE,此类方法适合测量视觉里程计系统的漂移。表示相机位置有七个变量,分别是位置参数 t x , t y , t z t_x, t_y, t_z tx,ty,tz和方向参数 q x , q y , q z , q w q_x, q_y, q_z, q_w qx,qy,qz,qw,前者代表相机光学中心在空间中的位置,后者代表相机光学中心的方向。

数据预处理

在正式执行重建算法之前,作者进行了一些数据集的预处理。preprocess.py在重建程序执行前使用的,目的是关联RGBD文件的名称,保存成更易处理的格式。
preprocess.py

import os
import math
import shutil
import numpy as np
import argparse
from tum_rgbd import get_calib
from utils import load_config

def read_file_list(filename):
    """
    从txt文件中读取轨迹,此处遵循的是tum中相机trajectory的数据格式
    Reads a trajectory from a text file.
    File format:
    The file format is "stamp d1 d2 d3 ...", where stamp denotes the time stamp (to be matched)
    and "d1 d2 d3.." is arbitary data (e.g., a 3D position and 3D orientation) associated to this timestamp.
    Input:
    filename -- File name
    Output:
    dict -- dictionary of (stamp,data) tuples
    """
    
    file = open(filename)
    data = file.read()
    lines = data.replace(","," ").replace("\t"," ").split("\n")
    list = [[v.strip() for v in line.split(" ") if v.strip()!=""] for line in lines if len(line)>0 and line[0]!="#"]
    list = [(float(l[0]),l[1:]) for l in list if len(l)>1]
    return dict(list)

# 深度相机和色彩相机获取图像的时间帧存在时间差,需要对齐。最大时间差是0.02s
# 把拍摄时间相近的RGB图片和Depth图片,写在associate.txt文件的同一行,这就表示,两个图是一对 RGB-D图片。
def associate(first_list, second_list, offset=0.0, max_difference=0.02):
    """
    Associate two dictionaries of (stamp,data). As the time stamps never match exactly, we aim
    to find the closest match for every input tuple.
    Input:
    first_list -- first dictionary of (stamp,data) tuples
    second_list -- second dictionary of (stamp,data) tuples
    offset -- time offset between both dictionaries (e.g., to models the delay between the sensors)
    max_difference -- search radius for candidate generation
    Output:
    matches -- list of matched tuples ((stamp1,data1),(stamp2,data2))
    """
    first_keys = list(first_list)
    second_keys = list(second_list)
    potential_matches = [(abs(a - (b + offset)), a, b)
                         for a in first_keys
                         for b in second_keys
                         if abs(a - (b + offset)) < max_difference]
    potential_matches.sort()
    matches = []
    for diff, a, b in potential_matches:
        if a in first_keys and b in second_keys:
            first_keys.remove(a)
            second_keys.remove(b)
            matches.append((a, b))

    matches.sort()
    return matches

# 从输入的文件中得到匹配的文件对,并保存
def get_association(file_a, file_b, out_file):
    first_list = read_file_list(file_a)
    second_list = read_file_list(file_b)
    matches = associate(first_list, second_list)
    with open(out_file, "w") as f:
        for a, b in matches:
            line = "%f %s %f %s\n" % (a, " ".join(first_list[a]), b, " ".join(second_list[b]))
            f.write(line)
            
# 此处是相机旋转角度信息,从四元数到齐次矩阵的转换
def tum2matrix(pose):
    """
    Return homogeneous rotation matrix from quaternion.
    """
    # 只截取前3个位置坐标
    t = pose[:3]
    # under TUM format q is in the order of [x, y, z, w], need change to [w, x, y, z]
    quaternion = [pose[6], pose[3], pose[4], pose[5]]
    q = np.array(quaternion, dtype=np.float64, copy=True)
    n = np.dot(q, q)
    if n < np.finfo(np.float64).eps:
        return np.identity(4)

    q *= math.sqrt(2.0 / n)
    q = np.outer(q, q)
    return np.array([
        [1.0-q[2, 2]-q[3, 3],     q[1, 2]-q[3, 0],     q[1, 3]+q[2, 0], t[0]],
        [    q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3],     q[2, 3]-q[1, 0], t[1]],
        [    q[1, 3]-q[2, 0],     q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], t[2]],
        [0., 0., 0., 1.]])

# 从association中得到相机位姿
def get_poses_from_associations(fname):
    poses = []
    with open(fname) as f:
        for line in f.readlines():
            pose_str = line.strip("\n").split(" ")[-7:]
            pose = [float(p) for p in pose_str]
            poses += [tum2matrix(pose)]
    return poses

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # 输入待处理的数据集,并保存在processed文件夹里面
    # standard configs
    parser.add_argument('--config', type=str, default="../configs/fr1_desk.yaml", help='Path to config file.')
    args = load_config(parser.parse_args())
    out_dir = os.path.join(args.data_root, "processed")

    # create association files
    get_association(os.path.join(args.data_root, "depth.txt"), os.path.join(args.data_root, "groundtruth.txt"), os.path.join(args.data_root, "dep_traj.txt"))
    get_association(os.path.join(args.data_root, "rgb.txt"), os.path.join(args.data_root, "dep_traj.txt"), os.path.join(args.data_root, "rgb_dep_traj.txt"))

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    out_rgb_dir = os.path.join(out_dir, "rgb")
    if not os.path.exists(out_rgb_dir):
        os.makedirs(out_rgb_dir)
    out_dep_dir = os.path.join(out_dir, "depth")
    if not os.path.exists(out_dep_dir):
        os.makedirs(out_dep_dir)

    # rename image files and save c2w poses
    # 此处将图像文件编码之后重新储存
    poses = []
    with open(os.path.join(args.data_root, "rgb_dep_traj.txt")) as f:
        for i, line in enumerate(f.readlines()):
            line_list = line.strip().split(" ")
            rgb_file = line_list[1]
            shutil.copyfile(os.path.join(args.data_root, rgb_file), os.path.join(out_rgb_dir, "%04d.png" % i))
            dep_file = line_list[3]
            shutil.copyfile(os.path.join(args.data_root, dep_file), os.path.join(out_dep_dir, "%04d.png" % i))
            poses += [tum2matrix([float(x) for x in line_list[5:]])]

    np.savez(os.path.join(out_dir, "raw_poses.npz"), c2w_mats=poses)

    # save projection matrices
    # 保存投影矩阵
    K = np.eye(3)
    intri = get_calib()[args.data_type] # 此处得到相机内参
    K[0, 0] = intri[0]
    K[1, 1] = intri[1]
    K[0, 2] = intri[2]
    K[1, 2] = intri[3]
    camera_dict = np.load(os.path.join(out_dir, "raw_poses.npz"))
    poses = camera_dict["c2w_mats"]
    P_mats = []
    for c2w in poses:
        w2c = np.linalg.inv(c2w)
        P = K @ w2c[:3, :]
        P_mats += [P]
    np.savez(os.path.join(out_dir, "cameras.npz"), world_mats=P_mats)

算法主函数

首先从main函数开始,kinfu.py
此段是代码执行的主函数,输入RGBD数据集,重建为网格并保存为.ply文件。

import os
import argparse # 用于解析命令行参数和选项
import numpy as np
import torch
import cv2
import trimesh # 用于绘制网格
from matplotlib import pyplot as plt
from fusion import TSDFVolumeTorch # 外部函数,用于Volumetric Integration
from dataset.tum_rgbd import TUMDataset, TUMDatasetOnline # 用来处理RGBD数据集
from tracker import ICPTracker # 用来追踪相机位姿(相机位姿是通过RGB图像估计的)
from utils import load_config, get_volume_setting, get_time #(用来获取重建参数)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # standard configs
    # 执行命令时使用,确定数据集和重建结果保存的路径
    parser.add_argument('--config', type=str, default="configs/fr1_desk.yaml", help='Path to config file.')
    parser.add_argument("--save_dir", type=str, default=None, help="Directory of saving results.")
    args = load_config(parser.parse_args())

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")
   
    # 处理数据集,此处用到了tum_rgbd中的函数
    # 此处获得了相机到世界的转换矩阵,相机内参,rgb图和深度图
    dataset = TUMDataset(os.path.join(args.data_root), device, near=args.near, far=args.far, img_scale=0.25)
    # 此处获得了图像的高和宽
    H, W = dataset.H, dataset.W
	
	# 重建的体空间在config文件中有定义
    vol_dims, vol_origin, voxel_size = get_volume_setting(args)
    # 不同帧体素中体素的融合,信息的动态合并过程,由面到体,此处调用了GPU加速
    tsdf_volume = TSDFVolumeTorch(vol_dims, vol_origin, voxel_size, device, margin=3, fuse_color=True)
    # 根据图像特征追踪相机位姿
    icp_tracker = ICPTracker(args, device)
    
    # 时间,位姿,位姿变换
    t, poses, poses_gt = list(), list(), list()
    curr_pose, depth1, color1 = None, None, None
    for i in range(0, len(dataset), 1):
        t0 = get_time()
        sample = dataset[i]
        color0, depth0, pose_gt, K = sample  # use live image as template image (0)
        # depth0[depth0 <= 0.5] = 0.

	   # 初始化位姿
        if i == 0:  # initialize
            curr_pose = pose_gt
        else:  # tracking
            # 1. render depth image (1) from tsdf volume
            depth1, color1, vertex01, normal1, mask1 = tsdf_volume.render_model(curr_pose, K, H, W, near=args.near, far=args.far, n_samples=args.n_steps)
            # depth0和depth1是什么?
            T10 = icp_tracker(depth0, depth1, K)  # transform from 0 to 1
            curr_pose = curr_pose @ T10

        # fusion
        # 此处将不同帧所得深度图像进行融合
        tsdf_volume.integrate(depth0,
                              K,
                              curr_pose,
                              obs_weight=1.,
                              color_img=color0
                              )
        t1 = get_time()
        t += [t1 - t0]
        print("processed frame: {:d}, time taken: {:f}s".format(i, t1 - t0))
        poses += [curr_pose.cpu().numpy()]
        poses_gt += [pose_gt.cpu().numpy()]

    avg_time = np.array(t).mean()
    print("average processing time: {:f}s per frame, i.e. {:f} fps".format(avg_time, 1. / avg_time))
    
    # compute tracking ATE
    # 本文使用ATE来估计追踪效果
    poses_gt = np.stack(poses_gt, 0)
    poses = np.stack(poses, 0)
    traj_gt = np.array(poses_gt)[:, :3, 3]
    traj = np.array(poses)[:, :3, 3]
    rmse = np.sqrt(np.mean(np.linalg.norm(traj_gt - traj, axis=-1) ** 2))
    print("RMSE: {:f}".format(rmse))
    # plt.plot(traj[:, 0], traj[:, 1])
    # plt.plot(traj_gt[:, 0], traj_gt[:, 1])
    # plt.legend(['Estimated', 'GT'])
    # plt.show()

    # save results
    if args.save_dir is not None:
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        
        # 绘制重建网格
        verts, faces, norms, colors = tsdf_volume.get_mesh()
        partial_tsdf = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=norms, vertex_colors=colors)
        partial_tsdf.export(os.path.join(args.save_dir, "mesh.ply"))
        np.savez(os.path.join(args.save_dir, "traj.npz"), poses=poses)
        np.savez(os.path.join(args.save_dir, "traj_gt.npz"), poses=poses_gt)

数据集的处理

tum_rgbd.py
此代码主要是为了在程序执行过程中获得RGB图,深度图,相机位置转换矩阵和相机内参四个数组:rgb, depth, c2w, K。

import torch
from os import path
from tqdm import tqdm
import imageio
import cv2
import numpy as np
import open3d as o3d


def get_calib(): # TUM数据集有三组相机。此处用于获取拍摄图像时相机的参数,fx, fy, cx, cy。默认每个相机的焦距和光学中心的位置是固定的。
    return {
        "fr1": [517.306408, 516.469215, 318.643040, 255.313989],
        "fr2": [520.908620, 521.007327, 325.141442, 249.701764],
        "fr3": [535.4, 539.2, 320.1, 247.6]
    }


# Note,this step converts w2c (Tcw) to c2w (Twc)
# 从世界到相机转换矩阵,化为相机到世界转换矩阵
# 获取相机内参和旋转、平移矩阵
def load_K_Rt_from_P(P):
    """
    modified from IDR https://github.com/lioryariv/idr
    """
    # 此处参见opencv官方文件
    out = cv2.decomposeProjectionMatrix(P)
    K = out[0]
    R = out[1]
    t = out[2]

    # 相机参数归一化
    K = K/K[2,2]
    intrinsics = np.eye(4)
    intrinsics[:3, :3] = K

    pose = np.eye(4, dtype=np.float32)
    pose[:3, :3] = R.transpose()  # convert from w2c to c2w
    pose[:3, 3] = (t[:3] / t[3])[:, 0]

    return intrinsics, pose

# TUM数据集的数据类型
class TUMDataset(torch.utils.data.Dataset):
    """
    TUM dataset loader, pre-load images in advance
    """
    # 此处定义了重建参数
    def __init__(
            self,
            rootdir,
            device,
            near: float = 0.2, #设置最近点的范围
            far: float = 5., #最远点
            img_scale: float = 1.,  # image scale factor
            start: int = -1,
            end: int = -1,
    ):
        super().__init__()
        assert path.isdir(rootdir), f"'{rootdir}' is not a directory"
        self.device = device
        self.c2w_all = [] # 相机到世界坐标系的变换矩阵
        self.K_all = [] # 相机内参
        self.rgb_all = [] # rgb图
        self.depth_all = [] # 深度图

        # root should be tum_sequence
        data_path = path.join(rootdir, "processed") # 重新整理后的RGBD数据集
        cam_file = path.join(data_path, "cameras.npz") # 相机位置数组
        print("LOAD DATA", data_path)

        # world_mats, normalize_mat
        cam_dict = np.load(cam_file) # 相机位置
        world_mats = cam_dict["world_mats"]  # K @ w2c # 世界坐标

        d_min = [] # 最小深度
        d_max = [] # 最大深度
        
        # TUM saves camera poses in OpenCV convention
        # tqdm 提示进度信息
        for i, world_mat in enumerate(tqdm(world_mats)):
            # ignore all the frames betfore
            if start > 0 and i < start:
                continue
            # ignore all the frames after
            if 0 < end < i:
                break

            intrinsics, c2w = load_K_Rt_from_P(world_mat)
            c2w = torch.tensor(c2w, dtype=torch.float32)
            # read images
            # 提取深度图和彩色图数据,保存为数组
            rgb = np.array(imageio.imread(path.join(data_path, "rgb/{:04d}.png".format(i)))).astype(np.float32)
            depth = np.array(imageio.imread(path.join(data_path, "depth/{:04d}.png".format(i)))).astype(np.float32)
            depth /= 5000.  # TODO: put depth factor to args
            d_max += [depth.max()]
            d_min += [depth.min()]
            # depth = cv2.bilateralFilter(depth, 5, 0.2, 15)
            # print(depth[depth > 0.].min())
            # 超出距离范围的都定义为无效点
            invalid = (depth < near) | (depth > far)
            depth[invalid] = -1.
            # downscale the image size if needed
            if img_scale < 1.0:
                full_size = list(rgb.shape[:2])
                rsz_h, rsz_w = [round(hw * img_scale) for hw in full_size]
                # TODO: figure out which way is better: skimage.rescale or cv2.resize
                rgb = cv2.resize(rgb, (rsz_w, rsz_h), interpolation=cv2.INTER_AREA)
                depth = cv2.resize(depth, (rsz_w, rsz_h), interpolation=cv2.INTER_NEAREST)
                intrinsics[0, 0] *= img_scale
                intrinsics[1, 1] *= img_scale
                intrinsics[0, 2] *= img_scale
                intrinsics[1, 2] *= img_scale

            self.c2w_all.append(c2w)
            self.K_all.append(torch.from_numpy(intrinsics[:3, :3]))
            self.rgb_all.append(torch.from_numpy(rgb))
            self.depth_all.append(torch.from_numpy(depth))
        print("Depth min: {:f}".format(np.array(d_min).min()))
        print("Depth max: {:f}".format(np.array(d_max).max()))
        self.n_images = len(self.rgb_all)
        self.H, self.W, _ = self.rgb_all[0].shape

    def __len__(self):
        return self.n_images

    def __getitem__(self, idx):
        return self.rgb_all[idx].to(self.device), self.depth_all[idx].to(self.device), \
               self.c2w_all[idx].to(self.device), self.K_all[idx].to(self.device)


class TUMDatasetOnline(torch.utils.data.Dataset):
    """
    Online TUM dataset loader, load images when __getitem__() is called
    """

    def __init__(
            self,
            rootdir,
            device,
            near: float = 0.2,
            far: float = 5.,
            img_scale: float = 1.,  # image scale factor
            start: int = -1,
            end: int = -1,
    ):
        super().__init__()
        assert path.isdir(rootdir), f"'{rootdir}' is not a directory"
        self.device = device
        self.img_scale = img_scale
        self.near = near
        self.far = far
        self.c2w_all = []
        self.K_all = []
        self.rgb_files_all = []
        self.depth_files_all = []

        # root should be tum_sequence
        data_path = path.join(rootdir, "processed")
        cam_file = path.join(data_path, "cameras.npz")
        print("LOAD DATA", data_path)

        # world_mats, normalize_mat
        cam_dict = np.load(cam_file)
        world_mats = cam_dict["world_mats"]  # K @ w2c

        # TUM saves camera poses in OpenCV convention
        for i, world_mat in enumerate(world_mats):
            # ignore all the frames betfore
            if start > 0 and i < start:
                continue
            # ignore all the frames after
            if 0 < end < i:
                break

            intrinsics, c2w = load_K_Rt_from_P(world_mat)
            c2w = torch.tensor(c2w, dtype=torch.float32)
            self.c2w_all.append(c2w)
            self.K_all.append(torch.from_numpy(intrinsics[:3, :3]))
            self.rgb_files_all.append(path.join(data_path, "rgb/{:04d}.png".format(i)))
            self.depth_files_all.append(path.join(data_path, "depth/{:04d}.png".format(i)))

        self.n_images = len(self.rgb_files_all)
        H, W, _ = np.array(imageio.imread(self.rgb_files_all[0])).shape
        self.H = round(H * img_scale)
        self.W = round(W * img_scale)

    def __len__(self):
        return self.n_images

    def __getitem__(self, idx):
        K = self.K_all[idx].to(self.device)
        c2w = self.c2w_all[idx].to(self.device)
        # read images
        rgb = np.array(imageio.imread(self.rgb_files_all[idx])).astype(np.float32)
        depth = np.array(imageio.imread(self.depth_files_all[idx])).astype(np.float32)
        depth /= 5000.
        # depth = cv2.bilateralFilter(depth, 5, 0.2, 15)
        depth[depth < self.near] = 0.
        depth[depth > self.far] = -1.
        # downscale the image size if needed
        if self.img_scale < 1.0:
            full_size = list(rgb.shape[:2])
            rsz_h, rsz_w = [round(hw * self.img_scale) for hw in full_size]
            rgb = cv2.resize(rgb, (rsz_w, rsz_h), interpolation=cv2.INTER_AREA)
            depth = cv2.resize(depth, (rsz_w, rsz_h), interpolation=cv2.INTER_NEAREST)
            K[0, 0] *= self.img_scale
            K[1, 1] *= self.img_scale
            K[0, 2] *= self.img_scale
            K[1, 2] *= self.img_scale

        rgb = torch.from_numpy(rgb).to(self.device)
        depth = torch.from_numpy(depth).to(self.device)

        return rgb, depth, c2w, K

体素的融合

fusion.py
此处还涉及图像金字塔的实现。

import os
import numpy as np
from skimage import measure
import torch
import cv2
import open3d as o3d
import imageio


def integrate(
        depth_im,
        cam_intr,
        cam_pose,
        obs_weight, # 权重,在KinectFusion算法有特殊的权重设置
        world_c,  # world coordinates grid [nx*ny*nz, 4]
        vox_coords,  # voxel coordinates grid [nx*ny*nz, 3]
        weight_vol,  # weight volume [nx, ny, nz]
        tsdf_vol,  # tsdf volume [nx, ny, nz]
        sdf_trunc,
        im_h,
        im_w,
        color_vol=None,
        color_im=None,
):

    world2cam = torch.inverse(cam_pose) # cam_pose中记录相机坐标到相机世界坐标的转换关系,现在进行转置,获得相机到世界的转换关系
    cam_c = torch.matmul(world2cam, world_c.transpose(1, 0)).transpose(1, 0).float()  # [nx*ny*nz, 4],连乘,此处获得了相机的位置
    # Convert camera coordinates to pixel coordinates,将相机坐标转化为像素坐标
    fx, fy = cam_intr[0, 0], cam_intr[1, 1]
    cx, cy = cam_intr[0, 2], cam_intr[1, 2]
    pix_z = cam_c[:, 2]
    # project all the voxels back to image plane,将所有体素投射至成像面
    pix_x = torch.round((cam_c[:, 0] * fx / cam_c[:, 2]) + cx).long()  # [nx*ny*nz]
    pix_y = torch.round((cam_c[:, 1] * fy / cam_c[:, 2]) + cy).long()  # [nx*ny*nz]

    # Eliminate pixels outside view frustum,消除视锥之外的像素
    valid_pix = (pix_x >= 0) & (pix_x < im_w) & (pix_y >= 0) & (pix_y < im_h) & (pix_z > 0)  # [n_valid]
    valid_vox_x = vox_coords[valid_pix, 0]
    valid_vox_y = vox_coords[valid_pix, 1]
    valid_vox_z = vox_coords[valid_pix, 2]
    depth_val = depth_im[pix_y[valid_pix], pix_x[valid_pix]]  # [n_valid]

    # Integrate tsdf
    depth_diff = depth_val - pix_z[valid_pix]
    dist = torch.clamp(depth_diff / sdf_trunc, max=1)
    valid_pts = (depth_val > 0.) & (depth_diff >= -sdf_trunc)  # all points 1. inside frustum 2. with valid depth 3. outside -truncate_dist
    valid_vox_x = valid_vox_x[valid_pts]
    valid_vox_y = valid_vox_y[valid_pts]
    valid_vox_z = valid_vox_z[valid_pts]
    valid_dist = dist[valid_pts]
    w_old = weight_vol[valid_vox_x, valid_vox_y, valid_vox_z]
    tsdf_vals = tsdf_vol[valid_vox_x, valid_vox_y, valid_vox_z]
    w_new = w_old + obs_weight
    tsdf_vol[valid_vox_x, valid_vox_y, valid_vox_z] = (w_old * tsdf_vals + obs_weight * valid_dist) / w_new
    weight_vol[valid_vox_x, valid_vox_y, valid_vox_z] = w_new

    if color_vol is not None and color_im is not None:
        old_color = color_vol[valid_vox_x, valid_vox_y, valid_vox_z]
        new_color = color_im[pix_y[valid_pix], pix_x[valid_pix]]
        new_color = new_color[valid_pts]
        color_vol[valid_vox_x, valid_vox_y, valid_vox_z, :] = (w_old[:, None] * old_color + obs_weight * new_color) / w_new[:, None]

    return weight_vol, tsdf_vol, color_vol


class TSDFVolumeTorch:
    """
    Volumetric TSDF Fusion of RGB-D Images.
    """

    def __init__(self, voxel_dim, origin, voxel_size, device, margin=3, fuse_color=False):
        """
        Args:
            voxel_dim (ndarray): [3,] stores volume dimensions: Nx, Ny, Nz
            origin (ndarray): [3,] world coordinate of voxel [0, 0, 0]
            voxel_size (float): The volume discretization in meters.
        """

        self.device = device
        # Define voxel volume parameters
        self.voxel_size = float(voxel_size)
        self.sdf_trunc = margin * self.voxel_size
        self.integrate_func = integrate
        self.fuse_color = fuse_color

        # Adjust volume bounds
        if isinstance(voxel_dim, list):
            voxel_dim = torch.Tensor(voxel_dim).to(self.device)
        elif isinstance(voxel_dim, np.ndarray):
            voxel_dim = torch.from_numpy(voxel_dim).to(self.device)
        if isinstance(origin, list):
            origin = torch.Tensor(origin).to(self.device)
        elif isinstance(origin, np.ndarray):
            origin = torch.from_numpy(origin).to(self.device)

        self.vol_dim = voxel_dim.long()
        self.vol_origin = origin
        self.num_voxels = torch.prod(self.vol_dim).item()

        # Get voxel grid coordinates
        xv, yv, zv = torch.meshgrid(
            torch.arange(0, self.vol_dim[0]),
            torch.arange(0, self.vol_dim[1]),
            torch.arange(0, self.vol_dim[2]),
        )
        self.vox_coords = torch.stack([xv.flatten(), yv.flatten(), zv.flatten()], dim=1).long().to(self.device)

        # Convert voxel coordinates to world coordinates
        self.world_c = self.vol_origin + (self.voxel_size * self.vox_coords)
        self.world_c = torch.cat([
            self.world_c, torch.ones(len(self.world_c), 1, device=self.device)], dim=1).float()
        self.reset()

    def reset(self):
        """Set volumes
        """
        self.tsdf_vol = torch.ones(*self.vol_dim).to(self.device)
        self.weight_vol = torch.zeros(*self.vol_dim).to(self.device)
        if self.fuse_color:
            # [nx, ny, nz, 3]
            self.color_vol = torch.zeros(*self.vol_dim, 3).to(self.device)
        else:
            self.color_vol = None

    def data_transfer(self, data):
        if isinstance(data, np.ndarray):
            data = torch.from_numpy(data)
        return data.float().to(self.device)

    @torch.no_grad()
    def integrate(self, depth_im, cam_intr, cam_pose, obs_weight, color_img=None):
        """Integrate an RGB-D frame into the TSDF volume.
        Args:
        depth_im (torch.Tensor): A depth image of shape (H, W).
        cam_intr (torch.Tensor): The camera intrinsics matrix of shape (3, 3).
        cam_pose (torch.Tensor): The camera pose (i.e. extrinsics) of shape (4, 4). T_wc
        obs_weight (float): The weight to assign to the current observation.
        """

        cam_pose = self.data_transfer(cam_pose)
        cam_intr = self.data_transfer(cam_intr)
        depth_im = self.data_transfer(depth_im)
        if color_img is not None:
            color_img = self.data_transfer(color_img)
        else:
            color_img = None
        im_h, im_w = depth_im.shape
        # fuse
        weight_vol, tsdf_vol, color_vol = self.integrate_func(
            depth_im,
            cam_intr,
            cam_pose,
            obs_weight,
            self.world_c,
            self.vox_coords,
            self.weight_vol,
            self.tsdf_vol,
            self.sdf_trunc,
            im_h, im_w,
            self.color_vol,
            color_img,
        )
        self.weight_vol = weight_vol
        self.tsdf_vol = tsdf_vol
        self.color_vol = color_vol

    def get_volume(self):
        return self.tsdf_vol, self.weight_vol, self.color_vol

    def get_mesh(self):
        """Compute a mesh from the voxel volume using marching cubes.
        """
        tsdf_vol, weight_vol, color_vol = self.get_volume()
        verts, faces, norms, vals = measure.marching_cubes(tsdf_vol.cpu().numpy(), level=0)
        verts_ind = np.round(verts).astype(int)
        verts = verts * self.voxel_size + self.vol_origin.cpu().numpy()  # voxel grid coordinates to world coordinates

        if self.fuse_color:
            rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]].cpu().numpy()
            return verts, faces, norms, rgb_vals.astype(np.uint8)
        else:
            return verts, faces, norms

    def to_o3d_mesh(self):
        """Convert to o3d mesh object for visualization
        """
        verts, faces, norms, colors = self.get_mesh()
        mesh = o3d.geometry.TriangleMesh()
        mesh.vertices = o3d.utility.Vector3dVector(verts.astype(float))
        mesh.triangles = o3d.utility.Vector3iVector(faces.astype(np.int32))
        mesh.vertex_colors = o3d.utility.Vector3dVector(colors / 255.)
        return mesh

    def get_normals(self):
        """Compute normal volume
        """
        nx, ny, nz = self.vol_dim
        device = self.device
        # dx = torch.cat([torch.zeros(1, ny, nz).to(device), (self.tsdf_vol[2:, :, :] - self.tsdf_vol[:-2, :, :]) / (2 * self.voxel_size), torch.zeros(1, ny, nz).to(device)], dim=0)
        # dy = torch.cat([torch.zeros(nx, 1, nz).to(device), (self.tsdf_vol[:, 2:, :] - self.tsdf_vol[:, :-2, :]) / (2 * self.voxel_size), torch.zeros(nx, 1, nz).to(device)], dim=1)
        # dz = torch.cat([torch.zeros(nx, ny, 1).to(device), (self.tsdf_vol[:, :, 2:] - self.tsdf_vol[:, :, :-2]) / (2 * self.voxel_size), torch.zeros(nx, ny, 1).to(device)], dim=2)
        # norms = torch.stack([dx, dy, dz], -1)
        dx = torch.cat([(self.tsdf_vol[1:, :, :] - self.tsdf_vol[:-1, :, :]) / self.voxel_size, torch.zeros(1, ny, nz).to(device)], dim=0)
        dy = torch.cat([(self.tsdf_vol[:, 1:, :] - self.tsdf_vol[:, :-1, :]) / self.voxel_size, torch.zeros(nx, 1, nz).to(device)], dim=1)
        dz = torch.cat([(self.tsdf_vol[:, :, 1:] - self.tsdf_vol[:, :, :-1]) / self.voxel_size, torch.zeros(nx, ny, 1).to(device)], dim=2)
        norms = torch.stack([dx, dy, dz], -1)
        n = torch.norm(norms, dim=-1)
        # remove large values
        outliers_mask = n > 1. / (2 * self.voxel_size)
        norms[outliers_mask] = 0.
        # normalize
        eps = 1e-7
        non_zero_grad = n > eps
        norms[non_zero_grad, :] = norms[non_zero_grad, :] / n[non_zero_grad][:, None]
        return norms  # [nx, ny, nz, 3]

    def get_nn(self, field_vol, coords_w):
        """Get nearest-neigbor values from a given volume
        """
        field_dim = field_vol.shape
        assert len(field_dim) == 3 or len(field_dim) == 4
        vox_coord_float = (coords_w - self.vol_origin[None, :]) / self.voxel_size
        vox_coord = torch.floor(vox_coord_float)
        vox_offset = vox_coord_float - vox_coord  # [N, 3]
        vox_coord[vox_offset >= 0.5] += 1.
        vox_coord[:, 0] = torch.clamp(vox_coord[:, 0], 0., self.vol_dim[0] - 1)
        vox_coord[:, 1] = torch.clamp(vox_coord[:, 1], 0., self.vol_dim[1] - 1)
        vox_coord[:, 2] = torch.clamp(vox_coord[:, 2], 0., self.vol_dim[2] - 1)
        vox_coord = vox_coord.long()
        vx, vy, vz = vox_coord[:, 0], vox_coord[:, 1], vox_coord[:, 2]
        v_nn = field_vol[vx, vy, vz]
        return v_nn

    def tril_interp(self, field_vol, coords_w):
        """Get tri-linear interpolated value from a given volume
        """
        field_dim = field_vol.shape
        assert len(field_dim) == 3 or len(field_dim) == 4
        n_pts = coords_w.shape[0]
        vox_coord = torch.floor((coords_w - self.vol_origin[None, :]) / self.voxel_size).long()  # [N, 3]

        # for border points, don't do interpolation
        non_border_mask = (vox_coord[:, 0] < self.vol_dim[0] - 1) & (vox_coord[:, 1] < self.vol_dim[1] - 1) & \
                          (vox_coord[:, 2] < self.vol_dim[2] - 1)
        v_interp = torch.zeros(n_pts) if len(field_dim) == 3 else torch.zeros(n_pts, field_vol.shape[-1])
        v_interp = v_interp.to(self.device)
        vx_, vy_, vz_ = vox_coord[~non_border_mask, 0], vox_coord[~non_border_mask, 1], vox_coord[~non_border_mask, 2]
        v_interp[~non_border_mask] = field_vol[vx_, vy_, vz_]

        # get interpolated values for normal points
        vx, vy, vz = vox_coord[non_border_mask, 0], vox_coord[non_border_mask, 1], vox_coord[non_border_mask, 2]  # [N]
        vox_idx = vz + vy * self.vol_dim[-1] + vx * self.vol_dim[-1] * self.vol_dim[-2]
        vertices_coord = self.world_c[vox_idx][:, :3]  # [N, 3]
        r = (coords_w[non_border_mask] - vertices_coord) / self.voxel_size
        rx, ry, rz = r[:, 0], r[:, 1], r[:, 2]
        if len(field_dim) == 4:
            rx = rx.unsqueeze(1)
            ry = ry.unsqueeze(1)
            rz = rz.unsqueeze(1)
        # get values at eight corners
        v000 = field_vol[vx, vy, vz]
        v001 = field_vol[vx, vy, vz+1]
        v010 = field_vol[vx, vy+1, vz]
        v011 = field_vol[vx, vy+1, vz+1]
        v100 = field_vol[vx+1, vy, vz]
        v101 = field_vol[vx+1, vy, vz+1]
        v110 = field_vol[vx+1, vy+1, vz]
        v111 = field_vol[vx+1, vy+1, vz+1]
        v_interp[non_border_mask] = v000 * (1 - rx) * (1 - ry) * (1 - rz) \
                                   + v001 * (1 - rx) * (1 - ry) * rz \
                                   + v010 * (1 - rx) * ry * (1 - rz) \
                                   + v011 * (1 - rx) * ry * rz \
                                   + v100 * rx * (1 - ry) * (1 - rz) \
                                   + v101 * rx * (1 - ry) * rz \
                                   + v110 * rx * ry * (1 - rz) \
                                   + v111 * rx * ry * rz

        return v_interp

    def get_pts_inside(self, pts, margin=0):
        vox_coord = torch.floor((pts - self.vol_origin[None, :]) / self.voxel_size).long()  # [N, 3]
        valid_pts_mask = (vox_coord[..., 0] >= margin) & (vox_coord[..., 0] < self.vol_dim[0] - margin) \
                         & (vox_coord[..., 1] >= margin) & (vox_coord[..., 1] < self.vol_dim[1] - margin) \
                         & (vox_coord[..., 2] >= margin) & (vox_coord[..., 2] < self.vol_dim[2] - margin)
        return valid_pts_mask

    # use simple root finding
    @torch.no_grad()
    def render_model(self, c2w, intri, imh, imw, near=0.5, far=5., n_samples=192):
        """
        Perform ray-casting for frame-to-model tracking
        :param c2w: camera pose, [4, 4]
        :param intri: camera intrinsics, [3, 3]
        :param imh: image height
        :param imw: image width
        :param near: near bound for ray-casting
        :param far: far bound for ray-casting
        :param n_samples: number of samples along the ray
        :return: rendered depth, color, vertex, normal and valid mask, [H, W, C]
        """
        rays_o, rays_d = self.get_rays(c2w, intri, imh, imw)  # [h, w, 3]
        z_vals = torch.linspace(near, far, n_samples).to(rays_o)  # [n_samples]
        ray_pts_w = (rays_o[:, :, None, :] + rays_d[:, :, None, :] * z_vals[None, None, :, None]).to(self.device)  # [h, w, n_samples, 3]

        # need to query the tsdf and feature grid
        tsdf_vals = torch.ones(imh, imw, n_samples).to(self.device)
        # filter points that are outside the volume
        valid_ray_pts_mask = self.get_pts_inside(ray_pts_w)
        valid_ray_pts = ray_pts_w[valid_ray_pts_mask]  # [n_valid, 3]
        tsdf_vals[valid_ray_pts_mask] = self.tril_interp(self.tsdf_vol, valid_ray_pts)

        # surface prediction by finding zero crossings
        sign_matrix = torch.cat([torch.sign(tsdf_vals[..., :-1] * tsdf_vals[..., 1:]),
                                 torch.ones(imh, imw, 1).to(self.device)], dim=-1)  # [h, w, n_samples]
        cost_matrix = sign_matrix * torch.arange(n_samples, 0, -1).float().to(self.device)[None, None, :]  # [h, w, n_samples]
        # Get first sign change and mask for values where
        # a.) a sign changed occurred and
        # b.) not a neg to pos sign change occurred
        # c.) ignore border points
        values, indices = torch.min(cost_matrix, -1)
        mask_sign_change = values < 0
        hs, ws = torch.meshgrid(torch.arange(imh), torch.arange(imw))
        mask_pos_to_neg = tsdf_vals[hs, ws, indices] > 0
        inside_vol = self.get_pts_inside(ray_pts_w[hs, ws, indices])
        hit_surface_mask = mask_sign_change & mask_pos_to_neg & inside_vol
        hit_pts = ray_pts_w[hs, ws, indices][hit_surface_mask]  # [n_surf_pts, 3]

        # compute normals
        norms = self.get_normals()
        surf_tsdf = self.tril_interp(self.tsdf_vol, hit_pts)  # [n_surf_pts]
        # surf_norms = self.tril_interp(norms, hit_pts)  # [n_surf_pts, 3]
        surf_norms = self.get_nn(norms, hit_pts)
        updated_hit_pts = hit_pts - surf_tsdf[:, None] * self.sdf_trunc * surf_norms
        valid_mask = self.get_pts_inside(updated_hit_pts)
        hit_pts[valid_mask, :] = updated_hit_pts[valid_mask, :]

        # get depth values
        w2c = torch.inverse(c2w).to(self.device)
        hit_pts_c = (w2c[:3, :3] @ hit_pts.transpose(1, 0)).transpose(1, 0) + w2c[:3, 3][None, :]
        hit_pts_z = hit_pts_c[:, -1]
        depth_rend = torch.zeros(imh, imw).to(self.device)
        # depth_rend[hit_surface_mask] = z_vals[indices[hit_surface_mask]]
        depth_rend[hit_surface_mask] = hit_pts_z

        # vertex map
        vertex_rend = torch.zeros(imh, imw, 3).to(self.device)
        vertex_rend[hit_surface_mask] = hit_pts_c
        # normal map
        surf_norms_c = (w2c[:3, :3] @ surf_norms.transpose(1, 0)).transpose(1, 0)  # [h, w, 3]
        normal_rend = torch.zeros(imh, imw, 3).to(self.device)
        normal_rend[hit_surface_mask] = surf_norms_c

        if self.color_vol is not None:
            # hit_colors = self.color_vol[cx, cy, cz, :]
            hit_colors = self.tril_interp(self.color_vol, hit_pts)
            # set color
            color_rend = torch.zeros(imh, imw, 3).to(self.device)
            color_rend[hit_surface_mask] = hit_colors
        else:
            color_rend = None

        return depth_rend, color_rend, vertex_rend, normal_rend, hit_surface_mask

    def render_pyramid(self, c2w, intri, imh, imw, n_pyr=4, near=0.5, far=5., n_samples=192):
        K = intri.clone()
        dep_pyr, rgb_pyr, vtx_pyr, nrm_pyr, mask_pyr = [], [], [], [], []
        for l in range(n_pyr):
            dep, rgb, feat, vtx, nrm, mask = self.render_model(c2w, K, imh, imw, near=near, far=far, n_samples=n_samples)
            dep_pyr += [dep]
            rgb_pyr += [rgb]
            vtx_pyr += [vtx]
            nrm_pyr += [nrm]
            mask_pyr += [mask]
            imh = imh // 2
            imw = imw // 2
            K /= 2
        return dep_pyr, rgb_pyr, vtx_pyr, nrm_pyr, mask_pyr

    # get voxel index given world coordinate
    # used for testing
    def get_voxel_idx(self, x):
        """
        :param x: [N, 3] query points
        :return: [N] voxel indices
        """
        assert len(x.shape) == 2, print("only accept flattened input!!!")
        x.to(self.device)
        vox_coord = torch.floor((x - self.vol_origin[None, :]) / self.voxel_size)  # [N, 3]
        vx, vy, vz = vox_coord[:, 0], vox_coord[:, 1], vox_coord[:, 2]
        # very important! get voxel index from voxel coordinate
        vox_idx = vz + vy * self.vol_dim[-1] + vx * self.vol_dim[-1] * self.vol_dim[-2]
        return vox_idx.long()

    def get_rays(self, c2w, intrinsics, H, W):
        device = self.device
        c2w = c2w.to(device)
        fx = intrinsics[0, 0]
        fy = intrinsics[1, 1]
        cx = intrinsics[0, 2]
        cy = intrinsics[1, 2]

        i, j = torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H))  # pytorch's meshgrid has indexing='ij'
        i = i.t().to(device).reshape(H * W)  # [hw]
        j = j.t().to(device).reshape(H * W)  # [hw]

        dirs = torch.stack([(i - cx) / fx, (j - cy) / fy, torch.ones_like(i)], -1).to(device)  # [hw, 3]
        # permute for bmm
        dirs = dirs.transpose(1, 0)  # [3, hw]
        rays_d = (c2w[:3, :3] @ dirs).transpose(1, 0)  # [hw, 3]
        rays_o = c2w[:3, 3].expand(rays_d.shape)

        return rays_o.reshape(H, W, 3), rays_d.reshape(H, W, 3)

相机的追踪

此处是我最喜欢的环节,自认为是整个算法的核心。
tracker.py是追踪过程的主函数,icp.py完成了追踪的主要功能。

import torch
import torch.nn as nn
from icp import ICP

class ICPTracker(nn.Module):

    def __init__(self,
                 args,
                 device,
                 ):

        super(ICPTracker, self).__init__()
        self.n_pyr = args.n_pyramids
        self.scales = list(range(self.n_pyr))
        self.n_iters = args.n_iters
        self.dampings = args.dampings
        # KinectFusion在执行过程调用了图像金字塔
        self.construct_image_pyramids = ImagePyramids(self.scales, pool='avg')
        self.construct_depth_pyramids = ImagePyramids(self.scales, pool='max')
        self.device = device
        # initialize tracker at different levels
        self.icp_solvers = []
        for i in range(self.n_pyr):
            self.icp_solvers += [ICP(self.n_iters[i], damping=self.dampings[i])]

    @torch.no_grad()
    def forward(self, depth0, depth1, K):
        H, W = depth0.shape
        dpt0_pyr = self.construct_depth_pyramids(depth0.view(1, 1, H, W))
        dpt0_pyr = [d.squeeze() for d in dpt0_pyr]
        dpt1_pyr = self.construct_depth_pyramids(depth1.view(1, 1, H, W))
        dpt1_pyr = [d.squeeze() for d in dpt1_pyr]
        # optimization steps
        pose10 = torch.eye(4).to(self.device)  # initialize from identity, eye()在这里创建了单位矩阵
        for i in reversed(range(self.n_pyr)):
            Ki = get_scaled_K(K, i)
            pose10 = self.icp_solvers[i](pose10, dpt0_pyr[i], dpt1_pyr[i], Ki)

        return pose10


class ImagePyramids(nn.Module):
    """ Construct the pyramids in the image / depth space
    """
    def __init__(self, scales, pool='avg'):
        super(ImagePyramids, self).__init__()
        if pool == 'avg':
            self.multiscales = [nn.AvgPool2d(1<

icp.py

import torch
import torch.nn as nn
import torch.nn.functional as F


# forward ICP
class ICP(nn.Module):
    def __init__(self,
                 max_iter=3, # 最大迭代次数
                 damping=1e-3, # 衰减率
                 ):
        """
        :param max_iter, maximum number of iterations
        :param damping, damping added to Hessian matrix
        """
        super(ICP, self).__init__()

        self.max_iterations = max_iter
        self.damping = damping

    def forward(self, pose10, depth0, depth1, K):
        """
        In all cases we refer to 0 as template, and always warp pixels from 0 to 1
        :param pose10: initial pose estimate
        :param depth0: template depth image (0)
        :param depth1: depth image (1)
        :param K: intrinsic matric
        :return: refined 0-to-1 transformation pose10
        """
        # create vertex and normal for current frame
        vertex0 = compute_vertex(depth0, K)
        normal0 = compute_normal(vertex0)
        mask0 = depth0 > 0.
        vertex1 = compute_vertex(depth1, K)
        normal1 = compute_normal(vertex1)                            

        for idx in range(self.max_iterations):
            # compute residuals
            residuals, J_F_p = self.compute_residuals_jacobian(vertex0, vertex1, normal0, normal1, mask0, pose10, K)
            JtWJ = self.compute_jtj(J_F_p)  # [B, 6, 6]
            JtR = self.compute_jtr(J_F_p, residuals)
            pose10 = self.GN_solver(JtWJ, JtR, pose10, damping=self.damping)

        return pose10

    @staticmethod
    def compute_residuals_jacobian(vertex0, vertex1, normal0, normal1, mask0, pose10, K):
        """
        :param vertex0: vertex map 0
        :param vertex1: vertex map 1
        :param normal0: normal map 0
        :param normal1: normal map 1
        :param mask0: valid mask of template depth image
        :param pose10: current estimate of pose10
        :param K: intrinsics
        :return: residuals and Jacobians
        """
        R = pose10[:3, :3]
        t = pose10[:3, 3]
        H, W, C = vertex0.shape

        rot_vertex0_to1 = (R @ vertex0.view(-1, 3).permute(1, 0)).permute(1, 0).view(H, W, 3)
        vertex0_to1 = rot_vertex0_to1 + t[None, None, :]
        normal0_to1 = (R @ normal0.view(-1, 3).permute(1, 0)).permute(1, 0).view(H, W, 3)

        fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
        x_, y_, z_ = vertex0_to1[..., 0], vertex0_to1[..., 1], vertex0_to1[..., 2]  # [h, w]
        u_ = (x_ / z_) * fx + cx  # [h, w]
        v_ = (y_ / z_) * fy + cy  # [h, w]

        inviews = (u_ > 0) & (u_ < W-1) & (v_ > 0) & (v_ < H-1)
        # projective data association
        r_vertex1 = warp_features(vertex1, u_, v_)  # [h, w, 3]
        r_normal1 = warp_features(normal1, u_, v_)  # [h, w, 3]
        mask1 = r_vertex1[..., -1] > 0.

        diff = vertex0_to1 - r_vertex1  # [h, w, 3]

        # point-to-plane residuals
        res = (r_normal1 * diff).sum(dim=-1)  # [h, w]
        # point-to-plane jacobians
        J_trs = r_normal1.view(-1, 3)  # [hw, 3]
        J_rot = -torch.bmm(J_trs.unsqueeze(dim=1), batch_skew(vertex0_to1.view(-1, 3))).squeeze()   # [hw, 3]

        # compose ja cobians
        J_F_p = torch.cat((J_rot, J_trs), dim=-1).view(H, W, 6)  # follow the order of [rot, trs]  [hw, 1, 6]

        # occlusion
        occ = ~inviews | (diff.norm(p=2, dim=-1) > 0.10)
        invalid_mask = occ | ~mask0 | ~mask1
        J_F_p[invalid_mask] = 0.
        res[invalid_mask] = 0.
        res = res.view(-1, 1)  # [hw, 1]
        J_F_p = J_F_p.view(-1, 1, 6)  # [hw, 1, 6]

        return res, J_F_p

    @staticmethod
    def compute_jtj(jac):
        # J in the dimension of (HW, C, 6)
        jacT = jac.transpose(-1, -2)  # [HW, 6, C]
        jtj = torch.bmm(jacT, jac).sum(0)  # [6, 6]
        return jtj  # [6, 6]

    @staticmethod
    def compute_jtr(jac, res):
        # J in the dimension of (HW, C, 6)
        # res in the dimension of [HW, C]
        jacT = jac.transpose(-1, -2)  # [HW, 6, C]
        jtr = torch.bmm(jacT, res.unsqueeze(-1)).sum(0)  # [6, 1]
        return jtr  # [6, 1]

    @staticmethod
    def GN_solver(JtJ, JtR, pose0, damping=1e-6):
        # Add a small diagonal damping. Without it, the training becomes quite unstable
        # Do not see a clear difference by removing the damping in inference though
        Hessian = lev_mar_H(JtJ, damping)
        # Hessian = JtJ
        updated_pose = forward_update_pose(Hessian, JtR, pose0)

        return updated_pose


def warp_features(Feat, u, v, mode='bilinear'):
    """
    Warp the feature map (F) w.r.t. the grid (u, v). This is the non-batch version
    """
    assert len(Feat.shape) == 3
    H, W, C = Feat.shape
    u_norm = u / ((W - 1) / 2) - 1  # [h, w]
    v_norm = v / ((H - 1) / 2) - 1  # [h, w]
    uv_grid = torch.cat((u_norm.view(1, H, W, 1), v_norm.view(1, H, W, 1)), dim=-1)
    Feat_warped = F.grid_sample(Feat.unsqueeze(0).permute(0, 3, 1, 2), uv_grid, mode=mode, padding_mode='border', align_corners=True).squeeze()
    return Feat_warped.permute(1, 2, 0)


def compute_vertex(depth, K):
    H, W = depth.shape
    fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
    device = depth.device

    i, j = torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H))  # pytorch's meshgrid has indexing='ij'
    i = i.t().to(device)  # [h, w]
    j = j.t().to(device)  # [h, w]

    vertex = torch.stack([(i - cx) / fx, (j - cy) / fy, torch.ones_like(i)], -1).to(device) * depth[..., None]  # [h, w, 3]
    return vertex


def compute_normal(vertex_map):
    """ Calculate the normal map from a depth map
    :param the input depth image
    -----------
    :return the normal map
    """
    H, W, C = vertex_map.shape
    img_dx, img_dy = feature_gradient(vertex_map, normalize_gradient=False)  # [h, w, 3]

    normal = torch.cross(img_dx.view(-1, 3), img_dy.view(-1, 3))
    normal = normal.view(H, W, 3)  # [h, w, 3]

    mag = torch.norm(normal, p=2, dim=-1, keepdim=True)
    normal = normal / (mag + 1e-8)

    # filter out invalid pixels
    depth = vertex_map[:, :, -1]
    # 0.5 and 5.
    invalid_mask = (depth <= depth.min()) | (depth >= depth.max())
    zero_normal = torch.zeros_like(normal)
    normal = torch.where(invalid_mask[..., None], zero_normal, normal)

    return normal


def feature_gradient(img, normalize_gradient=True):
    """ Calculate the gradient on the feature space using Sobel operator
    :param the input image
    -----------
    :return the gradient of the image in x, y direction
    """
    H, W, C = img.shape
    # to filter the image equally in each channel
    wx = torch.FloatTensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).view(1, 1, 3, 3).type_as(img)
    wy = torch.FloatTensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).view(1, 1, 3, 3).type_as(img)

    img_permuted = img.permute(2, 0, 1).view(-1, 1, H, W)  # [c, 1, h, w]
    img_pad = F.pad(img_permuted, (1, 1, 1, 1), mode='replicate')
    img_dx = F.conv2d(img_pad, wx, stride=1, padding=0).squeeze().permute(1, 2, 0)  # [h, w, c]
    img_dy = F.conv2d(img_pad, wy, stride=1, padding=0).squeeze().permute(1, 2, 0)  # [h, w, c]

    if normalize_gradient:
        mag = torch.sqrt((img_dx ** 2) + (img_dy ** 2) + 1e-8)
        img_dx = img_dx / mag
        img_dy = img_dy / mag

    return img_dx, img_dy  # [h, w, c]


def batch_skew(w):
    """ Generate a batch of skew-symmetric matrices.

        function tested in 'test_geometry.py'

    :input
    :param skew symmetric matrix entry Bx3
    ---------
    :return
    :param the skew-symmetric matrix Bx3x3
    """
    B, D = w.shape
    assert(D == 3)
    o = torch.zeros(B).type_as(w)
    w0, w1, w2 = w[:, 0], w[:, 1], w[:, 2]
    return torch.stack((o, -w2, w1, w2, o, -w0, -w1, w0, o), 1).view(B, 3, 3)


def lev_mar_H(JtWJ, damping):
    # Add a small diagonal damping. Without it, the training becomes quite unstable
    # Do not see a clear difference by removing the damping in inference though
    diag_mask = torch.eye(6).to(JtWJ)
    diagJtJ = diag_mask * JtWJ
    traceJtJ = torch.sum(diagJtJ)
    epsilon = (traceJtJ * damping) * diag_mask
    Hessian = JtWJ + epsilon
    return Hessian


def forward_update_pose(H, Rhs, pose):
    """
    :param H:
    :param Rhs:
    :param pose:
    :return:
    """
    xi = least_square_solve(H, Rhs).squeeze()
    pose = exp_se3(xi) @ pose
    return pose


def exp_se3(xi):
    """
    :param x: Cartesian vector of Lie Algebra se(3)
    :return: exponential map of x
    """
    w = xi[:3].squeeze()  # rotation
    v = xi[3:6].squeeze()  # translation
    w_hat = torch.tensor([[0., -w[2], w[1]],
                          [w[2], 0., -w[0]],
                          [-w[1], w[0], 0.]]).to(xi)
    w_hat_second = torch.mm(w_hat, w_hat).to(xi)

    theta = torch.norm(w)
    theta_2 = theta ** 2
    theta_3 = theta ** 3
    sin_theta = torch.sin(theta)
    cos_theta = torch.cos(theta)
    eye_3 = torch.eye(3).to(xi)

    eps = 1e-8

    if theta <= eps:
        e_w = eye_3
        j = eye_3
    else:
        e_w = eye_3 + w_hat * sin_theta / theta + w_hat_second * (1. - cos_theta) / theta_2
        k1 = (1 - cos_theta) / theta_2
        k2 = (theta - sin_theta) / theta_3
        j = eye_3 + k1 * w_hat + k2 * w_hat_second

    T = torch.eye(4).to(xi)
    T[:3, :3] = e_w
    T[:3, 3] = torch.mv(j, v)
    # T[:3, 3] = v

    return T


def invH(H):
    """ Generate (H+damp)^{-1}, with predicted damping values
    :param approximate Hessian matrix JtWJ
    -----------
    :return the inverse of Hessian
    """
    # GPU is much slower for matrix inverse when the size is small (compare to CPU)
    # works (50x faster) than inversing the dense matrix in GPU
    if H.is_cuda:
        invH = torch.inverse(H.cpu()).cuda()
    else:
        invH = torch.inverse(H)
    return invH


def least_square_solve(H, Rhs):
    """
    Solve for JTJ @ xi = -JTR
    """
    inv_H = invH(H)  # [B, 6, 6] square matrix
    xi = -inv_H @ Rhs
    return xi

其他链接:
KinectFusion原理介绍.

本文内容总结

  • 介绍KinectFusion算法流程
  • 介绍RGBD重建的主要数据集
  • 介绍python实现KinectFusion的代码

你可能感兴趣的:(计算机视觉,opencv,人工智能)