focal loss理解

(1) focal loss

<1> focal loss的函数形式为:

                   (1)

 其中,zk为softmax的输入,f(zk)为softmax的输出,-log(f(zk))为softmaxloss, alpha和gamma为focal loss超参。

<2> focal loss对其输入zj求导:

根据链式法则有:(2)

下面分别对(2)式中的两项求导:

第一项:

                     focal loss理解_第1张图片(3)

第二项:

                    focal loss理解_第2张图片(4)

(3)(4)项合并有(加上(1)式中的负号):

 (5)

(5)式中黑色框为softmax的loss反传公式。

 (2) focal_loss_layer.cpp中的Forward_cpu函数:

 1 Dtype loss = 0;
 2   for (int i = 0; i < outer_num_; ++i) {
 3     for (int j = 0; j < inner_num_; j++) {
 4       const int label_value = static_cast<int>(label[i * inner_num_ + j]);
 5       if (has_ignore_label_ && label_value == ignore_label_) {
 6         continue;
 7       }
 8       DCHECK_GE(label_value, 0);
 9       DCHECK_LT(label_value, channels);
10       const int index = i * dim + label_value * inner_num_ + j;
//power_prob_ blob就是公式(1)中的第一项,log_prob_是第二项 14 loss -= power_prob_data[index] * log_prob_data[index]; 15 ++count; 16 } 17 }

 

 (3) focal_loss_layer.cpp中的Backward_cpu函数:

 1     for (int i = 0; i < outer_num_; ++i) {
 2       for (int j = 0; j < inner_num_; ++j) {
 3         // label
 4         const int label_value = static_cast<int>(label[i * inner_num_ + j]);
 5         
 6         // ignore label
 7         if (has_ignore_label_ && label_value == ignore_label_) {
 8           for (int c = 0; c < channels; ++c) {
 9             bottom_diff[i * dim + c * inner_num_ + j] = 0;
10           }
11           continue;
12         }
//对于每个样本的channel,ind_i是label索引对应的channel中元素的值
13 int ind_i = i * dim + label_value * inner_num_ + j;
//grad就是(5)式中的第一项,prob_ blob为softmax的输出,log_prob_为对其求log 14 Dtype grad = 0 - gamma_ * (power_prob_data[ind_i] / std::max(1 - prob_data[ind_i], eps)) 15 * log_prob_data[ind_i] * prob_data[ind_i] 16 + power_prob_data[ind_i]; 17 for (int c = 0; c < channels; ++c) { 18 int ind_j = i * dim + c * inner_num_ + j; 19 if(c == label_value) { 20 CHECK_EQ(ind_i, ind_j);
//对应公式(5) 21 bottom_diff[ind_j] = grad * (prob_data[ind_i] - 1); 22 } else { 23 bottom_diff[ind_j] = grad * prob_data[ind_j]; 24 } 25 } 26 ++count; 27 } 28 }

 

你可能感兴趣的:(focal loss理解)