yolov5 + second_classify -- 代码

因内容重要,故做此笔记,也仅做笔记。 

detect_correct.py

from yolov5 import YOLOv5


import torch
from torchvision import transforms
# import numpy as np
from PIL import Image
import cv2
import os


def detect(image_path, yolov5_model, recog_model):
    img = cv2.imread(image_path)
    if img is None:
        print('None image', image_path)
        return False
    h, w, c = img.shape  # cv2 read format
    bboxes, scores = yolov5_model.detect(img, conf_thres=0.1)
    bboxes = bboxes.numpy()
    scores = scores.numpy()
    for bbox, score in zip(bboxes, scores):
        # print(bbox, score)
        # bbox = np.maximum(np.array(bbox), 0).tolist()
        bbox[0] = max(bbox[0], 0)
        bbox[1] = max(bbox[1], 0)
        bbox[2] = min(bbox[2], w)
        bbox[3] = min(bbox[3], h)
        x1, y1, x2, y2 = [int(_) for _ in bbox[:4]]
        cropped_img = Image.fromarray(cv2.cvtColor(img[y1:y2, x1:x2, :], cv2.COLOR_BGR2RGB))
        cropped_img = data_transforms(cropped_img).unsqueeze(0).cuda()
        outputs = recog_model(cropped_img)
        confidence, preds = torch.max(outputs.data, 1)
        # print(preds.item(), confidence.item())
        print(preds.item())
        class_name = class_dict[preds.item()]
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 255), thickness=2)
        if confidence.item() > 0.5:
            cv2.putText(img, class_name, (x1+5, y1+20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), thickness=2)
    if len(bboxes) > 0:
        cv2.imshow('image', img)
        cv2.waitKey(5000)



class_dict = dict()
data_transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

image_root = '/home/img_test'
yolov5_weight = '/home/weights/best.pt'
efficient_weight = '/home/models/efficientnet-b0.pth'
efficient_model = torch.load(efficient_weight).cuda().eval()

yolov5 = YOLOv5(yolov5_weight)
for image_path in os.listdir(image_root):
    ext = os.path.splitext(image_path)[1]
    if not ext in ['.jpg', '.png', '.jpeg']:
        continue
    print(image_path)
    image_path = os.path.join(image_root, image_path)
    detect(image_path, yolov5, efficient_model)

yolov5.py

import time
import torch
import numpy as np
import cv2
import os
import argparse
from tqdm import tqdm

from models.experimental import attempt_load
from utils.datasets import letterbox
from utils.general import check_img_size, non_max_suppression, scale_coords
from utils.torch_utils import select_device, time_sync


parser = argparse.ArgumentParser(description='Retinaface')
parser.add_argument('--image_dir', type=str)
parser.add_argument('--target_dir', type=str)
args = parser.parse_args()


class YOLOv5:
    def __init__(self, weights, imgsz=640):
        # Initialize
        print('Loading YOLO from', weights)
        self.device = select_device()
        self.half = self.device.type != 'cpu'  # half precision only supported on CUDA

        # Load model
        self.model = attempt_load(weights, map_location=self.device)  # load FP32 model
        self.imgsz = check_img_size(imgsz, s=self.model.stride.max())  # check img_size
        if self.half:
            self.model.half()  # to FP16
    

    def detect(self, orig_img, augment=True, conf_thres=0.25):
        iou_thres = 0.45

        # Padded resize
        img = letterbox(orig_img, new_shape=self.imgsz)[0]

        # Convert
        img = img[:,:,::-1].transpose(2, 0, 1) # BGR to RGB
        img = np.ascontiguousarray(img)
        img = torch.from_numpy(img).to(self.device)
        img = img.half() if self.half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        t0 = time.time()
        pred = self.model(img, augment=augment)[0]

        # Apply NMS
        pred = non_max_suppression(pred, conf_thres, iou_thres, classes=0)
        t2 = time_sync()

        # Process detections
        for i, det in enumerate(pred):
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], orig_img.shape).round()
        results = pred[0].cpu().numpy()
        bboxes = results[np.where(results[:, 5] == 0)]
        bboxes, scores = bboxes[:, :4], bboxes[:, 4]
        return torch.from_numpy(bboxes), torch.from_numpy(scores)


if __name__ == "__main__":  # 单独运行此文件时
    weigths = '/home/weights/yolov5x.pt'
    yolo = YOLOv5(weigths)
    img_root = args.image_dir
    img_list_file = None
    save_dir = args.target_dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    print(save_dir)
    img_list = list()
    if img_list_file is not None:
        with open(img_list_file, 'r') as f:
            for line in f.readlines():
                img_list.append(line.strip().split()[0])
    else:
        img_list = os.listdir(img_root)
    for img_name in tqdm(img_list):
        # print(img_name)
        img_path = os.path.join(img_root, img_name)
        img = cv2.imread(img_path)
        if img is None:
            continue
        bboxes, scores = yolo.detect(img, conf_thres=0.5)
        if len(bboxes) < 2:
            continue
        # print(bboxes.shape)
        save_file_name = os.path.splitext(img_name)[0]+'.txt'
        save_path = os.path.join(save_dir, save_file_name)
        op_f = open(save_path, 'w')
        op_f.write(img_name+'\n')
        op_f.write(str(len(bboxes))+'\n')
        for bbox, score in zip(bboxes, scores):
            score = score.numpy()
            x1, y1, x2, y2 = [int(_) for _ in bbox]
            op_array = [str(int(_)) for _ in bbox] + [str(score)]
            op_f.write('\t'.join(op_array)+'\n')
        #     cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 255), thickness=2)
        #     cv2.putText(img, str(np.round(score, 2)), (x1+5, y1+20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)
        # cv2.imshow('image', img)
        # cv2.waitKey(1000)

其他调用文件与yolov5默认一致。

你可能感兴趣的:(求职,项目,CV-计算机视觉,pytorch,python,计算机视觉)