YOLOX官方链接
#服务器--pytorch环境
git clone [email protected]:Megvii-BaseDetection/YOLOX.git
cd YOLOX
pip3 install -v -e . # or python3 setup.py develop
windows 参考环境配置和修改代码
conda create -n yolox python=3.7 #用python 3.8也可以
conda activate yolox
`#如果你切换了国内的源可以把后面的-c pytorch去掉。
conda install pytorch=1.7 torchvision cudatoolkit=10.2 -c pytorch``
git clone git@github.com:Megvii-BaseDetection/YOLOX.git
cd YOLOX
pip install -r requirements.txt
python setup.py develop
安装完成, 下载pretrained model yolox_s测试
python tools/demo.py image -f exps/default/yolox_s.py -c yolox_s.pth --path assets/dog.jpg --conf 0.25 --nms 0.45 --tsize 640 --save_result --device gpu
voc dataset
download voc dataset
作为初步验证, 仅使用了VOC2007的部分数据, 由于后面想做单类别的检测, 仅使用了voc中的car作为验证。
8张测试数据 提取码:wxyh
使用Python把原始数据中的car类别挑选出来
import xml.etree.ElementTree as ET
import os
def newImageSets(oldSets, newSets):
#保存含有car的文件名
savelist = []
with open(oldSets, 'r') as f:
for line in f.readlines():
ids = int(line)
path_i = 'Annotations/%06d.xml'%ids
if os.path.exists(path_i):
print(path_i)
savelist.append(line)
with open(newSets, 'a') as f1:
for id in savelist:
f1.write(id)
return
def selectCarAnn(srcAnnPath, dstAnnPath):
srcPath = os.path.join(srcAnnPath, "%06d.xml")
dstPath = os.path.join(dstAnnPath, "%06d.xml")
count = 0
#遍历所有标签文件
for id in range(1,9964):
_path = srcPath % id
rootTree = ET.parse(_path)
target = rootTree.getroot()
#判断此标签文件中, 是否有car
carFlag = False
for obj in target.iter("object"):
name = obj.find("name").text.strip()
if name == 'car':
carFlag = True
#print("name: ",_path,"--", name)
#如果有car, remove所有非car的标注box
if carFlag:
count += 1
#print(count)
#保存需要remove的非car物体
rm_list = []
for obj in target.iter("object"):
name = obj.find("name").text.strip()
if name != 'car':
rm_list.append(obj)
for o in rm_list:
target.remove(o)
rootTree.write(dstPath%id)
print(count)
return
def main():
selectCarAnn("Annotations", "Annotations_new")
newImageSets("ImageSets/Main/test.txt", "ImageSets/Main/test_new.txt")
return
if __name__ == '__main__':
main()
Hand dataset
download Hand dataset
此数据集是用matlab打包的, 这里使用python解析.mat数据文件
如果想查看原始数据集的标注情况, 使用以下Python代码
import scipy.io as scio
import cv2
import random
import colorsys
import os
def loadbox(data):
out = []
for box in data['boxes'][0]:
p0 = box[0][0][0]
p1 = box[0][0][1]
p2 = box[0][0][2]
p3 = box[0][0][3]
res = []
res.append(p0[0])
res.append(p1[0])
res.append(p2[0])
res.append(p3[0])
out.append(res)
return out
def get_n_hls_colors(num):
hls_colors = []
i = 0
step = 360.0 / num
while i < 360:
h = i
s = 90 + random.random() * 10
l = 50 + random.random() * 10
_hlsc = [h / 360.0, l / 100.0, s / 100.0]
hls_colors.append(_hlsc)
i += step
return hls_colors
def ncolors(num):
rgb_colors = []
if num < 1:
return rgb_colors
hls_colors = get_n_hls_colors(num)
for hlsc in hls_colors:
_r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2])
r, g, b = [int(x * 255.0) for x in (_r, _g, _b)]
rgb_colors.append([r, g, b])
return rgb_colors
def drawImg(boxes, img):
print(img.shape)
n = len(boxes)
colors = ncolors(n)
for bb in boxes:
n -= 1
color = (colors[n][2], colors[n][1], colors[n][0])
pt0 = (int(bb[0][1]), int(bb[0][0]))
pt1 = (int(bb[1][1]), int(bb[1][0]))
pt2 = (int(bb[2][1]), int(bb[2][0]))
pt3 = (int(bb[3][1]), int(bb[3][0]))
min_x = min(int(bb[0][1]), int(bb[1][1]), int(bb[2][1]), int(bb[3][1]))
max_x = max(int(bb[0][1]), int(bb[1][1]), int(bb[2][1]), int(bb[3][1]))
min_y = min(int(bb[0][0]), int(bb[1][0]), int(bb[2][0]), int(bb[3][0]))
max_y = max(int(bb[0][0]), int(bb[1][0]), int(bb[2][0]), int(bb[3][0]))
# cv2.line(img,pt0, pt1, color,3 )
# cv2.line(img,pt0, pt3, color,3 )
# cv2.line(img,pt1, pt2, color,3 )
# cv2.line(img,pt2, pt3, color,3 )
cv2.rectangle(img, (min_x, min_y), (max_x, max_y),color, 1)
return 0
def process_single_pair(imgPath, matPath):
img = cv2.imread(imgPath)
data = scio.loadmat(matPath)
drawImg(loadbox(data), img)
cv2.imshow("tt", img)
return 0
def vis_folder(imgFolder, matFolder):
files = os.listdir(imgFolder)
for f in files:
if f.endswith(".jpg"):
imgPath = os.path.join(imgFolder, f)
print(imgPath)
matPath = os.path.join(matFolder, f[:-3] + "mat")
if os.path.exists(imgPath) and os.path.exists(matPath) :
process_single_pair(imgPath, matPath)
if cv2.waitKey(0) == ord('q'):
cv2.destroyAllWindows()
break
return 0
if __name__ == '__main__':
vis_folder("test_dataset/test_data/images/","test_dataset/test_data/annotations/")
hand dataset 转换为 voc格式的代码
import copy
import scipy.io as scio
import cv2
import os
import xml.etree.ElementTree as ET
class HandData2VOC():
def __init__(self):
self.srcHandAnnotationsDir = "hand_dataset/annotations/"
self.dstHandAnnotationsDir = "voc_style/Annotations/"
self.imageDir = 'voc_style/JPEGImages/'
self.imageSets = 'voc_style/ImageSets/'
self.xmlAnnotation = "000004.xml"
def test(self):
return
def runTransfor(self, matName):
rootTree = ET.parse(self.xmlAnnotation)
target = rootTree.getroot()
saveDir = os.path.join(self.dstHandAnnotationsDir, matName[:-3]+"xml")
#print("==============1. saveDir: ", saveDir)
### change folder name && filename
nodeFolder = target.find('folder')
nodeFolder.text = "Hand_data"
imageName = matName[:-3] + 'jpg'
nodeFileName = target.find("filename")
nodeFileName.text = imageName
########################Get Image##############################
imgDir = os.path.join(self.imageDir, imageName)
#print("==============2. imgDir: ", imgDir)
img = cv2.imread(imgDir)
h,w,c = img.shape
nodeSize = target.find('size')
nodewidth = nodeSize.find('width')
nodewidth.text = str(w)
nodeheight = nodeSize.find("height")
nodeheight.text = str(h)
nodec = nodeSize.find("depth")
nodec.text = str(c)
objNodei = target.find("object")
objNodei.find("name").text = "hand"
#########################Get Boxes###############################
matDir = os.path.join(self.srcHandAnnotationsDir, matName)
data = scio.loadmat(matDir)
#print("==============3. matDir: ", matDir)
boxes_data = data['boxes'][0]
numBox = len(boxes_data)
print("==============4. numBox: ", numBox)
if numBox == 0:
print('----------------------->: ', matName)
return 0
if numBox > 0:
#print("")
box = boxes_data[0]
p0 = box[0][0][0][0]
p1 = box[0][0][1][0]
p2 = box[0][0][2][0]
p3 = box[0][0][3][0]
min_x = min(int(p0[1]), int(p1[1]), int(p2[1]), int(p3[1]))
max_x = max(int(p0[1]), int(p1[1]), int(p2[1]), int(p3[1]))
min_y = min(int(p0[0]), int(p1[0]), int(p2[0]), int(p3[0]))
max_y = max(int(p0[0]), int(p1[0]), int(p2[0]), int(p3[0]))
#objNodei = target.find("object")
nodeBndBox = objNodei.find("bndbox")
nodeBndBox.find("xmin").text = str(min_x)
nodeBndBox.find("ymin").text = str(min_y)
nodeBndBox.find("xmax").text = str(max_x)
nodeBndBox.find("ymax").text = str(max_y)
#return 1
for bb in data['boxes'][0][1:]:
nodeObjCp = copy.deepcopy(objNodei)
p_0 = bb[0][0][0][0]
p_1 = bb[0][0][1][0]
p_2 = bb[0][0][2][0]
p_3 = bb[0][0][3][0]
xmin = min(int(p_0[1]), int(p_1[1]), int(p_2[1]), int(p_3[1]))
xmax = max(int(p_0[1]), int(p_1[1]), int(p_2[1]), int(p_3[1]))
ymin = min(int(p_0[0]), int(p_1[0]), int(p_2[0]), int(p_3[0]))
ymax = max(int(p_0[0]), int(p_1[0]), int(p_2[0]), int(p_3[0]))
nodeBndBoxi = nodeObjCp.find("bndbox")
nodeBndBoxi.find("xmin").text = str(xmin)
nodeBndBoxi.find("ymin").text = str(ymin)
nodeBndBoxi.find("xmax").text = str(xmax)
nodeBndBoxi.find("ymax").text = str(ymax)
target.append(nodeObjCp)
print(xmin," ", ymin," ", xmax," ", ymax)
rootTree.write(saveDir)
print("===========save: ", saveDir)
return 1
def run(self):
annotations = os.listdir(self.srcHandAnnotationsDir)
#imageSet = "[]"
setDir = os.path.join(self.imageSets, "test0.txt")
with open(setDir, 'w') as file:
for mat in annotations:
#print(mat)
if self.runTransfor(mat):
file.write(mat[:-4]+"\n")
return 0
完整数据和代码 提取码:1fut
数据处理结束, 检查一下处理的效果如何:
import numpy as np
def parserXml(xmlPath):
target = ET.parse(xmlPath).getroot()
res = np.empty((0, 5))
for obj in target.iter("object"):
difficult = obj.find("difficult")
if difficult is not None:
difficult = int(difficult.text) == 1
else:
difficult = False
if not 1 and difficult:
continue
name = obj.find("name").text.strip()
bbox = obj.find("bndbox")
pts = ["xmin", "ymin", "xmax", "ymax"]
bndbox = []
for i, pt in enumerate(pts):
cur_pt = int(float(bbox.find(pt).text)) - 1
# scale height or width
# cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
bndbox.append(cur_pt)
label_idx = 1
bndbox.append(label_idx)
res = np.vstack((res, bndbox)) # [xmin, ymin, xmax, ymax, label_ind]
width = int(target.find("size").find("width").text)
height = int(target.find("size").find("height").text)
img_info = (height, width)
print(img_info)
return res
def showVocStyleData(imgDir, annDir):
for ann in os.listdir(annDir):
annPath = os.path.join(annDir, ann)
imgPath = os.path.join(imgDir, ann[:-3]+"jpg")
img = cv2.imread(imgPath)
print(img.shape)
res = parserXml(annPath)
#print(res)
for i in range(res.shape[0]):
resi = res[i]
print(resi)
cv2.rectangle(img, (int(resi[0]), int(resi[1])), (int(resi[2]), int(resi[3])),(0,200,0),2)
cv2.imshow('test', img)
if cv2.waitKey(0) & 0xff == 27:
cv2.destroyAllWindows()
break
return