Win环境YOLOv8-Seg自定义数据集分割训练

安装

在安装了PyTorch的环境下

pip install ultralytics

制作coco-seg格式数据集

语义分割数据

先用opencv findContours生成轮廓,再按格式存储

# %% 通过CV识别生成YOLO-Seg数据集,格式为coco-seg
class obj(Enum):
    human = 0
    danger = 1
    
def fliter(blured, thresh=115, minisize=500, maxsize=2000, findMAN = False):
    '''
    returns: [array(*, 1, 2), ]
    '''
    ret, binary = cv2.threshold(blured, thresh, 255 ,cv2.THRESH_BINARY)
    contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    # print([cv2.contourArea(x) for x in sorted(contours, key=cv2.contourArea, reverse=True)])
    if findMAN:
        contours = sorted(contours, key=cv2.contourArea, reverse=True)
        if cv2.contourArea(contours[0])>120000 and cv2.contourArea(contours[0])<200000:
            return [contours[0]]
        else:
            return []
    contours_draw = []
    for contour in contours:
        flag = False
        if cv2.contourArea(contour) > minisize and cv2.contourArea(contour) < maxsize:
            flag = True
            for pt in contour:
                if cv2.pointPolygonTest(contours[0], tuple(pt[0]), measureDist = False) < 0:
                    # pointPolygonTest 1在内部 0在轮廓上 -1在外部
                    flag = False
        if flag:
            contours_draw.append(contour)
    return contours_draw

def contour2str(contour, w, h, label):
    if len(contour) == 0:
        return ""
    row_str = ""
    for point in contour:
        x = round(float(point[0, 0]) / w, 6)
        y = round(float(point[0, 1]) / h, 6)
        row_str += " " + str(x) + " " + str(y)
    return str(label.value) + " " + row_str + "\n"

def contours2str(contours, w, h, label):
    res = ""
    for contour in contours:
        res += contour2str(contour, w, h, label)
    return res

def deal(img_path, label_path):
    label = ""
    img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
    h, w = img.shape[:2]
    # 降噪
    ksize = 5
    blured = cv2.GaussianBlur(img, (ksize, ksize), 1, 1)
    blured = cv2.boxFilter(blured, -1, (ksize, ksize), normalize=1)
    blured = cv2.bilateralFilter(blured, 2, 25, 25)
    blured = cv2.medianBlur(blured, ksize)
    blured = cv2.blur(blured, (ksize, ksize))
    label += contours2str(fliter(blured, thresh=80, findMAN = True), w, h, obj.human)
    label += contours2str(fliter(blured, thresh=100, minisize=1500, maxsize=3000), w, h, obj.danger)
    label += contours2str(fliter(blured, thresh=115, minisize=500, maxsize=2000), w, h, obj.danger)
    label += contours2str(fliter(blured, thresh=140, minisize=2000, maxsize=5000), w, h, obj.danger)
    with open(label_path, 'w') as f:
        f.write(label)

img_folder = 'imgs'
label_folder = 'labels'
if not os.path.isdir(label_folder):
    os.mkdir(label_folder)
imglist = os.listdir(img_folder)
for i in tqdm.trange(len(imglist)):
    img_path = os.path.join(img_folder, imglist[i])
    label_path = os.path.join(label_folder, imglist[i].replace(".bmp", ".txt"))
    deal(img_path, label_path)
# 转换图片格式
imgfolder = r"./datasets/THZ-dataset/images/train"
imglist = os.listdir(imgfolder)
newfolder = r"./datasets/THZ-dataset/images/train1"
for img_path in imglist:
    img = cv2.imread(os.path.join(imgfolder, img_path), cv2.IMREAD_UNCHANGED)
    cv2.imwrite(os.path.join(newfolder, img_path.replace(".bmp", ".jpg")), img)

# %% 训练验证划分
import numpy, shutil
imgfolder = r"./datasets/THZ-dataset/images/train"
imgval = r"./datasets/THZ-dataset/images/val"
labelfolder = r"./datasets/THZ-dataset/labels/train"
labelval = r"./datasets/THZ-dataset/labels/val"
imglist = os.listdir(imgfolder)
numpy.random.shuffle(imglist)
for img in imglist[:631]:
    shutil.move(os.path.join(imgfolder, img), imgval)
    shutil.move(os.path.join(labelfolder, img.replace(".jpg", ".txt")), labelval)

数据格式:(不支持BMP图像)Win环境YOLOv8-Seg自定义数据集分割训练_第1张图片

数据集描述

THZ-dataset-seg.yaml

# TZ-seg dataset by Straka

# Train/val/test sets as 
# 1) dir: path/to/imgs, 
# 2) file: path/to/imgs.txt, or 
# 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: ./THZ-dataset  # dataset root dir
train: images/train  # train images (relative to 'path') 4 images
val: images/val  # val images (relative to 'path') 4 images
test:  # test images (optional)

# Classes
names:
  0: person
  1: object

# Download script/URL (optional)
download: 

v8自动会把images/train换成labels/train

训练

from ultralytics import YOLO
model = YOLO("yolov8n-seg.pt")
model.train(data="datasets/THZ-dataset-seg.yaml", batch=-1, epochs=100, imgsz=640, workers=0, cos_lr=True, optimizer='Adam', device=0)

注意,windows下workers=0才可以运行,dataloader多线程的坑。权重可以在官网下载。
参考链接:官网 GitHub

验证

# %% 验证
model = YOLO("runs/segment/train/weights/best.pt")
metrics = model.val(workers=0)
# %% 预测
model = YOLO("runs/best.pt")
img = cv2.imread("./THZ_imgs/20230323204429_20729_.bmp")
res = model.predict([img])
res_plotted = res[0].plot()
plt.figure(figsize=(16, 20))
plt.imshow(cv2.cvtColor(res_plotted, cv2.COLOR_BGR2RGB))
plt.show()

你可能感兴趣的:(YOLO,计算机视觉,opencv)