yolov5只训练数据集中的某几个类别

文章目录

  • 前言
  • 一、直接修改数据集标签
  • 二、修改加载labels的代码
    • 1.train
    • 2.create_dataloader
    • 3.LoadImagesAndLabels
    • 4.cache_labels
    • 5.verify_image_label
  • 总结


前言

提示:在训练网络过程中,我们找到的公开数据集可能有很多分类,但是我们的检测任务又不需要那么多,或者说是对自己的训练集做一个取舍:

例如:一个训练集有猫和狗,但是我不想训练猫了,只想训练狗,所以就只加载狗的标签。


基本思路:只训练某几类标签的话,那就需要修改dataset中的labels,本文提供两种思路

一、直接修改数据集标签

通过直接修改数据集标签(*.txt)来删去某种类别的数据。

这种方法很直接,但是也意味着你多了一个整个数据集文件,虽然内存不大,但是感觉比较呆。

二、修改加载labels的代码

数据集labels在加载进dataloader过程中本身就有某些处理过程(如检验是否为空),我们可以在上面加些筛选条件就可以做到过滤效果。

1.train

在train.py文件下找到加载数据集的代码,如:

# Trainloader
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
                                          hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
                                          workers=workers, image_weights=opt.image_weights, quad=opt.quad,
                                          prefix=colorstr('train: '))

然后我们进入create_dataloader继续跟踪

2.create_dataloader

找到加载数据集LoadImagesAndLabels:

dataset = LoadImagesAndLabels(path, imgsz, batch_size,
                                      augment=augment,  # augment images
                                      hyp=hyp,  # augmentation hyperparameters
                                      rect=rect,  # rectangular training
                                      cache_images=cache,
                                      single_cls=single_cls,
                                      stride=int(stride),
                                      pad=pad,
                                      image_weights=image_weights,
                                      prefix=prefix)

3.LoadImagesAndLabels

其中,下面这一段代码是加载cache缓存文件,这里不细说,就把它简单看成数据集文件。如果cache已存在,就直接加载,不存在才创建(所以我们要在数据集的文件夹下把cache文件删掉!把cache文件删掉!!把cache文件删掉!!!),我们需要进入创建部分cache_labels
在这里插入图片描述

try:
   cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # load dict
   assert cache['version'] == self.cache_version  # same version
   assert cache['hash'] == get_hash(self.label_files + self.img_files)  # same hash
except:
    cache, exists = self.cache_labels(cache_path, prefix), False  # cache

4.cache_labels

这个部分就是处理数据集的信息统计(如是否为空等),其中一段遍历整个数据集的代码

pbar = tqdm(pool.imap(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
                        desc=desc, total=len(self.img_files))

这段代码含义大致就是将img_files, label_files, prefix打包丢进verify_image_label函数中处理后返回

5.verify_image_label

这段函数就是我们的最终目标了,这里面有加载图片,标签的功能,还可以进行一定筛选,我们就从这里修改。找到加载labels的代码:

withopen(lb_file) as f:
	l = [x.split() for x in f.read().strip().splitlines() if len(x)]

这段代码就是将labels的内容加载进列表l中,如这里有个label文件
yolov5只训练数据集中的某几个类别_第1张图片
有类别6、7,通过代码加载进去就是
yolov5只训练数据集中的某几个类别_第2张图片
list L 中有两个list,代表两个目标,每个list第一位就是类别。这个时候效果就很明显了,如果我们不想要类别6,我们只需要修改成

withopen(lb_file) as f:
	l = [x.split() for x in f.read().strip().splitlines() if len(x) and x[0]!='6']

就行了,最后效果为

思路就是这样,还有些其他的修改方法根据自己的需要再操作,内核就是对list的处理而已,基本功。

总结

上面都是我在做项目过程中遇到的问题,而且在csdn上没找到详细的解答,于是自己动手解决并分享。

你可能感兴趣的:(python,深度学习,pytorch)