PointNet++的SA模块中有不可导的FPS采样操作,梯度如何能够反向传播?

3D检测里面有这样一类操作,用白话讲,可以归类为“路由”:即把特征从哪里取出来,或把特征汇聚到哪里去。“路由”的建立过程本身可以不可导(比如FPS和ball query)。但“路由”一旦建立,特征的传递和映射过程一般都是可导的。所以对于FPS操作在反向传播时的具体实现我认为和torch.scatter的梯度反传类似,初始化一个0张量,然后按照scatter赋值过去。

对于pointnet++中的一个例子(因为不能debug进用C++写的自定义的cuda算子,debug没法进c++的代码的,只能到调用的那一步,所以推荐直接采用全局搜索),对于前向传播中那些不能导的操作,是要重写一下backward函数的,重写时可能会调用自己实现的cuda代码。可以看到重写的backward函数返回值只有第一个有梯度,第二个没有梯度。backward的参数输入来自于forward的输出的梯度,backward返回值对应于forward的输入的梯度,所以说明idx没有梯度,被截断了,说明idx这条路是死的,但是features这条路依然有梯度

案例
PointNet++的SA模块中有不可导的FPS采样操作,梯度如何能够反向传播?_第1张图片
PointNet++的SA模块中有不可导的FPS采样操作,梯度如何能够反向传播?_第2张图片
外部调用逻辑:
PointNet++的SA模块中有不可导的FPS采样操作,梯度如何能够反向传播?_第3张图片

Refer

  1. 【代码阅读】PointNet++代码梳理
  2. 【代码阅读】详解在Pytorch中定义自己写的CUDA编程函数
  3. 【代码阅读】PointNet++具体实现详解
  4. Python 利用setup制作自定义包-打包-安装
  5. pytorch通过torch.utils.cpp_extension构建CUDA/C++拓展
  6. 利用torch.autograd.Function自定义层的forward和backward
  7. PyTorch 源码解读之 torch.autograd:梯度计算详解

你可能感兴趣的:(深度学习,pytorch,深度学习,人工智能,pointnet++)