KinectFusion是RGBD SLAM重建技术的开山鼻祖,是三维重建过程中的经典算法。本文是对KinectFusion实践过程中的一次梳理,希望能够督促我快点完成代码的应用。水平不高,但求学习,通过记录的方式让自己静下心来梳理。
代码参考:
python版本:https://github.com/JingwenWang95/KinectFusion
C++版本:https://github.com/Nerei/kinfu_remake
先验知识:
RGBD图像,网格重建,四元数,相机模型
opencv关于相机校准和重建的文档:https://docs.opencv.org/2.4/modules/calib3d/doc/camera_calibration_and_3d_reconstruction.html#decomposeprojectionmatrix
常用的RGBD SLAM重建数据集是TUM的RGBD数据集: https://vision.in.tum.de/data/datasets/rgbd-dataset/download
该数据集中包含PNG格式的RBG彩色图和Depth灰度图对。其中,深度图按照5000倍的比例缩放,深度图中像素值5000代表距离相机1米,10000代表距离相机2米。像素值为0则表示缺失值,即此处没有数据。rgb.txt和depth.txt记录的每张图片的采集时间和图片名称(图片名称也是以时间来命名的)。由于RGB相机和深度传感器的差异,几乎没有对应的RGB图像和深度图是同一个时刻采集的。在处理图像时,需要一步预处理,找到rgb和depth图像的一一对应关系。
除了重建必要的深度图和彩色图之外,TUM数据集也包含相机轨迹的groundtruth,即在一个固定坐标系下,每一帧图像拍摄时相机的位置和方向坐标。在groundtruth.txt中,timestamp代表unix纪元下的时间, t x , t y , t z t_x, t_y,t_z tx,ty,tz代表了相机的光学中心相对于运动捕捉系统所定义的世界原点的位置。 q x , q y , q z , q w q_x, q_y, q_z, q_w qx,qy,qz,qw是以四元数的形式给出的相机的光学中心相对于世界原点的方向。
在实战过程中,相机焦距 f x f_x fx, f y f_y fy和相机光学中心 c x c_x cx, c y c_y cy在关联深度和RGB图中至关重要。相机参数矩阵为:
[ f x 0 c x 0 f y c y 0 0 1 ] \begin{bmatrix} f_x & 0&c_x \\ 0 & f_y&c_y\\0&0&1 \end{bmatrix} \quad fx000fy0cxcy1
TUM数据集中device共有三组,分别对应 f r 1 , f r 2 , f r 3 fr1, fr2, fr3 fr1,fr2,fr3三组相机参数。
对于初学者(比如我)而言,要了解程序的各个步骤是个万丈高楼平地起的过程。具体应用过程会出现这样那样的不理解,看CSDN和博客园是很好的方式。关于TUM数据集,推荐半闲居士: 一起做RGB-D SLAM。
TSDF: TSDF可以译作截断符号距离函数,其是一种网格式的地图先选定要建模的三维空间,比如3×3×3 m^3 那么大,按照一定分辨率,将这个空间分成许多小块,存储每个小块内部的信息。TSDF地图整个存储在显存当中而不是内存中,由于每个体素的计算互相并不干扰,所以可以利用GPU的并行特性,并行地对对每个体素进行计算和更新(此处参考https://blog.csdn.net/qinqinxiansheng/article/details/119449196)。
图像金字塔:
RGBD SLAM算法的核心是使用多帧RGB-D图像从相机拍摄的图像中重建出三维场景。
深度图本质上是2.5D的信息,由彩色图(像素点由u = (x, y)表示)和深度图(包含深度信息D)组成。采集深度图的过程,是世界坐标到相机坐标的一个转换过程。算法根据采得的深度信息,将RGBD图转为点云,随后根据点云信息计算出顶点图和法线图。
Depth Map 深度图
|— Vertex Map 顶点图
|— Normal Map 法线图
相机位姿估计
相机姿态估计是重建过程中至关重要的一步。本文中使用ICP算法用于不同帧图像间的注册,通过寻找不同帧彩色图点与点之间的相关性估计相机位姿 T g , k T_{g,k} Tg,k
体块集成 (Volumetric Integration)
体块集成的过程中用到了TSDF(volumetric truncated signed distance)技术。作者先设置一个很大的三维空间用于存放重建场景,场景中的内容是不同帧获取的空间切片的叠加,体块内的元素随着深度图的读取逐步更新。
射线投射 (Raycasting)
用于重建物体表面的绘制。
KinectFusion的运行过程需要GPU加速,Jingwen版本的KinectFusion运行使用了pytorch框架而不是直接使用CUDA。
在正式执行重建算法之前,作者进行了一些数据集的预处理。preprocess.py在重建程序执行前使用的,目的是关联RGBD文件的名称,保存成更易处理的格式。
preprocess.py
import os
import math
import shutil
import numpy as np
import argparse
from tum_rgbd import get_calib
from utils import load_config
def read_file_list(filename):
"""
从txt文件中读取轨迹,此处遵循的是tum中相机trajectory的数据格式
Reads a trajectory from a text file.
File format:
The file format is "stamp d1 d2 d3 ...", where stamp denotes the time stamp (to be matched)
and "d1 d2 d3.." is arbitary data (e.g., a 3D position and 3D orientation) associated to this timestamp.
Input:
filename -- File name
Output:
dict -- dictionary of (stamp,data) tuples
"""
file = open(filename)
data = file.read()
lines = data.replace(","," ").replace("\t"," ").split("\n")
list = [[v.strip() for v in line.split(" ") if v.strip()!=""] for line in lines if len(line)>0 and line[0]!="#"]
list = [(float(l[0]),l[1:]) for l in list if len(l)>1]
return dict(list)
# 深度相机和色彩相机获取图像的时间帧存在时间差,需要对齐。最大时间差是0.02s
# 把拍摄时间相近的RGB图片和Depth图片,写在associate.txt文件的同一行,这就表示,两个图是一对 RGB-D图片。
def associate(first_list, second_list, offset=0.0, max_difference=0.02):
"""
Associate two dictionaries of (stamp,data). As the time stamps never match exactly, we aim
to find the closest match for every input tuple.
Input:
first_list -- first dictionary of (stamp,data) tuples
second_list -- second dictionary of (stamp,data) tuples
offset -- time offset between both dictionaries (e.g., to models the delay between the sensors)
max_difference -- search radius for candidate generation
Output:
matches -- list of matched tuples ((stamp1,data1),(stamp2,data2))
"""
first_keys = list(first_list)
second_keys = list(second_list)
potential_matches = [(abs(a - (b + offset)), a, b)
for a in first_keys
for b in second_keys
if abs(a - (b + offset)) < max_difference]
potential_matches.sort()
matches = []
for diff, a, b in potential_matches:
if a in first_keys and b in second_keys:
first_keys.remove(a)
second_keys.remove(b)
matches.append((a, b))
matches.sort()
return matches
# 从输入的文件中得到匹配的文件对,并保存
def get_association(file_a, file_b, out_file):
first_list = read_file_list(file_a)
second_list = read_file_list(file_b)
matches = associate(first_list, second_list)
with open(out_file, "w") as f:
for a, b in matches:
line = "%f %s %f %s\n" % (a, " ".join(first_list[a]), b, " ".join(second_list[b]))
f.write(line)
# 此处是相机旋转角度信息,从四元数到齐次矩阵的转换
def tum2matrix(pose):
"""
Return homogeneous rotation matrix from quaternion.
"""
# 只截取前3个位置坐标
t = pose[:3]
# under TUM format q is in the order of [x, y, z, w], need change to [w, x, y, z]
quaternion = [pose[6], pose[3], pose[4], pose[5]]
q = np.array(quaternion, dtype=np.float64, copy=True)
n = np.dot(q, q)
if n < np.finfo(np.float64).eps:
return np.identity(4)
q *= math.sqrt(2.0 / n)
q = np.outer(q, q)
return np.array([
[1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], t[0]],
[ q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], t[1]],
[ q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], t[2]],
[0., 0., 0., 1.]])
# 从association中得到相机位姿
def get_poses_from_associations(fname):
poses = []
with open(fname) as f:
for line in f.readlines():
pose_str = line.strip("\n").split(" ")[-7:]
pose = [float(p) for p in pose_str]
poses += [tum2matrix(pose)]
return poses
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# 输入待处理的数据集,并保存在processed文件夹里面
# standard configs
parser.add_argument('--config', type=str, default="../configs/fr1_desk.yaml", help='Path to config file.')
args = load_config(parser.parse_args())
out_dir = os.path.join(args.data_root, "processed")
# create association files
get_association(os.path.join(args.data_root, "depth.txt"), os.path.join(args.data_root, "groundtruth.txt"), os.path.join(args.data_root, "dep_traj.txt"))
get_association(os.path.join(args.data_root, "rgb.txt"), os.path.join(args.data_root, "dep_traj.txt"), os.path.join(args.data_root, "rgb_dep_traj.txt"))
if not os.path.exists(out_dir):
os.makedirs(out_dir)
out_rgb_dir = os.path.join(out_dir, "rgb")
if not os.path.exists(out_rgb_dir):
os.makedirs(out_rgb_dir)
out_dep_dir = os.path.join(out_dir, "depth")
if not os.path.exists(out_dep_dir):
os.makedirs(out_dep_dir)
# rename image files and save c2w poses
# 此处将图像文件编码之后重新储存
poses = []
with open(os.path.join(args.data_root, "rgb_dep_traj.txt")) as f:
for i, line in enumerate(f.readlines()):
line_list = line.strip().split(" ")
rgb_file = line_list[1]
shutil.copyfile(os.path.join(args.data_root, rgb_file), os.path.join(out_rgb_dir, "%04d.png" % i))
dep_file = line_list[3]
shutil.copyfile(os.path.join(args.data_root, dep_file), os.path.join(out_dep_dir, "%04d.png" % i))
poses += [tum2matrix([float(x) for x in line_list[5:]])]
np.savez(os.path.join(out_dir, "raw_poses.npz"), c2w_mats=poses)
# save projection matrices
# 保存投影矩阵
K = np.eye(3)
intri = get_calib()[args.data_type] # 此处得到相机内参
K[0, 0] = intri[0]
K[1, 1] = intri[1]
K[0, 2] = intri[2]
K[1, 2] = intri[3]
camera_dict = np.load(os.path.join(out_dir, "raw_poses.npz"))
poses = camera_dict["c2w_mats"]
P_mats = []
for c2w in poses:
w2c = np.linalg.inv(c2w)
P = K @ w2c[:3, :]
P_mats += [P]
np.savez(os.path.join(out_dir, "cameras.npz"), world_mats=P_mats)
首先从main函数开始,kinfu.py
此段是代码执行的主函数,输入RGBD数据集,重建为网格并保存为.ply文件。
import os
import argparse # 用于解析命令行参数和选项
import numpy as np
import torch
import cv2
import trimesh # 用于绘制网格
from matplotlib import pyplot as plt
from fusion import TSDFVolumeTorch # 外部函数,用于Volumetric Integration
from dataset.tum_rgbd import TUMDataset, TUMDatasetOnline # 用来处理RGBD数据集
from tracker import ICPTracker # 用来追踪相机位姿(相机位姿是通过RGB图像估计的)
from utils import load_config, get_volume_setting, get_time #(用来获取重建参数)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# standard configs
# 执行命令时使用,确定数据集和重建结果保存的路径
parser.add_argument('--config', type=str, default="configs/fr1_desk.yaml", help='Path to config file.')
parser.add_argument("--save_dir", type=str, default=None, help="Directory of saving results.")
args = load_config(parser.parse_args())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
# 处理数据集,此处用到了tum_rgbd中的函数
# 此处获得了相机到世界的转换矩阵,相机内参,rgb图和深度图
dataset = TUMDataset(os.path.join(args.data_root), device, near=args.near, far=args.far, img_scale=0.25)
# 此处获得了图像的高和宽
H, W = dataset.H, dataset.W
# 重建的体空间在config文件中有定义
vol_dims, vol_origin, voxel_size = get_volume_setting(args)
# 不同帧体素中体素的融合,信息的动态合并过程,由面到体,此处调用了GPU加速
tsdf_volume = TSDFVolumeTorch(vol_dims, vol_origin, voxel_size, device, margin=3, fuse_color=True)
# 根据图像特征追踪相机位姿
icp_tracker = ICPTracker(args, device)
# 时间,位姿,位姿变换
t, poses, poses_gt = list(), list(), list()
curr_pose, depth1, color1 = None, None, None
for i in range(0, len(dataset), 1):
t0 = get_time()
sample = dataset[i]
color0, depth0, pose_gt, K = sample # use live image as template image (0)
# depth0[depth0 <= 0.5] = 0.
# 初始化位姿
if i == 0: # initialize
curr_pose = pose_gt
else: # tracking
# 1. render depth image (1) from tsdf volume
depth1, color1, vertex01, normal1, mask1 = tsdf_volume.render_model(curr_pose, K, H, W, near=args.near, far=args.far, n_samples=args.n_steps)
# depth0和depth1是什么?
T10 = icp_tracker(depth0, depth1, K) # transform from 0 to 1
curr_pose = curr_pose @ T10
# fusion
# 此处将不同帧所得深度图像进行融合
tsdf_volume.integrate(depth0,
K,
curr_pose,
obs_weight=1.,
color_img=color0
)
t1 = get_time()
t += [t1 - t0]
print("processed frame: {:d}, time taken: {:f}s".format(i, t1 - t0))
poses += [curr_pose.cpu().numpy()]
poses_gt += [pose_gt.cpu().numpy()]
avg_time = np.array(t).mean()
print("average processing time: {:f}s per frame, i.e. {:f} fps".format(avg_time, 1. / avg_time))
# compute tracking ATE
# 本文使用ATE来估计追踪效果
poses_gt = np.stack(poses_gt, 0)
poses = np.stack(poses, 0)
traj_gt = np.array(poses_gt)[:, :3, 3]
traj = np.array(poses)[:, :3, 3]
rmse = np.sqrt(np.mean(np.linalg.norm(traj_gt - traj, axis=-1) ** 2))
print("RMSE: {:f}".format(rmse))
# plt.plot(traj[:, 0], traj[:, 1])
# plt.plot(traj_gt[:, 0], traj_gt[:, 1])
# plt.legend(['Estimated', 'GT'])
# plt.show()
# save results
if args.save_dir is not None:
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# 绘制重建网格
verts, faces, norms, colors = tsdf_volume.get_mesh()
partial_tsdf = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=norms, vertex_colors=colors)
partial_tsdf.export(os.path.join(args.save_dir, "mesh.ply"))
np.savez(os.path.join(args.save_dir, "traj.npz"), poses=poses)
np.savez(os.path.join(args.save_dir, "traj_gt.npz"), poses=poses_gt)
tum_rgbd.py
此代码主要是为了在程序执行过程中获得RGB图,深度图,相机位置转换矩阵和相机内参四个数组:rgb, depth, c2w, K。
import torch
from os import path
from tqdm import tqdm
import imageio
import cv2
import numpy as np
import open3d as o3d
def get_calib(): # TUM数据集有三组相机。此处用于获取拍摄图像时相机的参数,fx, fy, cx, cy。默认每个相机的焦距和光学中心的位置是固定的。
return {
"fr1": [517.306408, 516.469215, 318.643040, 255.313989],
"fr2": [520.908620, 521.007327, 325.141442, 249.701764],
"fr3": [535.4, 539.2, 320.1, 247.6]
}
# Note,this step converts w2c (Tcw) to c2w (Twc)
# 从世界到相机转换矩阵,化为相机到世界转换矩阵
# 获取相机内参和旋转、平移矩阵
def load_K_Rt_from_P(P):
"""
modified from IDR https://github.com/lioryariv/idr
"""
# 此处参见opencv官方文件
out = cv2.decomposeProjectionMatrix(P)
K = out[0]
R = out[1]
t = out[2]
# 相机参数归一化
K = K/K[2,2]
intrinsics = np.eye(4)
intrinsics[:3, :3] = K
pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = R.transpose() # convert from w2c to c2w
pose[:3, 3] = (t[:3] / t[3])[:, 0]
return intrinsics, pose
# TUM数据集的数据类型
class TUMDataset(torch.utils.data.Dataset):
"""
TUM dataset loader, pre-load images in advance
"""
# 此处定义了重建参数
def __init__(
self,
rootdir,
device,
near: float = 0.2, #设置最近点的范围
far: float = 5., #最远点
img_scale: float = 1., # image scale factor
start: int = -1,
end: int = -1,
):
super().__init__()
assert path.isdir(rootdir), f"'{rootdir}' is not a directory"
self.device = device
self.c2w_all = [] # 相机到世界坐标系的变换矩阵
self.K_all = [] # 相机内参
self.rgb_all = [] # rgb图
self.depth_all = [] # 深度图
# root should be tum_sequence
data_path = path.join(rootdir, "processed") # 重新整理后的RGBD数据集
cam_file = path.join(data_path, "cameras.npz") # 相机位置数组
print("LOAD DATA", data_path)
# world_mats, normalize_mat
cam_dict = np.load(cam_file) # 相机位置
world_mats = cam_dict["world_mats"] # K @ w2c # 世界坐标
d_min = [] # 最小深度
d_max = [] # 最大深度
# TUM saves camera poses in OpenCV convention
# tqdm 提示进度信息
for i, world_mat in enumerate(tqdm(world_mats)):
# ignore all the frames betfore
if start > 0 and i < start:
continue
# ignore all the frames after
if 0 < end < i:
break
intrinsics, c2w = load_K_Rt_from_P(world_mat)
c2w = torch.tensor(c2w, dtype=torch.float32)
# read images
# 提取深度图和彩色图数据,保存为数组
rgb = np.array(imageio.imread(path.join(data_path, "rgb/{:04d}.png".format(i)))).astype(np.float32)
depth = np.array(imageio.imread(path.join(data_path, "depth/{:04d}.png".format(i)))).astype(np.float32)
depth /= 5000. # TODO: put depth factor to args
d_max += [depth.max()]
d_min += [depth.min()]
# depth = cv2.bilateralFilter(depth, 5, 0.2, 15)
# print(depth[depth > 0.].min())
# 超出距离范围的都定义为无效点
invalid = (depth < near) | (depth > far)
depth[invalid] = -1.
# downscale the image size if needed
if img_scale < 1.0:
full_size = list(rgb.shape[:2])
rsz_h, rsz_w = [round(hw * img_scale) for hw in full_size]
# TODO: figure out which way is better: skimage.rescale or cv2.resize
rgb = cv2.resize(rgb, (rsz_w, rsz_h), interpolation=cv2.INTER_AREA)
depth = cv2.resize(depth, (rsz_w, rsz_h), interpolation=cv2.INTER_NEAREST)
intrinsics[0, 0] *= img_scale
intrinsics[1, 1] *= img_scale
intrinsics[0, 2] *= img_scale
intrinsics[1, 2] *= img_scale
self.c2w_all.append(c2w)
self.K_all.append(torch.from_numpy(intrinsics[:3, :3]))
self.rgb_all.append(torch.from_numpy(rgb))
self.depth_all.append(torch.from_numpy(depth))
print("Depth min: {:f}".format(np.array(d_min).min()))
print("Depth max: {:f}".format(np.array(d_max).max()))
self.n_images = len(self.rgb_all)
self.H, self.W, _ = self.rgb_all[0].shape
def __len__(self):
return self.n_images
def __getitem__(self, idx):
return self.rgb_all[idx].to(self.device), self.depth_all[idx].to(self.device), \
self.c2w_all[idx].to(self.device), self.K_all[idx].to(self.device)
class TUMDatasetOnline(torch.utils.data.Dataset):
"""
Online TUM dataset loader, load images when __getitem__() is called
"""
def __init__(
self,
rootdir,
device,
near: float = 0.2,
far: float = 5.,
img_scale: float = 1., # image scale factor
start: int = -1,
end: int = -1,
):
super().__init__()
assert path.isdir(rootdir), f"'{rootdir}' is not a directory"
self.device = device
self.img_scale = img_scale
self.near = near
self.far = far
self.c2w_all = []
self.K_all = []
self.rgb_files_all = []
self.depth_files_all = []
# root should be tum_sequence
data_path = path.join(rootdir, "processed")
cam_file = path.join(data_path, "cameras.npz")
print("LOAD DATA", data_path)
# world_mats, normalize_mat
cam_dict = np.load(cam_file)
world_mats = cam_dict["world_mats"] # K @ w2c
# TUM saves camera poses in OpenCV convention
for i, world_mat in enumerate(world_mats):
# ignore all the frames betfore
if start > 0 and i < start:
continue
# ignore all the frames after
if 0 < end < i:
break
intrinsics, c2w = load_K_Rt_from_P(world_mat)
c2w = torch.tensor(c2w, dtype=torch.float32)
self.c2w_all.append(c2w)
self.K_all.append(torch.from_numpy(intrinsics[:3, :3]))
self.rgb_files_all.append(path.join(data_path, "rgb/{:04d}.png".format(i)))
self.depth_files_all.append(path.join(data_path, "depth/{:04d}.png".format(i)))
self.n_images = len(self.rgb_files_all)
H, W, _ = np.array(imageio.imread(self.rgb_files_all[0])).shape
self.H = round(H * img_scale)
self.W = round(W * img_scale)
def __len__(self):
return self.n_images
def __getitem__(self, idx):
K = self.K_all[idx].to(self.device)
c2w = self.c2w_all[idx].to(self.device)
# read images
rgb = np.array(imageio.imread(self.rgb_files_all[idx])).astype(np.float32)
depth = np.array(imageio.imread(self.depth_files_all[idx])).astype(np.float32)
depth /= 5000.
# depth = cv2.bilateralFilter(depth, 5, 0.2, 15)
depth[depth < self.near] = 0.
depth[depth > self.far] = -1.
# downscale the image size if needed
if self.img_scale < 1.0:
full_size = list(rgb.shape[:2])
rsz_h, rsz_w = [round(hw * self.img_scale) for hw in full_size]
rgb = cv2.resize(rgb, (rsz_w, rsz_h), interpolation=cv2.INTER_AREA)
depth = cv2.resize(depth, (rsz_w, rsz_h), interpolation=cv2.INTER_NEAREST)
K[0, 0] *= self.img_scale
K[1, 1] *= self.img_scale
K[0, 2] *= self.img_scale
K[1, 2] *= self.img_scale
rgb = torch.from_numpy(rgb).to(self.device)
depth = torch.from_numpy(depth).to(self.device)
return rgb, depth, c2w, K
fusion.py
此处还涉及图像金字塔的实现。
import os
import numpy as np
from skimage import measure
import torch
import cv2
import open3d as o3d
import imageio
def integrate(
depth_im,
cam_intr,
cam_pose,
obs_weight, # 权重,在KinectFusion算法有特殊的权重设置
world_c, # world coordinates grid [nx*ny*nz, 4]
vox_coords, # voxel coordinates grid [nx*ny*nz, 3]
weight_vol, # weight volume [nx, ny, nz]
tsdf_vol, # tsdf volume [nx, ny, nz]
sdf_trunc,
im_h,
im_w,
color_vol=None,
color_im=None,
):
world2cam = torch.inverse(cam_pose) # cam_pose中记录相机坐标到相机世界坐标的转换关系,现在进行转置,获得相机到世界的转换关系
cam_c = torch.matmul(world2cam, world_c.transpose(1, 0)).transpose(1, 0).float() # [nx*ny*nz, 4],连乘,此处获得了相机的位置
# Convert camera coordinates to pixel coordinates,将相机坐标转化为像素坐标
fx, fy = cam_intr[0, 0], cam_intr[1, 1]
cx, cy = cam_intr[0, 2], cam_intr[1, 2]
pix_z = cam_c[:, 2]
# project all the voxels back to image plane,将所有体素投射至成像面
pix_x = torch.round((cam_c[:, 0] * fx / cam_c[:, 2]) + cx).long() # [nx*ny*nz]
pix_y = torch.round((cam_c[:, 1] * fy / cam_c[:, 2]) + cy).long() # [nx*ny*nz]
# Eliminate pixels outside view frustum,消除视锥之外的像素
valid_pix = (pix_x >= 0) & (pix_x < im_w) & (pix_y >= 0) & (pix_y < im_h) & (pix_z > 0) # [n_valid]
valid_vox_x = vox_coords[valid_pix, 0]
valid_vox_y = vox_coords[valid_pix, 1]
valid_vox_z = vox_coords[valid_pix, 2]
depth_val = depth_im[pix_y[valid_pix], pix_x[valid_pix]] # [n_valid]
# Integrate tsdf
depth_diff = depth_val - pix_z[valid_pix]
dist = torch.clamp(depth_diff / sdf_trunc, max=1)
valid_pts = (depth_val > 0.) & (depth_diff >= -sdf_trunc) # all points 1. inside frustum 2. with valid depth 3. outside -truncate_dist
valid_vox_x = valid_vox_x[valid_pts]
valid_vox_y = valid_vox_y[valid_pts]
valid_vox_z = valid_vox_z[valid_pts]
valid_dist = dist[valid_pts]
w_old = weight_vol[valid_vox_x, valid_vox_y, valid_vox_z]
tsdf_vals = tsdf_vol[valid_vox_x, valid_vox_y, valid_vox_z]
w_new = w_old + obs_weight
tsdf_vol[valid_vox_x, valid_vox_y, valid_vox_z] = (w_old * tsdf_vals + obs_weight * valid_dist) / w_new
weight_vol[valid_vox_x, valid_vox_y, valid_vox_z] = w_new
if color_vol is not None and color_im is not None:
old_color = color_vol[valid_vox_x, valid_vox_y, valid_vox_z]
new_color = color_im[pix_y[valid_pix], pix_x[valid_pix]]
new_color = new_color[valid_pts]
color_vol[valid_vox_x, valid_vox_y, valid_vox_z, :] = (w_old[:, None] * old_color + obs_weight * new_color) / w_new[:, None]
return weight_vol, tsdf_vol, color_vol
class TSDFVolumeTorch:
"""
Volumetric TSDF Fusion of RGB-D Images.
"""
def __init__(self, voxel_dim, origin, voxel_size, device, margin=3, fuse_color=False):
"""
Args:
voxel_dim (ndarray): [3,] stores volume dimensions: Nx, Ny, Nz
origin (ndarray): [3,] world coordinate of voxel [0, 0, 0]
voxel_size (float): The volume discretization in meters.
"""
self.device = device
# Define voxel volume parameters
self.voxel_size = float(voxel_size)
self.sdf_trunc = margin * self.voxel_size
self.integrate_func = integrate
self.fuse_color = fuse_color
# Adjust volume bounds
if isinstance(voxel_dim, list):
voxel_dim = torch.Tensor(voxel_dim).to(self.device)
elif isinstance(voxel_dim, np.ndarray):
voxel_dim = torch.from_numpy(voxel_dim).to(self.device)
if isinstance(origin, list):
origin = torch.Tensor(origin).to(self.device)
elif isinstance(origin, np.ndarray):
origin = torch.from_numpy(origin).to(self.device)
self.vol_dim = voxel_dim.long()
self.vol_origin = origin
self.num_voxels = torch.prod(self.vol_dim).item()
# Get voxel grid coordinates
xv, yv, zv = torch.meshgrid(
torch.arange(0, self.vol_dim[0]),
torch.arange(0, self.vol_dim[1]),
torch.arange(0, self.vol_dim[2]),
)
self.vox_coords = torch.stack([xv.flatten(), yv.flatten(), zv.flatten()], dim=1).long().to(self.device)
# Convert voxel coordinates to world coordinates
self.world_c = self.vol_origin + (self.voxel_size * self.vox_coords)
self.world_c = torch.cat([
self.world_c, torch.ones(len(self.world_c), 1, device=self.device)], dim=1).float()
self.reset()
def reset(self):
"""Set volumes
"""
self.tsdf_vol = torch.ones(*self.vol_dim).to(self.device)
self.weight_vol = torch.zeros(*self.vol_dim).to(self.device)
if self.fuse_color:
# [nx, ny, nz, 3]
self.color_vol = torch.zeros(*self.vol_dim, 3).to(self.device)
else:
self.color_vol = None
def data_transfer(self, data):
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
return data.float().to(self.device)
@torch.no_grad()
def integrate(self, depth_im, cam_intr, cam_pose, obs_weight, color_img=None):
"""Integrate an RGB-D frame into the TSDF volume.
Args:
depth_im (torch.Tensor): A depth image of shape (H, W).
cam_intr (torch.Tensor): The camera intrinsics matrix of shape (3, 3).
cam_pose (torch.Tensor): The camera pose (i.e. extrinsics) of shape (4, 4). T_wc
obs_weight (float): The weight to assign to the current observation.
"""
cam_pose = self.data_transfer(cam_pose)
cam_intr = self.data_transfer(cam_intr)
depth_im = self.data_transfer(depth_im)
if color_img is not None:
color_img = self.data_transfer(color_img)
else:
color_img = None
im_h, im_w = depth_im.shape
# fuse
weight_vol, tsdf_vol, color_vol = self.integrate_func(
depth_im,
cam_intr,
cam_pose,
obs_weight,
self.world_c,
self.vox_coords,
self.weight_vol,
self.tsdf_vol,
self.sdf_trunc,
im_h, im_w,
self.color_vol,
color_img,
)
self.weight_vol = weight_vol
self.tsdf_vol = tsdf_vol
self.color_vol = color_vol
def get_volume(self):
return self.tsdf_vol, self.weight_vol, self.color_vol
def get_mesh(self):
"""Compute a mesh from the voxel volume using marching cubes.
"""
tsdf_vol, weight_vol, color_vol = self.get_volume()
verts, faces, norms, vals = measure.marching_cubes(tsdf_vol.cpu().numpy(), level=0)
verts_ind = np.round(verts).astype(int)
verts = verts * self.voxel_size + self.vol_origin.cpu().numpy() # voxel grid coordinates to world coordinates
if self.fuse_color:
rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]].cpu().numpy()
return verts, faces, norms, rgb_vals.astype(np.uint8)
else:
return verts, faces, norms
def to_o3d_mesh(self):
"""Convert to o3d mesh object for visualization
"""
verts, faces, norms, colors = self.get_mesh()
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts.astype(float))
mesh.triangles = o3d.utility.Vector3iVector(faces.astype(np.int32))
mesh.vertex_colors = o3d.utility.Vector3dVector(colors / 255.)
return mesh
def get_normals(self):
"""Compute normal volume
"""
nx, ny, nz = self.vol_dim
device = self.device
# dx = torch.cat([torch.zeros(1, ny, nz).to(device), (self.tsdf_vol[2:, :, :] - self.tsdf_vol[:-2, :, :]) / (2 * self.voxel_size), torch.zeros(1, ny, nz).to(device)], dim=0)
# dy = torch.cat([torch.zeros(nx, 1, nz).to(device), (self.tsdf_vol[:, 2:, :] - self.tsdf_vol[:, :-2, :]) / (2 * self.voxel_size), torch.zeros(nx, 1, nz).to(device)], dim=1)
# dz = torch.cat([torch.zeros(nx, ny, 1).to(device), (self.tsdf_vol[:, :, 2:] - self.tsdf_vol[:, :, :-2]) / (2 * self.voxel_size), torch.zeros(nx, ny, 1).to(device)], dim=2)
# norms = torch.stack([dx, dy, dz], -1)
dx = torch.cat([(self.tsdf_vol[1:, :, :] - self.tsdf_vol[:-1, :, :]) / self.voxel_size, torch.zeros(1, ny, nz).to(device)], dim=0)
dy = torch.cat([(self.tsdf_vol[:, 1:, :] - self.tsdf_vol[:, :-1, :]) / self.voxel_size, torch.zeros(nx, 1, nz).to(device)], dim=1)
dz = torch.cat([(self.tsdf_vol[:, :, 1:] - self.tsdf_vol[:, :, :-1]) / self.voxel_size, torch.zeros(nx, ny, 1).to(device)], dim=2)
norms = torch.stack([dx, dy, dz], -1)
n = torch.norm(norms, dim=-1)
# remove large values
outliers_mask = n > 1. / (2 * self.voxel_size)
norms[outliers_mask] = 0.
# normalize
eps = 1e-7
non_zero_grad = n > eps
norms[non_zero_grad, :] = norms[non_zero_grad, :] / n[non_zero_grad][:, None]
return norms # [nx, ny, nz, 3]
def get_nn(self, field_vol, coords_w):
"""Get nearest-neigbor values from a given volume
"""
field_dim = field_vol.shape
assert len(field_dim) == 3 or len(field_dim) == 4
vox_coord_float = (coords_w - self.vol_origin[None, :]) / self.voxel_size
vox_coord = torch.floor(vox_coord_float)
vox_offset = vox_coord_float - vox_coord # [N, 3]
vox_coord[vox_offset >= 0.5] += 1.
vox_coord[:, 0] = torch.clamp(vox_coord[:, 0], 0., self.vol_dim[0] - 1)
vox_coord[:, 1] = torch.clamp(vox_coord[:, 1], 0., self.vol_dim[1] - 1)
vox_coord[:, 2] = torch.clamp(vox_coord[:, 2], 0., self.vol_dim[2] - 1)
vox_coord = vox_coord.long()
vx, vy, vz = vox_coord[:, 0], vox_coord[:, 1], vox_coord[:, 2]
v_nn = field_vol[vx, vy, vz]
return v_nn
def tril_interp(self, field_vol, coords_w):
"""Get tri-linear interpolated value from a given volume
"""
field_dim = field_vol.shape
assert len(field_dim) == 3 or len(field_dim) == 4
n_pts = coords_w.shape[0]
vox_coord = torch.floor((coords_w - self.vol_origin[None, :]) / self.voxel_size).long() # [N, 3]
# for border points, don't do interpolation
non_border_mask = (vox_coord[:, 0] < self.vol_dim[0] - 1) & (vox_coord[:, 1] < self.vol_dim[1] - 1) & \
(vox_coord[:, 2] < self.vol_dim[2] - 1)
v_interp = torch.zeros(n_pts) if len(field_dim) == 3 else torch.zeros(n_pts, field_vol.shape[-1])
v_interp = v_interp.to(self.device)
vx_, vy_, vz_ = vox_coord[~non_border_mask, 0], vox_coord[~non_border_mask, 1], vox_coord[~non_border_mask, 2]
v_interp[~non_border_mask] = field_vol[vx_, vy_, vz_]
# get interpolated values for normal points
vx, vy, vz = vox_coord[non_border_mask, 0], vox_coord[non_border_mask, 1], vox_coord[non_border_mask, 2] # [N]
vox_idx = vz + vy * self.vol_dim[-1] + vx * self.vol_dim[-1] * self.vol_dim[-2]
vertices_coord = self.world_c[vox_idx][:, :3] # [N, 3]
r = (coords_w[non_border_mask] - vertices_coord) / self.voxel_size
rx, ry, rz = r[:, 0], r[:, 1], r[:, 2]
if len(field_dim) == 4:
rx = rx.unsqueeze(1)
ry = ry.unsqueeze(1)
rz = rz.unsqueeze(1)
# get values at eight corners
v000 = field_vol[vx, vy, vz]
v001 = field_vol[vx, vy, vz+1]
v010 = field_vol[vx, vy+1, vz]
v011 = field_vol[vx, vy+1, vz+1]
v100 = field_vol[vx+1, vy, vz]
v101 = field_vol[vx+1, vy, vz+1]
v110 = field_vol[vx+1, vy+1, vz]
v111 = field_vol[vx+1, vy+1, vz+1]
v_interp[non_border_mask] = v000 * (1 - rx) * (1 - ry) * (1 - rz) \
+ v001 * (1 - rx) * (1 - ry) * rz \
+ v010 * (1 - rx) * ry * (1 - rz) \
+ v011 * (1 - rx) * ry * rz \
+ v100 * rx * (1 - ry) * (1 - rz) \
+ v101 * rx * (1 - ry) * rz \
+ v110 * rx * ry * (1 - rz) \
+ v111 * rx * ry * rz
return v_interp
def get_pts_inside(self, pts, margin=0):
vox_coord = torch.floor((pts - self.vol_origin[None, :]) / self.voxel_size).long() # [N, 3]
valid_pts_mask = (vox_coord[..., 0] >= margin) & (vox_coord[..., 0] < self.vol_dim[0] - margin) \
& (vox_coord[..., 1] >= margin) & (vox_coord[..., 1] < self.vol_dim[1] - margin) \
& (vox_coord[..., 2] >= margin) & (vox_coord[..., 2] < self.vol_dim[2] - margin)
return valid_pts_mask
# use simple root finding
@torch.no_grad()
def render_model(self, c2w, intri, imh, imw, near=0.5, far=5., n_samples=192):
"""
Perform ray-casting for frame-to-model tracking
:param c2w: camera pose, [4, 4]
:param intri: camera intrinsics, [3, 3]
:param imh: image height
:param imw: image width
:param near: near bound for ray-casting
:param far: far bound for ray-casting
:param n_samples: number of samples along the ray
:return: rendered depth, color, vertex, normal and valid mask, [H, W, C]
"""
rays_o, rays_d = self.get_rays(c2w, intri, imh, imw) # [h, w, 3]
z_vals = torch.linspace(near, far, n_samples).to(rays_o) # [n_samples]
ray_pts_w = (rays_o[:, :, None, :] + rays_d[:, :, None, :] * z_vals[None, None, :, None]).to(self.device) # [h, w, n_samples, 3]
# need to query the tsdf and feature grid
tsdf_vals = torch.ones(imh, imw, n_samples).to(self.device)
# filter points that are outside the volume
valid_ray_pts_mask = self.get_pts_inside(ray_pts_w)
valid_ray_pts = ray_pts_w[valid_ray_pts_mask] # [n_valid, 3]
tsdf_vals[valid_ray_pts_mask] = self.tril_interp(self.tsdf_vol, valid_ray_pts)
# surface prediction by finding zero crossings
sign_matrix = torch.cat([torch.sign(tsdf_vals[..., :-1] * tsdf_vals[..., 1:]),
torch.ones(imh, imw, 1).to(self.device)], dim=-1) # [h, w, n_samples]
cost_matrix = sign_matrix * torch.arange(n_samples, 0, -1).float().to(self.device)[None, None, :] # [h, w, n_samples]
# Get first sign change and mask for values where
# a.) a sign changed occurred and
# b.) not a neg to pos sign change occurred
# c.) ignore border points
values, indices = torch.min(cost_matrix, -1)
mask_sign_change = values < 0
hs, ws = torch.meshgrid(torch.arange(imh), torch.arange(imw))
mask_pos_to_neg = tsdf_vals[hs, ws, indices] > 0
inside_vol = self.get_pts_inside(ray_pts_w[hs, ws, indices])
hit_surface_mask = mask_sign_change & mask_pos_to_neg & inside_vol
hit_pts = ray_pts_w[hs, ws, indices][hit_surface_mask] # [n_surf_pts, 3]
# compute normals
norms = self.get_normals()
surf_tsdf = self.tril_interp(self.tsdf_vol, hit_pts) # [n_surf_pts]
# surf_norms = self.tril_interp(norms, hit_pts) # [n_surf_pts, 3]
surf_norms = self.get_nn(norms, hit_pts)
updated_hit_pts = hit_pts - surf_tsdf[:, None] * self.sdf_trunc * surf_norms
valid_mask = self.get_pts_inside(updated_hit_pts)
hit_pts[valid_mask, :] = updated_hit_pts[valid_mask, :]
# get depth values
w2c = torch.inverse(c2w).to(self.device)
hit_pts_c = (w2c[:3, :3] @ hit_pts.transpose(1, 0)).transpose(1, 0) + w2c[:3, 3][None, :]
hit_pts_z = hit_pts_c[:, -1]
depth_rend = torch.zeros(imh, imw).to(self.device)
# depth_rend[hit_surface_mask] = z_vals[indices[hit_surface_mask]]
depth_rend[hit_surface_mask] = hit_pts_z
# vertex map
vertex_rend = torch.zeros(imh, imw, 3).to(self.device)
vertex_rend[hit_surface_mask] = hit_pts_c
# normal map
surf_norms_c = (w2c[:3, :3] @ surf_norms.transpose(1, 0)).transpose(1, 0) # [h, w, 3]
normal_rend = torch.zeros(imh, imw, 3).to(self.device)
normal_rend[hit_surface_mask] = surf_norms_c
if self.color_vol is not None:
# hit_colors = self.color_vol[cx, cy, cz, :]
hit_colors = self.tril_interp(self.color_vol, hit_pts)
# set color
color_rend = torch.zeros(imh, imw, 3).to(self.device)
color_rend[hit_surface_mask] = hit_colors
else:
color_rend = None
return depth_rend, color_rend, vertex_rend, normal_rend, hit_surface_mask
def render_pyramid(self, c2w, intri, imh, imw, n_pyr=4, near=0.5, far=5., n_samples=192):
K = intri.clone()
dep_pyr, rgb_pyr, vtx_pyr, nrm_pyr, mask_pyr = [], [], [], [], []
for l in range(n_pyr):
dep, rgb, feat, vtx, nrm, mask = self.render_model(c2w, K, imh, imw, near=near, far=far, n_samples=n_samples)
dep_pyr += [dep]
rgb_pyr += [rgb]
vtx_pyr += [vtx]
nrm_pyr += [nrm]
mask_pyr += [mask]
imh = imh // 2
imw = imw // 2
K /= 2
return dep_pyr, rgb_pyr, vtx_pyr, nrm_pyr, mask_pyr
# get voxel index given world coordinate
# used for testing
def get_voxel_idx(self, x):
"""
:param x: [N, 3] query points
:return: [N] voxel indices
"""
assert len(x.shape) == 2, print("only accept flattened input!!!")
x.to(self.device)
vox_coord = torch.floor((x - self.vol_origin[None, :]) / self.voxel_size) # [N, 3]
vx, vy, vz = vox_coord[:, 0], vox_coord[:, 1], vox_coord[:, 2]
# very important! get voxel index from voxel coordinate
vox_idx = vz + vy * self.vol_dim[-1] + vx * self.vol_dim[-1] * self.vol_dim[-2]
return vox_idx.long()
def get_rays(self, c2w, intrinsics, H, W):
device = self.device
c2w = c2w.to(device)
fx = intrinsics[0, 0]
fy = intrinsics[1, 1]
cx = intrinsics[0, 2]
cy = intrinsics[1, 2]
i, j = torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t().to(device).reshape(H * W) # [hw]
j = j.t().to(device).reshape(H * W) # [hw]
dirs = torch.stack([(i - cx) / fx, (j - cy) / fy, torch.ones_like(i)], -1).to(device) # [hw, 3]
# permute for bmm
dirs = dirs.transpose(1, 0) # [3, hw]
rays_d = (c2w[:3, :3] @ dirs).transpose(1, 0) # [hw, 3]
rays_o = c2w[:3, 3].expand(rays_d.shape)
return rays_o.reshape(H, W, 3), rays_d.reshape(H, W, 3)
此处是我最喜欢的环节,自认为是整个算法的核心。
tracker.py是追踪过程的主函数,icp.py完成了追踪的主要功能。
import torch
import torch.nn as nn
from icp import ICP
class ICPTracker(nn.Module):
def __init__(self,
args,
device,
):
super(ICPTracker, self).__init__()
self.n_pyr = args.n_pyramids
self.scales = list(range(self.n_pyr))
self.n_iters = args.n_iters
self.dampings = args.dampings
# KinectFusion在执行过程调用了图像金字塔
self.construct_image_pyramids = ImagePyramids(self.scales, pool='avg')
self.construct_depth_pyramids = ImagePyramids(self.scales, pool='max')
self.device = device
# initialize tracker at different levels
self.icp_solvers = []
for i in range(self.n_pyr):
self.icp_solvers += [ICP(self.n_iters[i], damping=self.dampings[i])]
@torch.no_grad()
def forward(self, depth0, depth1, K):
H, W = depth0.shape
dpt0_pyr = self.construct_depth_pyramids(depth0.view(1, 1, H, W))
dpt0_pyr = [d.squeeze() for d in dpt0_pyr]
dpt1_pyr = self.construct_depth_pyramids(depth1.view(1, 1, H, W))
dpt1_pyr = [d.squeeze() for d in dpt1_pyr]
# optimization steps
pose10 = torch.eye(4).to(self.device) # initialize from identity, eye()在这里创建了单位矩阵
for i in reversed(range(self.n_pyr)):
Ki = get_scaled_K(K, i)
pose10 = self.icp_solvers[i](pose10, dpt0_pyr[i], dpt1_pyr[i], Ki)
return pose10
class ImagePyramids(nn.Module):
""" Construct the pyramids in the image / depth space
"""
def __init__(self, scales, pool='avg'):
super(ImagePyramids, self).__init__()
if pool == 'avg':
self.multiscales = [nn.AvgPool2d(1<
icp.py
import torch
import torch.nn as nn
import torch.nn.functional as F
# forward ICP
class ICP(nn.Module):
def __init__(self,
max_iter=3, # 最大迭代次数
damping=1e-3, # 衰减率
):
"""
:param max_iter, maximum number of iterations
:param damping, damping added to Hessian matrix
"""
super(ICP, self).__init__()
self.max_iterations = max_iter
self.damping = damping
def forward(self, pose10, depth0, depth1, K):
"""
In all cases we refer to 0 as template, and always warp pixels from 0 to 1
:param pose10: initial pose estimate
:param depth0: template depth image (0)
:param depth1: depth image (1)
:param K: intrinsic matric
:return: refined 0-to-1 transformation pose10
"""
# create vertex and normal for current frame
vertex0 = compute_vertex(depth0, K)
normal0 = compute_normal(vertex0)
mask0 = depth0 > 0.
vertex1 = compute_vertex(depth1, K)
normal1 = compute_normal(vertex1)
for idx in range(self.max_iterations):
# compute residuals
residuals, J_F_p = self.compute_residuals_jacobian(vertex0, vertex1, normal0, normal1, mask0, pose10, K)
JtWJ = self.compute_jtj(J_F_p) # [B, 6, 6]
JtR = self.compute_jtr(J_F_p, residuals)
pose10 = self.GN_solver(JtWJ, JtR, pose10, damping=self.damping)
return pose10
@staticmethod
def compute_residuals_jacobian(vertex0, vertex1, normal0, normal1, mask0, pose10, K):
"""
:param vertex0: vertex map 0
:param vertex1: vertex map 1
:param normal0: normal map 0
:param normal1: normal map 1
:param mask0: valid mask of template depth image
:param pose10: current estimate of pose10
:param K: intrinsics
:return: residuals and Jacobians
"""
R = pose10[:3, :3]
t = pose10[:3, 3]
H, W, C = vertex0.shape
rot_vertex0_to1 = (R @ vertex0.view(-1, 3).permute(1, 0)).permute(1, 0).view(H, W, 3)
vertex0_to1 = rot_vertex0_to1 + t[None, None, :]
normal0_to1 = (R @ normal0.view(-1, 3).permute(1, 0)).permute(1, 0).view(H, W, 3)
fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
x_, y_, z_ = vertex0_to1[..., 0], vertex0_to1[..., 1], vertex0_to1[..., 2] # [h, w]
u_ = (x_ / z_) * fx + cx # [h, w]
v_ = (y_ / z_) * fy + cy # [h, w]
inviews = (u_ > 0) & (u_ < W-1) & (v_ > 0) & (v_ < H-1)
# projective data association
r_vertex1 = warp_features(vertex1, u_, v_) # [h, w, 3]
r_normal1 = warp_features(normal1, u_, v_) # [h, w, 3]
mask1 = r_vertex1[..., -1] > 0.
diff = vertex0_to1 - r_vertex1 # [h, w, 3]
# point-to-plane residuals
res = (r_normal1 * diff).sum(dim=-1) # [h, w]
# point-to-plane jacobians
J_trs = r_normal1.view(-1, 3) # [hw, 3]
J_rot = -torch.bmm(J_trs.unsqueeze(dim=1), batch_skew(vertex0_to1.view(-1, 3))).squeeze() # [hw, 3]
# compose ja cobians
J_F_p = torch.cat((J_rot, J_trs), dim=-1).view(H, W, 6) # follow the order of [rot, trs] [hw, 1, 6]
# occlusion
occ = ~inviews | (diff.norm(p=2, dim=-1) > 0.10)
invalid_mask = occ | ~mask0 | ~mask1
J_F_p[invalid_mask] = 0.
res[invalid_mask] = 0.
res = res.view(-1, 1) # [hw, 1]
J_F_p = J_F_p.view(-1, 1, 6) # [hw, 1, 6]
return res, J_F_p
@staticmethod
def compute_jtj(jac):
# J in the dimension of (HW, C, 6)
jacT = jac.transpose(-1, -2) # [HW, 6, C]
jtj = torch.bmm(jacT, jac).sum(0) # [6, 6]
return jtj # [6, 6]
@staticmethod
def compute_jtr(jac, res):
# J in the dimension of (HW, C, 6)
# res in the dimension of [HW, C]
jacT = jac.transpose(-1, -2) # [HW, 6, C]
jtr = torch.bmm(jacT, res.unsqueeze(-1)).sum(0) # [6, 1]
return jtr # [6, 1]
@staticmethod
def GN_solver(JtJ, JtR, pose0, damping=1e-6):
# Add a small diagonal damping. Without it, the training becomes quite unstable
# Do not see a clear difference by removing the damping in inference though
Hessian = lev_mar_H(JtJ, damping)
# Hessian = JtJ
updated_pose = forward_update_pose(Hessian, JtR, pose0)
return updated_pose
def warp_features(Feat, u, v, mode='bilinear'):
"""
Warp the feature map (F) w.r.t. the grid (u, v). This is the non-batch version
"""
assert len(Feat.shape) == 3
H, W, C = Feat.shape
u_norm = u / ((W - 1) / 2) - 1 # [h, w]
v_norm = v / ((H - 1) / 2) - 1 # [h, w]
uv_grid = torch.cat((u_norm.view(1, H, W, 1), v_norm.view(1, H, W, 1)), dim=-1)
Feat_warped = F.grid_sample(Feat.unsqueeze(0).permute(0, 3, 1, 2), uv_grid, mode=mode, padding_mode='border', align_corners=True).squeeze()
return Feat_warped.permute(1, 2, 0)
def compute_vertex(depth, K):
H, W = depth.shape
fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
device = depth.device
i, j = torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t().to(device) # [h, w]
j = j.t().to(device) # [h, w]
vertex = torch.stack([(i - cx) / fx, (j - cy) / fy, torch.ones_like(i)], -1).to(device) * depth[..., None] # [h, w, 3]
return vertex
def compute_normal(vertex_map):
""" Calculate the normal map from a depth map
:param the input depth image
-----------
:return the normal map
"""
H, W, C = vertex_map.shape
img_dx, img_dy = feature_gradient(vertex_map, normalize_gradient=False) # [h, w, 3]
normal = torch.cross(img_dx.view(-1, 3), img_dy.view(-1, 3))
normal = normal.view(H, W, 3) # [h, w, 3]
mag = torch.norm(normal, p=2, dim=-1, keepdim=True)
normal = normal / (mag + 1e-8)
# filter out invalid pixels
depth = vertex_map[:, :, -1]
# 0.5 and 5.
invalid_mask = (depth <= depth.min()) | (depth >= depth.max())
zero_normal = torch.zeros_like(normal)
normal = torch.where(invalid_mask[..., None], zero_normal, normal)
return normal
def feature_gradient(img, normalize_gradient=True):
""" Calculate the gradient on the feature space using Sobel operator
:param the input image
-----------
:return the gradient of the image in x, y direction
"""
H, W, C = img.shape
# to filter the image equally in each channel
wx = torch.FloatTensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).view(1, 1, 3, 3).type_as(img)
wy = torch.FloatTensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).view(1, 1, 3, 3).type_as(img)
img_permuted = img.permute(2, 0, 1).view(-1, 1, H, W) # [c, 1, h, w]
img_pad = F.pad(img_permuted, (1, 1, 1, 1), mode='replicate')
img_dx = F.conv2d(img_pad, wx, stride=1, padding=0).squeeze().permute(1, 2, 0) # [h, w, c]
img_dy = F.conv2d(img_pad, wy, stride=1, padding=0).squeeze().permute(1, 2, 0) # [h, w, c]
if normalize_gradient:
mag = torch.sqrt((img_dx ** 2) + (img_dy ** 2) + 1e-8)
img_dx = img_dx / mag
img_dy = img_dy / mag
return img_dx, img_dy # [h, w, c]
def batch_skew(w):
""" Generate a batch of skew-symmetric matrices.
function tested in 'test_geometry.py'
:input
:param skew symmetric matrix entry Bx3
---------
:return
:param the skew-symmetric matrix Bx3x3
"""
B, D = w.shape
assert(D == 3)
o = torch.zeros(B).type_as(w)
w0, w1, w2 = w[:, 0], w[:, 1], w[:, 2]
return torch.stack((o, -w2, w1, w2, o, -w0, -w1, w0, o), 1).view(B, 3, 3)
def lev_mar_H(JtWJ, damping):
# Add a small diagonal damping. Without it, the training becomes quite unstable
# Do not see a clear difference by removing the damping in inference though
diag_mask = torch.eye(6).to(JtWJ)
diagJtJ = diag_mask * JtWJ
traceJtJ = torch.sum(diagJtJ)
epsilon = (traceJtJ * damping) * diag_mask
Hessian = JtWJ + epsilon
return Hessian
def forward_update_pose(H, Rhs, pose):
"""
:param H:
:param Rhs:
:param pose:
:return:
"""
xi = least_square_solve(H, Rhs).squeeze()
pose = exp_se3(xi) @ pose
return pose
def exp_se3(xi):
"""
:param x: Cartesian vector of Lie Algebra se(3)
:return: exponential map of x
"""
w = xi[:3].squeeze() # rotation
v = xi[3:6].squeeze() # translation
w_hat = torch.tensor([[0., -w[2], w[1]],
[w[2], 0., -w[0]],
[-w[1], w[0], 0.]]).to(xi)
w_hat_second = torch.mm(w_hat, w_hat).to(xi)
theta = torch.norm(w)
theta_2 = theta ** 2
theta_3 = theta ** 3
sin_theta = torch.sin(theta)
cos_theta = torch.cos(theta)
eye_3 = torch.eye(3).to(xi)
eps = 1e-8
if theta <= eps:
e_w = eye_3
j = eye_3
else:
e_w = eye_3 + w_hat * sin_theta / theta + w_hat_second * (1. - cos_theta) / theta_2
k1 = (1 - cos_theta) / theta_2
k2 = (theta - sin_theta) / theta_3
j = eye_3 + k1 * w_hat + k2 * w_hat_second
T = torch.eye(4).to(xi)
T[:3, :3] = e_w
T[:3, 3] = torch.mv(j, v)
# T[:3, 3] = v
return T
def invH(H):
""" Generate (H+damp)^{-1}, with predicted damping values
:param approximate Hessian matrix JtWJ
-----------
:return the inverse of Hessian
"""
# GPU is much slower for matrix inverse when the size is small (compare to CPU)
# works (50x faster) than inversing the dense matrix in GPU
if H.is_cuda:
invH = torch.inverse(H.cpu()).cuda()
else:
invH = torch.inverse(H)
return invH
def least_square_solve(H, Rhs):
"""
Solve for JTJ @ xi = -JTR
"""
inv_H = invH(H) # [B, 6, 6] square matrix
xi = -inv_H @ Rhs
return xi
其他链接:
KinectFusion原理介绍.