Yolox目标检测训练实录--1. 数据准备

参考官方给出的教程train_custom_data

YOLOX官方链接

  1. 配置环境, install yolox.
  2. 准备voc dataset, 调试train代码.
  3. 把自己的 dataset, 转换为voc格式, 调试train代码.
  4. 选择不同的模型结构, 训练最终可使用的model.
  5. 模型转换压缩, inference部署应用。

本文重点介绍dataset的处理

  1. install yolox
#服务器--pytorch环境
git clone [email protected]:Megvii-BaseDetection/YOLOX.git
cd YOLOX
pip3 install -v -e .  # or  python3 setup.py develop

windows 参考环境配置和修改代码

conda create -n yolox python=3.7 #用python 3.8也可以
conda activate yolox

`#如果你切换了国内的源可以把后面的-c pytorch去掉。
conda install pytorch=1.7 torchvision cudatoolkit=10.2 -c pytorch``

git clone git@github.com:Megvii-BaseDetection/YOLOX.git
cd YOLOX
pip install -r requirements.txt
python setup.py develop

安装完成, 下载pretrained model yolox_s测试

python tools/demo.py image -f exps/default/yolox_s.py -c yolox_s.pth --path assets/dog.jpg --conf 0.25 --nms 0.45 --tsize 640 --save_result --device gpu

  1. voc dataset

    download voc dataset
    作为初步验证, 仅使用了VOC2007的部分数据, 由于后面想做单类别的检测, 仅使用了voc中的car作为验证。
    Yolox目标检测训练实录--1. 数据准备_第1张图片
    8张测试数据 提取码:wxyh

使用Python把原始数据中的car类别挑选出来

import xml.etree.ElementTree as ET
import os

def newImageSets(oldSets, newSets):

    #保存含有car的文件名
    savelist = []

    with open(oldSets, 'r') as f:
        for line in f.readlines():

            ids = int(line)
            path_i = 'Annotations/%06d.xml'%ids
            if os.path.exists(path_i):
                print(path_i)
                savelist.append(line)

    with open(newSets, 'a') as f1:
        for id in savelist:
            f1.write(id)


    return 

def selectCarAnn(srcAnnPath, dstAnnPath):

    srcPath = os.path.join(srcAnnPath, "%06d.xml")
    dstPath = os.path.join(dstAnnPath, "%06d.xml")

    count = 0
    #遍历所有标签文件
    for id in range(1,9964):
        _path = srcPath % id
        rootTree = ET.parse(_path)
        target = rootTree.getroot()


        #判断此标签文件中, 是否有car
        carFlag = False
        for obj in target.iter("object"):
            name = obj.find("name").text.strip()

            if name == 'car':
                carFlag = True
                #print("name: ",_path,"--", name)
        

        #如果有car, remove所有非car的标注box
        if carFlag:
            count += 1
            #print(count)

            #保存需要remove的非car物体
            rm_list = []
            for obj in target.iter("object"):
                name = obj.find("name").text.strip()
 
                if name != 'car':
                    rm_list.append(obj)

            for o in rm_list:
                target.remove(o)

            rootTree.write(dstPath%id)

    print(count)


    return 


def main():

    selectCarAnn("Annotations", "Annotations_new")

    newImageSets("ImageSets/Main/test.txt", "ImageSets/Main/test_new.txt")

    return 


if __name__ == '__main__':
    main()
  1. Hand dataset
    download Hand dataset
    此数据集是用matlab打包的, 这里使用python解析.mat数据文件

    如果想查看原始数据集的标注情况, 使用以下Python代码

import scipy.io as scio
import cv2
import random
import colorsys
import os

def loadbox(data):
    out = []
    for box in data['boxes'][0]:
        p0 = box[0][0][0]
        p1 = box[0][0][1]
        p2 = box[0][0][2]
        p3 = box[0][0][3]
        res = []
        res.append(p0[0])
        res.append(p1[0])
        res.append(p2[0])
        res.append(p3[0])

        out.append(res)
    return out
def get_n_hls_colors(num):
    hls_colors = []
    i = 0
    step = 360.0 / num
    while i < 360:
        h = i
        s = 90 + random.random() * 10
        l = 50 + random.random() * 10
        _hlsc = [h / 360.0, l / 100.0, s / 100.0]
        hls_colors.append(_hlsc)
        i += step

    return hls_colors
def ncolors(num):
    rgb_colors = []
    if num < 1:
        return rgb_colors
    hls_colors = get_n_hls_colors(num)
    for hlsc in hls_colors:
        _r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2])
        r, g, b = [int(x * 255.0) for x in (_r, _g, _b)]
        rgb_colors.append([r, g, b])

    return rgb_colors


def drawImg(boxes, img):
    print(img.shape)
    n = len(boxes)
    colors = ncolors(n)

    for bb in boxes:
        n -= 1
        color = (colors[n][2], colors[n][1], colors[n][0])
        pt0 = (int(bb[0][1]), int(bb[0][0]))
        pt1 = (int(bb[1][1]), int(bb[1][0]))
        pt2 = (int(bb[2][1]), int(bb[2][0]))
        pt3 = (int(bb[3][1]), int(bb[3][0]))

        min_x = min(int(bb[0][1]), int(bb[1][1]), int(bb[2][1]), int(bb[3][1]))
        max_x = max(int(bb[0][1]), int(bb[1][1]), int(bb[2][1]), int(bb[3][1]))
        min_y = min(int(bb[0][0]), int(bb[1][0]), int(bb[2][0]), int(bb[3][0]))
        max_y = max(int(bb[0][0]), int(bb[1][0]), int(bb[2][0]), int(bb[3][0]))

        # cv2.line(img,pt0, pt1, color,3 )
        # cv2.line(img,pt0, pt3, color,3 )
        # cv2.line(img,pt1, pt2, color,3 )
        # cv2.line(img,pt2, pt3, color,3 )
       
        cv2.rectangle(img, (min_x, min_y), (max_x, max_y),color, 1)
        

    return 0

def process_single_pair(imgPath, matPath):
    img = cv2.imread(imgPath)
    data = scio.loadmat(matPath)
    drawImg(loadbox(data), img)

    cv2.imshow("tt", img)
    
    return 0

def vis_folder(imgFolder, matFolder):
    files = os.listdir(imgFolder)
    for f in files:
        if f.endswith(".jpg"):
            imgPath = os.path.join(imgFolder, f)
            print(imgPath)
            matPath = os.path.join(matFolder, f[:-3] + "mat")
            if os.path.exists(imgPath) and os.path.exists(matPath) : 
                process_single_pair(imgPath, matPath)
                if cv2.waitKey(0) == ord('q'):
                    cv2.destroyAllWindows()
                    break
                    


    return 0

if __name__ == '__main__':
	vis_folder("test_dataset/test_data/images/","test_dataset/test_data/annotations/")
	

hand dataset 转换为 voc格式的代码

import copy
import scipy.io as scio
import cv2
import os
import xml.etree.ElementTree as ET

class HandData2VOC():
    def __init__(self):
        self.srcHandAnnotationsDir = "hand_dataset/annotations/"
        self.dstHandAnnotationsDir = "voc_style/Annotations/"
        self.imageDir = 'voc_style/JPEGImages/'
        self.imageSets = 'voc_style/ImageSets/'

        self.xmlAnnotation = "000004.xml"

        
        
    def test(self):


        return 

    def runTransfor(self, matName):
        rootTree = ET.parse(self.xmlAnnotation)
        target = rootTree.getroot()
        saveDir = os.path.join(self.dstHandAnnotationsDir, matName[:-3]+"xml")
        #print("==============1. saveDir: ", saveDir)

        ### change folder name && filename
        nodeFolder = target.find('folder')
        nodeFolder.text = "Hand_data"

        imageName = matName[:-3] + 'jpg'
        nodeFileName = target.find("filename")
        nodeFileName.text = imageName

        ########################Get Image##############################
        imgDir = os.path.join(self.imageDir, imageName)
        #print("==============2. imgDir: ", imgDir)
        img = cv2.imread(imgDir)
        h,w,c = img.shape
        
        nodeSize = target.find('size')
        nodewidth = nodeSize.find('width')
        nodewidth.text = str(w)
        nodeheight = nodeSize.find("height")
        nodeheight.text = str(h)
        nodec = nodeSize.find("depth")
        nodec.text = str(c)


        objNodei = target.find("object")
        objNodei.find("name").text = "hand"
        

        #########################Get Boxes###############################
        matDir = os.path.join(self.srcHandAnnotationsDir, matName)
        data = scio.loadmat(matDir)
        #print("==============3. matDir: ", matDir)
        boxes_data = data['boxes'][0]
        numBox = len(boxes_data)
        print("==============4. numBox: ", numBox)
        if numBox == 0:
            print('----------------------->: ', matName)
            return 0
        if numBox > 0:
            #print("")
            box = boxes_data[0]
            p0 = box[0][0][0][0]
            p1 = box[0][0][1][0]
            p2 = box[0][0][2][0]
            p3 = box[0][0][3][0]
            min_x = min(int(p0[1]), int(p1[1]), int(p2[1]), int(p3[1]))
            max_x = max(int(p0[1]), int(p1[1]), int(p2[1]), int(p3[1]))
            min_y = min(int(p0[0]), int(p1[0]), int(p2[0]), int(p3[0]))
            max_y = max(int(p0[0]), int(p1[0]), int(p2[0]), int(p3[0]))

            #objNodei = target.find("object")
            nodeBndBox = objNodei.find("bndbox")
            nodeBndBox.find("xmin").text = str(min_x)
            nodeBndBox.find("ymin").text = str(min_y)
            nodeBndBox.find("xmax").text = str(max_x)
            nodeBndBox.find("ymax").text = str(max_y)

            
            #return 1

            for bb in data['boxes'][0][1:]:
                nodeObjCp = copy.deepcopy(objNodei)
                p_0 = bb[0][0][0][0]
                p_1 = bb[0][0][1][0]
                p_2 = bb[0][0][2][0]
                p_3 = bb[0][0][3][0]
                xmin = min(int(p_0[1]), int(p_1[1]), int(p_2[1]), int(p_3[1]))
                xmax = max(int(p_0[1]), int(p_1[1]), int(p_2[1]), int(p_3[1]))
                ymin = min(int(p_0[0]), int(p_1[0]), int(p_2[0]), int(p_3[0]))
                ymax = max(int(p_0[0]), int(p_1[0]), int(p_2[0]), int(p_3[0]))

                nodeBndBoxi = nodeObjCp.find("bndbox")
                nodeBndBoxi.find("xmin").text = str(xmin)
                nodeBndBoxi.find("ymin").text = str(ymin)
                nodeBndBoxi.find("xmax").text = str(xmax)
                nodeBndBoxi.find("ymax").text = str(ymax)
                target.append(nodeObjCp)
                print(xmin," ", ymin," ", xmax," ", ymax)
            rootTree.write(saveDir)
            print("===========save: ", saveDir)

        return 1


    def run(self):
        annotations = os.listdir(self.srcHandAnnotationsDir)

        #imageSet = "[]"
        setDir = os.path.join(self.imageSets, "test0.txt")
        with open(setDir, 'w') as file:
            
            for mat in annotations:
                #print(mat)
                if self.runTransfor(mat):
                    file.write(mat[:-4]+"\n")


        return 0

完整数据和代码 提取码:1fut

数据处理结束, 检查一下处理的效果如何:

import numpy as np
def parserXml(xmlPath):
    target = ET.parse(xmlPath).getroot()

    res = np.empty((0, 5))
    for obj in target.iter("object"):
        difficult = obj.find("difficult")
        if difficult is not None:
            difficult = int(difficult.text) == 1
        else:
            difficult = False
        if not 1 and difficult:
            continue
        name = obj.find("name").text.strip()
        bbox = obj.find("bndbox")

        pts = ["xmin", "ymin", "xmax", "ymax"]
        bndbox = []
        for i, pt in enumerate(pts):
            cur_pt = int(float(bbox.find(pt).text)) - 1
            # scale height or width
            # cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
            bndbox.append(cur_pt)
        label_idx = 1
        bndbox.append(label_idx)
        res = np.vstack((res, bndbox))  # [xmin, ymin, xmax, ymax, label_ind]

    width = int(target.find("size").find("width").text)
    height = int(target.find("size").find("height").text)
    img_info = (height, width)
    print(img_info)

    return res

def showVocStyleData(imgDir, annDir):

    for ann in os.listdir(annDir):
        annPath = os.path.join(annDir, ann)
        imgPath = os.path.join(imgDir, ann[:-3]+"jpg")

        img = cv2.imread(imgPath)
        print(img.shape)

        res = parserXml(annPath)
        #print(res)
        for i in range(res.shape[0]):
            resi = res[i]
            print(resi)
            cv2.rectangle(img, (int(resi[0]), int(resi[1])), (int(resi[2]), int(resi[3])),(0,200,0),2)

        cv2.imshow('test', img)

        if cv2.waitKey(0) & 0xff == 27:
            cv2.destroyAllWindows()
            break


    return 

你可能感兴趣的:(python,cv,tools,python,图像处理,计算机视觉)