class FurthestPointSampling(Function):
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
Uses iterative furthest point sampling to select a set of npoint features that have the largest
minimum distance
:param ctx:
:param xyz: (B, N, 3) where N > npoint
:param npoint: int, number of features in the sampled set
output: (B, npoint) tensor containing the set
assert xyz.is_contiguous()
B, N, _ = xyz.size()
output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output
def backward(xyz, a=None):
return None, None
furthest_point_sample = FurthestPointSampling.apply
int furthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>();
cudaStream_t stream = THCState_getCurrentStream(state);
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
return 1;
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs, cudaStream_t stream) {
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
cudaError_t err;
unsigned int n_threads = opt_n_threads(n); //计算线程数量,最大为1024
switch (n_threads) {
case 1024:
furthest_point_sampling_kernel<1024><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 512:
furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 256:
furthest_point_sampling_kernel<256><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 128:
furthest_point_sampling_kernel<128><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 64:
furthest_point_sampling_kernel<64><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 32:
furthest_point_sampling_kernel<32><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 16:
furthest_point_sampling_kernel<16><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 8:
furthest_point_sampling_kernel<8><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 4:
furthest_point_sampling_kernel<4><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 2:
furthest_point_sampling_kernel<2><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
case 1:
furthest_point_sampling_kernel<1><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
// block_size就是对应kernel_launcher函数中的<1024>这个
template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
if (m <= 0) return;
// 开两个共享内存,dists储存每个线程找到的最远的dists,dists_i储存对应的下标
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int batch_index = blockIdx.x;
// 开的block的数量等于batch,一个block处理一个batch
// dataset、temp、idxs这些都是指针,加上batch_index就是为了使得指针指向当前block要处理的batch
dataset += batch_index * n * 3;
temp += batch_index * n;
idxs += batch_index * m;
int tid = threadIdx.x;
const int stride = block_size;
int old = 0;
// FPS总会找到第一个点,就用threadIdx.x=0这个线程处理一下。
if (threadIdx.x == 0)
idxs[0] = old;
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
// 把上一次找出的点的坐标拿出来
float x1 = dataset[old * 3 + 0];
float y1 = dataset[old * 3 + 1];
float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
// 利用多个线程加速,每个线程处理n/k个点
float x2, y2, z2;
x2 = dataset[k * 3 + 0];
y2 = dataset[k * 3 + 1];
z2 = dataset[k * 3 + 2];
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
// temp大小是[B, N],维护的是每个原始点到已经所有已经选到的点的最小距离
float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
dists[tid] = best;
dists_i[tid] = besti;
// 以下为找到dists中最大的点
if (block_size >= 1024) {
if (tid < 512) {
__update(dists, dists_i, tid, tid + 512);
if (block_size >= 512) {
if (tid < 256) {
__update(dists, dists_i, tid, tid + 256);
if (block_size >= 256) {
if (tid < 128) {
__update(dists, dists_i, tid, tid + 128);
if (block_size >= 128) {
if (tid < 64) {
__update(dists, dists_i, tid, tid + 64);
if (block_size >= 64) {
if (tid < 32) {
__update(dists, dists_i, tid, tid + 32);
if (block_size >= 32) {
if (tid < 16) {
__update(dists, dists_i, tid, tid + 16);
if (block_size >= 16) {
if (tid < 8) {
__update(dists, dists_i, tid, tid + 8);
if (block_size >= 8) {
if (tid < 4) {
__update(dists, dists_i, tid, tid + 4);
if (block_size >= 4) {
if (tid < 2) {
__update(dists, dists_i, tid, tid + 2);
if (block_size >= 2) {
if (tid < 1) {
__update(dists, dists_i, tid, tid + 1);
// 找到dist最大的一个,作为本次循环选出的点
old = dists_i[0];
if (tid == 0)
idxs[j] = old;