keras fit函数的class_weight参数

参考 https://datascience.stackexchange.com/questions/13490/how-to-set-class-weights-for-imbalanced-classes-in-keras

注意class_weight 一定是一个字典, 不然虽然不会报错,但是是没有效果的,loss完全不会变。

sklearn 直接提供了一个函数来计算类别权重:

# 计算类别权重
my_class_weight = class_weight.compute_class_weight('balanced'
                                      ,np.unique(train_Y)
                                      ,train_Y).tolist()
# 需要转成字典
class_weight_dict = dict(zip([x for x in np.unique(train_Y)], my_class_weight))

但是返回值也需要转换为字典。

你可能感兴趣的:(Keras)