目录
前言
一、Opencv采集数字图像
二、标记图像数字位置
三、yolov4-tiny机器学习训练
四、jetson nano识别数字
啊!四天三夜的电赛终于结束了,我们组做了两辆送药小车,先上作品图!
因为时间来不及,只有一辆车结构比较好,另一辆就有点随意了。。。。但是还好,都能完成任务,处理器现在我来分享一下完成任务的过程,训练识别这里可以参考我之前写的第七届全国大学生工程训练大赛智能+赛道生活垃圾分类垃圾训练步骤(win10+yolov4-tiny)
我用Python编写了一个OpenCV拍照脚本,代码如下:
import cv2
import os
print("=============================================")
print("= 热键(请在摄像头的窗口使用): =")
print("= z: 更改存储目录 =")
print("= x: 拍摄图片 =")
print("= q: 退出 =")
print("=============================================")
print()
class_name = input("请输入存储目录:")
while os.path.exists(class_name):
class_name = input("目录已存在!请输入存储目录:")
os.mkdir(class_name)
index = 1
cap = cv2.VideoCapture(0)
width = 640
height = 480
w = 360
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
crop_w_start = (width-w)//2
crop_h_start = (height-w)//2
print(width, height)
while True:
# get a frame
ret, frame = cap.read()
# show a frame
#frame = frame[crop_h_start:crop_h_start+w, 400]
#frame = cv2.flip(frame,1,dst=None)
cv2.imshow("capture", frame)
input = cv2.waitKey(1) & 0xFF
if input == ord('z'):
class_name = input("请输入存储目录:")
while os.path.exists(class_name):
class_name = input("目录已存在!请输入存储目录:")
os.mkdir(class_name)
elif input == ord('x'):
cv2.imwrite("%s/%d.jpg" % (class_name, index),
cv2.resize(frame, (224, 224), interpolation=cv2.INTER_AREA))
print("%s: %d 张图片" % (class_name, index))
index += 1
if input == ord('q'):
break
首先将USB摄像头固定在车上之后将车放在场地上进行拍照,每一个数字拍摄100-200张224*224的图像,然后保存在文件夹中。
使用标记软件labelImg进行,生成就能保存数字在图像中的坐标和对象名称的xml标签文件:
训练框架采用轻量的yolov4-tiny,使用win10电脑进行训练,将OpenCV摄像头采集到的数字图像新建文件夹JPEGImages保存,然后将标记好数字的xml标签文件新建一个Annotations文件夹保存。
然后编写Python脚本txt.py随机选择训练样本图片和对照样本图片,代码如下:
import os
import random
trainval_percent = 0.5
train_percent = 0.5
xmlfilepath = 'VOCdevkit/VOC2021/Annotations'
txtsavepath = 'VOCdevkit/VOC2021/ImageSets/Main'
total_xml = os.listdir(xmlfilepath)
num = len(total_xml)
list = range(num)
tv = int(num * trainval_percent)
ptr = int(tv * train_percent)
trainval = random.sample(list, tv)
train = random.sample(trainval, ptr)
ftrainval = open(txtsavepath + '/trainval.txt', 'w')
ftest = open(txtsavepath + '/test.txt', 'w')
ftrain = open(txtsavepath + '/train.txt', 'w')
fval = open(txtsavepath + '/val.txt', 'w')
for i in list:
name = total_xml[i][:-4] + '\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
然后编写Python脚本voc_label.py将标签文件中的名字和坐标位置提取出来生成一个txt文件,代码如下:
import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
sets=[('2021', 'train'), ('2021', 'val'), ('2021', 'test')]
classes = ["one","two","three","four","five","six","seven","eight"]
def convert(size, box):
dw = 1./(size[0])
dh = 1./(size[1])
x = (box[0] + box[1])/2.0 - 1
y = (box[2] + box[3])/2.0 - 1
w = box[1] - box[0]
h = box[3] - box[2]
x = x*dw
w = w*dw
y = y*dh
h = h*dh
return (x,y,w,h)
def convert_annotation(year, image_id):
in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id))
out_file = open('VOCdevkit/VOC%s/labels/%s.txt'%(year, image_id), 'w')
tree=ET.parse(in_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult)==1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
bb = convert((w,h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb])+'\n')
wd = getcwd()
for year, image_set in sets:
if not os.path.exists('VOCdevkit/VOC%s/labels/'%(year)):
os.makedirs('VOCdevkit/VOC%s/labels/'%(year))
image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
list_file = open('%s_%s.txt'%(year, image_set), 'w')
for image_id in image_ids:
list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg\n'%(wd, year, image_id))
convert_annotation(year, image_id)
list_file.close()
os.system("cat 2021_train.txt 2021_val.txt 2021_train.txt 2021_val.txt > train.txt")
os.system("cat 2021_train.txt 2021_val.txt 2021_test.txt 2021_train.txt 2021_val.txt > train.all.txt")
新建coco.name文件记录识别类别名字:
新建coco.data文件记录类别数量,训练图像文件位置,测试图像文件位置,类别名字对象文件位置,生成权重文件保存位置:
新建yolov4-tiny-train文件记录训练图像大小,训练迭代最大次数等:
一切准备就绪后使用darknet.exe程序输入指令:darknet detector train object/coco.data object/yolov4-tiny-train.cfg yolov4-conv.29开始训练。
迭代6500次后loss损失度已经降低到0.1左右后停止训练,得到权重文件:
编写Python脚本testimage.py加载权重文件识别数字,并将识别到的结果进行处理后将结果通过串口反馈给STM32F4单片机,识别流程:将第一次识别的数字作为目标变量target保存,在第二次识别到这个数字判断数字所在位置(假如在左侧),如果是近端或中端则通过串口发送字符i或r给单片机,i代表left在左侧,r代表right在右侧,如果是远端则等待第三次识别target数字然后通过串口发送字符i或r给单片机。代码如下:
def camera_thread():
network, class_names, class_colors = darknet.load_network('object2/yolov4-tiny-test.cfg', 'object2/coco.data',
'object2/backup/yolov4-tiny4.weights', batch_size=1)
cap = cv2.VideoCapture(0)
width = 640
height = 480
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
w = 480
crop_w_start = (width - w) // 2
crop_h_start = (height - w) // 2
while True:
numbers = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight']
number = [['number', 0], ['number', 0], ['number', 0], ['number', 0], ['number', 0], ['number', 0],
['number', 0], ['number', 0]]
x = [0.0, 0.0, 0.0, 0.0]
y = [0.0, 0.0, 0.0, 0.0]
worth = 0.0
s = 1
flag = 0
counts = 0
print("start")
while True:
input = cv2.waitKey(1) & 0xFF
if input == ord('x'):
print("stop")
break
ret, frame = cap.read()
# frame = frame[crop_h_start:crop_h_start+w+60,crop_w_start:crop_w_start+w+60]
image, detections = image_detection(frame, network, class_names, class_colors, 0.25)
# darknet.print_detections(detections, True)
cv2.imshow('Inference', image)
counts = len(detections)
# print("检测到",counts,"个数字")
for i in range(0, counts, 1):
# print(detections[i][0], detections[i][1], detections[i][2][0])
# str,str,float,float,float,float
# print(type(detections[i][0]),type(detections[i][1]),type(detections[i][2][0]),type(detections[i][2][1]),type(detections[i][2][2]),type(detections[i][2][3]))
number[i][0] = detections[i][0]
number[i][1] = int(detections[i][2][0])
# print(numbers.index(number[i][0]))
s = int(s + (detections[i][2][2] * detections[i][2][3]))
worth = worth + float(detections[i][1])
if s > counts and counts != 0:
# serial_port.write(('d'+(str(s/counts))).encode())
# print('distence:',s/counts)
s = 0
else:
s = 0
if worth > counts and counts >= 0:
if worth / counts >= 70:
number_sort = sorted(number, key=lambda x: x[1])
# serial_port.write(('n'+(str(counts))).encode())
for i in range(0, counts, 1):
print(numbers.index(number_sort[7 - i][0]), number_sort[7 - i][1])
# serial_port.write((str(numbers.index(number_sort[7-i][0]))).encode())
if flag == 1 and target == numbers.index(number_sort[7 - i][0]):
if (counts == 4 and (i == 0 or i == 1)) or (counts == 2 and (i == 0)):
GPIO.output(13, GPIO.LOW)
print("right")
serial_port.write(str(target).encode())
serial_port.write('r'.encode())
while True:
if serial_port.inWaiting() > 0:
data = serial_port.read()
# print(data)
if data == 'o':
flag = 2
break
input = cv2.waitKey(1) & 0xFF
if input == ord('q'):
break
# time.sleep(5)
elif (counts == 4 and (i == 2 or i == 3)) or (counts == 2 and (i == 1)):
GPIO.output(15, GPIO.LOW)
print("left")
serial_port.write('l'.encode())
while True:
if serial_port.inWaiting() > 0:
data = serial_port.read()
# print(data)
if data == 'o':
flag = 2
elif data == 'k':
flag = 0
break
input = cv2.waitKey(1) & 0xFF
if input == ord('q'):
break
print(i)
elif flag == 2 and target == numbers.index(number_sort[7 - i][0]):
if (counts == 4 and (i == 0 or i == 1)) or (counts == 2 and (i == 0)):
GPIO.output(13, GPIO.LOW)
print("right")
serial_port.write('r'.encode())
while True:
if serial_port.inWaiting() > 0:
data = serial_port.read()
# print(data)
if data == 'o':
flag = 3
elif data == 'k':
flag = 0
break
input = cv2.waitKey(1) & 0xFF
if input == ord('q'):
break
# time.sleep(5)
elif (counts == 4 and (i == 2 or i == 3)) or (counts == 2 and (i == 1)):
GPIO.output(15, GPIO.LOW)
print("left")
serial_port.write('l'.encode())
while True:
if serial_port.inWaiting() > 0:
data = serial_port.read()
# print(data)
if data == 'o':
flag = 3
break
input = cv2.waitKey(1) & 0xFF
if input == ord('q'):
break
print(i)
elif flag == 0 and counts == 1:
target = numbers.index(number[0][0])
print("target:", target)
if target == 1:
serial_port.write('1'.encode())
while True:
if serial_port.inWaiting() > 0:
data = serial_port.read()
# print(data)
if data == 'o':
flag = 3
break
input = cv2.waitKey(1) & 0xFF
if input == ord('q'):
break
elif target == 2:
serial_port.write('2'.encode())
while True:
if serial_port.inWaiting() > 0:
data = serial_port.read()
# print(data)
if data == 'o':
flag = 3
break
input = cv2.waitKey(1) & 0xFF
if input == ord('q'):
break
else:
serial_port.write('t'.encode())
serial_port.write('t'.encode())
serial_port.write('t'.encode())
# time.sleep(1)p
flag = 1
# serial_port.write('o'.encode())
print("counts:", counts)
number = [['number', 0], ['number', 0], ['number', 0], ['number', 0], ['number', 0], ['number', 0],
['number', 0], ['number', 0]]
worth = 0
else:
worth = 0
k = cv2.waitKey(1)
time.sleep(0.05)