https://pan.baidu.com/s/1q7I3-p_vWtKOmCWyg1f89Q
GitHub - ultralytics/yolov5 at v6.0
百度网盘 请输入提取码百度网盘为您提供文件的网络备份、同步和分享服务。空间大、速度快、安全稳固,支持教育网加速,支持手机端。注册使用百度网盘即可享受免费存储空间https://pan.baidu.com/s/1SyaOF1DV0Er-hfwtHtaK8g
写文章-CSDN创作中心https://mp.csdn.net/mp_blog/creation/editor/124817874
注意:在自己的笔记本上面训练,训练好之后将best.wts文件复制到nano中的tensorrtx-master/yolov5/中
具体制作方法按照B站 炮哥带你学 当中进行环境部署和数据集制作
1)hat.yaml用于加载数据集的路径
该文件中需要修改 (nc、names) nc为自己要识别的类别个数,4表示识别4类
names当中为具体识别的类别
train.py中修改 主要修改前三个红框里面的路径,根据自己的需要修改epoch中训练次数
将训练好的模型文件best.pt(在runs/train/expn中,expn为train中最后一个文件夹)复制到yolov5 V6.0的一级目录下
新建一个gen_wts.py文件
import sys
import argparse
import os
import struct
import torch
from utils.torch_utils import select_device
def parse_args():
parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')
parser.add_argument('-o', '--output', help='Output (.wts) file path (optional)')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid input file')
if not args.output:
args.output = os.path.splitext(args.weights)[0] + '.wts'
elif os.path.isdir(args.output):
args.output = os.path.join(
args.output,
os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
return args.weights, args.output
pt_file, wts_file = parse_args()
# Initialize
device = select_device('cpu')
# Load model
model = torch.load(pt_file, map_location=device) # load to FP32
model = model['ema' if model.get('ema') else 'model'].float()
# update anchor_grid info
anchor_grid = model.model[-1].anchors * model.model[-1].stride[...,None,None]
# model.model[-1].anchor_grid = anchor_grid
delattr(model.model[-1], 'anchor_grid') # model.model[-1] is detect layer
model.model[-1].register_buffer("anchor_grid",anchor_grid) #The parameters are saved in the OrderDict through the "register_buffer" method, and then saved to the weight.
model.to(device).eval()
with open(wts_file, 'w') as f:
f.write('{}\n'.format(len(model.state_dict().keys())))
for k, v in model.state_dict().items():
vr = v.reshape(-1).cpu().numpy()
f.write('{} {} '.format(k, len(vr)))
for vv in vr:
f.write(' ')
f.write(struct.pack('>f' ,float(vv)).hex())
f.write('\n')
使用该命令运行: python gen_wts.py -w best.pt -o best.wts gen_wts.py程序(即可生成下图best.wts文件)
1.将下载的tensorrtx-master复制到nano中
2.将自己电脑生成的best.wts文件复制到nano中的tensorrtx-master/yolov5/中
3.在yololayer.hz中修改
4.将原来的build文件删除,在tensorrtx-master路径下打开终端
输入
mkdir build
cd build
将生成的wts文件复制到build下
cmake ..
make -j4
sudo ./yolov5 -s best.wts best.engine n #生成engine文件
5.运行tensorrt-led.py即可使用CSI采集图像进行识别