如 DCNv1 和 DCNv2 论文所述,DeformConv 相比常规卷积的参数量和计算量增加不多,但对网络的提升很大。
然而,DeformConv 的计算模式并不利于高效实现,给网络带来的开销比纸面数值大:
在 Torchvision 以及其他框架中,DeformConv2d 采用 Explicit GEMM 的方式实现。具体步骤为:
bias
。与 torch.nn.functional.conv2d 相比多两个 Tensor 输入。
def deform_conv2d(
input: Tensor,
offset: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1),
mask: Optional[Tensor] = None,
) -> Tensor:
r"""
Performs Deformable Convolution v2, described in
`Deformable ConvNets v2: More Deformable, Better Results
`__ if :attr:`mask` is not ``None`` and
Performs Deformable Convolution, described in
`Deformable Convolutional Networks
`__ if :attr:`mask` is ``None``.
Args:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]):
offsets to be applied for each position in the convolution kernel.
weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): convolution weights,
split into groups of size (in_channels // groups)
bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None
stride (int or Tuple[int, int]): distance between convolution centers. Default: 1
padding (int or Tuple[int, int]): height/width of padding of zeroes around
each image. Default: 0
dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1
mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, out_height, out_width]):
masks to be applied for each position in the convolution kernel. Default: None
Returns:
Tensor[batch_sz, out_channels, out_h, out_w]: result of convolution
Examples::
>>> input = torch.rand(4, 3, 10, 10)
>>> kh, kw = 3, 3
>>> weight = torch.rand(5, 3, kh, kw)
>>> # offset and mask should have the same spatial size as the output
>>> # of the convolution. In this case, for an input of 10, stride of 1
>>> # and kernel size of 3, without padding, the output size is 8
>>> offset = torch.rand(4, 2 * kh * kw, 8, 8)
>>> mask = torch.rand(4, kh * kw, 8, 8)
>>> out = deform_conv2d(input, offset, weight, mask=mask)
>>> print(out.shape)
>>> # returns
>>> torch.Size([4, 5, 8, 8])
"""
torch.jit.is_scripting 函数在编译时返回 True ,否则返回 False。
torch.jit.is_tracing 在跟踪中返回 True,否则返回 False。
_log_api_usage_once 记录组织内的 API 使用情况(模块和名称)。
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(deform_conv2d)
_assert_has_ops()
out_channels = weight.shape[0]
use_mask = mask is not None
如果参数为空则使用 torch.zeros 设置为一个低维度为0的张量。
if mask is None:
mask = torch.zeros((input.shape[0], 0), device=input.device, dtype=input.dtype)
if bias is None:
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
_pair 可以将单个数值转换为两元素元组。
stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding)
dil_h, dil_w = _pair(dilation)
weights_h, weights_w = weight.shape[-2:]
_, n_in_channels, _, _ = input.shape
grps 是 groups 的缩写。
offset
的形状为[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]
,weight
的形状为[out_channels, in_channels // groups, kernel_height, kernel_width]
。
n_offset_grps
为偏移分组的数量。
n_weight_grps
为卷积时对于权重的分组数 G G G。
n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w)
n_weight_grps = n_in_channels // weight.shape[1]
if n_offset_grps == 0:
raise RuntimeError(
"the shape of the offset tensor at dimension 1 is not valid. It should "
"be a multiple of 2 * weight.size[2] * weight.size[3].\n"
f"Got offset.shape[1]={offset.shape[1]}, while 2 * weight.size[2] * weight.size[3]={2 * weights_h * weights_w}"
)
return torch.ops.torchvision.deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
use_mask,
)
C10_LOG_API_USAGE_ONCE 非常轻量级的日志记录首次使用 API。
Dispatcher::singleton 返回一个 Dispatcher 对象。
Dispatcher::findSchemaOrThrow 从算子查找表中查询,返回 OperatorHandle。
OperatorHandle::typed 返回 TypedOperatorHandle,后者是向调度器注册的算子模式的句柄。
TypedOperatorHandle::call 会调用 Dispatcher::call。
at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d");
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d)>();
return op.call(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
}
GPU 的实现。
输入数据需要连续。
at::Tensor deform_conv2d_forward_kernel(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask) {
at::Tensor input_c = input.contiguous();
at::Tensor offset_c = offset.contiguous();
at::Tensor weight_c = weight.contiguous();
at::Tensor mask_c = mask.contiguous();
at::Tensor bias_c = bias.contiguous();
输入变量除bias
外均为4维。
TORCH_CHECK(input_c.ndimension() == 4);
TORCH_CHECK(offset_c.ndimension() == 4);
TORCH_CHECK(!use_mask || mask_c.ndimension() == 4);
TORCH_CHECK(weight_c.ndimension() == 4);
TORCH_CHECK(input_c.is_cuda(), "input must be a CUDA tensor");
DeviceGuard 设置为输入数据所在的设备。
at::DeviceGuard guard(input_c.device());
int batch_sz = input_c.size(0);
int in_channels = input_c.size(1);
int in_h = input_c.size(2);
int in_w = input_c.size(3);
get_greatest_divisor_below_bound 找到batch_sz
不超过 kMaxParallelImgs 的最大因数作为 batch 维度拆分的份数。
整除省去了余数处理,但是可能并行度不高。
int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
int out_channels = weight_c.size(0);
int weight_h = weight_c.size(2);
int weight_w = weight_c.size(3);
int ker_h = dilation_h * (weight_h - 1) + 1;
int ker_w = dilation_w * (weight_w - 1) + 1;
int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1;
对维度和参数进行检查。与一般卷积相同。
TORCH_CHECK(
weight_h > 0 && weight_w > 0,
"weight_h: ",
weight_h,
" weight_w: ",
weight_w);
TORCH_CHECK(
stride_h > 0 && stride_w > 0,
"stride_h: ",
stride_h,
" stride_w: ",
stride_w);
TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w);
TORCH_CHECK(
dilation_h > 0 && dilation_w > 0,
"dilation_h: ",
dilation_h,
" dilation_w: ",
dilation_w);
TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1));
TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0);
n_offset_grps
表示将输入的通道分组应用 offset 和 mask,表示为 G Δ G_{\Delta} GΔ。
offset_c
形状为 N × 2 G Δ K h K w × H o × W o \mathrm{N\times 2G_{\Delta}K_h K_w \times H_o\times W_o} N×2GΔKhKw×Ho×Wo,mask_c
为 N × G Δ k h k w × H o × W o \mathrm{N\times G_{\Delta}k_h k_w \times H_o\times W_o} N×GΔkhkw×Ho×Wo
TORCH_CHECK(
(offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w),
"offset.shape[1] is not valid: got: ",
offset_c.size(1),
" expected: ",
n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(
(!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w),
"mask.shape[1] is not valid: got: ",
mask_c.size(1),
" expected: ",
n_offset_grps * weight_h * weight_w);
TORCH_CHECK(input_c.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset_c.size(0) == input_c.size(0)), "invalid batch size of offset");
TORCH_CHECK(
(offset_c.size(2) == out_h && offset_c.size(3) == out_w),
"offset output dims: (",
offset_c.size(2),
", ",
offset_c.size(3),
") - ",
"computed output dims: (",
out_h,
", ",
out_w,
")");
TORCH_CHECK(
(mask_c.size(0) == input_c.size(0)), "invalid batch size of mask");
TORCH_CHECK(
(!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)),
"mask output dims: (",
mask_c.size(2),
", ",
mask_c.size(3),
") - ",
"computed output dims: (",
out_h,
", ",
out_w,
")");
TORCH_CHECK(
out_h > 0 && out_w > 0,
"Calculated output size too small - out_h: ",
out_h,
" out_w: ",
out_w);
out
为最终的输出。对于非零输入而言创建的有点早。
auto out =
at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options());
if (batch_sz == 0) {
return out;
}
调整张量的形状,将out
等变量的 batch 维度拆分,调整为5维。便于每次处理n_parallel_imgs
张图。
out
的形状为 N n × n × C o × H o × W o \mathrm{\frac{N}{n}\times n\times C_o \times H_o\times W_o} nN×n×Co×Ho×Wo 。
// Separate batches into blocks
out = out.view(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
out_channels,
out_h,
out_w});
input_c = input_c.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w});
offset_c = offset_c.view(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
if (use_mask) {
mask_c = mask_c.view(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
}
out_buf
用于存放转置前的结果,形状为 N n × C o × n H o × W o \mathrm{\frac{N}{n} \times C_o \times n H_o\times W_o} nN×Co×nHo×Wo 。
at::Tensor out_buf = at::zeros(
{batch_sz / n_parallel_imgs,
out_channels,
n_parallel_imgs * out_h,
out_w},
out.options());
将out_buf
调整为 N n × G × C o G × n H o × W o \mathrm{\frac{N}{n}\times G\times\frac{C_o}{G}\times nH_o\times W_o} nN×G×GCo×nHo×Wo 的形状,weight_c
为 G × C o G × C i G × k h × k w \mathrm{G \times\frac{C_o}{G}\times\frac{C_i}{G}\times k_h \times k_w} G×GCo×GCi×kh×kw 。
// Separate channels into convolution groups
out_buf = out_buf.view(
{out_buf.size(0),
n_weight_grps,
out_buf.size(1) / n_weight_grps,
out_buf.size(2),
out_buf.size(3)});
weight_c = weight_c.view(
{n_weight_grps,
weight_c.size(0) / n_weight_grps,
weight_c.size(1),
weight_c.size(2),
weight_c.size(3)});
columns
是一个 C i k h k w × n H o W o \mathrm{C_i k_h k_w\times nH_o W_o} Cikhkw×nHoWo 二维矩阵,为每个卷积核准备输入。
循环调用 deformable_im2col 函数,每次处理n_parallel_imgs
张图像。
因为相比普通 im2col 的内存占用更多,只能分块进行处理。
// Sample points and perform convolution
auto columns = at::zeros(
{in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w},
input_c.options());
for (int b = 0; b < batch_sz / n_parallel_imgs; b++) {
deformable_im2col(
input_c[b],
offset_c[b],
mask_c[b],
in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
out_h,
out_w,
n_parallel_imgs,
n_offset_grps,
use_mask,
columns);
columns
形状调整为 G × C i G k h k w × n H o W o \mathrm{G \times \frac{C_i}{G} k_h k_w\times nH_o W_o} G×GCikhkw×nHoWo 。
out_buf[b][g]
形状为 C o G × n H o × W o \mathrm{\frac{C_o}{G}\times nH_o\times W_o} GCo×nHo×Wo,torch.Tensor.flatten 将其第1第2维展平后为 C o G × n H o W o \mathrm{\frac{C_o}{G}\times nH_o W_o} GCo×nHoWo 。
torch.Tensor.addmm_ 是 torch.Tensor.addmm 的原位版本,执行矩阵乘法。
weight_c[g]
的形状为 C o G × C i G × k h × k w \mathrm{\frac{C_o}{G}\times\frac{C_i}{G}\times k_h \times k_w} GCo×GCi×kh×kw,将weight_c[g]
展平成二维矩阵 C o G × C i G k h k w \mathrm{\frac{C_o}{G}\times\frac{C_i}{G} k_h k_w} GCo×GCikhkw 。
columns[g]
为 C i G k h k w × n H o W o \mathrm{\frac{C_i}{G} k_h k_w\times nH_o W_o} GCikhkw×nHoWo 二维的矩阵。
因此,循环中的计算公式为:
ℜ C o G × n H o W o = ℜ C o G × C i G k h k w × ℜ C i G k h k w × n H o W o \Re^\mathrm{\frac{C_o}{G}\times nH_o W_o} = \Re^{\mathrm{\frac{C_o}{G}\times\frac{C_i}{G} k_h k_w}}\times \Re^{\mathrm{\frac{C_i}{G} k_h k_w\times nH_o W_o}} ℜGCo×nHoWo=ℜGCo×GCikhkw×ℜGCikhkw×nHoWo
DeformConv2d 支持分组卷积特性。将columns
的行切分成n_weight_grps
段。每一段会产生一个输出。
还原columns
的形状用于下一次计算。
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) {
out_buf[b][g] = out_buf[b][g]
.flatten(1)
.addmm_(weight_c[g].flatten(1), columns[g])
.view_as(out_buf[b][g]);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
将out_buf
调整为 N n × C o × n × H o × W o \mathrm{\frac{N}{n}\times C_o\times n \times H_o\times W_o} nN×Co×n×Ho×Wo 的形状,转置后为 N n × n × C o × H o × W o \mathrm{\frac{N}{n} \times n \times C_o \times H_o\times W_o} nN×n×Co×Ho×Wo
为何不选择调用 torch.transpose?
最后加上bias_c
得到卷积结果。
out_buf = out_buf.view(
{batch_sz / n_parallel_imgs,
out_channels,
n_parallel_imgs,
out_h,
out_w});
out_buf.transpose_(1, 2);
out.copy_(out_buf);
out = out.view({batch_sz, out_channels, out_h, out_w});
return out + bias_c.view({1, out_channels, 1, 1});
}
void deformable_im2col(
const at::Tensor& input,
const at::Tensor& data_offset,
const at::Tensor& data_mask,
int n_in_channels,
int height,
int width,
int weight_h,
int weight_w,
int pad_h,
int pad_w,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w,
int out_h,
int out_w,
int parallel_imgs,
int deformable_group,
bool use_mask,
at::Tensor data_col) {
at::cuda::CUDAGuard device_guard(input.get_device());
num_kernels
为任务总数。考虑到columns
是一个 C i k h k w × n H o W o \mathrm{C_i k_h k_w\times nH_o W_o} Cikhkw×nHoWo 二维矩阵,每个线程内需要处理 k h k w \mathrm{ k_h k_w} khkw 个元素。
GET_THREADS 根据设备是 HIP 还是 CUDA 返回 block 内的线程数。
GET_BLOCKS 返回 block 数量。
const int64_t num_kernels =
(int64_t)n_in_channels * out_h * out_w * parallel_imgs;
const unsigned int threads = GET_THREADS();
const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
如果函数输入或者输出的元素数量超过 int32 的表示范围,则使用 int64 索引。
显然输入的数量是远大于输出的。
// Checks if we should use 64bits indexing
// https://github.com/pytorch/vision/issues/4269
bool use_64bits_indexing = false;
// Checks if num_kernels or columns numel larger than 2 ** 31
use_64bits_indexing |= num_kernels > (1 << 31);
use_64bits_indexing |=
((int64_t)n_in_channels * weight_h * weight_w * parallel_imgs * out_h *
out_w >
(1 << 31));
调用 deformable_im2col_kernel 进行处理。
if (use_64bits_indexing) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "deformable_im2col", ([&] {
deformable_im2col_kernel<scalar_t, int64_t><<<blocks, threads>>>(
num_kernels,
input.data_ptr<scalar_t>(),
data_offset.data_ptr<scalar_t>(),
data_mask.data_ptr<scalar_t>(),
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
n_in_channels,
deformable_group,
out_h,
out_w,
use_mask,
data_col.data_ptr<scalar_t>());
}));
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "deformable_im2col", ([&] {
deformable_im2col_kernel<scalar_t, int><<<blocks, threads>>>(
num_kernels,
input.data_ptr<scalar_t>(),
data_offset.data_ptr<scalar_t>(),
data_mask.data_ptr<scalar_t>(),
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
n_in_channels,
deformable_group,
out_h,
out_w,
use_mask,
data_col.data_ptr<scalar_t>());
}));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
kernel 函数有两个模板参数,分别为数据类型和处理所用索引类型。
不同线程间输出、offset 和 mask 连续,可以访存合并。输入访问不连续,且可能冲突。
template <typename scalar_t, typename index_t>
__global__ void deformable_im2col_kernel(
index_t n,
const scalar_t* input_ptr,
const scalar_t* offset_ptr,
const scalar_t* mask_ptr,
index_t height,
index_t width,
index_t weight_h,
index_t weight_w,
index_t pad_h,
index_t pad_w,
index_t stride_h,
index_t stride_w,
index_t dilation_h,
index_t dilation_w,
index_t batch_sz,
index_t n_in_channels,
index_t n_offset_grps,
index_t out_h,
index_t out_w,
bool use_mask,
scalar_t* columns_ptr) {
CUDA_1D_KERNEL_LOOP_T 接受索引类型参数。
columns
是一个 C i k h k w × n H o W o \mathrm{C_i k_h k_w\times nH_o W_o} Cikhkw×nHoWo 二维矩阵。
从一维的index
中解算出in_c
、out_b
、out_y
和out_x
等变量。
out_c
对应 C i k h k w \mathrm{C_i k_h k_w} Cikhkw 。
CUDA_1D_KERNEL_LOOP_T(index, n, index_t) {
const index_t out_x = index % out_w;
const index_t out_y = (index / out_w) % out_h;
const index_t out_b = (index / (out_w * out_h)) % batch_sz;
const index_t in_c = index / (out_w * out_h * batch_sz);
const index_t out_c = in_c * weight_h * weight_w;
index_t c_per_offset_grp = n_in_channels / n_offset_grps;
const index_t grp_idx = in_c / c_per_offset_grp;
columns_ptr +=
(out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) +
out_y * out_w + out_x);
输入数据为 n × C i × H i × W i \mathrm{ n\times C_i \times H_i\times W_i} n×Ci×Hi×Wi 。
offset 形状为 N × 2 G Δ K h K w × H o × W o \mathrm{N\times 2G_{\Delta}K_h K_w \times H_o\times W_o} N×2GΔKhKw×Ho×Wo,mask 为 N × G Δ k h k w × H o × W o \mathrm{N\times G_{\Delta}k_h k_w \times H_o\times W_o} N×GΔkhkw×Ho×Wo
input_ptr +=
(out_b * (n_in_channels * height * width) + in_c * (height * width));
offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w *
out_h * out_w;
if (use_mask) {
mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w *
out_h * out_w;
}
线程内需要处理 k h k w \mathrm{ k_h k_w} khkw 个元素。
deform_conv2d 的卷积公式为:
y ( p ) = ∑ k = 1 K w k ⋅ x ( p + p k + Δ p k ) ⋅ Δ m k , y(p) = \sum_{k=1}^{K} w_k \cdot x(p+p_k+\Delta p_k)\cdot \Delta m_k, y(p)=k=1∑Kwk⋅x(p+pk+Δpk)⋅Δmk,
x
和y
为加上了偏移后的输入元素位置。 p = p 0 + p n + Δ p n \mathbf{p}=\mathbf{p}_0+\mathbf{p}_n+\Delta \mathbf{p}_n p=p0+pn+Δpn
bilinear_interpolate 插值合成规则输入。
for (int i = 0; i < weight_h; ++i) {
for (int j = 0; j < weight_w; ++j) {
const index_t mask_idx = i * weight_w + j;
const index_t offset_idx = 2 * mask_idx;
scalar_t mask_value = 1;
if (use_mask) {
mask_value =
mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x];
}
const scalar_t offset_h =
offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t offset_w = offset_ptr
[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t y =
(out_y * stride_h - pad_h) + i * dilation_h + offset_h;
const scalar_t x =
(out_x * stride_w - pad_w) + j * dilation_w + offset_w;
*columns_ptr =
mask_value * bilinear_interpolate(input_ptr, height, width, y, x);
columns_ptr += batch_sz * out_h * out_w;
}
}
}
}
x ( p ) = ∑ q G ( q , p ) ⋅ x ( q ) , \mathbf{x}(\mathbf{p})=\sum_\mathbf{q} G(\mathbf{q},\mathbf{p})\cdot \mathbf{x}(\mathbf{q}), x(p)=q∑G(q,p)⋅x(q),
超出边界返回0。
template <typename scalar_t, typename index_t>
__device__ scalar_t bilinear_interpolate(
const scalar_t* in,
index_t height,
index_t width,
scalar_t h,
scalar_t w) {
if (h <= -1 || height <= h || w <= -1 || width <= w) {
return 0;
}
f ( x , y 1 ) = x 2 − x x 2 − x 1 f ( Q 11 ) + x − x 1 x 2 − x 1 f ( Q 21 ) f ( x , y 2 ) = x 2 − x x 2 − x 1 f ( Q 12 ) + x − x 1 x 2 − x 1 f ( Q 22 ) \begin{aligned} f(x, y_1) &= \frac{x_2 -x}{x_2 - x_1}f(Q_{11}) + \frac{x -x_1}{x_2 - x_1}f(Q_{21}) \\ f(x, y_2) &= \frac{x_2 -x}{x_2 - x_1}f(Q_{12}) + \frac{x -x_1}{x_2 - x_1}f(Q_{22}) \end{aligned} f(x,y1)f(x,y2)=x2−x1x2−xf(Q11)+x2−x1x−x1f(Q21)=x2−x1x2−xf(Q12)+x2−x1x−x1f(Q22)
f ( x , y ) = y 2 − y y 2 − y 1 f ( x , y 1 ) + y − y 1 y 2 − y 1 f ( x , y 2 ) = 1 ( x 2 − x 1 ) ( y 2 − y 1 ) ( f ( Q 11 ) ( x 2 − x ) ( y 2 − y ) + f ( Q 21 ) ( x − x 1 ) ( y 2 − y ) + f ( Q 12 ) ( x 2 − x ) ( y − y 1 ) + f ( Q 22 ) ( x − x 1 ) ( y − y 1 ) ) = w 11 f ( Q 11 ) + w 12 f ( Q 12 ) + w 21 f ( Q 21 ) + w 22 f ( Q 22 ) \begin{aligned} f(x, y) &= \frac{y_2 -y}{y_2 - y_1}f(x, y_1) + \frac{y -y_1}{y_2 - y_1}f(x, y_2) \\ &= \frac{1}{(x_2 - x_1)(y_2 - y_1)}\left( f(Q_{11})(x_2 -x)(y_2 -y) + f(Q_{21})(x -x_1)(y_2 -y) + f(Q_{12})(x_2 -x)(y -y_1) + f(Q_{22})(x -x_1)(y -y_1) \right) \\ &=w_{11} f(Q_{11}) + w_{12} f(Q_{12}) + w_{21} f(Q_{21}) + w_{22} f(Q_{22}) \end{aligned} f(x,y)=y2−y1y2−yf(x,y1)+y2−y1y−y1f(x,y2)=(x2−x1)(y2−y1)1(f(Q11)(x2−x)(y2−y)+f(Q21)(x−x1)(y2−y)+f(Q12)(x2−x)(y−y1)+f(Q22)(x−x1)(y−y1))=w11f(Q11)+w12f(Q12)+w21f(Q21)+w22f(Q22)
得到周围4个点的像素索引。
index_t h_low = floor(h);
index_t w_low = floor(w);
index_t h_high = h_low + 1;
index_t w_high = w_low + 1;
计算插值系数。
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
取4个点。
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = in[h_low * width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = in[h_low * width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = in[h_high * width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = in[h_high * width + w_high];
双线性插值。
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}