源代码:DynamicNeRF
主要思想:利用RAFT模型来预测flow
argparse.ArgumentParser()用法解析参数解析
action='store_true’的区别
#创建对象——添加参数——返回命名空间
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, help='Dataset path')
parser.add_argument('--model', help="restore RAFT checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
args = parser.parse_args()
os.path模块主要用来获取文件属性
os.path.join(path1[, path2[, …]]) 把目录和文件名合成一个路径
def create_dir(dir):
if not os.path.exists(dir):
os.makedirs(dir)
# 将目录和文件名合成一个路径
input_path = os.path.join(args.dataset_path, 'images')
output_path = os.path.join(args.dataset_path, 'flow')
output_img_path = os.path.join(args.dataset_path, 'flow_png')
create_dir(output_path)
create_dir(output_img_path)
torch.nn.DataParallel()表示多GPU训练用多块显卡来加速训练,参考好文
model.eval()使得网络处于测试模式,若处于训练模式:model.train()
glob.glob()返回所有匹配的文件路径列表
GPU上的tensor转成CPU上的numpy:.cpu().numpy()
.transpose(1,2,0)是调节数组维数变换。
np.savez()用来存储.npz文件,numpy数据存储
Image.fromarray()数组到图像的转换
np.arrary() 对应地从图像转数组
DEVICE = 'cuda'
def run(args, input_path, output_path, output_img_path):
#加载模型
#加速训练
model = torch.nn.DataParallel(RAFT(args))
#torch.load 加载state_dicts——使用state_dicts实例化model
model.load_state_dict(torch.load(args.model))
# 将模型送入gpu
model = model.module
model.to(DEVICE)
model.eval()
#获取指定目录下的所有图片
with torch.no_grad():
images = glob.glob(os.path.join(input_path, '*.png')) + \
glob.glob(os.path.join(input_path, '*.jpg'))
images = sorted(images)
for i in range(len(images) - 1):
print(i)
image1 = load_image(images[i])
image2 = load_image(images[i + 1])
#填充图像,使尺寸可被 8 整除,子函数来自RAFT
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
_, flow_fwd = model(image1, image2, iters=20, test_mode=True)
_, flow_bwd = model(image2, image1, iters=20, test_mode=True)
#GPU上的tensor———>CPU上的数组
flow_fwd = padder.unpad(flow_fwd[0]).cpu().numpy().transpose(1, 2, 0)
flow_bwd = padder.unpad(flow_bwd[0]).cpu().numpy().transpose(1, 2, 0)
#计算前向光流、后向光流的mask
mask_fwd, mask_bwd = compute_fwdbwd_mask(flow_fwd, flow_bwd)
# Save flow
np.savez(os.path.join(output_path, '%03d_fwd.npz'%i), flow=flow_fwd, mask=mask_fwd)
np.savez(os.path.join(output_path, '%03d_bwd.npz'%(i + 1)), flow=flow_bwd, mask=mask_bwd)
# Save flow_img
#numpy 转 image类
Image.fromarray(flow_viz.flow_to_image(flow_fwd)).save(os.path.join(output_img_path, '%03d_fwd.png'%i))
Image.fromarray(flow_viz.flow_to_image(flow_bwd)).save(os.path.join(output_img_path, '%03d_bwd.png'%(i + 1)))
Image.fromarray(mask_fwd).save(os.path.join(output_img_path, '%03d_fwd_mask.png'%i))
Image.fromarray(mask_bwd).save(os.path.join(output_img_path, '%03d_bwd_mask.png'%(i + 1)))
加载图像
像素值表达的两种类型:
float:0-1
uint8:0-255
float到unit8需要image*255;从uint8到float需要image/255
torch.from_numpy()把数组转换成张量,且二者共享内存,对张量进行修改比如重新赋值,那么原始数组也会相应发生改变。
tensor.permute()将tensor的维度换位,跟transpose()作用类似
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None].to(DEVICE)
图像warp变换
切片操作
[:2] 表示索引 0至1行 [ :, 2]:表示所有行的第3列
[:,:,0] 取其中的所有0号索引
[:,np.newaxis]表示升维
cv2.remap()函数的使用remap函数实际就是通过修改像素点的位置得到一幅新图像。已知前一帧的图像和光流,通过cv2.remap()来预测恢复下一帧的图像
def warp_flow(img, flow):
h, w = flow.shape[:2]
flow_new = flow.copy()
flow_new[:,:,0] += np.arange(w)
flow_new[:,:,1] += np.arange(h)[:,np.newaxis]
res = cv2.remap(img, flow_new, None, cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT)
return res
np.linalg.norm()求范数
#计算前向光流mask以及后向光流mask
def compute_fwdbwd_mask(fwd_flow, bwd_flow):
alpha_1 = 0.5
alpha_2 = 0.5
bwd2fwd_flow = warp_flow(bwd_flow, fwd_flow)
fwd_lr_error = np.linalg.norm(fwd_flow + bwd2fwd_flow, axis=-1)
fwd_mask = fwd_lr_error < alpha_1 * (np.linalg.norm(fwd_flow, axis=-1) \
+ np.linalg.norm(bwd2fwd_flow, axis=-1)) + alpha_2
fwd2bwd_flow = warp_flow(fwd_flow, bwd_flow)
bwd_lr_error = np.linalg.norm(bwd_flow + fwd2bwd_flow, axis=-1)
bwd_mask = bwd_lr_error < alpha_1 * (np.linalg.norm(bwd_flow, axis=-1) \
+ np.linalg.norm(fwd2bwd_flow, axis=-1)) + alpha_2
return fwd_mask, bwd_mask