pytorch-retinanet训练自己的数据集

1.VOC数据集标注

https://blog.csdn.net/qq_38082979/article/details/102868269

2.VOC2csv

得到annotations.csv

classes.csv

val.csv

# -*- coding:utf-8 -*-
 
import csv
import os
import glob
import sys
 
class PascalVOC2CSV(object):
    def __init__(self,xml=[], ann_path='./annotations.csv',classes_path='./classes.csv'):
        '''
        :param xml: 所有Pascal VOC的xml文件路径组成的列表
        :param ann_path: ann_path
        :param classes_path: classes_path
        '''
        self.xml = xml
        self.ann_path = ann_path
        self.classes_path=classes_path
        self.label=[]
        self.annotations=[]
 
        self.data_transfer()
        self.write_file()
 
 
    def data_transfer(self):
        for num, xml_file in enumerate(self.xml):
            try:
                # print(xml_file)
                # 进度输出
                sys.stdout.write('\r>> Converting image %d/%d' % (
                    num + 1, len(self.xml)))
                sys.stdout.flush()
 
                with open(xml_file, 'r') as fp:
                    for p in fp:
                        if '' in p:
                            self.filen_ame = p.split('>')[1].split('<')[0]
 
                        if '' in p:
                            # 类别
                            d = [next(fp).split('>')[1].split('<')[0] for _ in range(9)]
                            self.supercategory = d[0]
                            if self.supercategory not in self.label:
                                self.label.append(self.supercategory)
 
                            # 边界框
                            x1 = int(d[-4]);
                            y1 = int(d[-3]);
                            x2 = int(d[-2]);
                            y2 = int(d[-1])
 
                            self.annotations.append([os.path.join('/data/VOCdevkit/VOC2007/JPEGImages',self.filen_ame),x1,y1,x2,y2,self.supercategory])
            except:
                continue
 
        sys.stdout.write('\n')
        sys.stdout.flush()
 
    def write_file(self,):
        with open(self.ann_path, 'w', newline='') as fp:
            csv_writer = csv.writer(fp, dialect='excel')
            csv_writer.writerows(self.annotations)
 
        class_name=sorted(self.label)
        class_=[]
        for num,name in enumerate(class_name):
            class_.append([name,num])
        with open(self.classes_path, 'w', newline='') as fp:
            csv_writer = csv.writer(fp, dialect='excel')
            csv_writer.writerows(class_)
 
 
xml_file = glob.glob('./Annotations/*.xml')
 
PascalVOC2CSV(xml_file)

3.pytorch-retinanet

https://github.com/yhenon/pytorch-retinanet

4.pytorch1.1修改nms

https://github.com/huaifeng1993/NMS

model文件修改

def nms(dets, thresh):
    "Dispatch to either CPU or GPU NMS implementations.\
    Accept dets as tensor"""
    dets = np.array(dets.cpu())
    #return pth_nms(dets, thresh)
    return gpu_nms(dets, thresh)

 

你可能感兴趣的:(pytorch-retinanet训练自己的数据集)