通过TenSorRT转换后的engine引擎文件进行验证的脚本

YOLOv8算法验证pt文件的精度脚本一般都很常见,工程项目里面一般会有

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO('/best.pt')   #权重文件路径
    model.val(data='/data.yaml',  #yaml文件路径
              split='val',
              imgsz=640,
              batch=1,
              # rect=False,
              # save_json=True, # if you need to cal coco metric
              project='runs/val',
              name='exp',
              )

一般结果是
在这里插入图片描述

但是在Jetson上面通过engine引擎文件去验证精确度的文件很少有,下面将进行分享
可以将其命名为validate_model.py

import os
from ultralytics import YOLO
import ctypes
import tensorrt as trt
import yaml
import psutil
import shutil

def print_memory_usage():
    process = psutil.Process(os.getpid())
    print(f"Memory usage: {process.memory_info().rss / (1024 ** 2):.2f} MB")

# 显式加载插件库
plugin_library = "/usr/lib/aarch64-linux-gnu/libnvinfer_plugin.so"
ctypes.CDLL(plugin_library)

# 初始化 TensorRT 插件
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(TRT_LOGGER, "")

def main():
    # 配置路径
    model_path = "/best.engine"#生成的engine文件路径
    data_path = "/data.yaml" #yaml文件所在路径
    img_size = 640  # 这个看生成的engine文件设定的图片分辨率,降低分辨率可以以减少内存
    chunk_size = 300  # 每批次验证图片数量,内存过小可以设置更小,1+
    #下面三个文件路径可以放datasets文件下
    temp_image_dir = "/temp_val/images"#临时存储的验证集图片路径
    temp_label_dir = "/temp_val/labels"#临时存储的验证集标签路径
    temp_data_path = "/temp_data.yaml"#临时生成的yaml文件所在路径

    # 检查引擎文件是否存在
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found: {model_path}")

    # 加载 YOLO 模型
    try:
        model = YOLO(model_path)
        print("Model loaded successfully!")
        print_memory_usage()
    except Exception as e:
        print(f"Failed to load model: {e}")
        return

    # 读取数据集
    try:
        with open(data_path, 'r') as f:
            data_config = yaml.safe_load(f)
        val_images_dir = data_config['val']
        val_labels_dir = "/datasets/labels/val"#验证集标签所在路径
        image_list = [os.path.join(val_images_dir, img) for img in os.listdir(val_images_dir) if img.endswith(('.JPG', '.png'))]
    except Exception as e:
        print(f"Failed to load dataset: {e}")
        return

    # 分批次验证
    os.makedirs(temp_image_dir, exist_ok=True)
    os.makedirs(temp_label_dir, exist_ok=True)
    for i in range(0, len(image_list), chunk_size):
        subset_images = image_list[i:i + chunk_size]

        # 拷贝分批次图片和对应标签到临时目录
        for img_path in subset_images:
            img_name = os.path.basename(img_path)
            label_name = img_name.replace('.JPG', '.txt').replace('.png', '.txt')
            label_path = os.path.join(val_labels_dir, label_name)

            # 拷贝图片
            symlink_image_path = os.path.join(temp_image_dir, img_name)
            if not os.path.exists(symlink_image_path):
                os.symlink(img_path, symlink_image_path)

            # 拷贝标签
            symlink_label_path = os.path.join(temp_label_dir, label_name)
            if os.path.exists(label_path) and not os.path.exists(symlink_label_path):
                os.symlink(label_path, symlink_label_path)

        # 创建临时 `temp_data.yaml`
        temp_data_config = {
            'path': data_config['path'],
            'train': '',  # 验证阶段不需要训练数据
            'val': temp_image_dir,
            'names': data_config['names']
        }
        with open(temp_data_path, 'w') as f:
            yaml.dump(temp_data_config, f)

        # 运行验证
        try:
            print(f"Processing batch {i // chunk_size} with {len(subset_images)} images...")
            results = model.val(
                data=temp_data_path,
                imgsz=img_size,
                batch=1,
                save_json=False,
                project='runs/val',
                name=f'exp_batch_{i // chunk_size}',
            )
            # 打印详细的 mAP 结果
            print(f"Validation results for batch {i // chunk_size}:")
            print(results.pandas())  # 改为打印 pandas 数据
            print_memory_usage()
        except Exception as e:
            print(f"Validation failed for batch {i // chunk_size}: {e}")

        # 清理临时图片和标签目录
        shutil.rmtree(temp_image_dir)
        shutil.rmtree(temp_label_dir)
        os.makedirs(temp_image_dir, exist_ok=True)
        os.makedirs(temp_label_dir, exist_ok=True)

if __name__ == "__main__":
    main()

结果如下
通过TenSorRT转换后的engine引擎文件进行验证的脚本_第1张图片

你可能感兴趣的:(深度学习-硬件篇,嵌入式硬件,mcu,python)