在安装了PyTorch的环境下
pip install ultralytics
先用opencv findContours生成轮廓,再按格式存储
# %% 通过CV识别生成YOLO-Seg数据集,格式为coco-seg
class obj(Enum):
human = 0
danger = 1
def fliter(blured, thresh=115, minisize=500, maxsize=2000, findMAN = False):
'''
returns: [array(*, 1, 2), ]
'''
ret, binary = cv2.threshold(blured, thresh, 255 ,cv2.THRESH_BINARY)
contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# print([cv2.contourArea(x) for x in sorted(contours, key=cv2.contourArea, reverse=True)])
if findMAN:
contours = sorted(contours, key=cv2.contourArea, reverse=True)
if cv2.contourArea(contours[0])>120000 and cv2.contourArea(contours[0])<200000:
return [contours[0]]
else:
return []
contours_draw = []
for contour in contours:
flag = False
if cv2.contourArea(contour) > minisize and cv2.contourArea(contour) < maxsize:
flag = True
for pt in contour:
if cv2.pointPolygonTest(contours[0], tuple(pt[0]), measureDist = False) < 0:
# pointPolygonTest 1在内部 0在轮廓上 -1在外部
flag = False
if flag:
contours_draw.append(contour)
return contours_draw
def contour2str(contour, w, h, label):
if len(contour) == 0:
return ""
row_str = ""
for point in contour:
x = round(float(point[0, 0]) / w, 6)
y = round(float(point[0, 1]) / h, 6)
row_str += " " + str(x) + " " + str(y)
return str(label.value) + " " + row_str + "\n"
def contours2str(contours, w, h, label):
res = ""
for contour in contours:
res += contour2str(contour, w, h, label)
return res
def deal(img_path, label_path):
label = ""
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
h, w = img.shape[:2]
# 降噪
ksize = 5
blured = cv2.GaussianBlur(img, (ksize, ksize), 1, 1)
blured = cv2.boxFilter(blured, -1, (ksize, ksize), normalize=1)
blured = cv2.bilateralFilter(blured, 2, 25, 25)
blured = cv2.medianBlur(blured, ksize)
blured = cv2.blur(blured, (ksize, ksize))
label += contours2str(fliter(blured, thresh=80, findMAN = True), w, h, obj.human)
label += contours2str(fliter(blured, thresh=100, minisize=1500, maxsize=3000), w, h, obj.danger)
label += contours2str(fliter(blured, thresh=115, minisize=500, maxsize=2000), w, h, obj.danger)
label += contours2str(fliter(blured, thresh=140, minisize=2000, maxsize=5000), w, h, obj.danger)
with open(label_path, 'w') as f:
f.write(label)
img_folder = 'imgs'
label_folder = 'labels'
if not os.path.isdir(label_folder):
os.mkdir(label_folder)
imglist = os.listdir(img_folder)
for i in tqdm.trange(len(imglist)):
img_path = os.path.join(img_folder, imglist[i])
label_path = os.path.join(label_folder, imglist[i].replace(".bmp", ".txt"))
deal(img_path, label_path)
# 转换图片格式
imgfolder = r"./datasets/THZ-dataset/images/train"
imglist = os.listdir(imgfolder)
newfolder = r"./datasets/THZ-dataset/images/train1"
for img_path in imglist:
img = cv2.imread(os.path.join(imgfolder, img_path), cv2.IMREAD_UNCHANGED)
cv2.imwrite(os.path.join(newfolder, img_path.replace(".bmp", ".jpg")), img)
# %% 训练验证划分
import numpy, shutil
imgfolder = r"./datasets/THZ-dataset/images/train"
imgval = r"./datasets/THZ-dataset/images/val"
labelfolder = r"./datasets/THZ-dataset/labels/train"
labelval = r"./datasets/THZ-dataset/labels/val"
imglist = os.listdir(imgfolder)
numpy.random.shuffle(imglist)
for img in imglist[:631]:
shutil.move(os.path.join(imgfolder, img), imgval)
shutil.move(os.path.join(labelfolder, img.replace(".jpg", ".txt")), labelval)
THZ-dataset-seg.yaml
# TZ-seg dataset by Straka
# Train/val/test sets as
# 1) dir: path/to/imgs,
# 2) file: path/to/imgs.txt, or
# 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: ./THZ-dataset # dataset root dir
train: images/train # train images (relative to 'path') 4 images
val: images/val # val images (relative to 'path') 4 images
test: # test images (optional)
# Classes
names:
0: person
1: object
# Download script/URL (optional)
download:
v8自动会把images/train换成labels/train
from ultralytics import YOLO
model = YOLO("yolov8n-seg.pt")
model.train(data="datasets/THZ-dataset-seg.yaml", batch=-1, epochs=100, imgsz=640, workers=0, cos_lr=True, optimizer='Adam', device=0)
注意,windows下workers=0才可以运行,dataloader多线程的坑。权重可以在官网下载。
参考链接:官网 GitHub
# %% 验证
model = YOLO("runs/segment/train/weights/best.pt")
metrics = model.val(workers=0)
# %% 预测
model = YOLO("runs/best.pt")
img = cv2.imread("./THZ_imgs/20230323204429_20729_.bmp")
res = model.predict([img])
res_plotted = res[0].plot()
plt.figure(figsize=(16, 20))
plt.imshow(cv2.cvtColor(res_plotted, cv2.COLOR_BGR2RGB))
plt.show()