卷积计算是深度学习模型的常见算子,在3D项目中,比如点云分割,由于点云数据是稀疏的,使用常规的卷积计算,将会加大卷积计算时间,不利于模型推理加速。由此SECOND网络提出了稀疏卷积的概念。
稀疏卷积的主要理念就是由正常的全部数据进行卷积运算,优化了为只计算有效的输入点的卷积结果。稀疏卷积的思路网上已经有很多简明扼要的文章,比如知乎的这一篇就很清晰,本文就是根据这一篇的思路实现的一个简单的稀疏卷积流程。建议先看一下先了解。
稀疏卷积的输入是有效输入点的索引坐标(哈希表)和对应的features值,大概流程是:
1,根据输入坐标得到输出点的索引坐标(哈希表)。每一个输入点,可以最多和kernel个点(比如3d卷积,kernel=3,则kernel点个数是3*3*3=27)相乘,得到kernel点个数的输出坐标。所以rulebook可以建立成kenel点个数的字典,每个kenel对应一个或多个输入点索引和输出点索引。
2,将输入点和对应kernel点进行矩阵乘法,得到卷积结果。
3,将同一个输出点坐标的卷积结果进行累加,根据输出点索引与真实坐标的关系,将结果还原到输出位置,即完成了稀疏卷积运算。
下面是实现的一个简单示例代码,其中稀疏卷积结果和普通卷积结果进行了对比,误差为0。
输入坐标和输出点坐标的映射关系,是遍历每个输出点的坐标,根据输出点坐标,kernel,stride可以得到相关的kernel点个输出点的坐标,如果在有点输出点列表里面,则表示这是一个有效输出点,更新输出点索引哈希表和rulebook字典。
这种方法的时间复杂度较大,需要遍历所有输出点,后面有优化方案,直接有公式计算输入点对应的输出点坐标。但是可以大概看一下整体流程。
# -*- coding: utf-8 -*-
import time
import torch
import torch.nn as nn
import itertools
import numpy as np
def generate_sparse_data(shape,
num_points,
num_channels,
integer=False,
data_range=(-1, 1),
with_dense=True,
dtype=np.float32):
dense_shape = shape
ndim = len(dense_shape)
num_points = np.array(num_points)
batch_size = len(num_points)
batch_indices = []
coors_total = np.stack(np.meshgrid(*[np.arange(0, s) for s in shape]),
axis=-1)
coors_total = coors_total.reshape(-1, ndim)
for i in range(batch_size):
np.random.shuffle(coors_total)
inds_total = coors_total[:num_points[i]]
inds_total = np.pad(inds_total, ((0, 0), (0, 1)),
mode="constant",
constant_values=i)
batch_indices.append(inds_total)
if integer:
sparse_data = np.random.randint(data_range[0],
data_range[1],
size=[num_points.sum(),
num_channels]).astype(dtype)
else:
sparse_data = np.random.uniform(data_range[0],
data_range[1],
size=[num_points.sum(),
num_channels]).astype(dtype)
res = {
"features": sparse_data.astype(dtype),
}
if with_dense:
dense_data = np.zeros([batch_size, num_channels, *dense_shape],
dtype=sparse_data.dtype)
start = 0
for i, inds in enumerate(batch_indices):
for j, ind in enumerate(inds):
dense_slice = (i, slice(None), *ind[:-1])
dense_data[dense_slice] = sparse_data[start + j]
start += len(inds)
res["features_dense"] = dense_data.astype(dtype)
batch_indices = np.concatenate(batch_indices, axis=0)
res["indices"] = batch_indices.astype(np.int32)
return res
def get_Pin2Pout_Rulebook_3d(n,ho, wo,do,ks,stride, in_indice):
'''
根据有效的输入点位置,得到有效的输出点位置,并建立kernel, in_idx, out_indx字典关系。
in_indice:有效点的坐标 [[hi,wi,ni],[hi1,wi1,ni1],...]
return:
offset, {k0:[[pin_idx, pout_idx],...], k2:[[pin_idx, pout_idx],...]}
pout_indice, same to in_indice
'''
offset = {i: [] for i in range(ks**3)}
pout_indice = []
out_count = 0
for b, i, j, d in itertools.product(range(n), range(ho), range(wo), range(do)):
flag = False
for kh, kw, kd in itertools.product(range(ks),range(ks),range(ks)):
if [stride*i + kh, stride*j + kw,stride*d+kd,b] in in_indice:
flag = True
offset[kh*ks*ks+kw*ks+kd].append(
[in_indice.index([ stride*i + kh, stride*j + kw,stride*d+kd,b]), out_count]) # [in_index,out_index]
if flag == True:
pout_indice.append([b, i, j,d])
out_count += 1
return offset, pout_indice
def get_output_3d(rulebook,in_data,weight_data,out_indice,out_data):
'''
遍历每一个kernel, 通过查找pin_idx和对应的kernel, 矩阵乘得到pout的值,并放回位置。
同一个pout结果累加
'''
for key in rulebook.keys():
cur_book=rulebook[key]
w_data=weight_data[key]
for i in range(len(cur_book)):
x=in_data[cur_book[i][0],:]
n,ho,wo,do=out_indice[cur_book[i][1]]
out_data[n,:,ho,wo,do]+=np.matmul(x,w_data)
return out_data
def test_conv3d(sparse_dict,ci,co,kernel,stride):
features=sparse_dict['features']
features_dense=sparse_dict['features_dense']
in_indices=sparse_dict['indices'] #
conv3d=nn.Conv3d(ci,co, kernel,stride=stride, bias=False)
weight = conv3d.weight.detach().numpy() # co,ci,kh,kw
weight = weight.reshape(co, ci, kernel ** 3).transpose(2, 1, 0)
ref_out=conv3d(torch.tensor(features_dense))
bs,co,ho,wo,do=ref_out.shape
spconv_out=np.zeros([bs,co,ho,wo,do])
rulebook,pout_indice=get_Pin2Pout_Rulebook_3d(bs,ho,wo,do,kernel,stride, in_indices.tolist())
spconv_out=get_output_3d(rulebook,features,weight,pout_indice,spconv_out)
dif=np.abs(ref_out.detach().numpy()-spconv_out)
print('max diff is:',round(np.max(dif),4))
print('sparse conv3d test over')
return spconv_out
if __name__ =="__main__":
shapes=(9,19,18) # conv3d:(h,w,d)
bs=1 #batch_size
ks=3 #kernel_size
stride=2
ci=7
co=32
num_points = [100] * bs # 100个有效点个数
sparse_dict=generate_sparse_data(shapes,num_points,ci)
test_conv3d(sparse_dict, ci, co, ks,stride)
备注
该示例代码默认无padding,可以任意定义输入shapes, 其中generate_sparse_data是spconv的github代码里面给产生的稀疏数据代码。