【炼丹随记】样本不均衡时,class_weight 的计算

语义分割时,如果样本比例失衡,就需要设置 class_weight来平衡损失,那么该如何计算呢?
直观的想到是,先获取图片的每个类别的像素点的个数之间的比例,然后用1去除以。比如:
class1 : class2 : class3 = 100 : 10 : 1,那么 weight1 : weight2 : weight3 = 1:10:100。但这个比值偏差太大,放到loss中训练并不能得到一个好的结果。

OK的操作:
一个比较不错的计算方式:获取每个类别的像素的总值,数值大的类别应该有偏小的权重(注意这里是偏小,小一点点,不能说是小的离谱)。所以需要一个递减且 当自变量增长很大时因变量依然递减 且递减缓慢的函数,于是使用 1/log(x),在变量 x 非常大时,1/log(x) 符合递减且递减的很缓慢的特征。这样就可以得到合适的权重用于神经网络的训练。
方法来源于【https://github.com/openseg-group/OCNet.pytorch/issues/14】。

这个过程我们需要知道的信息:

  • 类别数为 class_num
  • 每个类别的的像素点的个数为 pixel_count:分别统计训练集中所有不同类别的像素个数
    计算的代码为:
import numpy as np

"""========权重的计算============"""
def get_weight(class_num, pixel_count):

   W = 1 / np.log(pixel_count)

   #除以np.sum(W),是将权重归一化,让每个类别的权重相加为1。
   #乘以class_num,是为了让权重中的每个类别的权重值接近1,使网络在一个正常的水平上进行训练。
   W = class_num * W/np.sum(W)  
   
   return W

if __name__ == "__main__":

   """========测试============"""
   base = 5000
   pixel_count = np.array([100, 10, 1])*base  
   W = get_weight(3, pixel_count)

   print(W)  # [0.79925325 0.96934441 1.23140234]

此时,class1 : class2 : class3 = 100 : 10 : 1,它们的权重设置为weight1 : weight2 : weight3 = 0.79925325 : 0.96934441 : 1.23140234。

那我们对图片完整的操作呢?
这里同时统计图片的均值、方差、权重。需要注意:

  • opencv读取出来的是 BGR,注意神经网络中的均值方差是 BGR or RGB ?
  • 我们准备数据时,如果图片尺寸大小不一,计算权重时,应将图片处理成同尺寸,然后再进行计算。我这里图像的尺寸不同但比例相同,所以使用了简单的resize。resize时要使用最近邻插值,不要给标签带来新的数值
  • 读取label.png时,确保是正确读取单通道图片,并且像素正确


代码如下:

from random import shuffle
import numpy as np
import os
import cv2

def get_weight(class_num, pixel_count):
   W = 1 / np.log(pixel_count)
   W = class_num * W / np.sum(W)
   return W

def get_MeanStdWeight(class_num=12, size=(640,360)):

   image_path = "../datasets/data/train/"
   label_path = "../datasets/label/train/"
   
   namelist = os.listdir(image_path)
   """========如果提供的是txt文本,保存的训练集中的namelist=============="""
   # file_name = "../datasets/train.txt"
   # with open(file_name,"r") as f:
   #     namelist = f.readlines()
   #     namelist = [file[:-1].split(",") for file in namelist]
   """==============================================================="""

   MEAN = []
   STD = []
   pixel_count = np.zeros((class_num,1))

   for i in range(len(namelist)):
       print(i, os.path.join(image_path, namelist[i]))

       image = cv2.imread(os.path.join(image_path, namelist[i]))[:,:,::-1]
       image = cv2.resize(image, size, interpolation=cv2.INTER_NEAREST)
       print(image.shape)

       mean = np.mean(image, axis=(0,1))
       std = np.std(image, axis=(0,1))
       MEAN.append(mean)
       STD.append(std)

       label = cv2.imread(os.path.join(label_path, namelist[i]), 0)
       label = cv2.resize(label, size, cv2.INTER_LINEAR)

       label_uni = np.unique(label)
       for m in label_uni:
           pixel_count[m] += np.sum(label == m)


   MEAN = np.mean(MEAN, axis=0) / 255.0
   STD = np.mean(STD, axis=0) / 255.0

   weight = get_weight(class_num, pixel_count.T)
   print(MEAN)
   print(STD)
   print(weight)

   return MEAN, STD, weight

你可能感兴趣的:(cnn知识,深度学习)