【YOLO】yolov5的训练策略3 -- 图像加权image weights

目录

  • 一、什么是训练图像加权
  • 二、代码

一、什么是训练图像加权

根据样本种类分布使用图像调用频率不同的方法解决。
1、读取训练样本中的GT,保存为一个列表;
2、计算训练样本列表中不同类别个数,然后给每个类别按相应目标框数的倒数赋值,数目越多的种类权重越小,形成按种类的分布直方图;
3、对于训练数据列表,训练时按照类别权重筛选出每类的图像作为训练数据。使用random.choice(population, weights=None, *, cum_weights=None, k=1)更改训练图像索引,可达到样本均衡的效果。

二、代码

计算类别的权重:

def labels_to_class_weights(labels, nc=80):
    """
    计算类别权重
    Get class weights (inverse frequency) from training labels
    输出入: 
        labels -- 真实的标签列表 [class xywh]
        nc -- 类别数
    """
    if labels[0] is None:  # no labels loaded
        return torch.Tensor()

    labels = np.concatenate(labels, 0)  # labels.shape = (866643, 5) for COCO
    classes = labels[:, 0].astype(int)  # labels = [class xywh]
    # 直方图统计
    weights = np.bincount(classes, minlength=nc)  # occurrences per class

    # Prepend gridpoint count (for uCE training)
    # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum()  # gridpoints per image
    # weights = np.hstack([gpi * len(labels)  - weights.sum() * 9, weights * 9]) ** 0.5  # prepend gridpoints to start
    
    # 统计的倒数
    weights[weights == 0] = 1   # replace empty bins with 1
    weights = 1 / weights       # number of targets per class
    # 归一化为0~1,就是权重
    weights /= weights.sum()    # normalize
    return torch.from_numpy(weights).float()

计算图片权重

def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
    """
    计算图片的权重
    # Produces image weights based on class_weights and image contents
    # Usage: index = random.choices(range(n), weights=image_weights, k=1)  # weighted image sample
    """
    class_counts = []
    for x in labels:
        class_bin = np.bincount(x[:, 0].astype(int), minlength=nc)
        class_counts.append(class_bin)
    class_counts = np.array(class_counts)
    # 类别权重 * 每张图片中类别数量, 然后在sum
    return (class_weights.reshape(1, nc) * class_counts).sum(1)

更新数据集中的图片索引, 使得权重高的图片出现概率大:

# epoch ------------------------------------------------------------------
for epoch in range(self.start_epoch, self.epochs):
    # Update image weights 在每次迭代前,更新数据集的indices,使得权重高的图片出现概率大
    if self.opt["image_weights"]:
        cw = class_weights * (1 - maps) ** 2 / self.nc  # class weights
        iw = labels_to_image_weights(dataset.labels, nc=self.nc, class_weights=cw)      # image weights
        dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n)     # rand weighted idx

点赞、收藏、关注、评论

你可能感兴趣的:(深度学习从入门到实战,YOLO)