香烟小草屋 模型引入 代码逻辑

项目根目录:/nfs/xs/Codes/retail_detect

1. 起点:运行 flask 网页程序

$ python run.py 

run.py 定义了程序监听端口 5050,网页会运行在 http://10.214.211.207:5050/

from app import app

app.run(
    host='0.0.0.0',
    port=5050,
    debug=True
)

2. 导入模型 API

app/__init__.pydetect 包引入模型推理 API 接口

from flask import Flask
from detect.demo import load_detection_model
from detect.demo_classifier import load_classification_model

app = Flask(__name__)

app.config.from_object('config')

from app import views

# init dababase in model.py

# load models
import os

# 指定模型推理 GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

# 调试网页时,可注释掉模型导入部分,节省时间
# load_detection_model()
# load_classification_model()

3. 模型 API 使用

  • app/views.pydetect_and_count() 函数调用 count_objs_with_detect() 方法,该方法从 count.count_simple 导入
from count.count_simple import count_objs_with_detect

# 103 行
cat_num, cigar_cat_num = count_objs_with_detect(
    in_video,
    thumbnail_basepath=thumbnail_basepath,
    out_video=out_video  # save out video isn't necessary
)

  • count/count_simple.py 采用 跳帧检测 + 时空约束 对香烟进行计数
    • 代码第 104 行,调用 faster_rcnn_detect() 对图片中 条烟/盒烟 进行 检测定位,该方法从 detect.demo 导入
    • 代码第 116 行,merge_two_obj_dict(),按照 时空约束当前帧检测结果已有全部检测结果 进行 归并并进行物体细分类,更新各类物体检测数目,该方法从 count.utils_cnt 导入
from detect.demo import faster_rcnn_detect

# 104 行
frame_track_boxes, frame_box_labels = faster_rcnn_detect(frame_s, score_thre=0.8)  # x1,y1,x2,y2

# 116 行
total_obj_dict = merge_two_obj_dict(total_obj_dict, frame_obj_dict,
                                    begin_dt, video_fps,
                                    thumbnail_basepath)

① 调用 粗检测器 faster_rcnn_detect()

detect/demo.py 定义 2 个函数:

  • load_detection_model() 加载模型,在 app/__init__.py 导入
  • faster_rcnn_detect() 模型推理函数
@torch.no_grad()
def faster_rcnn_detect(img, score_thre=0.7):
    # img: ``PIL Image`` or ``numpy.ndarray``, cv bgr will convert to rgb
    img_tensor = F.to_tensor(img)

    detection = model([img_tensor.to(device)])[0]  # only 1 img

    # parse detection
    boxes = detection['boxes'].cpu().numpy()
    labels = detection['labels'].cpu().numpy()  # label idxs
    scores = detection['scores'].cpu().numpy()
    keep_idxs = np.where(scores > score_thre)[0]

    return boxes[keep_idxs], labels[keep_idxs]

② 调用 细分类器 merge_two_obj_dict()

count/utils_cnt.py 定义 3 个函数:

  • merge_two_obj_list():归并 某一类 的检测结果,2个 list
  • merge_two_obj_dict():归并 所有类 的检测结果,2个 dict
  • find_obj_in_list():在 已计数物体中 寻找 当前帧检测到的物体,决定更新还是添加 1 个新物体

merge_two_obj_list()

  • 30 行 调用 classify() 对粗检测结果 (检测存储的图片缩略图) 进行细分类
from detect.demo_classifier import classify

# 30 行
sub_label = cigar_5_names[classify(new_obj.thumbnail)]  # todo 更改为 10 类
  • 42 行 调用 insert_new_obj_todb() 将细分类结果存入数据库
from detect.demo_classifier import classify

# 42 行
insert_new_obj_todb(new_obj, thumbnail_path)

4. 检测模型 detect 目录 放置方式

  • net:模型定义文件,如 faster_rcnn.py, resnet_classifier.py
  • ckpt:训练好的模型
  • API 接口:如 demo.py,demo_classifier.py

5. 全局变量 所有类名

count/cats.py 在此定义 所有类别名称,包括 大类/小类

# super class
cat_names = [
    'bg', 'cigar_A', 'cigar_a'
]

# sub class
# todo: add 10 class names
cigar_A_10_names = [
    'huanghelou_A', 'zhonghua_A', 'jiaozi_E', ...
]

cigar_a_10_names = [
    'yunyan_a', 'liqun_a', ...
]

你可能感兴趣的:(香烟小草屋 模型引入 代码逻辑)