项目根目录:/nfs/xs/Codes/retail_detect
1. 起点:运行 flask 网页程序
$ python run.py
run.py
定义了程序监听端口 5050,网页会运行在 http://10.214.211.207:5050/
from app import app
app.run(
host='0.0.0.0',
port=5050,
debug=True
)
2. 导入模型 API
app/__init__.py
从 detect
包引入模型推理 API 接口
from flask import Flask
from detect.demo import load_detection_model
from detect.demo_classifier import load_classification_model
app = Flask(__name__)
app.config.from_object('config')
from app import views
# init dababase in model.py
# load models
import os
# 指定模型推理 GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
# 调试网页时,可注释掉模型导入部分,节省时间
# load_detection_model()
# load_classification_model()
3. 模型 API 使用
-
app/views.py
中detect_and_count()
函数调用count_objs_with_detect()
方法,该方法从count.count_simple
导入
from count.count_simple import count_objs_with_detect
# 103 行
cat_num, cigar_cat_num = count_objs_with_detect(
in_video,
thumbnail_basepath=thumbnail_basepath,
out_video=out_video # save out video isn't necessary
)
-
count/count_simple.py
采用 跳帧检测 + 时空约束 对香烟进行计数- 代码第 104 行,调用
faster_rcnn_detect()
对图片中 条烟/盒烟 进行 检测定位,该方法从detect.demo
导入 - 代码第 116 行,
merge_two_obj_dict()
,按照 时空约束 对 当前帧检测结果 和 已有全部检测结果 进行 归并,并进行物体细分类,更新各类物体检测数目,该方法从count.utils_cnt
导入
- 代码第 104 行,调用
from detect.demo import faster_rcnn_detect
# 104 行
frame_track_boxes, frame_box_labels = faster_rcnn_detect(frame_s, score_thre=0.8) # x1,y1,x2,y2
# 116 行
total_obj_dict = merge_two_obj_dict(total_obj_dict, frame_obj_dict,
begin_dt, video_fps,
thumbnail_basepath)
① 调用 粗检测器 faster_rcnn_detect()
detect/demo.py
定义 2 个函数:
-
load_detection_model()
加载模型,在app/__init__.py
导入 -
faster_rcnn_detect()
模型推理函数
@torch.no_grad()
def faster_rcnn_detect(img, score_thre=0.7):
# img: ``PIL Image`` or ``numpy.ndarray``, cv bgr will convert to rgb
img_tensor = F.to_tensor(img)
detection = model([img_tensor.to(device)])[0] # only 1 img
# parse detection
boxes = detection['boxes'].cpu().numpy()
labels = detection['labels'].cpu().numpy() # label idxs
scores = detection['scores'].cpu().numpy()
keep_idxs = np.where(scores > score_thre)[0]
return boxes[keep_idxs], labels[keep_idxs]
② 调用 细分类器 merge_two_obj_dict()
count/utils_cnt.py
定义 3 个函数:
-
merge_two_obj_list()
:归并 某一类 的检测结果,2个list
-
merge_two_obj_dict()
:归并 所有类 的检测结果,2个dict
-
find_obj_in_list()
:在 已计数物体中 寻找 当前帧检测到的物体,决定更新还是添加 1 个新物体
merge_two_obj_list()
- 30 行 调用
classify()
对粗检测结果 (检测存储的图片缩略图) 进行细分类
from detect.demo_classifier import classify
# 30 行
sub_label = cigar_5_names[classify(new_obj.thumbnail)] # todo 更改为 10 类
- 42 行 调用
insert_new_obj_todb()
将细分类结果存入数据库
from detect.demo_classifier import classify
# 42 行
insert_new_obj_todb(new_obj, thumbnail_path)
4. 检测模型 detect
目录 放置方式
-
net
:模型定义文件,如faster_rcnn.py
,resnet_classifier.py
-
ckpt
:训练好的模型 - API 接口:如
demo.py
,demo_classifier.py
5. 全局变量 所有类名
count/cats.py
在此定义 所有类别名称,包括 大类/小类
# super class
cat_names = [
'bg', 'cigar_A', 'cigar_a'
]
# sub class
# todo: add 10 class names
cigar_A_10_names = [
'huanghelou_A', 'zhonghua_A', 'jiaozi_E', ...
]
cigar_a_10_names = [
'yunyan_a', 'liqun_a', ...
]