提示:在训练网络过程中,我们找到的公开数据集可能有很多分类,但是我们的检测任务又不需要那么多,或者说是对自己的训练集做一个取舍:
例如:一个训练集有猫和狗,但是我不想训练猫了,只想训练狗,所以就只加载狗的标签。
基本思路:只训练某几类标签的话,那就需要修改dataset中的labels,本文提供两种思路
通过直接修改数据集标签(*.txt)来删去某种类别的数据。
这种方法很直接,但是也意味着你多了一个整个数据集文件,虽然内存不大,但是感觉比较呆。
数据集labels在加载进dataloader过程中本身就有某些处理过程(如检验是否为空),我们可以在上面加些筛选条件就可以做到过滤效果。
在train.py文件下找到加载数据集的代码,如:
# Trainloader
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
prefix=colorstr('train: '))
然后我们进入create_dataloader继续跟踪
找到加载数据集LoadImagesAndLabels:
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
single_cls=single_cls,
stride=int(stride),
pad=pad,
image_weights=image_weights,
prefix=prefix)
其中,下面这一段代码是加载cache缓存文件,这里不细说,就把它简单看成数据集文件。如果cache已存在,就直接加载,不存在才创建(所以我们要在数据集的文件夹下把cache文件删掉!把cache文件删掉!!把cache文件删掉!!!),我们需要进入创建部分cache_labels
try:
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache['version'] == self.cache_version # same version
assert cache['hash'] == get_hash(self.label_files + self.img_files) # same hash
except:
cache, exists = self.cache_labels(cache_path, prefix), False # cache
这个部分就是处理数据集的信息统计(如是否为空等),其中一段遍历整个数据集的代码
pbar = tqdm(pool.imap(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
desc=desc, total=len(self.img_files))
这段代码含义大致就是将img_files, label_files, prefix打包丢进verify_image_label函数中处理后返回
这段函数就是我们的最终目标了,这里面有加载图片,标签的功能,还可以进行一定筛选,我们就从这里修改。找到加载labels的代码:
withopen(lb_file) as f:
l = [x.split() for x in f.read().strip().splitlines() if len(x)]
这段代码就是将labels的内容加载进列表l中,如这里有个label文件
有类别6、7,通过代码加载进去就是
list L 中有两个list,代表两个目标,每个list第一位就是类别。这个时候效果就很明显了,如果我们不想要类别6,我们只需要修改成
withopen(lb_file) as f:
l = [x.split() for x in f.read().strip().splitlines() if len(x) and x[0]!='6']
就行了,最后效果为
思路就是这样,还有些其他的修改方法根据自己的需要再操作,内核就是对list的处理而已,基本功。
上面都是我在做项目过程中遇到的问题,而且在csdn上没找到详细的解答,于是自己动手解决并分享。