使用jwyang的pytorch框架代码,由于主分支支持的pytorch版本为0.4,年代久远,可能出现许多问题,这边选择pytorch-1.0分支。
创建文件夹data和子文件夹pretrained_model、VOCdevkit2007,用于存放预训练模型和数据集
cd faster-rcnn.pytorch-pytorch-1.0 && mkdir data
cd data && mkdir pretrained_model VOCdevkit2007
下载官方在caffe上的训练的预训练模型,VGG和ResNet选一即可,本文以ResNet为例,VT Server服务商国内可以裸连下载。下载完成后,将该模型放置到./data/pretrained_model
下载VOC数据集,解压完后放置到./data/VOCdevkit2007
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar
tar xvf VOCtrainval_06-Nov-2007.tar
tar xvf VOCtest_06-Nov-2007.tar
tar xvf VOCdevkit_08-Jun-2007.tar
创建虚拟环境,完成环境配置,pytorch 版本1.0.0,torchvision版本0.2.1,scipy版本1.2.1,其余支持按requirements.txt要求默认安装即可,pip库时建议使用豆瓣源。
conda create -n fr python=3.6
conda activate fr
pip install pytorch==1.0 torchvision==0.2.1 scipy==1.2.1 -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
使用最新的官方coco api替换原代码的coco api
git clone https://github.com/pdollar/coco.git
cd cocoapi/PythonAPI
make
make时注意在创建的fr环境下进行,否则可能导致失败,make完成后,用~/cocoapi/PythonAPI/pycocotools 替换~/faster-rcnn.pytorch-pytorch-1.0/lib/pycocotools
cd ~/faster-rcnn.pytorch-pytorch-1.0/lib
python setup.py build develop
进行训练,使用GPU0,网络采用ResNet101,batch size设置为4(8G显存下可行),线程设置为2,运行100 epochs
CUDA_VISIBLE_DEVICES=0 python trainval_net.py --dataset pascal_voc --net res101 --bs 4 --nw 2 --epochs 100 --cuda
使用GPU0,网络采用ResNet101,载入faster_rcnn_1_20_470.pth模型
CUDA_VISIBLE_DEVICES=0 python test_net.py --dataset pascal_voc --net res101 --checksession 1 --checkepoch 20 --checkpoint 470 --cuda
在./images存放需要画框的图片,输入指令运行,参数与上类似
CUDA_VISIBLE_DEVICES=0 python demo.py --net res101 \
--checksession 1 --checkepoch 100 --checkpoint 470 \
--cuda --load_dir ./models
所要做的就是将自己的数据集替换VOC2007,由于我的数据集是yolo格式,因此要先进行转换
from xml.dom.minidom import Document
import os
import cv2
from tqdm import tqdm
# def makexml(txtPath, xmlPath, picPath): # txt所在文件夹路径,xml文件保存路径,图片所在文件夹路径
def makexml(picPath, txtPath, xmlPath): # txt所在文件夹路径,xml文件保存路径,图片所在文件夹路径
"""此函数用于将yolo格式txt标注文件转换为voc格式xml标注文件
"""
dic = {'0': "holothurian", # 创建字典用来对类型进行转换
'1': "echinus",
'2': 'scallop',
'3': 'starfish',
'4': 'waterweeds'
}
files = os.listdir(txtPath)
for i, name in tqdm(enumerate(files)):
xmlBuilder = Document()
annotation = xmlBuilder.createElement("annotation") # 创建annotation标签
xmlBuilder.appendChild(annotation)
txtFile = open(txtPath + name)
txtList = txtFile.readlines()
img = cv2.imread(picPath + name.split('.')[0] + ".jpg")
Pheight, Pwidth, Pdepth = img.shape
folder = xmlBuilder.createElement("folder") # folder标签
foldercontent = xmlBuilder.createTextNode("driving_annotation_dataset")
folder.appendChild(foldercontent)
annotation.appendChild(folder) # folder标签结束
filename = xmlBuilder.createElement("filename") # filename标签
filenamecontent = xmlBuilder.createTextNode(name[0:-4] + ".jpg")
filename.appendChild(filenamecontent)
annotation.appendChild(filename) # filename标签结束
size = xmlBuilder.createElement("size") # size标签
width = xmlBuilder.createElement("width") # size子标签width
widthcontent = xmlBuilder.createTextNode(str(Pwidth))
width.appendChild(widthcontent)
size.appendChild(width) # size子标签width结束
height = xmlBuilder.createElement("height") # size子标签height
heightcontent = xmlBuilder.createTextNode(str(Pheight))
height.appendChild(heightcontent)
size.appendChild(height) # size子标签height结束
depth = xmlBuilder.createElement("depth") # size子标签depth
depthcontent = xmlBuilder.createTextNode(str(Pdepth))
depth.appendChild(depthcontent)
size.appendChild(depth) # size子标签depth结束
annotation.appendChild(size) # size标签结束
for j in txtList:
oneline = j.strip().split(" ")
object = xmlBuilder.createElement("object") # object 标签
picname = xmlBuilder.createElement("name") # name标签
namecontent = xmlBuilder.createTextNode(dic[oneline[0]])
picname.appendChild(namecontent)
object.appendChild(picname) # name标签结束
pose = xmlBuilder.createElement("pose") # pose标签
posecontent = xmlBuilder.createTextNode("Unspecified")
pose.appendChild(posecontent)
object.appendChild(pose) # pose标签结束
truncated = xmlBuilder.createElement("truncated") # truncated标签
truncatedContent = xmlBuilder.createTextNode("0")
truncated.appendChild(truncatedContent)
object.appendChild(truncated) # truncated标签结束
difficult = xmlBuilder.createElement("difficult") # difficult标签
difficultcontent = xmlBuilder.createTextNode("0")
difficult.appendChild(difficultcontent)
object.appendChild(difficult) # difficult标签结束
bndbox = xmlBuilder.createElement("bndbox") # bndbox标签
xmin = xmlBuilder.createElement("xmin") # xmin标签
mathData = int(((float(oneline[1])) * Pwidth + 1) - (float(oneline[3])) * 0.5 * Pwidth)
xminContent = xmlBuilder.createTextNode(str(mathData))
xmin.appendChild(xminContent)
bndbox.appendChild(xmin) # xmin标签结束
ymin = xmlBuilder.createElement("ymin") # ymin标签
mathData = int(((float(oneline[2])) * Pheight + 1) - (float(oneline[4])) * 0.5 * Pheight)
yminContent = xmlBuilder.createTextNode(str(mathData))
ymin.appendChild(yminContent)
bndbox.appendChild(ymin) # ymin标签结束
xmax = xmlBuilder.createElement("xmax") # xmax标签
mathData = int(((float(oneline[1])) * Pwidth + 1) + (float(oneline[3])) * 0.5 * Pwidth)
xmaxContent = xmlBuilder.createTextNode(str(mathData))
xmax.appendChild(xmaxContent)
bndbox.appendChild(xmax) # xmax标签结束
ymax = xmlBuilder.createElement("ymax") # ymax标签
mathData = int(((float(oneline[2])) * Pheight + 1) + (float(oneline[4])) * 0.5 * Pheight)
ymaxContent = xmlBuilder.createTextNode(str(mathData))
ymax.appendChild(ymaxContent)
bndbox.appendChild(ymax) # ymax标签结束
object.appendChild(bndbox) # bndbox标签结束
annotation.appendChild(object) # object标签结束
f = open(xmlPath + name[0:-4] + ".xml", 'w')
xmlBuilder.writexml(f, indent='\t', newl='\n', addindent='\t', encoding='utf-8')
f.close()
if __name__ == "__main__":
picPath = "./pic/train/" # 图片所在文件夹路径,后面的/一定要带上
txtPath = "./labels/train/" # txt所在文件夹路径,后面的/一定要带上
xmlPath = "./xml/train/" # xml文件保存路径,后面的/一定要带上
#运行前请先修改makexml函数中dic
makexml(picPath, txtPath, xmlPath)
import glob
#图片的地址
train_image_path = r"./pic/train/"
trainvalid_image_path = r"./pic/test/"
test_image_path = r"./pic/test/"
#生成的txt的路径
txt_path = r"./txt/"
def generate_train_and_val(image_path, txt_file):
with open(txt_file, 'w') as tf:
for jpg_file in glob.glob(image_path + '*.jpg'):
jpg_file = jpg_file.split('/')[-1].split('.')[0]
tf.write(jpg_file + '\n')
generate_train_and_val(train_image_path, txt_path + 'train.txt')
generate_train_and_val(trainvalid_image_path, txt_path + 'trainvalid.txt')
generate_train_and_val(test_image_path, txt_path + 'test.txt')
VOC2007数据集格式
–VOC2007
-------------Annotations
--------------ImagesSet
-------------------------Main
----------------------------trainval.txt
----------------------------train.txt
----------------------------test.txt
-------------JPEGImages
Annotations存放所有的XML文件,JPEGImages存放所有图片,./ImagesSet/Main存放索引TXT文件,将自己的数据集按此替换即可
# before
self._classes = ('__background__', # always index 0
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
#after
self._classes = ('__background__', # always index 0
'holothurian',
'echinus',
'scallop',
'starfish',
'waterweeds')
# before
cls = self._class_to_ind[obj.find('name').text.lower().strip()]
# after
cls = self._class_to_ind[obj.find('name').text.strip()]
pascal_classes
的类别进行相应更改#before
def append_flipped_images(self):
num_images = self.num_images
widths = self._get_widths()
for i in range(num_images):
boxes = self.roidb[i]['boxes'].copy()
oldx1 = boxes[:, 0].copy()
oldx2 = boxes[:, 2].copy()
boxes[:, 0] = widths[i] - oldx2 - 1
boxes[:, 2] = widths[i] - oldx1 - 1
assert (boxes[:, 2] >= boxes[:, 0]).all()
entry = {'boxes': boxes,
'gt_overlaps': self.roidb[i]['gt_overlaps'],
'gt_classes': self.roidb[i]['gt_classes'],
'flipped': True}
self.roidb.append(entry)
self._image_index = self._image_index * 2
#after
def append_flipped_images(self):
num_images = self.num_images
widths = self._get_widths()
for i in range(num_images):
boxes = self.roidb[i]['boxes'].copy()
oldx1 = boxes[:, 0].copy()
oldx2 = boxes[:, 2].copy()
#---------add--------------
for b in range(len(boxes)):
if boxes[b][2]< boxes[b][0]:
boxes[b][0] = 0
#---------end--------------
boxes[:, 0] = widths[i] - oldx2 - 1
boxes[:, 2] = widths[i] - oldx1 - 1
assert (boxes[:, 2] >= boxes[:, 0]).all()
entry = {'boxes': boxes,
'gt_overlaps': self.roidb[i]['gt_overlaps'],
'gt_classes': self.roidb[i]['gt_classes'],
'flipped': True}
self.roidb.append(entry)
self._image_index = self._image_index * 2
#before
for ix, obj in enumerate(objs):
bbox = obj.find('bndbox')
# Make pixel indexes 0-based
x1 = float(bbox.find('xmin').text) - 1
y1 = float(bbox.find('ymin').text) - 1
x2 = float(bbox.find('xmax').text) - 1
y2 = float(bbox.find('ymax').text) - 1
diffc = obj.find('difficult')
difficult = 0 if diffc == None else int(diffc.text)
ishards[ix] = difficult
#after
for ix, obj in enumerate(objs):
bbox = obj.find('bndbox')
# Make pixel indexes 0-based
#------------change------------------
x1 = float(bbox.find('xmin').text)
y1 = float(bbox.find('ymin').text)
x2 = float(bbox.find('xmax').text)
y2 = float(bbox.find('ymax').text)
cls = self._class_to_ind[obj.find('name').text.lower().strip()]
#--------------end-------------------
diffc = obj.find('difficult')
difficult = 0 if diffc == None else int(diffc.text)
ishards[ix] = difficult
最后请检查是否删除了./data/cache,然后开跑。