如何使用YOLOv8在AI-TOD数据集上进行遥感目标检测,从安装依赖项、准备数据集、配置YOLOv8、训练和评估模型以及构建GUI应用程序展示检测
遥感目标检测,AI-TOD数据集aitod,训练集11214张,测试集集14018,验证集2804张,共计28036张。
yolo格式。
标签有8类,分别是 [‘airplane’, ‘bridge’, ‘storage-tank’, ‘ship’, ‘swimming-pool’, ‘vehicle’, ‘person’, ‘wind-mill’]
基于AI-TOD数据集进行遥感目标检测的训练和评估,、使用YOLOv8模型。下面将详细介绍从环境配置、数据准备、模型训练、评估到构建简单GUI应用程序的完整流程。
首先确保你的环境中已经安装了必要的库。这里我们主要需要PyTorch和Ultralytics的YOLOv8实现。
pip install torch torchvision torchaudio # 根据需要选择适合的版本
pip install ultralytics opencv-python
假设你已经下载并解压了AI-TOD数据集,并且它包含了YOLO格式的标注文件。确保数据集结构如下:
aitod/
├── images/
│ ├── train/
│ ├── val/
│ └── test/
└── labels/
├── train/
├── val/
└── test/
创建一个data.yaml
文件来定义数据集路径和类别信息。
train: ./aitod/images/train/
val: ./aitod/images/val/
nc: 8 # 类别数量
names: ['airplane', 'bridge', 'storage-tank', 'ship', 'swimming-pool', 'vehicle', 'person', 'wind-mill']
你可以选择加载YOLOv8的预训练模型开始训练,或者根据自己的需求定义模型架构。
from ultralytics import YOLO
# 加载预训练的YOLOv8n模型
model = YOLO('yolov8n.yaml') # 或者 'yolov8n.pt' 如果你想从预训练权重开始
接下来,使用定义好的data.yaml
文件开始训练。
# 开始训练
results = model.train(data='path/to/data.yaml', epochs=100, imgsz=640, batch=16)
epochs
: 设置训练轮数。imgsz
: 输入图像的尺寸。batch
: 每批次图像的数量。训练完成后,可以在验证集上评估模型性能。
# 在验证集上评估模型性能
metrics = model.val()
print(f"Validation mAP: {metrics.box.map}")
提供一个简单的PyQt5 GUI示例,用于展示YOLOv8的检测结果。
首先,确保安装了PyQt5。
pip install PyQt5
```
然后,编写以下代码来构建GUI。
```python
import sys
from PyQt5.QtWidgets import QApplication, QLabel, QVBoxLayout, QWidget, QPushButton, QFileDialog
from PyQt5.QtGui import QPixmap, QImage
import cv2
import numpy as np
from ultralytics import YOLO
class AppDemo(QWidget):
def __init__(self):
super().__init__()
self.setWindowTitle('YOLOv8 Remote Sensing Object Detection')
self.setGeometry(100, 100, 800, 600)
self.image_label = QLabel(self)
self.button = QPushButton("Load Image", self)
self.button.clicked.connect(self.load_image)
vbox = QVBoxLayout()
vbox.addWidget(self.image_label)
vbox.addWidget(self.button)
self.setLayout(vbox)
self.model = YOLO('path/to/your/best.pt') # 使用训练好的模型
def load_image(self):
fname, _ = QFileDialog.getOpenFileName(self, 'Open file', 'c:\\', "Image files (*.jpg *.png)")
if fname:
self.show_image(fname)
def show_image(self, image_path):
results = self.model.predict(source=image_path)
img = cv2.imread(image_path)
for r in results:
for box in r.boxes.xyxy:
x1, y1, x2, y2 = map(int, box)
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
height, width, channel = img.shape
bytes_per_line = 3 * width
q_img = QImage(img.data, width, height, bytes_per_line, QImage.Format_RGB888).rgbSwapped()
pixmap = QPixmap.fromImage(q_img)
self.image_label.setPixmap(pixmap)
if __name__ == '__main__':
app = QApplication(sys.argv)
demo = AppDemo()
demo.show()
sys.exit(app.exec_())