关于sklearn下class_weight参数的一点源码阅读与测试

版权声明:欢迎转载,请注明原出处 https://blog.csdn.net/go_og/article/details/81281387

一直没有很在意过sklearn的class_weight的这个参数的具体作用细节,只大致了解是是用于处理样本不均衡。后来在简书上阅读svm松弛变量的一些推导的时候,看到样本不均衡的带来的问题时候,想更深层次的看一下class_weight的具体作用方式,

svm松弛变量的简书链接:https://www.jianshu.com/p/8a499171baa9

该文中的样本不均衡的描述:

“样本偏斜是指数据集中正负类样本数量不均,比如正类样本有10000个,负类样本只有100个,这就可能使得超平面被“推向”负类(因为负类数量少,分布得不够广),影响结果的准确性。” 

随后翻开sklearn LR的源码:

我们以分类作为说明重点

在输入参数class_weight=‘balanced’的时候:

 
  1. # compute the class weights for the entire dataset y

  2. if class_weight == "balanced":

  3. class_weight = compute_class_weight(class_weight,

  4. np.arange(len(self.classes_)),

  5. y)

  6. class_weight = dict(enumerate(class_weight))

进一步阅读 compute_class_weight这个函数:

 
  1. elif class_weight == 'balanced':

  2. # Find the weight of each class as present in y.

  3. le = LabelEncoder()

  4. y_ind = le.fit_transform(y)

  5. if not all(np.in1d(classes, le.classes_)):

  6. raise ValueError("classes should have valid labels that are in y")

  7.  
  8. recip_freq = len(y) / (len(le.classes_) *

  9. np.bincount(y_ind).astype(np.float64))

  10. weight = recip_freq[le.transform(classes)]

compute_class_weight这个函数的作用是对于输入的样本,平衡类别之间的权重,下面写段测试代码测试这个函数:

 
  1. # coding:utf-8

  2.  
  3. from sklearn.utils.class_weight import compute_class_weight

  4.  
  5. class_weight = 'balanced'

  6. label = [0] * 9 + [1]*1 + [2, 2]

  7. print(label) # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2]

  8. classes=[0, 1, 2]

  9. weight = compute_class_weight(class_weight, classes, label)

  10. print(weight) #[ 0.44444444 4. 2. ]

  11. print(.44444444 * 9) # 3.99999996

  12. print(4 * 1) # 4

  13. print(2 * 2) # 4

如上图所示,可以看到这个函数把样本的平衡后的权重乘积为4,每个类别均如此。

关于class_weight与sample_weight在损失函数上的具体计算方式:

 
  1. sample_weight *= class_weight_[le.fit_transform(y_bin)] # sample_weight 与 class_weight相乘

  2.  
  3. # Logistic loss is the negative of the log of the logistic function.

  4. out = -np.sum(sample_weight * log_logistic(yz)) + .5 * alpha * np.dot(w, w)

上述可以看出对于每个样本,计算的损失函数乘上对应的sample_weight来计算最终的损失。这样计算而来的损失函数不会因为样本不平衡而被“推向”样本量偏少的类别中。

class_weight以及sample_weight并没有进行不平衡数据的处理,比如,上下采样。详细参见SMOTE EasyEnsemble等。

--------------------- 本文来自 摸摸小松鼠宝宝 的CSDN 博客 ,全文地址请点击:https://blog.csdn.net/go_og/article/details/81281387?utm_source=copy 

你可能感兴趣的:(sklearn,class_weight,机器学习,sklearn)