参考自:mmdetection源码阅读:cuda拓展之focal loss - 知乎
读者需要大致了解CUDA编程及损失函数原理;本文不做详细介绍
图片来自上述参考文献(侵删),红色文字是我加的注释;
这样说可能还是模模糊糊,下面会详细讲解;
注:本节代码内容来自开头参考文章
step1:python端调用(源代码在mmdetection工具包)
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
# sigmoid_focal_loss其实就是SigmoidFocalLossFunction的forward方法
class FocalLoss(nn.module):
def forward(self,
pred, # tensor(num_total_anchors, num_classes)
target, # tensor(num_total_anchors, )
):
if if torch.cuda.is_available() and pred.is_cuda:
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma,alpha, None, 'none')
# 注意要经过contiguous确保内存连续存储!这样在cuda核函数访问连续内存就不会出错
return loss
step2:autograd.Function的使用,用来指定前后向计算方法(为什么:我们用cuda定义了一个算子,封装之后需要用autograd让torch知道怎么来做前后向计算)
class SigmoidFocalLossFunction(Function):
@staticmethod
def forward(ctx,
input,
target,
gamma=2.0,
alpha=0.25,
weight=None,
reduction='mean'):
# 存储reduction_dict、gamma等到ctx,以便反传backward调用
ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
assert reduction in ctx.reduction_dict.keys()
ctx.gamma = float(gamma)
ctx.alpha = float(alpha)
ctx.reduction = ctx.reduction_dict[reduction]
output = input.new_zeros(input.size()) # 开辟output的空间
# 调用真正的cuda拓展:这里ext_module即被用来绑定cuda版本的代码的
ext_module.sigmoid_focal_loss_forward(input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha)
if ctx.reduction == ctx.reduction_dict['mean']:
output = output.sum() / input.size(0)
elif ctx.reduction == ctx.reduction_dict['sum']:
output = output.sum()
ctx.save_for_backward(input, target, weight) # 保存变量供反向计算使用
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, target, weight = ctx.saved_tensors
grad_input = input.new_zeros(input.size())
# 调用真正的cuda拓展
ext_module.sigmoid_focal_loss_backward(input, target,weight, grad_input, gamma=ctx.gamma, alpha=ctx.alpha)
grad_input *= grad_output
if ctx.reduction == ctx.reduction_dict['mean']:
grad_input /= input.size(0)
return grad_input, None, None, None, None, None
# 定义sigmoid_focal_loss为SigmoidFocalLossFunction.apply,apply方法会调用forward
sigmoid_focal_loss = SigmoidFocalLossFunction.apply
step3:使用pybind绑定python-cuda(c++),以实现调用python的module来调用cuda的函数;
sigmoid_focal_loss_forward
函数将会调用SigmoidFocalLossForwardCUDAKernelLauncher
函数;见下文PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sigmoid_focal_loss_forward", &sigmoid_focal_loss_forward,
"sigmoid_focal_loss_forward ", py::arg("input"), py::arg("target"),
py::arg("weight"), py::arg("output"), py::arg("gamma"),
py::arg("alpha"));
}
# 同理sigmoid_focal_loss_backward略
step4:SigmoidFocalLossForwardCUDAKernelLauncher,根据各种情况设置CUDA线程的组织形式以及其他一些基本设置(可以理解为准备好环境),然后调用真正的CUDA kernel进行核心部分计算(这里就是focal_loss的计算了,真正干活的)
.data_ptr()
模板成员函数,将返回tensor的连续存储的首地址,并且转换为scalar_t *
的指针类型。我们知道tensor在内存中真正存放的是一维连续数组!tensor(B,C,H,W)在内存中真正存放的是长为B*C*H*W的连续数组;data_ptr()返回的就是这个连续数组的首地址; void SigmoidFocalLossForwardCUDAKernelLauncher(Tensor input, Tensor target,
Tensor weight, Tensor output,
const float gamma,
const float alpha) {
// input为tensor(num_total_anchors, num_classes)
// output为tensor(num_total_anchors, num_classes)
// target为tensor(num_total_anchors, ),0 ~ num_class-1表示正样本对应的类别,num_class值表示负样本和忽略样本
int output_size = output.numel(); //等于num_total_anchors*num_classes
int num_classes = input.size(1);
AT_ASSERTM(target.max().item() <= (long)num_classes,
"target label should smaller or equal than num classes");
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "sigmoid_focal_loss_forward_cuda_kernel", [&] {
sigmoid_focal_loss_forward_cuda_kernel
//sigmoid_focal_loss_forward_cuda_kernel是cuda核函数,被定义成模板函数,通过确定数据类型
<<>>(
output_size, input.data_ptr(),
target.data_ptr(), weight.data_ptr(),
output.data_ptr(), gamma, alpha, num_classes);
});
AT_CUDA_CHECK(cudaGetLastError());
}
//backward略
// 以下是CUDA计算需要设定的线程组织形式,不熟悉的请查阅CUDA编程
#define THREADS_PER_BLOCK 1024、128、512
inline int GET_BLOCKS(const int N) {
int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
int max_block_num = 65000;
return min(optimal_block_num, max_block_num);
}
// 获得开辟的线程块的x维,使得blockDim.x * gridDim.x要等于N
step5:kernel实现
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
// blockDim.x * gridDim.x就是当前开辟的线程总数
template
__global__ void sigmoid_focal_loss_forward_cuda_kernel(
const int nthreads, const T* input, const int64_t* target, const T* weight,
T* output, const T gamma, const T alpha, const int num_classes) {
// nthreads就是outputsize,等于num_total_anchors*num_classes
// const T*就是tensor的连续内存首地址
// input的连续内存长度为num_total_anchors*num_classes,target的连续内存长度为num_total_anchors。
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// index等于blockIdx.x * blockDim.x + threadIdx.x,即线程索引
// 因为index就是对应tensor(num_total_anchors,num_classes)的一个元素
int n = index / num_classes; // 所以n就是该元素对应的anchor
int c = index % num_classes; // 所以c就是该元素对应的class
int64_t t = target[n]; // 获得anchor n 的target label
T flag_p = (t == c); // 表示正样本
T flag_n = (t != c); // 表示负样本
// p = sigmoid(x) = 1. / 1. + expf(-x)
T p = (T)1. / ((T)1. + expf(-input[index]));
// (1 - p)**gamma * log(p) 正样本的focal loss权重
T term_p = pow(((T)1. - p), gamma) * log(max(p, (T)FLT_MIN));
// p**gamma * log(1 - p) 负样本的focal loss权重
T term_n = pow(p, gamma) * log(max((T)1. - p, (T)FLT_MIN));
output[index] = (T)0.; //计算结果放到output tensor中
output[index] += -flag_p * alpha * term_p;
output[index] += -flag_n * ((T)1. - alpha) * term_n;
if (weight != NULL) {
output[index] *= weight[t];
}
}
}
//同理,反向传播
template
__global__ void sigmoid_focal_loss_backward_cuda_kernel(
const int nthreads, const T* input, const int64_t* target, const T* weight,
T* grad_input, const T gamma, const T alpha, const int num_classes) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int n = index / num_classes;
int c = index % num_classes;
int64_t t = target[n];
T flag_p = (t == c);
T flag_n = (t != c);
// p = sigmoid(x) = 1. / 1. + expf(-x)
T p = (T)1. / ((T)1. + exp(-input[index]));
// (1 - p)**gamma * (1 - p - gamma*p*log(p))
T term_p = pow(((T)1. - p), gamma) *
((T)1. - p - (gamma * p * log(max(p, (T)FLT_MIN))));
// p**gamma * (gamma * (1 - p) * log(1 - p) - p)
T term_n = pow(p, gamma) *
(gamma * ((T)1. - p) * log(max((T)1. - p, (T)FLT_MIN)) - p);
grad_input[index] = (T)0.;
grad_input[index] += -flag_p * alpha * term_p;
grad_input[index] += -flag_n * ((T)1. - alpha) * term_n;
if (weight != NULL) {
grad_input[index] *= weight[t];
}
}
}