安全帽检测(yolov5+tnn)

文章目录

  • Ubuntu18.04下安装PyTorch(GPU)详细步骤
    • 1.Anaconda的下载安装
    • 2. NVIDIA驱动安装
    • 3. 安装CUDA10.2
      • 3.1 安装依赖库
      • 3.2 安装cuda10.2
      • 3.3 添加cuda路径
    • 4. 安装cudnn
    • 5 安装PyTorch
  • 用yolov5训练自己的目标检测模型
    • 1.克隆yolov5-5.0项目
    • 2.安装依赖环境
    • 3. 下载预训练权重
    • 4 准备自己的数据集
    • 5 写自己的配置文件
    • 6 训练自己的模型
    • 7 利用训练好的模型检测图片
  • yolov5+TNN推理框架部署到windows上
    • 1.下载tnn-master项目
    • 2.将.pt转换成tnn模型(在ubuntu18.04系统完成)
    • 3.将ONNX转换成tnn模型
    • 4 windows+cpu环境编译tnn
      • 4.1 环境要求
      • 4.2.编译步骤
    • 5 在VS2019上运行推理程序
      • 5.1 添加路径和库
      • 5.2 代码
        • yolo.h
        • yolo.cpp
        • RST_HelmetDetection.h
        • RST_HelmetDetection.cpp
        • fileLoad.h
        • fileLoad.cpp
      • 5.3 一些比较和分析

Ubuntu18.04下安装PyTorch(GPU)详细步骤

1.Anaconda的下载安装

到官网Anaconda [https://www.anaconda.com/products/individual] 进行下载,Anaconda3-2021.11-Linux-x86_64.sh,带了python3.9:
安全帽检测(yolov5+tnn)_第1张图片

Anaconda3-2021.11-Linux-x86_64.sh 所在文件目录打开终端,运行

bash Anaconda3-2021.11-Linux-x86_64.sh

然后一路yes,enter,yes, 安装成功后需要关闭当前终端,打开新终端就会出现(base)
安全帽检测(yolov5+tnn)_第2张图片
查看python版本:
安全帽检测(yolov5+tnn)_第3张图片

2. NVIDIA驱动安装

在英伟达官网查看设备是否支持显卡驱动:https://developer.nvidia.com/cuda-gpus
点击查看
安全帽检测(yolov5+tnn)_第4张图片
使用命令ubuntu-drivers devices可查看当前的设备和驱动:
安全帽检测(yolov5+tnn)_第5张图片
安装所有驱动:

sudo ubuntu-drivers autoinstall

安装完成之后重启电脑
重启之后使用命令sudo nvidia-smi验证是否安装成功:
安全帽检测(yolov5+tnn)_第6张图片
cuda最高支持版本是11.5.

3. 安装CUDA10.2

3.1 安装依赖库

sudo apt-get install freeglut3-dev build-essential libx11-dev libxmu-dev
sudo apt-get install libxi-dev libgl1-mesa-glx libglu1-mesa libglu1-mesa-dev

可能出现连接失败错误,需要添加源,参考Ubuntu学习系列——添加源,don’t quit
https://blog.csdn.net/weixin_44354586/article/details/89392951

sudo su
sudo gedit /etc/apt/sources.list

sources.list文件末尾添加源

更改保存后输入命令

sudo apt-get update
sudo apt upgrade

3.2 安装cuda10.2

去英伟达官网下载cuda10.2: https://developer.nvidia.com/cuda-10.2-download-archive
cuda历史版本: https://developer.nvidia.com/cuda-toolkit-archive
选择
安全帽检测(yolov5+tnn)_第7张图片
在终端输入命令下载cuda10.2

wget https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.runsudo

文件夹会存放在主目录中,在cuda_10.2.89_440.33.01_linux.run存放目右键打开终端运行:

sh cuda_10.2.89_440.33.01_linux.run

选择continue
安全帽检测(yolov5+tnn)_第8张图片
选择accept
安全帽检测(yolov5+tnn)_第9张图片
X代表安装,不选择driver后install:
安全帽检测(yolov5+tnn)_第10张图片
同样的方法下载安装fix patch:
安全帽检测(yolov5+tnn)_第11张图片

3.3 添加cuda路径

输入sudo gedit ~/.bashrc 打开.bashrc
安全帽检测(yolov5+tnn)_第12张图片
在文件末尾添加:

export PATH=$PATH:$/usr/local/cuda-10.2/bin  #根据CUDA版本更换路径
export LD_LIBRARY_PATH=/usr/local/cuda-10.2/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}  #根据CUDA版本更换路径

保存并关闭.bashrc,在终端输入:

source ~/.bashrc

重启电脑验证cuda是否安装成功:

cd /usr/local/cuda-10.2/samples/1_Utilities/deviceQuery
sudo make
sudo ./deviceQuery

出现pass说明安装成功:
安全帽检测(yolov5+tnn)_第13张图片

4. 安装cudnn

cudnn官网 :https://developer.nvidia.com/rdp/cudnn-archive
需要注册账号,有很多cuda10.2适配的cudnn,下载两个包:
安全帽检测(yolov5+tnn)_第14张图片
然后在文件所在文件夹下运行命令:

sudo dpkg -i libcudnn8_8.0.5.39-1+cuda10.2_amd64.deb
sudo dpkg -i libcudnn8-dev_8.0.5.39-1+cuda10.2_amd64.deb

5 安装PyTorch

PyTorch官网:https://pytorch.org/get-started/locally/
在这里插入图片描述

终端输入:

conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

如果长时间连接不上,去掉-c pytorch,输入

conda install pytorch torchvision torchaudio cudatoolkit=10.2

安全帽检测(yolov5+tnn)_第15张图片
还是连不上,修改Anaconda国内镜像,添加源(直接在命令行输入就行
清华源:

conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
conda config --set show_channel_urls yes

添加三方conda源:

conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/menpo/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/

测试安装是否成功:

python
import torch
torch.__version__

查看是否可加速:

print(torch.cuda.is_available())

使用一个矩阵运算测试是否能正常运行:

import torch as t
x = t.rand(5,3)
y = t.rand(5,3)
if t.cuda.is_available():
    x = x.cuda()
    y = y.cuda()
    print(x+y)

如果是下图这样的,说明PyTorch+GPU就安装成功了:
安全帽检测(yolov5+tnn)_第16张图片

用yolov5训练自己的目标检测模型

1.克隆yolov5-5.0项目

安全帽检测(yolov5+tnn)_第17张图片

2.安装依赖环境

在yolov5-5.0目录下运行:

pip install -r requirements.txt

安全帽检测(yolov5+tnn)_第18张图片

3. 下载预训练权重

3.1 下载 .pt : https://github.com/ultralytics/yolov5/releases/tag/v5.0 放在yolov5-5.0目录下
安全帽检测(yolov5+tnn)_第19张图片

4 准备自己的数据集

准备文件夹
yolov5-5.0/VOCdevkit/voc2007/Annotations里面存放着.xml格式的标签文件,
yolov5-5.0/VOCdevkit/voc2007JPEGImages里面存放着照片数据文件

安全帽检测(yolov5+tnn)_第20张图片
yolov5-5.0/prepare.py将VOC标签格式.xml转yolo格式.txt并划分训练集和测试集:
prepare.py

import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
import random
from shutil import copyfile
 
classes = ["white","blue","red","yellow","orange","none"]

TRAIN_RATIO = 80
 
def clear_hidden_files(path):
    dir_list = os.listdir(path)
    for i in dir_list:
        abspath = os.path.join(os.path.abspath(path), i)
        if os.path.isfile(abspath):
            if i.startswith("._"):
                os.remove(abspath)
        else:
            clear_hidden_files(abspath)
 
def convert(size, box):
    dw = 1./size[0]
    dh = 1./size[1]
    x = (box[0] + box[1])/2.0
    y = (box[2] + box[3])/2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x*dw
    w = w*dw
    y = y*dh
    h = h*dh
    return (x,y,w,h)
 
def convert_annotation(image_id):
    in_file = open('VOCdevkit/VOC2007/Annotations/%s.xml' %image_id)
    out_file = open('VOCdevkit/VOC2007/YOLOLabels/%s.txt' %image_id, 'w')
    tree=ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)
 
    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
        bb = convert((w,h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
    in_file.close()
    out_file.close()
 
wd = os.getcwd()
wd = os.getcwd()
data_base_dir = os.path.join(wd, "VOCdevkit/")
if not os.path.isdir(data_base_dir):
    os.mkdir(data_base_dir)
work_sapce_dir = os.path.join(data_base_dir, "VOC2007/")
if not os.path.isdir(work_sapce_dir):
    os.mkdir(work_sapce_dir)
annotation_dir = os.path.join(work_sapce_dir, "Annotations/")
if not os.path.isdir(annotation_dir):
        os.mkdir(annotation_dir)
clear_hidden_files(annotation_dir)
image_dir = os.path.join(work_sapce_dir, "JPEGImages/")
if not os.path.isdir(image_dir):
        os.mkdir(image_dir)
clear_hidden_files(image_dir)
yolo_labels_dir = os.path.join(work_sapce_dir, "YOLOLabels/")
if not os.path.isdir(yolo_labels_dir):
        os.mkdir(yolo_labels_dir)
clear_hidden_files(yolo_labels_dir)
yolov5_images_dir = os.path.join(data_base_dir, "images/")
if not os.path.isdir(yolov5_images_dir):
        os.mkdir(yolov5_images_dir)
clear_hidden_files(yolov5_images_dir)
yolov5_labels_dir = os.path.join(data_base_dir, "labels/")
if not os.path.isdir(yolov5_labels_dir):
        os.mkdir(yolov5_labels_dir)
clear_hidden_files(yolov5_labels_dir)
yolov5_images_train_dir = os.path.join(yolov5_images_dir, "train/")
if not os.path.isdir(yolov5_images_train_dir):
        os.mkdir(yolov5_images_train_dir)
clear_hidden_files(yolov5_images_train_dir)
yolov5_images_test_dir = os.path.join(yolov5_images_dir, "val/")
if not os.path.isdir(yolov5_images_test_dir):
        os.mkdir(yolov5_images_test_dir)
clear_hidden_files(yolov5_images_test_dir)
yolov5_labels_train_dir = os.path.join(yolov5_labels_dir, "train/")
if not os.path.isdir(yolov5_labels_train_dir):
        os.mkdir(yolov5_labels_train_dir)
clear_hidden_files(yolov5_labels_train_dir)
yolov5_labels_test_dir = os.path.join(yolov5_labels_dir, "val/")
if not os.path.isdir(yolov5_labels_test_dir):
        os.mkdir(yolov5_labels_test_dir)
clear_hidden_files(yolov5_labels_test_dir)
 
train_file = open(os.path.join(wd, "yolov5_train.txt"), 'w')
test_file = open(os.path.join(wd, "yolov5_val.txt"), 'w')
train_file.close()
test_file.close()
train_file = open(os.path.join(wd, "yolov5_train.txt"), 'a')
test_file = open(os.path.join(wd, "yolov5_val.txt"), 'a')
list_imgs = os.listdir(image_dir) # list image files
prob = random.randint(1, 100)
print("Probability: %d" % prob)
for i in range(0,len(list_imgs)):
    path = os.path.join(image_dir,list_imgs[i])
    if os.path.isfile(path):
        image_path = image_dir + list_imgs[i]
        voc_path = list_imgs[i]
        (nameWithoutExtention, extention) = os.path.splitext(os.path.basename(image_path))
        (voc_nameWithoutExtention, voc_extention) = os.path.splitext(os.path.basename(voc_path))
        annotation_name = nameWithoutExtention + '.xml'
        annotation_path = os.path.join(annotation_dir, annotation_name)
        label_name = nameWithoutExtention + '.txt'
        label_path = os.path.join(yolo_labels_dir, label_name)
    prob = random.randint(1, 100)
    print("Probability: %d" % prob)
    if(prob < TRAIN_RATIO): # train dataset
        if os.path.exists(annotation_path):
            train_file.write(image_path + '\n')
            convert_annotation(nameWithoutExtention) # convert label
            copyfile(image_path, yolov5_images_train_dir + voc_path)
            copyfile(label_path, yolov5_labels_train_dir + label_name)
    else: # test dataset
        if os.path.exists(annotation_path):
            test_file.write(image_path + '\n')
            convert_annotation(nameWithoutExtention) # convert label
            copyfile(image_path, yolov5_images_test_dir + voc_path)
            copyfile(label_path, yolov5_labels_test_dir + label_name)
train_file.close()
test_file.close()


运行后产生文件夹:
yolov5-5.0/VOCdevkit/images: 分为trainval两个文件夹,存放图片
yolov5-5.0/VOCdevkit/labels: 分为trainval两个文件夹,存放.txt label

安全帽检测(yolov5+tnn)_第21张图片

5 写自己的配置文件

5.1 写自己的数据配置文件data/hat.yaml

train: /home/wyb/下载/yolov5-5.0/VOCdevkit/images/train/  # 16551 images
val: /home/wyb/下载/yolov5-5.0/VOCdevkit/images/val/  # 4952 images

# Classes
nc: 6  # number of classes
names: ["white","blue","red","yellow","orange","none"]  # class names

安全帽检测(yolov5+tnn)_第22张图片
5.2 写自己的模型配置文件models/yolov5s_hat.yaml
只需要修改nc
安全帽检测(yolov5+tnn)_第23张图片

6 训练自己的模型

运行yolov5-5.0/train.py 修改这些就可以跑自己的模型了:
安全帽检测(yolov5+tnn)_第24张图片

python train.py --weights yolov5s.pt --cfg models/yolov5s_hat.yaml --data data/hat.yaml --epochs 185 --device 0

会先显示一些配置信息:
安全帽检测(yolov5+tnn)_第25张图片
然后开始训练:
安全帽检测(yolov5+tnn)_第26张图片
6696张图片,185 epoch, 训练时间是2h+,训练好的模型是runs/train/exp/best.pt,
安全帽检测(yolov5+tnn)_第27张图片
安全帽检测(yolov5+tnn)_第28张图片

7 利用训练好的模型检测图片

查看修改detect.py接口 安全帽检测(yolov5+tnn)_第29张图片
运行detect.py:

python detect.py
--weight runs/train/exp/weights/best.pt  
--source  VOCdevkit/images/train/00000.jpg

检测结果保存在runs/detect文件夹下:
安全帽检测(yolov5+tnn)_第30张图片
一些检测结果:
安全帽检测(yolov5+tnn)_第31张图片
安全帽检测(yolov5+tnn)_第32张图片

yolov5+TNN推理框架部署到windows上

1.下载tnn-master项目

TNN-master : https://github.com/Tencent/TNN
安全帽检测(yolov5+tnn)_第33张图片

2.将.pt转换成tnn模型(在ubuntu18.04系统完成)

2.1 将.pt转换成ONNX
安装ONNX:

pip3 install onnx

yolov5-5.0/models/export.py中的export.py复制到yolov5-5.0下,查看接口:
安全帽检测(yolov5+tnn)_第34张图片
运行:

python export.py --weight runs/train/exp/weights/best.pt

如下图说明转换成功:
在这里插入图片描述
netron可视化ONNX:
安全帽检测(yolov5+tnn)_第35张图片
可以在export.py文件中修改img-size为[448,640],利用tnn部署到wndows上结果要好些:
安全帽检测(yolov5+tnn)_第36张图片

3.将ONNX转换成tnn模型

按照官方文档,利用 Convert2tnn Docker。
安全帽检测(yolov5+tnn)_第37张图片

安装docker:

sudo apt install docker.io

拉镜像并修改镜像名称:

sudo docker pull ccr.ccs.tencentyun.com/qcloud/tnn-convert
sudo docker tag ccr.ccs.tencentyun.com/qcloud/tnn-convert tnn-convert:latest
sudo docker rmi ccr.ccs.tencentyun.com/qcloud/tnn-convert
sudo docker images

安全帽检测(yolov5+tnn)_第38张图片

把yolov5所在文件夹挂载到 docker 容器中的 "/workspace” 文件夹下面:

cd /home/wyb/下载
sudo docker run --volume=$(pwd):/workspace -it tnn-convert:latest

查看:
安全帽检测(yolov5+tnn)_第39张图片
如果没有build文件夹,需要先./build.sh
建议用自己下载的/home/wyb/下载/ TNN-master, 挂载后的地址是/workspace/TNN-master,
cd /workspace/TNN-master/tools/convert2tnn 成功build后:
安全帽检测(yolov5+tnn)_第40张图片
将ONNX转换成tnnmodel:

python3  converter.py  onnx2tnn /workspace/yolov5-5.0/runs/train/exp/weights/best640/best.onnx

会在best.onnx所在目录下生成best.tnnprotobest.tnnmodel两个文件,分别可以用TextPad查看和netron查看。
安全帽检测(yolov5+tnn)_第41张图片

4 windows+cpu环境编译tnn

4.1 环境要求

依赖库

  • Visual Studio (2017 及更高版本)
  • cmake (把3.11及以上版本cmake加入环境变量或使用 Visual
    Studio 自带cmake)
  • ninja (编译速度更快,可以使用choco安装)

4.2.编译步骤

打开 x64 Native Tools Command Prompt for VS 2017/2019,如果想要编译32位的库,打开 x86 Native Tools Command Prompt for VS 2017/2019
切换到脚本目录

cd <path_to_tnn>/scripts

执行编译脚本
编译不带openvino的版本

.\build_msvc_native.bat

编译带openvino的版本

.\build_msvc.bat

openvino只能编译成64位的库.

运行前要先设置opencv的路径

set OpenCV_DIR ='F:/opencv/build'

编译成功后:
安全帽检测(yolov5+tnn)_第42张图片
生成文件夹scripts/build_win,包含TNN.libTNN.dll等文件。

5 在VS2019上运行推理程序

5.1 添加路径和库

添加包含路径:

<TNN_DIR>\include
<TNN_DIR>
<OpenCV_DIR>\include
<OpenCV_DIR>\include\opencv
<OpenCV_DIR>\include\opencv2

添加库路径:

<TNN_DIR>\scripts\build_win
<OpenCV_DIR>\x64\vc14\lib

TNN-master文件夹所在路径,opencv/build文件夹路径。
安全帽检测(yolov5+tnn)_第43张图片

添加库:

opencv_world3416.lib
TNN.lib

安全帽检测(yolov5+tnn)_第44张图片

5.2 代码

yolo.h

#pragma once

#include "tnn/core/tnn.h"
#include "tnn/utils/blob_converter.h"
#include "tnn/utils/mat_utils.h"
#include "tnn/utils/dims_vector_utils.h"

namespace TNN_NS
{
    struct ObjectInfo {
        int nImage_width = 0;
        int nImage_height = 0;

        float x1 = 0;
        float y1 = 0;
        float x2 = 0;
        float y2 = 0;
        float score = 0.;
        int nClass_id = -1;
    };

    struct ImageInfo {
        ImageInfo();
        ImageInfo(std::shared_ptr<Mat>mat);
        ImageInfo(const ImageInfo &info);
        int nImage_width = 0;
        int nImage_height = 0;
        int nImage_channel = 0;

        std::shared_ptr<char> data; // 4-channel image data
    };

    struct RGBA
    {
        RGBA(int nR = 0, int nG = 0, int nB = 0, int nA = 0) : m_r(nR), m_g(nG), m_b(nB),m_a(nA) {}
        unsigned char m_r,m_g, m_b, m_a;
    };

    extern const std::string kTNNSDKDefaultName;
    class ObjectDetectorYoloInput
    {
    public:
        ObjectDetectorYoloInput(std::shared_ptr<TNN_NS::Mat> mat = nullptr);
        virtual ~ObjectDetectorYoloInput();

        bool IsEmpty();
        std::shared_ptr<TNN_NS::Mat> GetMat(std::string strName = kTNNSDKDefaultName);
        bool AddMat(std::shared_ptr<TNN_NS::Mat> mat, std::string strName);

    protected:
        std::map<std::string, std::shared_ptr<TNN_NS::Mat> > m_mat_map= {};
    };

    class ObjectDetectorYoloOutput : public ObjectDetectorYoloInput
    {
    public:
        ObjectDetectorYoloOutput(std::shared_ptr<Mat> mat = nullptr) : ObjectDetectorYoloInput(mat) {};
        virtual ~ObjectDetectorYoloOutput();

        std::vector<ObjectInfo> m_vctObject_list;
    };

    class  ObjectDetectorYoloOption
    {
    public:
        ObjectDetectorYoloOption();
        virtual ~ObjectDetectorYoloOption();

        std::string m_strProto_content = "";
        std::string m_strModel_content = "";
        InputShapesMap m_input_shapes = {};
    };

    class ObjectDetectorYolo
    {
    public:
        ObjectDetectorYolo();
        virtual ~ObjectDetectorYolo();

       virtual DimsVector GetInputShape(std::string strName= kTNNSDKDefaultName);
        virtual MatType GetOutputMatType(std::string strName = "");
        virtual Status Init(std::shared_ptr< ObjectDetectorYoloOption> option);
        virtual Status Predict(std::shared_ptr<ObjectDetectorYoloInput> input, std::shared_ptr<ObjectDetectorYoloOutput> &output);
        virtual MatConvertParam GetConvertParamForOutput(std::string strName = "");
        virtual Status GetCommandQueue(void **command_queue);
        Status Resize(std::shared_ptr<TNN_NS::Mat> src, std::shared_ptr<TNN_NS::Mat> dst);
        virtual MatConvertParam GetConvertParamForInput(std::string strName = "");
        virtual std::shared_ptr<ObjectDetectorYoloOutput> CreateSDKOutput();
        virtual Status ProcessSDKOutput(std::shared_ptr<ObjectDetectorYoloOutput> output);
        std::shared_ptr<Mat> ResizeToInputShape(std::shared_ptr<Mat> input_mat, std::string strName);

    private:
        std::vector<std::string> GetOutputNames();
        void GenerateDetectResult(std::vector<std::shared_ptr<Mat> >outputs, std::vector<ObjectInfo> &detects,
            int image_width, int image_height);
        void NMS(std::vector<ObjectInfo> &objs, std::vector<ObjectInfo> &results);
        void PostProcessMat(std::vector<std::shared_ptr<Mat> >outputs, std::vector<std::shared_ptr<Mat> > &post_mats);

    private:
        std::shared_ptr<TNN> m_net = nullptr;
        std::shared_ptr<Instance> m_instance= nullptr;
        std::shared_ptr< ObjectDetectorYoloOption> m_option = nullptr;

        float conf_thres = 0.1;
        float iou_thres = 0.25;
        // yolov5s model configurations
        std::vector<float> m_strides= { 32.f, 16.f, 8.f };
        std::vector<float> m_anchor_grids= { 116.f, 90.f,    156.f, 188.f,    373.f, 326.f, \
                                                                              30.f, 61.f,    62.f, 45.f,   59.f, 119.f, \
                                                                              10.f, 13.f,   16.f, 30.f,   33.f, 23.f };
                                                      
        float m_iou_threshold= 0.25f;
        int m_num_anchor = 3;
        int m_detect_dim = 11;
        int m_grid_per_input= 6;
    };
}  // namespace TNN_NS

yolo.cpp

#include "yolo.h"

namespace TNN_NS {
    const std::string kTNNSDKDefaultName = "TNN.sdk.default.name";

    ImageInfo::ImageInfo() : nImage_width(0), nImage_height(0), nImage_channel(0), data(nullptr)
    {
        NULL;
    }

    ImageInfo::ImageInfo(const ImageInfo &info) 
    {
        nImage_width = info.nImage_width;
        nImage_height = info.nImage_height;
        nImage_channel = info.nImage_channel;
        data = info.data;
    }

    ImageInfo::ImageInfo(std::shared_ptr<Mat>image)
    {
        if (image != nullptr) {
            const auto &dims = image->GetDims();
            nImage_channel = dims[1];
            nImage_height = dims[2];
            nImage_width = dims[3];
            auto count = DimsVectorUtils::Count(dims);
            data.reset(new char[count]);
            memcpy(data.get(), image->GetData(), count);
        }
    }

#pragma mark - ObjectDetectorYoloInput
    ObjectDetectorYoloInput::ObjectDetectorYoloInput(std::shared_ptr<TNN_NS::Mat> mat)
    {
        if (mat)
        {
            m_mat_map[kTNNSDKDefaultName] = mat;
        }
    }

    ObjectDetectorYoloInput::~ObjectDetectorYoloInput() {}

    bool ObjectDetectorYoloInput::IsEmpty()
    {
        if (m_mat_map.size() <= 0)
        {
            return true;
        }
        return false;
    }

    bool ObjectDetectorYoloInput::AddMat(std::shared_ptr<TNN_NS::Mat> mat, std::string name) 
    {
        if (name.empty() || !mat)
        {
            return false;
        }

        m_mat_map[name] = mat;
        return true;
    }

    std::shared_ptr<TNN_NS::Mat> ObjectDetectorYoloInput::GetMat(std::string name)
    {
        std::shared_ptr<TNN_NS::Mat> mat = nullptr;
        if (name == kTNNSDKDefaultName && m_mat_map.size() > 0)
        {
            return m_mat_map.begin()->second;
        }

        if (m_mat_map.find(name) != m_mat_map.end()) {
            mat = m_mat_map[name];
        }
        return mat;
    }

#pragma mark -  ObjectDetectorYoloOption
    ObjectDetectorYoloOption::ObjectDetectorYoloOption() {}

    ObjectDetectorYoloOption::~ObjectDetectorYoloOption() {}

#pragma mark - ObjectDetectorYolo
    ObjectDetectorYolo::ObjectDetectorYolo() {}

    ObjectDetectorYolo::~ObjectDetectorYolo() {}

    Status ObjectDetectorYolo::GetCommandQueue(void **command_queue) 
    {
        if (m_instance)
        {
            return m_instance->GetCommandQueue(command_queue);
        }
        return Status(TNNERR_INST_ERR, "instance_ GetCommandQueue return nil");
    }

    Status ObjectDetectorYolo::Resize(std::shared_ptr<TNN_NS::Mat> src, std::shared_ptr<TNN_NS::Mat> dst/*, TNNInterpType interp_type*/)
    {
        Status status = TNN_OK;
        void *command_queue = nullptr;
        status = GetCommandQueue(&command_queue);
        if (status != TNN_NS::TNN_OK)
        {
            LOGE("getCommandQueue failed with:%s\n", status.description().c_str());
            return status;
        }

        ResizeParam param;
        param.type = TNN_NS::INTERP_TYPE_LINEAR;

        auto dst_dims = dst->GetDims();
        auto src_dims = src->GetDims();
        param.scale_w = dst_dims[3] / static_cast<float>(src_dims[3]);
        param.scale_h = dst_dims[2] / static_cast<float>(src_dims[2]);

        status = MatUtils::Resize(*(src.get()), *(dst.get()), param, command_queue);
        if (status != TNN_NS::TNN_OK) {
            LOGE("resize failed with:%s\n", status.description().c_str());
        }

        return status;
    }

    TNN_NS::Status ObjectDetectorYolo::Init(std::shared_ptr< ObjectDetectorYoloOption> option)
    {
        m_option = option;
        TNN_NS::Status status;

        if (!m_net)
        {
            TNN_NS::ModelConfig config;
            config.model_type = TNN_NS::MODEL_TYPE_TNN;
            config.params = { option->m_strProto_content, option->m_strModel_content};

            auto net = std::make_shared<TNN_NS::TNN>();
            status = net->Init(config);
            if (status != TNN_NS::TNN_OK) {
                LOGE("instance.net init failed %d", (int)status);
                return status;
            }
            m_net = net;
        }
        //创建实例instance
        {
            TNN_NS::NetworkConfig network_config;
            network_config.device_type = TNN_NS::DEVICE_X86;
            network_config.cache_path = "/sdcard/";
            std::shared_ptr<TNN_NS::Instance> instance;
           instance = m_net->CreateInst(network_config, status, option->m_input_shapes);
            m_instance = instance;
        }
        return status;
    }

   DimsVector ObjectDetectorYolo::GetInputShape(std::string name)
    {
        DimsVector shape = {};
        BlobMap blob_map = {};
        if (m_instance)
        {
            m_instance->GetAllInputBlobs(blob_map);
        }

        if (kTNNSDKDefaultName == name && blob_map.size() > 0)
        {
            if (blob_map.begin()->second)
            {
                shape = blob_map.begin()->second->GetBlobDesc().dims;
            }
        }

        if (blob_map.find(name) != blob_map.end() && blob_map[name])
        {
            shape = blob_map[name]->GetBlobDesc().dims;
        }
        return shape;
    }

    MatType ObjectDetectorYolo::GetOutputMatType(std::string name)
    {
        if (m_instance)
        {
            BlobMap output_blobs;
            m_instance->GetAllOutputBlobs(output_blobs);
            auto blob = name == "" ?  output_blobs.begin()->second : output_blobs[name];
        }
        return NCHW_FLOAT;
    }

    std::vector<std::string> ObjectDetectorYolo::GetOutputNames()
    {
        std::vector<std::string> names;
        if (m_instance)
        {
            BlobMap blob_map;
            m_instance->GetAllOutputBlobs(blob_map);
            for (const auto &item : blob_map)
            {
                names.push_back(item.first);
            }
        }
        return names;
    }

    std::shared_ptr<Mat> ObjectDetectorYolo::ResizeToInputShape(std::shared_ptr<Mat> input_mat, std::string name)
    {
        auto target_dims =GetInputShape(name);
        auto input_height = input_mat->GetHeight();
        auto input_width = input_mat->GetWidth();
        if (target_dims.size() >= 4 &&
            (input_height != target_dims[2] || input_width != target_dims[3]))
        {
            auto target_mat = std::make_shared<TNN_NS::Mat>(input_mat->GetDeviceType(),
                input_mat->GetMatType(), target_dims);
            auto status = Resize(input_mat, target_mat);
            if (status == TNN_OK)
            {
                return target_mat;
            }
            else
            {
                LOGE("%s\n", status.description().c_str());
                return nullptr;
            }
        }
        return input_mat;
    }

    TNN_NS::MatConvertParam ObjectDetectorYolo::GetConvertParamForOutput(std::string name)
    {
        return TNN_NS::MatConvertParam();
    }

    TNN_NS::Status ObjectDetectorYolo::Predict(std::shared_ptr<ObjectDetectorYoloInput> input, std::shared_ptr<ObjectDetectorYoloOutput> &output)
    {
        Status status = TNN_OK;
        if (!input || input->IsEmpty())
        {
            status = Status(TNNERR_PARAM_ERR, "input image is empty ,please check!");
            LOGE("input image is empty ,please check!\n");
            return status;
        }
        // step 1. set input mat
        auto input_mat = input->GetMat();
        input_mat = ResizeToInputShape(input_mat, "");
        auto input_convert_param = GetConvertParamForInput();
       status = m_instance->SetInputMat(input_mat, input_convert_param);
        RETURN_ON_NEQ(status, TNN_NS::TNN_OK);
        // step 2. Forward
        status = m_instance->ForwardAsync(nullptr);
        if (status != TNN_NS::TNN_OK)
        {
            LOGE("instance.Forward Error: %s\n", status.description().c_str());
            return status;
        }
        // step 3. get output mat
        auto input_device_type = input->GetMat()->GetDeviceType();
        output = CreateSDKOutput();
        auto output_names = GetOutputNames();

        if (output_names.size() == 1)
        {
            auto output_convert_param = GetConvertParamForOutput();
            std::shared_ptr<TNN_NS::Mat> output_mat = nullptr;
            status = m_instance->GetOutputMat(output_mat, output_convert_param, "",DEVICE_X86, GetOutputMatType());
            RETURN_ON_NEQ(status, TNN_NS::TNN_OK);
            output->AddMat(output_mat, output_names[0]);
        }
        else
        {
            for (auto name : output_names)
            {
                auto output_convert_param = GetConvertParamForOutput(name);
                std::shared_ptr<TNN_NS::Mat> output_mat = nullptr;
                status = m_instance->GetOutputMat(output_mat, output_convert_param, name, DEVICE_X86, GetOutputMatType(/*name*/));
                RETURN_ON_NEQ(status, TNN_NS::TNN_OK);
                output->AddMat(output_mat, name);
            }
        }

        ProcessSDKOutput(output);
        return status;
    }

    ObjectDetectorYoloOutput::~ObjectDetectorYoloOutput() {}
    std::shared_ptr<ObjectDetectorYoloOutput> ObjectDetectorYolo::CreateSDKOutput()
    {
        return std::make_shared<ObjectDetectorYoloOutput>();
    }

    MatConvertParam ObjectDetectorYolo::GetConvertParamForInput(std::string name)
    {
        MatConvertParam input_convert_param;
        input_convert_param.scale = { 1.0 / 255, 1.0 / 255, 1.0 / 255, 0 };
        input_convert_param.bias = { 0.0, 0.0, 0.0, 0.0 };
        return input_convert_param;
    }

    Status ObjectDetectorYolo::ProcessSDKOutput(std::shared_ptr<ObjectDetectorYoloOutput> output_)
    {
        Status status = TNN_OK;

        auto output = dynamic_cast<ObjectDetectorYoloOutput *>(output_.get());
        RETURN_VALUE_ON_NEQ(!output, false, Status(TNNERR_PARAM_ERR, "TNNSDKOutput is invalid"));

        auto output_mat_0 = output->GetMat("419");
        RETURN_VALUE_ON_NEQ(!output_mat_0, false, Status(TNNERR_PARAM_ERR, "GetMat is invalid"));
        auto output_mat_1 = output->GetMat("405");
        RETURN_VALUE_ON_NEQ(!output_mat_1, false, Status(TNNERR_PARAM_ERR, "GetMat is invalid"));
        auto output_mat_2 = output->GetMat("output");
        RETURN_VALUE_ON_NEQ(!output_mat_2, false, Status(TNNERR_PARAM_ERR, "GetMat is invalid"));

        auto input_shape = GetInputShape();
        RETURN_VALUE_ON_NEQ(input_shape.size() == 4, true,
            Status(TNNERR_PARAM_ERR, "GetInputShape is invalid"));

        std::vector<ObjectInfo> object_list;
        GenerateDetectResult({ output_mat_0, output_mat_1, output_mat_2 }, object_list, input_shape[3], input_shape[2]);
        output->m_vctObject_list = object_list;
        return status;
    }

    void ObjectDetectorYolo::NMS(std::vector<ObjectInfo> &input, std::vector<ObjectInfo> &output)
    {
        std::sort(input.begin(), input.end(), [](const ObjectInfo &a, const ObjectInfo &b) { return a.score > b.score; });
        output.clear();

        int box_num = input.size();
        std::vector<int> merged(box_num, 0);
        for (int i = 0; i < box_num; i++)
        {
            if (merged[i]) 
            {
                continue;
            }
            std::vector<ObjectInfo> buf;
            buf.push_back(input[i]);
            merged[i] = 1;

            float area0 = (input[i].y2 - input[i].y1 + 1) * (input[i].x2 - input[i].x1 + 1);

            for (int j = i + 1; j < box_num; j++) 
            {
                if (merged[j])
                {
                    continue;
                }
                float inner_x0 = input[i].x1 > input[j].x1 ? input[i].x1 : input[j].x1;
                float inner_y0 = input[i].y1 > input[j].y1 ? input[i].y1 : input[j].y1;
                float inner_x1 = input[i].x2 < input[j].x2 ? input[i].x2 : input[j].x2;
                float inner_y1 = input[i].y2 < input[j].y2 ? input[i].y2 : input[j].y2;
                float inner_h = inner_y1 - inner_y0 + 1;
                float inner_w = inner_x1 - inner_x0 + 1;
                if (inner_h <= 0 || inner_w <= 0)
                {
                    continue;
                }
                float inner_area = inner_h * inner_w;

                float area1 =(input[j].y2 - input[j].y1 + 1) * (input[j].x2 - input[j].x1 + 1);

                float score = inner_area / (area0 + area1 - inner_area);
                if (score > m_iou_threshold) 
                {
                    merged[j] = 1;
                    buf.push_back(input[j]);
                }
            }

            output.push_back(buf[0]);
        }
    }
        
    void ObjectDetectorYolo::PostProcessMat(std::vector<std::shared_ptr<Mat> >outputs, std::vector<std::shared_ptr<Mat> > &post_mats)
    {
        for (auto &output : outputs) {
            auto dims = output->GetDims();
            auto h_stride = DimsVectorUtils::Count(dims, 2);
            auto w_stride = DimsVectorUtils::Count(dims, 3);
            DimsVector permute_dims = { dims[0], dims[2], dims[3], dims[1] * dims[4] }; // batch, height, width, anchor*detect_dim
            auto mat = std::make_shared<Mat>(output->GetDeviceType(), output->GetMatType(), permute_dims);
            float *src_data = reinterpret_cast<float *>(output->GetData());
            float *dst_data = reinterpret_cast<float *>(mat->GetData());
            int out_idx = 0;
            for (int h = 0; h < permute_dims[1]; h++) {
                for (int w = 0; w < permute_dims[2]; w++) {
                    for (int s = 0; s < permute_dims[3]; s++) {
                        int anchor_idx = s / dims[4];
                        int detect_idx = s % dims[4];
                        int in_idx = anchor_idx * h_stride + h * w_stride + w * dims[4] + detect_idx;
                        dst_data[out_idx++] = 1.0f / (1.0f + exp(-src_data[in_idx]));
                    }
                }
            }
          post_mats.emplace_back(mat);
        }
    }

    void ObjectDetectorYolo::GenerateDetectResult(std::vector<std::shared_ptr<Mat> >outputs,
        std::vector<ObjectInfo> &detecs, int image_width, int image_height)
    {
        std::vector<ObjectInfo> extracted_objs;
        int blob_index = 0;

        std::vector<std::shared_ptr<Mat>> post_mats;
        PostProcessMat(outputs, post_mats);
        auto output_mats = post_mats;

        for (auto &output : output_mats)
        {
            auto dim = output->GetDims();
            if (dim[3] != m_num_anchor * m_detect_dim)
            {
                LOGE("Invalid detect output, the size of last dimension is: %d\n", dim[3]);
                return;
            }
            float *data = static_cast<float *>(output->GetData());

            int num_potential_detecs = dim[1] * dim[2] * m_num_anchor;
            for (int i = 0; i < num_potential_detecs; ++i)
            {
                float x = data[i * m_detect_dim + 0];
                float y = data[i * m_detect_dim + 1];
                float width = data[i * m_detect_dim + 2];
                float height = data[i * m_detect_dim + 3];
                float objectness = data[i * m_detect_dim + 4];

                if (objectness < conf_thres)
                {
                    continue;
                }
                //center point coord
                x = (x * 2 - 0.5 + ((i / m_num_anchor) % dim[2])) * m_strides[blob_index];
                y = (y * 2 - 0.5 + ((i / m_num_anchor) / dim[2]) % dim[1]) * m_strides[blob_index];
                width = pow((width * 2), 2) * m_anchor_grids[blob_index * m_grid_per_input + (i % m_num_anchor) * 2 + 0];
                height = pow((height * 2), 2) * m_anchor_grids[blob_index * m_grid_per_input + (i % m_num_anchor) * 2 + 1];
                // compute coords
                float x1 = x - width / 2;
                float y1 = y - height / 2;
                float x2 = x + width / 2;
                float y2 = y + height / 2;
                // compute confidence
                auto conf_start = data + i * m_detect_dim + 5;
                auto conf_end = data + (i + 1) * m_detect_dim;
                auto max_conf_iter = std::max_element(conf_start, conf_end);
                int conf_idx = static_cast<int>(std::distance(conf_start, max_conf_iter));
                float score = (*max_conf_iter) * objectness;

                ObjectInfo obj_info;
                obj_info.nImage_width = image_width;
                obj_info.nImage_height = image_height;
                obj_info.x1 = x1;
                obj_info.y1 = y1;
                obj_info.x2 = x2;
                obj_info.y2 = y2;
                obj_info.score = score;
                obj_info.nClass_id = conf_idx;

                extracted_objs.push_back(obj_info);
            }
            blob_index += 1;
        }
        NMS(extracted_objs, detecs);
    }
}//TNN_NS

RST_HelmetDetection.h

#pragma once

#include 

using namespace std;

typedef void *HELMETD_HANDLE;

#ifdef __cplusplus
extern "C" {
#endif

	HELMETD_HANDLE HelmetD_Open(const char *pszModelDir = NULL, const char *pszProtoDir = NULL);

	// pImage为图像数据,格式为RGB888,bRGBOrder为true时,图像顺序为RGB,否则为BGR,vctResult的每一项为检测到的安全帽的
	// 位置,格式为[x1,y1,x2,y2,score,color]
	int HelmetD_Detect(HELMETD_HANDLE hHelmetD, int nWidth, int nHeight, int nChannel, unsigned char *pImage,
		bool bRGBOrder, std::vector<TNN_NS::ObjectInfo> &vctObject_list);

	void HelmetD_Close(HELMETD_HANDLE hHelmetD);

#ifdef __cplusplus
}
#endif

RST_HelmetDetection.cpp

#define _CRT_SECURE_NO_WARNINGS

#include 
#include 

#include "yolo.h"
#include "fileLoad.h"
#include "RST_HelmetDetection.h"

HELMETD_HANDLE HelmetD_Open(const char *pszModelDir, const char *pszProtoDir)
{
    auto model_content = fdLoadFile(pszModelDir);
    auto proto_content = fdLoadFile(pszProtoDir);
    auto option = std::make_shared<TNN_NS::ObjectDetectorYoloOption>();
    {
        option->m_strProto_content = proto_content;
        option->m_strModel_content = model_content;
    }
    auto predictor = new TNN_NS::ObjectDetectorYolo();
    auto status = predictor->Init(option);
    if (status != TNN_NS::TNN_OK)
    {
        std::cout << "Predictor Initing failed, please check the option parameters" << std::endl;
        return NULL;
    }

    return HELMETD_HANDLE(predictor);
}


int HelmetD_Detect(HELMETD_HANDLE hHelmetD, int nWidth, int nHeight, int nChannel, unsigned char *pImage,
    bool bRGBOrder, std::vector<TNN_NS::ObjectInfo> &vctObject_list)
{
    if (hHelmetD == NULL)
    {
        return -1;
    }
    auto predictor = std::make_shared<TNN_NS::ObjectDetectorYolo>(*(TNN_NS::ObjectDetectorYolo *)hHelmetD);
    std::shared_ptr<TNN_NS::ObjectDetectorYoloOutput> YoloOutput = nullptr;

    cv::Mat frame;
    const int image_orig_height = int(nHeight);
    const int image_orig_width = int(nWidth);
    TNN_NS::DimsVector nchw = { 1, nChannel, nHeight, nWidth };
    uint8_t *dataTmp = new uint8_t[image_orig_width * image_orig_height * 3];
    for (int i = 0; i < image_orig_height * image_orig_width; i++) {
        if (bRGBOrder == false)
        {
            dataTmp[i * 3 + 2] = pImage[i * 3];
            dataTmp[i * 3 + 1] = pImage[i * 3 + 1];
            dataTmp[i * 3] = pImage[i * 3 + 2];
        }
        else 
        {
            dataTmp[i * 3] = pImage[i * 3];
            dataTmp[i * 3 + 1] = pImage[i * 3 + 1];
            dataTmp[i * 3 + 2] = pImage[i * 3 + 2];
        }
    }

    auto image_mat = std::make_shared<TNN_NS::Mat>(TNN_NS::DEVICE_X86, TNN_NS::N8UC3, nchw, dataTmp);
    auto resize_mat = predictor->ResizeToInputShape(image_mat, "images");
    predictor->Predict(std::make_shared<TNN_NS::ObjectDetectorYoloInput>(resize_mat), YoloOutput);
    predictor->ProcessSDKOutput(YoloOutput);

    if (YoloOutput && dynamic_cast<TNN_NS::ObjectDetectorYoloOutput *>(YoloOutput.get()))
    {
        auto obj_output = dynamic_cast<TNN_NS::ObjectDetectorYoloOutput *>(YoloOutput.get());
        vctObject_list = obj_output->m_vctObject_list;
    }

    return 0;
}

void HelmetD_Close(HELMETD_HANDLE hHelmetD)
{
    if (hHelmetD)
    {
        delete hHelmetD;
        hHelmetD = NULL;
    }
}

fileLoad.h

#pragma once

#include "yolo.h"

std::string fdLoadFile(std::string strPath);

void getAllFiles(std::string strPath, std::vector<std::string> &vecFilePaths, std::vector<std::string> &vecFileNames, std::string strFormat);

void retangle(int nxmin, int nxmax, int nymin, int nymax, int nimage_height, int nimage_width, TNN_NS::RGBA *image_rgba, int nr, int ng, int nb, int na);

void Rectangle(const std::string strLabel_list[], TNN_NS::ObjectInfo object, void *data_rgba, int nimage_height, int nimage_width,
	int nx0, int ny0, int nx1, int ny1, float fscale_x, float fscale_y);

fileLoad.cpp

#define _CRT_SECURE_NO_WARNINGS

#include 
#include 
#include 
#include 

#include "fileLoad.h"

std::string fdLoadFile(std::string strPath)
{
    std::ifstream file(strPath, std::ios::binary);
    if (file.is_open()) 
    {
        file.seekg(0, file.end);
        int nsize = file.tellg();
        char *content = new char[nsize];
        file.seekg(0, file.beg);
        file.read(content, nsize);
        std::string fileContent;
        fileContent.assign(content, nsize);
        delete[] content;
        file.close();
        return fileContent;
    }
    else
    {
        return "";
    }
}

void getAllFiles(std::string strPath, std::vector<std::string> &vecFilePaths, std::vector<std::string> &vecFileNames, std::string strFormat)
{
    long hFile = 0;//文件句柄
    struct _finddata_t fileinfo; //文件信息 
    std::string strTmp;
    if ((hFile = _findfirst(strTmp.assign(strPath).append("\\*" + strFormat).c_str(), &fileinfo)) != -1) {
        do
        {
            if ((fileinfo.attrib & _A_SUBDIR)) {  //比较文件类型是否是文件夹
                if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)
                {
                    getAllFiles(strTmp.assign(strPath).append("\\").append(fileinfo.name), vecFilePaths, vecFileNames, strFormat);
                }
            }
            else
            {
                vecFilePaths.push_back(strTmp.assign(strPath).append("\\").append(fileinfo.name));
                std::string fileName = fileinfo.name;
                fileName = fileName.substr(0, fileName.size() - strFormat.size() - 1);
                vecFileNames.push_back(fileName);
            }
        } while (_findnext(hFile, &fileinfo) == 0);  //寻找下一个,成功返回0,否则-1

        _findclose(hFile);
    }
}

void retangle(int nx_min, int nx_max, int ny_min, int ny_max, int nImage_height, int nImage_width, TNN_NS::RGBA *image_rgba, int nr,int ng,int nb,int na)
{
    if (nx_max > nx_min) {
        for (int x = nx_min; x <= nx_max; x++)
        {
            int offset = ny_min * nImage_width + x;
            image_rgba[offset] = {nr, ng, nb, na};
            image_rgba[offset + nImage_width] = { nr ,ng, nb, na };

            offset = ny_max * nImage_width + x;
            image_rgba[offset] = { nr, ng, nb, na };
            if (offset >= nImage_width)
            {
                image_rgba[offset - nImage_width] = { nr, ng, nb, na };
            }
        }
    }
  
    if (ny_max >ny_min) {
        for (int y = ny_min; y <= ny_max; y++) {
            int offset = y * nImage_width + nx_min;
            image_rgba[offset] = { nr,ng,nb,na };
            image_rgba[offset + 1] = { nr,ng,nb,na };

            offset = y * nImage_width + nx_max;
            image_rgba[offset] = { nr,ng,nb,na };
            if (offset >= 1) {
                image_rgba[offset - 1] = { nr,ng,nb,na };
            }
        }
    }
}

void Rectangle(const std::string strLabel_list[], TNN_NS::ObjectInfo object, void *data_rgba, int iImage_height, int iImage_width,
    int ix0, int iy0, int ix1, int iy1, float fscale_x, float fscale_y)
{
    TNN_NS::RGBA *image_rgba = (TNN_NS::RGBA *)data_rgba;

   int  x_min = std::min(   std::max( int(std::min(ix0,  ix1) * fscale_x ),  0 ),  iImage_width - 1);
   int  x_max = std::min(std::max(int(std::max(ix0, ix1) * fscale_x), 0), iImage_width - 1);
   int  y_min = std::min(std::max(int(std::min(iy0, iy1) * fscale_y), 0), iImage_height - 1);
   int  y_max = std::min(std::max(int(std::max(iy0, iy1) * fscale_y), 0), iImage_height - 1);

    std::string labelObject = strLabel_list[object.nClass_id] + ":";
    char scoreBuf[10];
    sprintf(scoreBuf, "%.2f", object.score);
    labelObject += scoreBuf;
    cv::Mat face_frame(iImage_height, iImage_width, CV_8UC4, (void *)data_rgba);
    int x = (int)(std::min)(object.x1, object.x2) * fscale_x + 8;
    int y = (int)(std::min)(object.y1, object.y2) * fscale_y + 18;
    cv::Point point(x, y);
    //rgb={255,255,255},{65,105,255},{255,0,0},{255,255,0},{255,97,0},{0,255,0}
    switch (object.nClass_id)
    {
    case 0:
    {
        retangle(x_min, x_max, y_min, y_max, iImage_height, iImage_width, image_rgba, 255, 255, 255, 255);
        cv::putText(face_frame, labelObject, point, cv::FONT_HERSHEY_PLAIN, 1.0, cv::Scalar(255, 255, 255, 255));
    }
    break;
    case 1:
    {
         retangle(x_min, x_max, y_min, y_max, iImage_height, iImage_width, image_rgba, 255, 105, 65, 255);
        cv::putText(face_frame, labelObject, point, cv::FONT_HERSHEY_PLAIN, 1.0, cv::Scalar(255, 105, 65, 255));
    }
    break;
    case 2:
    {
        retangle(x_min, x_max, y_min, y_max, iImage_height, iImage_width, image_rgba, 0, 0, 255, 255);
        cv::putText(face_frame, labelObject, point, cv::FONT_HERSHEY_PLAIN, 1.0, cv::Scalar(0, 0, 255, 255));
    }
    break;
    case 3:
    {
        retangle(x_min, x_max, y_min, y_max, iImage_height, iImage_width, image_rgba, 0, 255, 255, 255);
        cv::putText(face_frame, labelObject, point, cv::FONT_HERSHEY_PLAIN, 1.0, cv::Scalar(0, 255, 255, 255));
    }
    break;
    case 4:
    {
        retangle(x_min, x_max, y_min, y_max, iImage_height, iImage_width, image_rgba, 0, 97, 255, 255);
        cv::putText(face_frame, labelObject, point, cv::FONT_HERSHEY_PLAIN, 1.0, cv::Scalar(0, 97, 255, 255));
    }
    break;
    case 5:
    {
        retangle(x_min, x_max, y_min, y_max, iImage_height, iImage_width, image_rgba, 0, 255, 0, 255);
        cv::putText(face_frame, labelObject, point, cv::FONT_HERSHEY_PLAIN, 1.0, cv::Scalar(0, 255, 0, 255));
    }
    break;
    default:
    {
        retangle(x_min, x_max, y_min, y_max, iImage_height, iImage_width, image_rgba, 0, 0, 0, 255);
        cv::putText(face_frame, labelObject, point, cv::FONT_HERSHEY_PLAIN, 1.0, cv::Scalar(0, 0, 0, 255));
    }
    break;
    }
    cv::imshow("object_dectecting", face_frame);
    auto key_num = cv::waitKey(28);
    
}

main.cpp

#define _CRT_SECURE_NO_WARNINGS

#include 
#include 
#include 
#include 
#include 

#include "fileLoad.h"
#include "RST_HelmetDetection.h"

#define STB_IMAGE_IMPLEMENTATION
#include "third_party/stb/stb_image.h"
#define STB_IMAGE_RESIZE_IMPLEMENTATION
#include "third_party/stb/stb_image_resize.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "third_party/stb/stb_image_write.h"


//GetAsyncKeyState所需头文件
#define KEY_DOWN(VK_NONAME) ((GetAsyncKeyState(VK_NONAME) & 0x8000) ? 1:0) 

const std::string label_list[] = { "white","blue","red","yellow","orange","none" };

int main(int argc, char **argv)
{
    const char *pszModelDir = "F:/0work/_总结/weightsw/best448/best.tnnmodel";
    const char *pszProtoDir = "F:/0work/_总结/weightsw/best448/best.tnnproto";
    //model初始化
    HELMETD_HANDLE hHelmetd = HelmetD_Open(pszModelDir, pszProtoDir);
    //文件名读取
    std::string strFilePath = "F:\\0work\\assets";
    std::vector<std::string> vctStrFilePaths;
    std::vector<std::string> vctStrFileNames;
    getAllFiles(strFilePath, vctStrFilePaths, vctStrFileNames, "jpg");
    //处理每个jpg文件
    for (int i = 0; i < vctStrFilePaths.size(); i++)
    {
        const char *input_imgfn = vctStrFilePaths[i].c_str();
        int image_width, image_height, image_channel;
        unsigned char *data;
        std::vector<TNN_NS::ObjectInfo> object_list;

        data = stbi_load(input_imgfn, &image_width, &image_height, &image_channel, 3);
        if (!data) {
            std::cerr << "Image open failed.\n";
            return -1;
        }

        clock_t start, end;
        start = clock();
        //获得目标object_list
        int iTmp = HelmetD_Detect(hHelmetd, image_width, image_height, image_channel, data, true, object_list);
        if (iTmp)
        {
            std::cout << "model error" << std::endl;
            return 0;
        }
        end = clock();
        std::cout << "detect picture time=" << double(end - start) / CLOCKS_PER_SEC << " s" << std::endl;

        //cv绘图
        const int image_orig_height = int(image_height);
        const int image_orig_width = int(image_width);
        const int target_height = 448;
        const int target_width = 640;
        float scale_x = image_orig_width / (float)target_width;
        float scale_y = image_orig_height / (float)target_height;
        uint8_t *ifm_buf = new uint8_t[image_orig_width * image_orig_height * 4];
        for (int i = 0; i < image_orig_height * image_orig_width; i++)
        {
            ifm_buf[i * 4 + 2] = data[i * 3];
            ifm_buf[i * 4 + 1] = data[i * 3 + 1];
            ifm_buf[i * 4] = data[i * 3 + 2];
            ifm_buf[i * 4 + 3] = 255;
        }
        for (int i = 0; i < object_list.size(); i++)
        {
            auto object = object_list[i];
            Rectangle(label_list, object, (void *)ifm_buf, image_orig_height, image_orig_width, object.x1, object.y1,
                object.x2, object.y2, scale_x, scale_y);
        }
        //存储标记后图片
        for (int i = 0; i < image_orig_height * image_orig_width; i++)
        {
            auto tmp = ifm_buf[i * 4];
            ifm_buf[i * 4] = ifm_buf[i * 4 + 2];
            ifm_buf[i * 4 + 2] = tmp;
        }
        char buff[256];
        sprintf(buff, "%s.png", vctStrFileNames[i]);
        int success = stbi_write_bmp(buff, image_orig_width, image_orig_height, 4, ifm_buf);
        if (!success)
        {
            return -1;
        }
        fprintf(stdout, "Object-Detector Done.\nNumber of objects: %d\n", int(object_list.size()));

        //close
        if (ifm_buf)
        {
            delete[] ifm_buf;
            ifm_buf = NULL;
        }
        if (data)
        {
            delete data;
            data = NULL;
        }

        
       // 按任意键或鼠标处理下一张
        HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE);
        DWORD mode;
        GetConsoleMode(hStdin, &mode);
        mode &= ~ENABLE_QUICK_EDIT_MODE;
        SetConsoleMode(hStdin, mode);
        char ch;
        while (1) {
            //int x;
            if (KEY_DOWN(VK_LBUTTON)) { //鼠标左键按下
                break;
            }
            if (_kbhit()) {//如果有按键按下,则_kbhit()函数返回真
                ch = _getch();//使用_getch()获取按下的键值
                break;
            }
            Sleep(18);
        }

        if (ch == 27 || ch == 'q')
        {
            return 0;
        }//当按下ESC时退出循环,ESC键的键值是27.

    }//for file

    return 0;
}

5.3 一些比较和分析

1 模型部署到windows上后,测试结果和和yolov5 detetct.py运行结果大部分情况下是一致的(可信度的差异在0.1),通常是直接运行detetct.py的置信度更大。
00000.
安全帽检测(yolov5+tnn)_第45张图片
000024.
安全帽检测(yolov5+tnn)_第46张图片

00001
安全帽检测(yolov5+tnn)_第47张图片
00000.
安全帽检测(yolov5+tnn)_第48张图片

000002
安全帽检测(yolov5+tnn)_第49张图片
000003
安全帽检测(yolov5+tnn)_第50张图片
(网图)
安全帽检测(yolov5+tnn)_第51张图片
(网图)
安全帽检测(yolov5+tnn)_第52张图片
(网图)
安全帽检测(yolov5+tnn)_第53张图片
(网图)

安全帽检测(yolov5+tnn)_第54张图片

2 .模型存在错判,漏判的情况
(网图)
安全帽检测(yolov5+tnn)_第55张图片
3.tnn推理结果与yolov5运行detect.py的结果没有完全一致
安全帽检测(yolov5+tnn)_第56张图片

4 .检测时间与图片上的物体数量没有明显的关系
安全帽检测(yolov5+tnn)_第57张图片

你可能感兴趣的:(pytorch,深度学习,python)