pytorch——yolov3学习之模型训练

pytorch——yolov3学习之模型训练

代码来自:https://github.com/eriklindernoren/PyTorch-YOLOv3

最近在学习yolov3,学长布置了任务,让用yolov3来训练百度识虫的数据集,写这篇文章总结一下最近的学习心得,写一写我在学习过程中踩的坑

数据集地址:https://aistudio.baidu.com/aistudio/datasetdetail/19638
(该数据集使用的是VOC存储方式,由2183张jpeg格式图片构成,其中训练集1693张,验证集245,测试集245张,共有六种昆虫,label数据为xml格式)
pytorch——yolov3学习之模型训练_第1张图片

  • 参数设置

既然要训练模型就要从train.py以及datasets.py入手,首先看train.py里面的设置部分

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=10, help="number of epochs")
    parser.add_argument("--batch_size", type=int, default=1, help="size of each image batch")
    parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
    parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
    parser.add_argument("--data_config", type=str, default="config/my_data.data", help="path to data config file")
    parser.add_argument("--pretrained_weights", type=str, help="if specified starts from checkpoint model")
    parser.add_argument("--n_cpu", type=int, default=4, help="number of cpu threads to use during batch generation")
    parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
    parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")
    parser.add_argument("--evaluation_interval", type=int, default=1, help="interval evaluations on validation set")
    parser.add_argument("--compute_map", default=False, help="if True computes mAP every tenth batch")
    parser.add_argument("--multiscale_training", default=True, help="allow for multi-scale training")
    opt = parser.parse_args()
    print(opt)

上面的代码,一部分是对网络参数的调整,这部分我们先不管,要训练我们拿到的数据集,先看看data_config,可见该指向的是一个.data文件,让我们看看这个文件里面是啥(如下图)

在这里插入图片描述
上图中,
classes是数据集一共有几类目标,该昆虫数据集共有六种昆虫,所以我们将其设为6
train和valid分别指向一个txt文件(这两个文件一会详细介绍)
names里存的是每一种昆虫的学名即label的名称

  • txt文件

我们来详细地看下两个txt文件,这两个txt文件里面保存的是每一张图片的路径,一行一张,这个文件数据集里是没有的,我们要自己读取文件路径然后将其写入txt文件中,下图为train.txt的截取,valid处理与之相同

pytorch——yolov3学习之模型训练_第2张图片
图片数据的准备已经完成,下面我们开始着手准备图片的label,该数据集的label为xml格式文件,放置在train/annotation/xmls文件夹下,每一个xml文件对应一张图片,用下面的代码把每个xml中我们需要的信息拿出来,放到一个txt文件内,每个txt文件名字与图片的名称一一对应

import xml.etree.ElementTree as ET
import os
VOC_CLASSES = ('leconte','boerner','linnaeus','armandi','coleoptera','acuminatus')   # label的值分别为:0~5

VOC_ROOT = "insects/val/annotations/xmls"


class VOCAnnotationTransform(object):
    """
    把VOC的annotation中bbox的坐标转化为归一化的值;
    将类别转化为用索引来表示的字典形式;
    Args:
        class_to_ind: (dict)类别的索引字典
        keep_difficult: 是否保留difficult=1的物体
    """
    def __init__(self, class_to_ind=None, keep_difficult=False):
        self.class_to_ind = class_to_ind or dict(
                zip(VOC_CLASSES, range(len(VOC_CLASSES))))
        self.keep_difficult = keep_difficult

    def __call__(self, target, width, height):
        res = []
        for obj in target.iter('object'):
            # 判断difficult
            difficult = int(obj.find('difficult').text) == 1
            if not self.keep_difficult and difficult:
                continue

            # 读取xml中所需的信息
            name = obj.find('name').text.lower().strip()
            bbox = obj.find('bndbox')
            # bbox的表示
            pts = ['xmin', 'ymin', 'xmax', 'ymax']
            bndbox = []
            label_idx = self.class_to_ind[name]

            bndbox.append(label_idx)
            for i, pt in enumerate(pts):
                cur_pt = int(bbox.find(pt).text) - 1
                # print(cur_pt)
                # 归一化,x/w, y/h
                cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height

                bndbox.append(cur_pt)

            res += [bndbox]
        return res

if __name__ == "__main__":
    vocan = VOCAnnotationTransform()

    list_path = os.listdir(VOC_ROOT)
    for path in list_path:
        target = ET.parse("./insects/val/annotations/xmls/"+path).getroot()
        for i in target.iter('size'):
            width = int(i.find('width').text.lower().strip())
            height = int(i.find('height').text.lower().strip())
            res = vocan(target, width, height)
            name = path[:-4]
            the_path = 'insects/val/annotations/labels/'+name+'.txt'
            with open(the_path,'w') as txt:
                for ii in res:
                    for yy in ii:
                        txt.write(str(yy))
                        txt.write(' ')
                    txt.write('\r\n')

处理完成后在train中的opt里设置好文件路径即可开始训练

  • 坑1

dataset里面label_files要设置好,因为程序是靠图片文件的路径来寻找label文件的路径的,所以要把txt文件和图片放在同一个文件夹里,然后对dataset中的path.replace()方法做一些修改

class ListDataset(Dataset):
    def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
        with open(list_path, "r") as file:
            self.img_files = file.readlines()

        self.label_files = [
            path.replace(".jpeg", ".txt")  #通过这个代码把图片文件的地址改为txt文件的地址,原文代码里面是jpg,这里要改成jpeg
            for path in self.img_files
        ]

你可能感兴趣的:(python,深度学习,机器学习,pytorch)