代码来自:https://github.com/eriklindernoren/PyTorch-YOLOv3
最近在学习yolov3,学长布置了任务,让用yolov3来训练百度识虫的数据集,写这篇文章总结一下最近的学习心得,写一写我在学习过程中踩的坑
数据集地址:https://aistudio.baidu.com/aistudio/datasetdetail/19638
(该数据集使用的是VOC存储方式,由2183张jpeg格式图片构成,其中训练集1693张,验证集245,测试集245张,共有六种昆虫,label数据为xml格式)
既然要训练模型就要从train.py以及datasets.py入手,首先看train.py里面的设置部分
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=10, help="number of epochs")
parser.add_argument("--batch_size", type=int, default=1, help="size of each image batch")
parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
parser.add_argument("--data_config", type=str, default="config/my_data.data", help="path to data config file")
parser.add_argument("--pretrained_weights", type=str, help="if specified starts from checkpoint model")
parser.add_argument("--n_cpu", type=int, default=4, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")
parser.add_argument("--evaluation_interval", type=int, default=1, help="interval evaluations on validation set")
parser.add_argument("--compute_map", default=False, help="if True computes mAP every tenth batch")
parser.add_argument("--multiscale_training", default=True, help="allow for multi-scale training")
opt = parser.parse_args()
print(opt)
上面的代码,一部分是对网络参数的调整,这部分我们先不管,要训练我们拿到的数据集,先看看data_config,可见该指向的是一个.data文件,让我们看看这个文件里面是啥(如下图)
上图中,
classes是数据集一共有几类目标,该昆虫数据集共有六种昆虫,所以我们将其设为6
train和valid分别指向一个txt文件(这两个文件一会详细介绍)
names里存的是每一种昆虫的学名即label的名称
我们来详细地看下两个txt文件,这两个txt文件里面保存的是每一张图片的路径,一行一张,这个文件数据集里是没有的,我们要自己读取文件路径然后将其写入txt文件中,下图为train.txt的截取,valid处理与之相同
图片数据的准备已经完成,下面我们开始着手准备图片的label,该数据集的label为xml格式文件,放置在train/annotation/xmls文件夹下,每一个xml文件对应一张图片,用下面的代码把每个xml中我们需要的信息拿出来,放到一个txt文件内,每个txt文件名字与图片的名称一一对应
import xml.etree.ElementTree as ET
import os
VOC_CLASSES = ('leconte','boerner','linnaeus','armandi','coleoptera','acuminatus') # label的值分别为:0~5
VOC_ROOT = "insects/val/annotations/xmls"
class VOCAnnotationTransform(object):
"""
把VOC的annotation中bbox的坐标转化为归一化的值;
将类别转化为用索引来表示的字典形式;
Args:
class_to_ind: (dict)类别的索引字典
keep_difficult: 是否保留difficult=1的物体
"""
def __init__(self, class_to_ind=None, keep_difficult=False):
self.class_to_ind = class_to_ind or dict(
zip(VOC_CLASSES, range(len(VOC_CLASSES))))
self.keep_difficult = keep_difficult
def __call__(self, target, width, height):
res = []
for obj in target.iter('object'):
# 判断difficult
difficult = int(obj.find('difficult').text) == 1
if not self.keep_difficult and difficult:
continue
# 读取xml中所需的信息
name = obj.find('name').text.lower().strip()
bbox = obj.find('bndbox')
# bbox的表示
pts = ['xmin', 'ymin', 'xmax', 'ymax']
bndbox = []
label_idx = self.class_to_ind[name]
bndbox.append(label_idx)
for i, pt in enumerate(pts):
cur_pt = int(bbox.find(pt).text) - 1
# print(cur_pt)
# 归一化,x/w, y/h
cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
bndbox.append(cur_pt)
res += [bndbox]
return res
if __name__ == "__main__":
vocan = VOCAnnotationTransform()
list_path = os.listdir(VOC_ROOT)
for path in list_path:
target = ET.parse("./insects/val/annotations/xmls/"+path).getroot()
for i in target.iter('size'):
width = int(i.find('width').text.lower().strip())
height = int(i.find('height').text.lower().strip())
res = vocan(target, width, height)
name = path[:-4]
the_path = 'insects/val/annotations/labels/'+name+'.txt'
with open(the_path,'w') as txt:
for ii in res:
for yy in ii:
txt.write(str(yy))
txt.write(' ')
txt.write('\r\n')
处理完成后在train中的opt里设置好文件路径即可开始训练
dataset里面label_files要设置好,因为程序是靠图片文件的路径来寻找label文件的路径的,所以要把txt文件和图片放在同一个文件夹里,然后对dataset中的path.replace()方法做一些修改
class ListDataset(Dataset):
def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
with open(list_path, "r") as file:
self.img_files = file.readlines()
self.label_files = [
path.replace(".jpeg", ".txt") #通过这个代码把图片文件的地址改为txt文件的地址,原文代码里面是jpg,这里要改成jpeg
for path in self.img_files
]