pytorch使用torchvision自带fasterrcnn模型训练与测试(Pascal Voc与Coco数据集)

参考项目地址:https://github.com/lpuglia/torchvision_voc
参考链接:
[1]https://github.com/pytorch/vision/issues/1116
[2]https://pytorch.org/docs/stable/_modules/torchvision/models/detection/faster_rcnn.html
[3]https://pytorch.org/tutorials/beginner/data_loading_tutorial.html


2020-08-16更新
采用torchvision版本的faster rcnn模型训练自定义数据集(COCO数据集就格式)已经更新,代码托管在https://github.com/ouening/torchvision-FasterRCNN,下文做的修改目的是支持Pascal VOC格式,现已更新至支持COCO格式,PASCAL VOC至COCO格式的转换脚本亦已提供,用COCO格式数据集的好处是可以用pycocotools的评价指标,指标更加丰富。

本项目地址: https://github.com/ouening/MLPractice
项目文件结构:
pytorch使用torchvision自带fasterrcnn模型训练与测试(Pascal Voc与Coco数据集)_第1张图片
原项目工程只提供Pascal数据集和coco数据集的训练方法代码,为实现Pascal格式的自定义数据集,需要额外添加相关函数, 添加的函数以及其他改动有:
① voc_eval.py: custom_voc_eval()(该函数是冗余添加的,使用默认的voc_eval()函数也是可以的), _do_python_eval_custom_voc()
② engine.py: custom_voc_evaluate()
③ train.py: 添加两个参数选项:--train-data-path--test-ddata-path,用于设置自定义数据集路径

  • parser.add_argument('--train-data-path', help='train dataset path for custom voc dataset')
  • parser.add_argument('--test-data-path', help='test dataset path for custom voc dataset')

④ voc_utils.py: class ConvertCustomVOCtoCOCO(), get_custom_voc(), class VOCCustomData()

上面改动中新增的两个类ConvertCustomVOCtoCOCOVOCCustomData是加载自定义数据集的关键代码,参考了原项目ConvertVOCtoCOCO类以及pytorch官方VOCDetection类的实现,后面在小节内容中详细介绍.

1.数据集与文件修改

训练使用的数据识别的类别有2类

1.1 voc_utils.py文件修改

1.1.1 添加类VOCCustomData

class VOCCustomData(torchvision.datasets.vision.VisionDataset):
    """`Pascal VOC `_ Detection Dataset.

    Args:
        root (string): Root directory of the custom VOC Dataset which includes directories
            Annotations and JPEGImages

        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, required): A function/transform that takes in the
            target and transforms it.
    """

    def __init__(self,
                 root,
                 transform=None,
                 target_transform=None,
                 transforms=None):
        super(VOCCustomData, self).__init__(root, transforms, transform, target_transform)
        self.root = root
        self._transforms = transforms

        voc_root = self.root
        self.image_dir = os.path.join(voc_root, 'JPEGImages')
        self.annotation_dir = os.path.join(voc_root, 'Annotations')

        if not os.path.isdir(voc_root):
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' Please verify the correct Dataset!')
        file_names = []

        for imgs in os.listdir(self.image_dir):
            file_names.append(imgs.split('.')[0])
        
        images_file = pd.DataFrame(file_names,index=None)    
        # 保存图像路径,注意只有文件名,不带后缀和文件路径
        images_file.to_csv(voc_root+'/imagesetfile.txt',header=False,index=False)  

        self.images = [os.path.join(self.image_dir, x + ".jpg") for x in file_names]
        self.annotations = [os.path.join(self.annotation_dir, x + ".xml") for x in file_names]
        assert (len(self.images) == len(self.annotations))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is a dictionary of the XML tree.
        """
        img = Image.open(self.images[index]).convert('RGB')
        
        target = self.parse_voc_xml(
            ET.parse(self.annotations[index]).getroot())
        
        target = dict(image_id=index, annotations=target['annotation'])

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.images)

    def parse_voc_xml(self, node):
        voc_dict = {}
        children = list(node)
        if children:
            def_dic = collections.defaultdict(list)
            for dc in map(self.parse_voc_xml, children):
                for ind, v in dc.items():
                    def_dic[ind].append(v)
            voc_dict = {
                node.tag:
                    {ind: v[0] if len(v) == 1 else v
                     for ind, v in def_dic.items()}
            }
        if node.text:
            text = node.text.strip()
            if not children:
                voc_dict[node.tag] = text
        return voc_dict

1.1.2 添加类ConvertCustomVOCtoCOCO

class ConvertCustomVOCtoCOCO(object):
    # def __init__(self, class):

    CLASSES = (
        "__background__", "lost", "normal"
    )
    def __call__(self, image, target):
        # return image, target
        anno = target['annotations']
        filename = anno["filename"].split('.')[0]
        h, w = anno['size']['height'], anno['size']['width']
        boxes = []
        classes = []
        ishard = []
        objects = anno['object']
        if not isinstance(objects, list):
            objects = [objects]
        for obj in objects:
            bbox = obj['bndbox']
            bbox = [int(bbox[n]) - 1 for n in ['xmin', 'ymin', 'xmax', 'ymax']]
            boxes.append(bbox)
            classes.append(self.CLASSES.index(obj['name']))
            ishard.append(int(obj['difficult']))

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        classes = torch.as_tensor(classes)
        ishard = torch.as_tensor(ishard)

        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        target["ishard"] = ishard
        target['name'] = torch.tensor([ord(i) for i in list(filename)], dtype=torch.int8) #convert filename in int8

        return image, target

def get_custom_voc(root, transforms):
    t = [ConvertCustomVOCtoCOCO()]

    if transforms is not None:
        t.append(transforms)
    transforms = T.Compose(t)

    dataset = VOCCustomData(root=root,transforms=transforms)

    return dataset

1.2 voc_eval文件修改

1.2.1 添加函数custom_voc_eval()

def custom_voc_eval(classname,
             detpath,
             imagesetfile,
             annopath='',
             ovthresh=0.5,
             use_07_metric=False):
    """rec, prec, ap = voc_eval(detpath,
                                annopath,
                                imagesetfile,
                                classname,
                                [ovthresh],
                                [use_07_metric])

    Top level function that does the PASCAL VOC evaluation.

    detpath: Path to detections
        detpath.format(classname) should produce the detection results file.
    annopath: Path to annotations xml标准文件路径,一般在Annotations里面
        annopath.format(imagename) should be the xml annotations file.
    imagesetfile: Text file containing the list of images, one image per line.只包含图片名称的文本文件
    classname: Category name (duh)
    [ovthresh]: Overlap threshold (default = 0.5)
    [use_07_metric]: Whether to use VOC07's 11 point AP computation
        (default False)
    """
    # assumes detections are in detpath.format(classname)
    # assumes annotations are in annopath.format(imagename)
    # assumes imagesetfile is a text file with each line an image name

    recs = {}
    # read list of images
    with open(imagesetfile, 'r') as f:
        lines = f.readlines()
        imagenames = [x.strip() for x in lines]

        # load annotations
        for i, imagename in enumerate(imagenames):
            recs[imagename] = parse_rec(annopath.format(imagename))

    # extract gt objects for this class
    class_recs = {}
    npos = 0
    for imagename in imagenames:
        R = [obj for obj in recs[imagename] if obj['name'] == classname]
        bbox = np.array([x['bbox'] for x in R])
        difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
        det = [False] * len(R)
        npos = npos + sum(~difficult)
        class_recs[imagename] = {'bbox': bbox,
                                 'difficult': difficult,
                                 'det': det}

    # read dets
    detfile = detpath.format(classname)
    with open(detfile, 'r') as f:
        lines = f.readlines()

    splitlines = [x.strip().split(' ') for x in lines]
    image_ids = [x[0] for x in splitlines]
    confidence = np.array([float(x[1]) for x in splitlines])
    BB = np.array([[float(z) for z in x[2:]] for x in splitlines])

    nd = len(image_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)

    if BB.shape[0] > 0:
      # sort by confidence
      sorted_ind = np.argsort(-confidence)
      sorted_scores = np.sort(-confidence)
      BB = BB[sorted_ind, :]
      image_ids = [image_ids[x] for x in sorted_ind]

      # go down dets and mark TPs and FPs
      for d in range(nd):
        R = class_recs[image_ids[d]]
        bb = BB[d, :].astype(float)
        ovmax = -np.inf
        BBGT = R['bbox'].astype(float)

        if BBGT.size > 0:
            # compute overlaps
            # intersection
            ixmin = np.maximum(BBGT[:, 0], bb[0])
            iymin = np.maximum(BBGT[:, 1], bb[1])
            ixmax = np.minimum(BBGT[:, 2], bb[2])
            iymax = np.minimum(BBGT[:, 3], bb[3])
            iw = np.maximum(ixmax - ixmin + 1., 0.)
            ih = np.maximum(iymax - iymin + 1., 0.)
            inters = iw * ih

            # union
            uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
                 (BBGT[:, 2] - BBGT[:, 0] + 1.) *
                 (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)

            overlaps = inters / uni
            ovmax = np.max(overlaps)
            jmax = np.argmax(overlaps)

        if ovmax > ovthresh:
            if not R['difficult'][jmax]:
                if not R['det'][jmax]:
                    tp[d] = 1.
                    R['det'][jmax] = 1
                else:
                    fp[d] = 1.
        else:
            fp[d] = 1.

    # compute precision recall
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    rec = tp / float(npos)
    # avoid divide by zero in case the first detection matches a difficult
    # ground truth
    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
    ap = voc_ap(rec, prec, use_07_metric)

    return rec, prec, ap

1.2.2 添加函数_write_custom_voc_results_file

def _write_custom_voc_results_file(data_loader,all_boxes, image_index, root, classes, thread=0.3):
    if os.path.exists('/tmp/results'):
        shutil.rmtree('/tmp/results')
    os.makedirs('/tmp/results')
    print('Writing results file', end='\r')

    os.makedirs("output", exist_ok=True)    # 创建output目录,存储图片检测结果
    # Bounding-box colors
    # cmap = plt.get_cmap("tab20b")
    # colors = [cmap(i) for i in np.linspace(0, 1, 20)]
    colors = [(255,0,0),(0,255,0),(0,0,255)]

    for cls_ind, cls  in enumerate(classes):
        # DistributeSampler happens to clone the inputs to make the task 
        # lenghts even among the nodes:
        # https://github.com/pytorch/pytorch/issues/22584
        # Boxes can be duplicated in the process since multiple
        # evaluation of the same image can happen, multiple boxes in the
        # same location decrease the final mAP, later in the code we discard
        # repeated image_index thanks to the sorting
        new_image_index, all_boxes[cls_ind] = zip(*sorted(zip(image_index,
                                 all_boxes[cls_ind]), key=lambda x: x[0]))
        if cls == '__background__':
            continue
        images_dir = data_loader.dataset.image_dir
        filename = '/tmp/results/det_test_{:s}.txt'.format(cls)
        

        with open(filename, 'wt') as f:
            prev_index = ''
            for im_ind, index in enumerate(new_image_index):
                # opencv读取图片
                img = cv2.imread(os.path.join(images_dir,index+'.jpg'))
                h, w, _ = img.shape

                # check for repeated input and discard
                if prev_index == index: continue
                prev_index = index
                dets = all_boxes[cls_ind][im_ind]
                if dets == []:
                    continue
                dets = dets[0]
                
                bbox_colors = random.sample(colors, 3)

                # the VOCdevkit expects 1-based indices
                for k in range(dets.shape[0]):
                    f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
                            format(index, dets[k, -1],
                                   dets[k, 0] + 1, dets[k, 1] + 1,
                                   dets[k, 2] + 1, dets[k, 3] + 1))
                    if dets[k, -1]<thread:
                        continue
                    # print("\t+ Label: %s, Conf: %.5f" % (cls, dets[k, -1]))
                    x1, x2 = dets[k, 0], dets[k, 2]
                    y1, y2 = dets[k, 1], dets[k, 3]

                    color = colors[cls_ind]
                    thick = int((h + w) / 300)
                    cv2.rectangle(img,
                                    (x1, y1), (x2, y2),
                                    color, thick)
                    mess = '%s: %.3f' % (cls, dets[k, -1])
                    cv2.putText(img, mess, (x1, y1 - 12),
                                0, 1e-3 * h, color, thick // 3)
                
                filename = index
                cv2.imwrite(f"output/output-{filename}.png", img)

1.2.3 添加函数_do_python_eval_custom_voc

def _do_python_eval_custom_voc(data_loader,use_07_metric=True):

    imagesetfile = os.path.join(data_loader.dataset.root,'imagesetfile.txt')
    annopath = os.path.join(data_loader.dataset.annotation_dir,'{:s}.xml')

    classes = data_loader.dataset._transforms.transforms[0].CLASSES
    aps = []
    fig = plt.figure()

    for cls in classes:
        if cls == '__background__':    
            continue    
        filename = '/tmp/results/det_test_{:s}.txt'.format(cls)    
        rec, prec, ap = custom_voc_eval(cls, filename, imagesetfile, annopath,
                            ovthresh=0.5, use_07_metric=use_07_metric)    
        print('+ Class {} - AP: {}'.format(cls, ap))
        plt.plot(rec, prec, label='{}'.format(cls))
        aps += [ap]
    plt.xlabel('recall')
    plt.ylabel('precision')
    plt.legend()
    plt.show()
    print('Mean AP = {:.4f}        '.format(np.mean(aps)))

1.3 engine.py文件修改

1.3.1 添加函数custom_voc_evaluate()

@torch.no_grad()
def custom_voc_evaluate(model, data_loader, device):
    n_threads = torch.get_num_threads()
    torch.set_num_threads(1)
    cpu_device = torch.device("cpu")
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    all_boxes = [[] for i in range(21)]
    image_index = []
    for image, targets in metric_logger.log_every(data_loader, 100, header):
        image = list(img.to(device) for img in image)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        torch.cuda.synchronize()
        model_time = time.time()
        outputs = model(image)

        name = ''.join([chr(i) for i in targets[0]['name'].tolist()])
        image_index.append(name)

        outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]

        image_boxes = [[] for i in range(3)] # 需要修改该值
        for o in outputs:
            for i in range(o['boxes'].shape[0]):
                image_boxes[o['labels'][i]].extend([
                    torch.cat([o['boxes'][i],o['scores'][i].unsqueeze(0)], dim=0)
                ])

        #makes sure that the all_boxes is filled with empty array when
        #there are no boxes in image_boxes
        for i in range(3):
            if image_boxes[i] != []:
                all_boxes[i].append([torch.stack(image_boxes[i])])
            else:
                all_boxes[i].append([])

        model_time = time.time() - model_time

    metric_logger.synchronize_between_processes()

    all_boxes_gathered = utils.all_gather(all_boxes)
    image_index_gathered = utils.all_gather(image_index)
    
    # results from all processes are gathered here
    if utils.is_main_process():
        all_boxes = [[] for i in range(21)]
        for abgs in all_boxes_gathered:
            for ab,abg in zip(all_boxes,abgs):
                ab += abg
        image_index = []
        for iig in image_index_gathered:
            image_index+=iig

        _write_custom_voc_results_file(data_loader, all_boxes,image_index, data_loader.dataset.root, 
                                data_loader.dataset._transforms.transforms[0].CLASSES,)
        _do_python_eval_custom_voc(data_loader)
    torch.set_num_threads(n_threads)

1.4 train.py文件修改

1.4.1 修改get_dataset函数

def get_dataset(name, image_set, transform, data_path):
    paths = {
        "coco": (data_path, get_coco, 91),
        "coco_kp": (data_path, get_coco_kp, 2),
        "voc": (data_path, get_voc, 21),
        "custom_voc": (data_path, get_custom_voc, 3)
    }
    p, ds_fn, num_classes = paths[name]

    if name=='custom_voc':  # 加载自定义的Pascal格式数据集
        ds = ds_fn(p, transforms=transform)
        return ds, num_classes
    else:    
        ds = ds_fn(p, image_set=image_set, transforms=transform)
        return ds, num_classes

1.4.2 修改main函数

def main(args):
    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")

    # 支持加载自定义Pascal格式数据集 参数dataset设置为custom_voc
    if args.dataset=='custom_voc':
        # 如果是自定义Pascal数据集,不需要传入image_set参数,因此这里设置为None
        dataset, num_classes = get_dataset(args.dataset, None, get_transform(train=True), args.train_data_path)
        dataset_test, _ = get_dataset(args.dataset, None, get_transform(train=False), args.test_data_path)
    else :
        dataset, num_classes = get_dataset(args.dataset, "train" if args.dataset=='coco' else 'trainval', 
            get_transform(train=True), args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "test" if args.dataset=='coco' else 'val', 
                    get_transform(train=False), args.data_path)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
    else:
        train_batch_sampler = torch.utils.data.BatchSampler(
            train_sampler, args.batch_size, drop_last=True)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
        collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1,
        sampler=test_sampler, num_workers=args.workers,
        collate_fn=utils.collate_fn)

    print("Creating model")
    model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes,
                                                              pretrained=args.pretrained)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(
        params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])  # 用于恢复训练,处理模型还需要优化器和学习率规则
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

    # 如果只进行模型测试,注意这里传入的参数是--resume, 原作者只提到了--resume用于恢复训练,根据官方文档可知也是可以用于模型推理的
    # 参考官方文档https://pytorch.org/tutorials/beginner/saving_loading_models.html
    if args.test_only:  
        if not args.resume:
            raise Exception('需要checkpoints模型用于推理!')
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
            model_without_ddp.load_state_dict(checkpoint['model'])

            if 'coco' == args.dataset:
                coco_evaluate(model_without_ddp, data_loader_test, device=device)
            elif 'voc' == args.dataset:
                voc_evaluate(model_without_ddp, data_loader_test, device=device)
            elif 'custom_voc' == args.dataset:
                custom_voc_evaluate(model_without_ddp, data_loader_test, device=device)
            else:
                print(f'No evaluation method available for the dataset {args.dataset}')
            # evaluate(model, data_loader_test, device=device)
            return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
        lr_scheduler.step()
        if args.output_dir:
            # model.save('./checkpoints/model_{}_{}.pth'.format(args.dataset, epoch))
            utils.save_on_master({
                'model': model_without_ddp.state_dict(), # 存储网络参数(不存储网络骨架)
                # 'model': model_without_ddp, # 存储整个网络
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'args': args},
                os.path.join(args.output_dir, 'model_{}_{}.pth'.format(args.dataset, epoch)))

        # evaluate after every epoch
        if  args.dataset=='coco':
            coco_evaluate(model, data_loader_test, device=device)
        elif 'voc'==args.dataset:
            voc_evaluate(model, data_loader_test, device=device)
        elif 'custom_voc' == args.dataset:
            custom_voc_evaluate(model, data_loader_test, device=device)
        else:
            print(f'No evaluation method available for the dataset {args.dataset}')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(
        description=__doc__)

    parser.add_argument('--data-path', default='./', help='dataset path used for coco and voc(default is "./")')
    parser.add_argument('--train-data-path',  help='train dataset path for custom voc dataset')
    parser.add_argument('--test-data-path',  help='test dataset path for custom voc dataset')
    parser.add_argument('--dataset', default='coco', 
                        help='dataset type, option are "coco", "voc" and "coco_kp", defualt is "coco"')
    parser.add_argument('--model', default='fasterrcnn_resnet50_fpn', help='model, default="fasterrcnn_resnet50_fpn"')
    parser.add_argument('--device', default='cuda', help='device, default is cuda')
    parser.add_argument('-b', '--batch-size', default=2, type=int, 
                        help='number of batch_size(default is 2)')
    parser.add_argument('--epochs', default=13, type=int, metavar='N',
                        help='number of total epochs to run(default is 13)')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 16)')
    parser.add_argument('--lr', default=0.02, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
    parser.add_argument('--lr-steps', default=[8, 11], nargs='+', type=int, help='decrease lr every step-size epochs')
    parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
    parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
    parser.add_argument('--output-dir', default='./', help='path where to save,default="./" ')
    parser.add_argument('--resume', default='', help='resume from checkpoint,default=''')
    parser.add_argument('--aspect-ratio-group-factor', default=0, type=int)
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument(
        "--pretrained",
        dest="pretrained",
        help="Use pre-trained models from the modelzoo",
        action="store_true",
    )

    # distributed training parameters
    parser.add_argument('--world-size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')

    args = parser.parse_args()

    if args.output_dir:
        utils.mkdir(args.output_dir)

    main(args)

2.数据集加载

数据集加载是所有机器学习任务中最开始需要完成的操作,注意涉及对数据集的文件操作,因此需要对Python文件操作方式要熟练.对分类任务来说数据集加载还不算难,但是UI目标检测任务而言数据集加载就涉及比较多的细节,参考pytorch官方实现的对Pascal数据集的加载,可以发现里面需要对xml标注文件进行解析,同时把解析得到的内容存储到一个字典里,除了经典的Pascal VOC数据集格式,其他常见的数据集格式还有coco格式和yolo格式,不同算法模型会要求使用不同的数据集格式,因此这些格式直接的相互转换也是机器学习中的重点和难点内容,如果数据集这一块无法正常获取加载,后面的网络训练部分也就无从谈起了.
原项目https://github.com/lpuglia/torchvision_voc中只实现了标准VOC和Coco数据集的训练和检测,对于用Pascal Vov格式制作的自定义数据集加载和使用需要另外实现,这点在第一节内容已经介绍了各个文件中代码的修改,详细实现可以查看源码.

3.网络训练

$ python3 train.py --dataset custom_voc --train-data-path /data/to/train --test-data-path /data/to/test -b 2 --output-dir ./checkpoints

训练过程中,每完成一轮迭代训练,会对测试集进行一次模型评估,输出mAP值,绘制PR曲线.

4.模型评估测试

在训练过程已经自动对测试集进行评估测试过,亦可单独执行推理评估步骤,在Linux终端执行下列命令:

$ python3 train.py --dataset custom_voc --test-only --train-data-path /data/to/train --test-data-path /data/to/test --resume model_custom_voc_11.pth

注意上面参数–resume的作用是用于模型推理,结果为:

Test: Total time: 0:00:10 (0.0506 s / it)
+ Class lost - AP: 0.8858719783518023
+ Class normal - AP: 0.887533003893689
Mean AP = 0.8867 

5. 注意

本博客中使用的自定义数据集地址均为博主本地地址,主要是方便以后快速复现,也没有什么敏感内容就不做修改了,目前只支持Pascal格式的自定义数据集,数据集结构目录如下:
pytorch使用torchvision自带fasterrcnn模型训练与测试(Pascal Voc与Coco数据集)_第2张图片文件夹名称和里面的内容不能出错,因为在文件voc_utils.py中实现数据集的加载代码为:
pytorch使用torchvision自带fasterrcnn模型训练与测试(Pascal Voc与Coco数据集)_第3张图片另外一点需要注意的是在使用自定义数据集的时候还需要修改voc_utils.py中的类ConvertCustomVOCtoCOCO:
pytorch使用torchvision自带fasterrcnn模型训练与测试(Pascal Voc与Coco数据集)_第4张图片
以上红框中的内容要根据自己的数据集进行修改,笔者的数据集只有两类:lostnormal(背景__background__不用改动),这个方式是借鉴原项目的方法实现的,灵活度不够,后面有需要再进行改进.

你可能感兴趣的:(Python,机器学习,faster-rcnn,pytorch,目标检测)