这里本人的参考源码是grid_sample的CPU内核的CPP实现:https://github.com/pytorch/pytorch/blob/b039a715ce4e9cca82ae3bf72cb84652957b2844/aten/src/ATen/native/cpu/GridSamplerKernel.cpp。
给定一个input(4D或5D,一般指原图像)和一个流场grid(4D或5D,一般指变形流),基于来自grid的像素位置和input的像素值计算output(输出图像)。
例如,对于4D的情况,input的形状为(N,C,Hin,Win),grid的形状为(N,Hout,Wout,2),那么输出结果即为(N,C,Hout,Wout)。对于output的每个位置(i, j),将根据grid在(i, j)上的值(x, y),从input的(x, y)处采样像素值(采样过程要考虑x、y非整值和越界的情况),用作output在(i, j)处的像素值。
(grid的值应当是根据input的空间维度(H,W)归一化到[-1,1]后的像素点坐标。例如x=-1,y=1代表输入的左上角像素。)
如果grid具有超出[-1,1]范围的值,相应的输出由padding_mode来处理:
inline void forward(TensorAccessor<scalar_t, 3>& out_slice,
const TensorAccessor<scalar_t, 3>& inp_slice,
int64_t offset, const Vec& grid_x, const Vec& grid_y,
int64_t len) const {
auto x = compute_W.apply(grid_x);
auto y = compute_H.apply(grid_y); // 首先根据grid算出反归一化后的插入位置
//基于双线性插值,对每个位置(小数)首先获得四个方向(到最近的整数位置)上的距离作为插值的权重
//会返回权重和mask(考虑是否需要处理超出边界的部分)
auto interp_params = compute_interp_params(x, y);
//以下皆为上一个函数的返回值
auto nw = std::get<4>(interp_params);
auto ne = std::get<5>(interp_params);
auto sw = std::get<6>(interp_params);
auto se = std::get<7>(interp_params);
auto nw_mask = std::get<8>(interp_params);
auto ne_mask = std::get<9>(interp_params);
auto sw_mask = std::get<10>(interp_params);
auto se_mask = std::get<11>(interp_params);
auto i_y_n = std::get<12>(interp_params);
auto i_x_w = std::get<13>(interp_params);
//获得原图input上grid所指示的位置附近四个整数像素点的位置
auto i_nw_offset = i_y_n * iVec(inp_sH) + i_x_w * iVec(inp_sW);
auto i_ne_offset = i_nw_offset + iVec(inp_sW);
auto i_sw_offset = i_nw_offset + iVec(inp_sH);
auto i_se_offset = i_sw_offset + iVec(inp_sW);
#ifndef _MSC_VER
# pragma unroll
#endif
for (int64_t c = 0; c < C; ++c) { //C为batch_size
auto inp_slice_C_ptr = inp_slice[c].data();
// mask_gather zeros out the mask, so we need to make copies
Vec nw_mask_copy = nw_mask;
Vec ne_mask_copy = ne_mask;
Vec sw_mask_copy = sw_mask;
Vec se_mask_copy = se_mask;
//获得原图中四个方向位置中的像素值
//这里其实是通过对输入图像的底层指针进行偏移量计算来实现根据索引进行插入的效果的
auto nw_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_nw_offset, nw_mask_copy);
auto ne_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_ne_offset, ne_mask_copy);
auto sw_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_sw_offset, sw_mask_copy);
auto se_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_se_offset, se_mask_copy);
//根据各方向权重计算出最终插值结果
auto interpolated = (nw_val * nw) + (ne_val * ne) + (sw_val * sw) + (se_val * se);
interpolated.store(out_slice[c].data() + offset, len);
}
backward()计算关于grid的梯度关键:
gx = gx + ((ne_val - nw_val) * s + (se_val - sw_val) * n) * gOut;
gy = gy + ((sw_val - nw_val) * e + (se_val - ne_val) * w) * gOut;
这里gOut应指来自下一层传回的梯度,ne_val,nw_val,se_val,sw_val指四个方向位置上的原图像像素值,这四个值都是通过以grid值作为索引查找原图像相邻位置去获取到的;s、n、e、w分别指该grid值到这四个方向整数位置上的一个距离(用作双线性插值的权重)。
3. “grid_sample_2d_grid_slice_iterator”函数
提供一个抽象来有效地迭代一个“grid”分片(不带batch维度)。实质上是遍历了每个实例,然后对每个实例应用上述前向和反向处理。
(在双线性插值下),grid_sample()对grid求导采取了类似图像梯度的方式,直接用每个grid值关联到的周围四个位置上的像素值,将两两的差值乘上一个权重(双线性插值的距离),用作本函数的梯度,然后传回给前一层。
考虑这么一个图像矫正的问题,如果有一张输入的变形图像inp,一个参考恢复网格grid_gt,一个预测恢复网格grid_pred,要衡量grid_pred网格的正确性,有两种做法:
由于任务关心的实际上是最终的矫正结果的效果,而不是grid绝对值的差距,因此本人认为后一种方法更加准确。再结合上述对grid_sample()求导的分析,这两种方法传递的梯度信息是大不相同的(也即后一种做法在实现上是具有意义的),grid_sample()对grid某点值的偏导会考虑该点在原图像上所有相邻点的像素值。
由于本人水平有限,对代码的理解上可能不够深入,如果存在错误之处,请大神在评论区指正!