YOLOv2训练自己的数据集(识别海参)

检测或者训练有问题的可发邮件咨询小编,小编尽力解答[email protected]
好吧,正式开始跑程序(其实官网都有)
这篇文章是训练YOLO v2过程中的经验总结,我使用YOLO v2训练一组自己的数据,训练后的model,在阈值为.25的情况下,Recall值是95.54%,Precision 是97.27%。
需要注意的是,这一训练过程可能只对我自己的训练集有效,因为我是根据我这一训练集的特征来对YOLO代码进行修改,可能对你的数据集并不适用,所以仅供参考。
我的数据集
批量改名首先准备好自己的数据集,最好固定格式,此处以VOC为例,采用jpg格式的图像,在名字上最好使用像VOC一样类似000001.jpg、000002.jpg这样。可参照下面示例python代码

# -*- coding:utf8 -*-
#!/usr/bin/python2.7
import os

class BatchRename():
    '''
    批量重命名文件夹中的图片文件

    '''
    def __init__(self):
        self.path = '/home/xiaorun/label/haishen/'

    def rename(self):
        filelist = os.listdir(self.path)
        total_num = len(filelist)
        i = 1
	
        for item in filelist:
            if item.endswith('.jpg'):
                src = os.path.join(os.path.abspath(self.path), item)
		str1=str(i)
                dst = os.path.join(os.path.abspath(self.path), str1.zfill(6) + '.jpg')
                try:
                    os.rename(src, dst)
                    print 'converting %s to %s ...' % (src, dst)
                    i = i + 1
                except:
                    continue
        print 'total %d to rename & converted %d jpgs' % (total_num, i)

if __name__ == '__main__':
    demo = BatchRename()
    demo.rename()

读取某文件夹下的所有图像然后统一命名,用了opencv所以顺便还可以改格式。
准备好了自己的图像后,需要按VOC数据集的结构放置图像文件。VOC的结构如下

--VOC  
    --Annotations  (XML文件)
    --ImageSets    (txt文件)
      --Main    
    --JPEGImages    (图片)

这里面用到的文件夹是Annotation、ImageSets和JPEGImages。其中文件夹Annotation中主要存放xml文件,每一个xml对应一张图像,并且每个xml中存放的是标记的各个目标的位置和类别信息,命名通常与对应的原始图像一样;而ImageSets我们只需要用到Main文件夹,这里面存放的是一些文本文件,通常为train.txt、test.txt等,该文本文件里面的内容是需要用来训练或测试的图像的名字(无后缀无路径);JPEGImages文件夹中放我们已按统一规则命名好的原始图像。
因此,首先
1.新建文件夹VOC2007(通常命名为这个,也可以用其他命名,但一定是名字+年份,例如MYDATA2016,无论叫什么后面都需要改相关代码匹配这里,本例中以VOC2007为例)
2.在VOC2007文件夹下新建三个文件夹Annotation、ImageSets和JPEGImages,并把准备好的自己的原始图像放在JPEGImages文件夹下
3.在ImageSets文件夹中,新建三个空文件夹Layout、Main、Segmentation,然后把写了训练或测试的图像的名字的文本拷到Main文件夹下,按目的命名,我这里所有图像用来训练,故而Main文件夹下只有train.txt文件。上面说的小代码运行后会生成该文件,把它拷进去即可。
图像标注
2.标记图像目标区域
python代码直接运行之后开始标注, 因为做的是目标检测,所以接下来需要标记原始图像中的目标区域。相关方法和工具有很多,这里需用标注工具,相关用法也有说明,基本就是框住目标区域然后双击类别,标记完整张图像后点击保存即可。操作界面如下:
YOLOv2训练自己的数据集(识别海参)_第1张图片

# -*- coding:utf-8 -*-
# -------------------------------------------------------------------------------
# Name:        Object bounding box label tool
# Purpose:     Label object bboxes for ImageNet Detection data
# Author:      Qiushi
# Created:     06/06/2014

#
# -------------------------------------------------------------------------------
from __future__ import division
from Tkinter import *
import tkMessageBox
from PIL import Image, ImageTk
import os
import glob
import random

w0 = 1;  # 图片原始宽度
h0 = 1;  # 图片原始高度

# colors for the bboxes
COLORS = ['red', 'blue', 'yellow', 'pink', 'cyan', 'green', 'black']
# image sizes for the examples
SIZE = 256, 256

# 指定缩放后的图像大小
DEST_SIZE = 500, 500


class LabelTool():
    def __init__(self, master):
        # set up the main frame
        self.parent = master
        self.parent.title("LabelTool")
        self.frame = Frame(self.parent)
        self.frame.pack(fill=BOTH, expand=1)
        self.parent.resizable(width=TRUE, height=TRUE)

        # initialize global state
        self.imageDir = ''
        self.imageList = []
        self.egDir = ''
        self.egList = []
        self.outDir = ''
        self.cur = 0
        self.total = 0
        self.category = 0
        self.imagename = ''
        self.labelfilename = ''
        self.tkimg = None

        # initialize mouse state
        self.STATE = {}
        self.STATE['click'] = 0
        self.STATE['x'], self.STATE['y'] = 0, 0

        # reference to bbox
        self.bboxIdList = []
        self.bboxId = None
        self.bboxList = []
        self.hl = None
        self.vl = None

        # ----------------- GUI stuff ---------------------
        # dir entry & load
        self.label = Label(self.frame, text="Image Dir:")
        self.label.grid(row=0, column=0, sticky=E)
        self.entry = Entry(self.frame)
        self.entry.grid(row=0, column=1, sticky=W + E)
        self.ldBtn = Button(self.frame, text="Load", command=self.loadDir)
        self.ldBtn.grid(row=0, column=2, sticky=W + E)

        # main panel for labeling
        self.mainPanel = Canvas(self.frame, cursor='tcross')
        self.mainPanel.bind("", self.mouseClick)
        self.mainPanel.bind("", self.mouseMove)
        self.parent.bind("", self.cancelBBox)  # press  to cancel current bbox
        self.parent.bind("s", self.cancelBBox)
        self.parent.bind("a", self.prevImage)  # press 'a' to go backforward
        self.parent.bind("d", self.nextImage)  # press 'd' to go forward
        self.mainPanel.grid(row=1, column=1, rowspan=4, sticky=W + N)

        # showing bbox info & delete bbox
        self.lb1 = Label(self.frame, text='Bounding boxes:')
        self.lb1.grid(row=1, column=2, sticky=W + N)

        self.listbox = Listbox(self.frame, width=28, height=12)
        self.listbox.grid(row=2, column=2, sticky=N)

        self.btnDel = Button(self.frame, text='Delete', command=self.delBBox)
        self.btnDel.grid(row=3, column=2, sticky=W + E + N)
        self.btnClear = Button(self.frame, text='ClearAll', command=self.clearBBox)
        self.btnClear.grid(row=4, column=2, sticky=W + E + N)

        # control panel for image navigation
        self.ctrPanel = Frame(self.frame)
        self.ctrPanel.grid(row=5, column=1, columnspan=2, sticky=W + E)
        self.prevBtn = Button(self.ctrPanel, text='<< Prev', width=10, command=self.prevImage)
        self.prevBtn.pack(side=LEFT, padx=5, pady=3)
        self.nextBtn = Button(self.ctrPanel, text='Next >>', width=10, command=self.nextImage)
        self.nextBtn.pack(side=LEFT, padx=5, pady=3)
        self.progLabel = Label(self.ctrPanel, text="Progress:     /    ")
        self.progLabel.pack(side=LEFT, padx=5)
        self.tmpLabel = Label(self.ctrPanel, text="Go to Image No.")
        self.tmpLabel.pack(side=LEFT, padx=5)
        self.idxEntry = Entry(self.ctrPanel, width=5)
        self.idxEntry.pack(side=LEFT)
        self.goBtn = Button(self.ctrPanel, text='Go', command=self.gotoImage)
        self.goBtn.pack(side=LEFT)

        # example pannel for illustration
        self.egPanel = Frame(self.frame, border=10)
        self.egPanel.grid(row=1, column=0, rowspan=5, sticky=N)
        self.tmpLabel2 = Label(self.egPanel, text="Examples:")
        self.tmpLabel2.pack(side=TOP, pady=5)

        self.egLabels = []
        for i in range(3):
            self.egLabels.append(Label(self.egPanel))
            self.egLabels[-1].pack(side=TOP)

        # display mouse position
        self.disp = Label(self.ctrPanel, text='')
        self.disp.pack(side=RIGHT)

        self.frame.columnconfigure(1, weight=1)
        self.frame.rowconfigure(4, weight=1)

        # for debugging

    ##        self.setImage()
    ##        self.loadDir()


    def loadDir(self, dbg=False):
        if not dbg:
            s = self.entry.get()
            self.parent.focus()
            self.category = int(s)
        else:
            s = r'D:\workspace\python\labelGUI'
        ##        if not os.path.isdir(s):
        ##            tkMessageBox.showerror("Error!", message = "The specified dir doesn't exist!")
        ##            return
        # get image list

        print 'self.category =%d' % (self.category)

#        self.imageDir = os.path.join(r'./images', '%03d' % (self.category))
        a = '/home/xiaorun/label/JPEGImages/'
        b = '%06d' % (self.category)
        self.imageDir = a + b
        print(self.imageDir)

        self.imageList = glob.glob(self.imageDir+'*.jpg')
        if len(self.imageList) == 0:
            print 'No .jpg images found in the specified dir!'
            return
        else:
            print 'num=%d' % (len(self.imageList))

        # default to the 1st image in the collection
        self.cur = 1
        self.total = len(self.imageList)

        # set up output dir
        self.outDir = os.path.join(r'./labels', '%03d' % (self.category))
        if not os.path.exists(self.outDir):
            os.mkdir(self.outDir)

        # load example bboxes
        self.egDir = os.path.join(r'./Examples', '%06d' % (self.category))
        # if not os.path.exists(self.egDir):
        #   return

        filelist = glob.glob(os.path.join(self.egDir, '*.jpg'))
        self.tmp = []
        self.egList = []
        random.shuffle(filelist)
        for (i, f) in enumerate(filelist):
            if i == 3:
                break
            im = Image.open(f)
            r = min(SIZE[0] / im.size[0], SIZE[1] / im.size[1])
            new_size = int(r * im.size[0]), int(r * im.size[1])
            self.tmp.append(im.resize(new_size, Image.ANTIALIAS))
            self.egList.append(ImageTk.PhotoImage(self.tmp[-1]))
            self.egLabels[i].config(image=self.egList[-1], width=SIZE[0], height=SIZE[1])

        self.loadImage()
        print '%d images loaded from %s' % (self.total, s)

    def loadImage(self):
        # load image
        imagepath = self.imageList[self.cur - 1]
        pil_image = Image.open(imagepath)

        # get the size of the image
        # 获取图像的原始大小
        global w0, h0
        w0, h0 = pil_image.size

        # 缩放到指定大小
        pil_image = pil_image.resize((DEST_SIZE[0], DEST_SIZE[1]), Image.ANTIALIAS)

        # pil_image = imgresize(w, h, w_box, h_box, pil_image)
        self.img = pil_image

        self.tkimg = ImageTk.PhotoImage(pil_image)

        self.mainPanel.config(width=max(self.tkimg.width(), 400), height=max(self.tkimg.height(), 400))
        self.mainPanel.create_image(0, 0, image=self.tkimg, anchor=NW)
        self.progLabel.config(text="%04d/%04d" % (self.cur, self.total))

        # load labels
        self.clearBBox()
        self.imagename = os.path.split(imagepath)[-1].split('.')[0]
        labelname = self.imagename + '.txt'
        self.labelfilename = os.path.join(self.outDir, labelname)
        bbox_cnt = 0
        if os.path.exists(self.labelfilename):
            with open(self.labelfilename) as f:
                for (i, line) in enumerate(f):
                    if i == 0:
                        bbox_cnt = int(line.strip())
                        continue
                    print line
                    tmp = [(t.strip()) for t in line.split()]

                    print "********************"
                    print DEST_SIZE
                    # tmp = (0.1, 0.3, 0.5, 0.5)
                    print "tmp[0,1,2,3]===%.2f, %.2f, %.2f, %.2f" % (
                    float(tmp[0]), float(tmp[1]), float(tmp[2]), float(tmp[3]))
                    # print "%.2f,%.2f,%.2f,%.2f" %(tmp[0] tmp[1] tmp[2] tmp[3] )

                    print "********************"

                    # tx = (10, 20, 30, 40)
                    # self.bboxList.append(tuple(tx))
                    self.bboxList.append(tuple(tmp))
                    tmp[0] = float(tmp[0])
                    tmp[1] = float(tmp[1])
                    tmp[2] = float(tmp[2])
                    tmp[3] = float(tmp[3])

                    tx0 = int(tmp[0] * DEST_SIZE[0])
                    ty0 = int(tmp[1] * DEST_SIZE[1])

                    tx1 = int(tmp[2] * DEST_SIZE[0])
                    ty1 = int(tmp[3] * DEST_SIZE[1])
                    print "tx0, ty0, tx1, ty1"
                    print tx0, ty0, tx1, ty1

                    tmpId = self.mainPanel.create_rectangle(tx0, ty0, tx1, ty1, \
                                                            width=2, \
                                                            outline=COLORS[(len(self.bboxList) - 1) % len(COLORS)])

                    self.bboxIdList.append(tmpId)
                    self.listbox.insert(END, '(%.2f,%.2f)-(%.2f,%.2f)' % (tmp[0], tmp[1], tmp[2], tmp[3]))

                    # self.listbox.insert(END, '(%d, %d) -> (%d, %d)' %(tmp[0], tmp[1], tmp[2], tmp[3]))
                    self.listbox.itemconfig(len(self.bboxIdList) - 1,
                                            fg=COLORS[(len(self.bboxIdList) - 1) % len(COLORS)])

    def saveImage(self):
        # print "-----1--self.bboxList---------"
        print self.bboxList
        # print "-----2--self.bboxList---------"

        with open(self.labelfilename, 'w') as f:
            f.write('%d\n' % len(self.bboxList))
            for bbox in self.bboxList:
                f.write(' '.join(map(str, bbox)) + '\n')
        print 'Image No. %d saved' % (self.cur)

    def mouseClick(self, event):
        if self.STATE['click'] == 0:
            self.STATE['x'], self.STATE['y'] = event.x, event.y
        else:
            x1, x2 = min(self.STATE['x'], event.x), max(self.STATE['x'], event.x)
            y1, y2 = min(self.STATE['y'], event.y), max(self.STATE['y'], event.y)

            x1, x2 = x1 / DEST_SIZE[0], x2 / DEST_SIZE[0];
            y1, y2 = y1 / DEST_SIZE[1], y2 / DEST_SIZE[1];

            self.bboxList.append((x1, y1, x2, y2))
            self.bboxIdList.append(self.bboxId)
            self.bboxId = None
            self.listbox.insert(END, '(%.2f, %.2f)-(%.2f, %.2f)' % (x1, y1, x2, y2))
            self.listbox.itemconfig(len(self.bboxIdList) - 1, fg=COLORS[(len(self.bboxIdList) - 1) % len(COLORS)])
        self.STATE['click'] = 1 - self.STATE['click']

    def mouseMove(self, event):
        self.disp.config(text='x: %.2f, y: %.2f' % (event.x / DEST_SIZE[0], event.y / DEST_SIZE[1]))
        if self.tkimg:
            if self.hl:
                self.mainPanel.delete(self.hl)
            self.hl = self.mainPanel.create_line(0, event.y, self.tkimg.width(), event.y, width=2)
            if self.vl:
                self.mainPanel.delete(self.vl)
            self.vl = self.mainPanel.create_line(event.x, 0, event.x, self.tkimg.height(), width=2)
        if 1 == self.STATE['click']:
            if self.bboxId:
                self.mainPanel.delete(self.bboxId)
            self.bboxId = self.mainPanel.create_rectangle(self.STATE['x'], self.STATE['y'], \
                                                          event.x, event.y, \
                                                          width=2, \
                                                          outline=COLORS[len(self.bboxList) % len(COLORS)])

    def cancelBBox(self, event):
        if 1 == self.STATE['click']:
            if self.bboxId:
                self.mainPanel.delete(self.bboxId)
                self.bboxId = None
                self.STATE['click'] = 0

    def delBBox(self):
        sel = self.listbox.curselection()
        if len(sel) != 1:
            return
        idx = int(sel[0])
        self.mainPanel.delete(self.bboxIdList[idx])
        self.bboxIdList.pop(idx)
        self.bboxList.pop(idx)
        self.listbox.delete(idx)

    def clearBBox(self):
        for idx in range(len(self.bboxIdList)):
            self.mainPanel.delete(self.bboxIdList[idx])
        self.listbox.delete(0, len(self.bboxList))
        self.bboxIdList = []
        self.bboxList = []

    def prevImage(self, event=None):
        self.saveImage()
        if self.cur > 1:
            self.cur -= 1
            self.loadImage()

    def nextImage(self, event=None):
        self.saveImage()
        if self.cur < self.total:
            self.cur += 1
            self.loadImage()

    def gotoImage(self):
        idx = int(self.idxEntry.get())
        if 1 <= idx and idx <= self.total:
            self.saveImage()
            self.cur = idx
            self.loadImage()

            ##    def setImage(self, imagepath = r'test2.png'):
            ##        self.img = Image.open(imagepath)
            ##        self.tkimg = ImageTk.PhotoImage(self.img)
            ##        self.mainPanel.config(width = self.tkimg.width())
            ##        self.mainPanel.config(height = self.tkimg.height())
            ##        self.mainPanel.create_image(0, 0, image = self.tkimg, anchor=NW)

    def imgresize(w, h, w_box, h_box, pil_image):
        '''
        resize a pil_image object so it will fit into
        a box of size w_box times h_box, but retain aspect ratio
        '''
        f1 = 1.0 * w_box / w  # 1.0 forces float division in Python2
        f2 = 1.0 * h_box / h
        factor = min([f1, f2])
        # print(f1, f2, factor) # test
        # use best down-sizing filter
        width = int(w * factor)
        height = int(h * factor)
        return pil_image.resize((width, height), Image.ANTIALIAS)


if __name__ == '__main__':
    root = Tk()
    tool = LabelTool(root)
    root.mainloop()

接下来我们把标注文件改成.XML 文件才能训练,不多说直接上代码

import glob

s1="""    
        {0}
        Unspecified
        0
        0
        
            {1}
            {2}
            {3}
            {4}
        
    """

s2="""
    VOC2007
    {0}
    
        My Database
        VOC2007
        flickr
        NULL
    
    
        NULL
        J
    
    
        256
        256
        3
    
    0
    
        {1}
        Unspecified
        0
        0
        
            {2}
            {3}
            {4}
            {5}
        
    {6}

"""

alltext=glob.glob('labels/*')
print alltext
for textlist in alltext:
    textlist = glob.glob(textlist+'/*.txt') 
    for text_ in textlist:
        flabel = open(text_, 'r')
        lb = flabel.readlines()
        flabel.close()
        ob2 = ""
        if len(lb)<2:
            continue  # no annotation
        x1=2
        x2=lb[1].split(' ')
        x3 = [int(float(i) * 256) for i in x2]
        if len(lb)>2:  # extra annotation
            for i in range(2,len(lb)):
                y2 = lb[i].split(' ')
                y3 = [int(float(i) * 256) for i in y2]
                ob2+='\n' + s1.format(x1,y3[0],y3[1],y3[2],y3[3])
        imgname=('%06d' % (int(float(text_[13:-4]))))+'.jpg'
        savename='Annotations\\'+str('%06d' % (int(text_[13:-4])))+'.xml'
        f = open(savename, 'w')
        ob1=s2.format(imgname, x1, x3[0],x3[1],x3[2],x3[3],  ob2)
        f.write(ob1)
        f.close()

YOLOv2训练自己的数据集(识别海参)_第2张图片
每个xml文件是这样的画风是这样的



	JPEGImages
	00000
	/home/kinglch/VOC2007/JPEGImages/00000.jpg
	
		Unknown
	
	
		704
		576
		3
	
	0
	
		person
		Unspecified
		0
		0
		
			73
			139
			142
			247
		
	
	
		person
		Unspecified
		0
		0
		
			180
			65
			209
			151
		
	
	
		person
		Unspecified
		0
		0
		
			152
			70
			181
			144
		
	

最后我们生成txt文件,Python脚本

import os
import random

trainval_percent = 0.66
train_percent = 0.95
xmlfilepath = 'Annotations'
txtsavepath = 'ImageSets\Main'
total_xml = os.listdir(xmlfilepath)

num=len(total_xml)
list=range(num)
tv=int(num*trainval_percent)
tr=int(tv*train_percent)
trainval= random.sample(list,tv)
train=random.sample(trainval,tr)

ftrainval = open('ImageSets/Main/trainval.txt', 'w')
ftest = open('ImageSets/Main/test.txt', 'w')
ftrain = open('ImageSets/Main/train.txt', 'w')
fval = open('ImageSets/Main/val.txt', 'w')

for i  in list:
    name=total_xml[i][:-4]+'\n'
    if i in trainval:
        ftrainval.write(name)
        if i in train:
            ftrain.write(name)
        else:
            fval.write(name)
    else:
        ftest.write(name)

ftrainval.close()
ftrain.close()
fval.close()
ftest .close()

2.用YOLOv2训练
按darknet的说明编译好后,接下来在darknet-master/scripts文件夹中新建文件夹VOCdevkit,然后将整个VOC2007文件夹都拷到VOCdevkit文件夹下。
然后,需要利用scripts文件夹中的voc_label.py文件生成一系列训练文件和label,具体操作如下:
首先需要修改voc_label.py中的代码,这里主要修改数据集名,以及类别信息,我的是VOC2007,并且所有样本用来训练,没有val或test,并且只检测海参,故只有一类目标,因此按如下设置

import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join

#sets=[('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test')]

#classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]

sets=[('2007', 'train')]
classes = [ "haishen"]


def convert(size, box):
    dw = 1./size[0]
    dh = 1./size[1]
    x = (box[0] + box[1])/2.0
    y = (box[2] + box[3])/2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x*dw
    w = w*dw
    y = y*dh
    h = h*dh
    return (x,y,w,h)

def convert_annotation(year, image_id):
    in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id))  #(如果使用的不是VOC而是自设置数据集名字,则这里需要修改)
    out_file = open('VOCdevkit/VOC%s/labels/%s.txt'%(year, image_id), 'w')  #(同上)
    tree=ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
        bb = convert((w,h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')

wd = getcwd()

for year, image_set in sets:
    if not os.path.exists('VOCdevkit/VOC%s/labels/'%(year)):
        os.makedirs('VOCdevkit/VOC%s/labels/'%(year))
    image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
    list_file = open('%s_%s.txt'%(year, image_set), 'w')
    for image_id in image_ids:
        list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg\n'%(wd, year, image_id))
        convert_annotation(year, image_id)
    list_file.close()

修改好后在该目录下运行命令:python voc_label.py,之后则在文件夹scripts\VOCdevkit\VOC2007下生成了文件夹lable,该文件夹下的画风是这样的
YOLOv2训练自己的数据集(识别海参)_第3张图片

这里包含了类别和对应归一化后的位置(i guess,如有错请指正)。同时在scripts\下应该也生成了train_2007.txt这个文件,里面包含了所有训练样本的绝对路径。
2.配置文件修改
做好了上述准备,就可以根据不同的网络设置(cfg文件)来训练了。在文件夹cfg中有很多cfg文件,应该跟caffe中的prototxt文件是一个意思。这里以tiny-yolo-voc.cfg为例,该网络是yolo-voc的简版,相对速度会快些。主要修改参数如下

.  
.  
.  
[convolutional]  
size=1  
stride=1  
pad=1  
filters=30  //修改最后一层卷积层核参数个数,计算公式是依旧自己数据的类别数filter=num×(classes + coords + 1)=5×(1+4+1)=30  
activation=linear  
  
[region]  
anchors = 1.08,1.19,  3.42,4.41,  6.63,11.38,  9.42,5.11,  16.62,10.52  
bias_match=1  
classes=1  //类别数,本例为1类  
coords=4  
num=5  
softmax=1  
jitter=.2  
rescore=1  
  
object_scale=5  
noobject_scale=1  
class_scale=1  
coord_scale=1  
  
absolute=1  
thresh = .6  
random=1 

另外也可根据需要修改learning_rate、max_batches等参数。这里歪个楼吐槽一下其他网络配置,一开始是想用tiny.cfg来训练的官网作者说它够小也够快,但是它的网络配置最后几层是这样的画风:

[convolutional]  
filters=1000  
size=1  
stride=1  
pad=1  
activation=linear  
  
[avgpool]  
  
[softmax]  
groups=1  
  
[cost]  
type=sse

修改好了cfg文件之后,就需要修改两个文件,首先是data文件下的voc.names。打开voc.names文件可以看到有20类的名称,本例中只有一类,检测海参,因此将原来所有内容清空,仅写上person并保存。名字仍然用这个名字,如果喜欢用其他名字则请按一开始制作自己数据集的时候的名字来修改。

  接着需要修改cfg文件夹中的voc.data文件。也是按自己需求修改,我的修改之后是这样的画风:
    classes= 1  //类别数  
    train  = /home/xiao_run/darknet-master/scripts/2007_train.txt  //训练样本的绝对路径文件,也就是上文2.1中最后生成的  
    //valid  = /home/pjreddie/data/voc/2007_test.txt  //本例未用到  
    names = data/voc.names  //上一步修改的voc.names文件  
    backup = /home/xiao_run/darknet-master/results/  //指示训练后生成的权重放在哪  

修改后按原名保存最好,接下来就可以训练了。
3.运行训练

  上面完成了就可以命令训练了,可以在官网上找到一些预训练的模型作为参数初始值,也可以直接训练,训练命令为
$./darknet detector train ./cfg/voc.data cfg/tiny-yolo-voc.cfg 

如果用官网的预训练模型darknet.conv.weights做初始化,则训练命令为

$./darknet detector train ./cfg/voc.data .cfg/tiny-yolo-voc.cfg darknet.conv.weights  

训练过程中会根据迭代次数保存训练的权重模型,然后就可以拿来测试了,测试的命令同理:

./darknet detector test cfg/voc.data cfg/tiny-yolo-voc.cfg results/tiny-yolo-voc_6000.weights data/images.jpg

你可能感兴趣的:(python-opencv,图像处理)