1、接口主程序
1)接口参数数据校验。jsonschema
2)异常处理。
# 方法一: try: pass except Exception as err: pass
# 方法二,此方法,可将其编写在其它文件子函数里,并实现进行一步异常处理统一返回给前端 raise ValueError('....')
3)单/多进程 + 协程
# -*- coding:utf-8 -*-
import json
import os
import copy
import time
import torch
from gevent import pywsgi, monkey
# 多线程,非阻塞
monkey.patch_all()
from flask import Flask, request, jsonify
from flask_cors import CORS
from multiprocessing import cpu_count, Process
from jsonschema import validate, ValidationError
from detect import inference_main
from train import train_main
from utils.job_manager import kill_process_by_port, kill_process_by_name
from utils.logger import get_logger
# 日志
log_file = './logs/yolov5.log'
logger = get_logger(name='yolov5', log_file=log_file)
# flask服务
app = Flask(__name__)
CORS(app, resources=r'/*')
app.config['JSON_AS_ASCII'] = False
# http接口参数校验
# 接口http://ip:port/train的用户校验schema字典定义
schema_train = {
"type": "object",
"required": ["event_id", "event_type", "payload"],
"properties": {
"event_id": {
"type": "integer",
},
"event_type": {
"type": "string",
},
"payload": {
"type": "object",
"required": ["data_config", 'basic_hyp'],
"properties": {
"data_config": {"type": "object",
"required": ["train", "val", "nc", "names", "result_path"],
"properties": {"train": {"type": "string"},
"val": {"type": "string"},
"nc": {"type": "integer", "minimum": 1},
"names": {"type": "array"},
"result_path": {"type": "string"}
}},
"basic_hyp": {"type": "object",
"required": ["epochs", "batch-size", "workers", "img-size", "device"],
"properties": {"epochs": {"type": "integer", "minimum": 1},
"batch-size": {"type": "integer", "minimum": 1},
"workers": {"type": "integer", "minimum": 0},
"img-size": {"type": "array"},
"device": {"type": "string"}
}},
}
}
}
}
# 接口http://ip:port/inference的用户校验schema字典定义
schema_inference = {
"type": "object",
"required": ["event_id", "event_type", "payload"],
"properties": {
"event_id": {
"type": "integer",
},
"event_type": {
"type": "string",
},
"payload": {
"type": "object",
"required": ["data_config", "basic_hyp"],
"properties": {
"data_config": {"type": "object",
"required": ["test", "result_path"],
"properties": {"test": {"type": "string"},
"result_path": {"type": "string"}
}
},
"basic_hyp": {"type": "object",
"required": ["weights", "device"],
"properties": {"weights": {"type": "string"},
"device": {"type": "string"}
}
}
}
}
}
}
# data参数校验装饰器,可指定不同的校验schema
def json_validate(schema):
def wrapper(func):
def inner(data, *args, **kwargs):
try:
validate(data, schema)
except ValidationError as e:
logger.error("接口参数校验失败:{}!".format(e.message))
return {'error': True, 'msg': e.message}
else:
logger.info("接口参数校验通过!")
return func(data, *args, **kwargs)
return inner
return wrapper
def api_result(event_id, state_code, msg_type, msg, result):
"""
构建接口返回结果
"""
api_res = {
"event_id": event_id,
"state_code": state_code,
"feed_type": msg_type,
"feed_msg": msg,
"feed_data": result,
}
logger.info("feed_msg: {}".format(api_res))
logger.info("=====================================================\n")
return jsonify(api_res)
@app.route('/train', methods=['POST'])
def train_post():
"""
模型训练接口
Returns:
"""
if request.method == "POST":
try:
# 解析请求参数
@json_validate(schema=schema_train)
def api_parameters(msg_dict_copy):
logger.info("启动模型训练.......")
return msg_dict_copy
request_data = request.get_data().decode()
msg_dict = json.loads(request_data)
# msg_dict_copy = copy.deepcopy(msg_dict)
# msg_dict_copy = msg_dict
logger.info("request msg: {}".format(msg_dict))
# 校验参数
msg_dict['error'] = False
validate_msg = api_parameters(msg_dict)
# 训练
if not validate_msg['error']:
result_path = train_main(msg_dict)
result = {"result_path": result_path}
return api_result(msg_dict['event_id'], 200, 'train', 'success', result)
else:
return api_result(msg_dict['event_id'], 501, 'train', '参数设置有误,请核查,错误信息:{}'.format(validate_msg['msg']), None)
except Exception as e:
logger.error(e)
return api_result(msg_dict['event_id'], 500, "train", str(e), None)
else:
feed_msg = "error, request.method != POST"
return api_result("101010", 400, "train", feed_msg, None)
@app.route('/inference', methods=['POST'])
def inference_post():
"""
模型推理接口
Returns:
"""
if request.method == "POST":
try:
# 解析请求参数
@json_validate(schema=schema_inference)
def api_parameters(msg_dict_copy):
logger.info("启动模型推理.......")
return msg_dict_copy
request_data = request.get_data().decode()
msg_dict = json.loads(request_data)
# msg_dict_copy = copy.deepcopy(msg_dict)
# msg_dict_copy = msg_dict
logger.info("request msg: {}".format(msg_dict))
# 校验参数
msg_dict['error'] = False
validate_msg = api_parameters(msg_dict)
# 推理
if not validate_msg['error']:
result_path = inference_main(msg_dict)
result = {"result_path": result_path}
return api_result(msg_dict['event_id'], 200, 'inference', 'success', result)
else:
return api_result(msg_dict['event_id'], 501, 'inference', '参数设置有误,请核查,错误信息:{}'.format(validate_msg['msg']), None)
except Exception as e:
logger.error(e)
return api_result(msg_dict['event_id'], 500, "train", str(e), None)
else:
feed_msg = "error, request.method != POST"
return api_result("101010", 400, "inference", feed_msg, None)
def start_app(MULTI_PROCESS=False, USE_CORES=1):
"""
启动服务
Returns:
"""
# 先清空显存占用
torch.cuda.empty_cache()
try:
logger.info("\n===============================================================================")
logger.info("deeplearn server starting...")
# 持久化服务
if MULTI_PROCESS == False:
server = pywsgi.WSGIServer(("0.0.0.0", 8080), app)
server.serve_forever()
logger.info("deeplearn server start success.")
print('单进程 + 协程')
return
else:
mulserver = pywsgi.WSGIServer(('0.0.0.0', 8080), app)
mulserver.start()
def server_forever():
mulserver.start_accepting()
mulserver._stop_event.wait()
all_cpu_cores = cpu_count()
if USE_CORES > all_cpu_cores:
use_cores = all_cpu_cores
else:
use_cores = USE_CORES
for i in range(use_cores):
p = Process(target=server_forever)
p.start()
print('多进程 + 协程,进程数:{}+1'.format(use_cores))
return
except Exception as err:
logger.error("exception in server: {}".format(err))
logger.error("a same service port has been started. please shut down before operation.")
try:
logger.error("{}".format(kill_process_by_port(8080)))
except Exception as err:
logger.error("exception in server: {}".format(err))
def stop_app():
"""
结束服务
Returns:
"""
logger.info("\n===============================================================================")
logger.warning("deeplearn server stopping...")
try:
logger.info("stop info: {}".format(kill_process_by_port(8080)))
except Exception as err:
logger.error("stop err: {}".format(err, kill_process_by_name("python.exe")))
if __name__ == "__main__":
# app.run(port=8080, host="0.0.0.0", )
MULTI_PROCESS = True
# 默认启动2+1进程
USE_CORES = int(os.getenv('USE_CORES')) if os.getenv('USE_CORES') else 2
start_app(MULTI_PROCESS=MULTI_PROCESS, USE_CORES=USE_CORES)
# stop_app()
2、进程相关常用函数
# job_manager.py
# -*- coding:utf-8 -*-
import os
import psutil
def get_all_process():
pid_dict = {}
pids = psutil.pids()
try:
for pid in pids:
p = psutil.Process(pid)
pid_dict[pid] = p.name()
except Exception as err:
pass
return pid_dict
def find_pid_by_name(name: str):
"""
根据进程名获取进程pid
Args:
name: process name
Returns: process pid
"""
pros = psutil.process_iter()
print("[" + name + "]'s pid is:")
pids = []
for pro in pros:
if (pro.name() == name):
print(pro.pid)
pids.append(pro.pid)
return pids
def find_port_by_pid(pid: int):
"""根据pid寻找该进程对应的端口"""
alist = []
# 获取当前的网络连接信息
net_con = psutil.net_connections()
for con_info in net_con:
if con_info.pid == pid:
alist.append({pid: con_info.laddr.port})
return alist
def find_pid_by_port(port: int):
"""根据端口寻找该进程对应的pid"""
pid_list = []
# 获取当前的网络连接信息
net_con = psutil.net_connections()
for con_info in net_con:
if con_info.laddr.port == port:
pid_list.append(con_info.pid)
return pid_list
def kill_process_by_pid(pid):
# windows
# cmd = 'taskkill /pid ' + pid + ' /f'
cmd = 'kill -9 ' + pid
try:
os.system(cmd)
except Exception as e:
print(e)
def kill_process_by_name(set_name):
all_pid = get_all_process()
for pid, name in all_pid.items():
if name == set_name:
kill_process_by_pid(str(pid))
msg_str = "kill process in name: {}".format(set_name)
return msg_str
def kill_process_by_port(port):
pids = find_pid_by_port(port)
for pid in pids:
kill_process_by_pid(str(pid))
msg_str = "kill process in port: {}".format(port)
return msg_str
def clean_cmd():
kill_process_by_name("cmd.exe")
kill_process_by_name("bash.exe")
if __name__ == "__main__":
# kill_process_by_port(8010)
# kill_process_by_name("python.exe")
# kill_process_by_name("cmd.exe")
# kill_process_by_name("bash.exe")
# kill_process_by_name("myProcess")
# print(find_pid_by_port('8080'))
print(find_pid_by_name('myProcess'))
3、日志模块
# logger.py
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
from logging.handlers import RotatingFileHandler
import functools
logger_initialized = {}
@functools.lru_cache()
def get_logger(name='root', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name == logger_name:
return logger
formatter = logging.Formatter(
'[%(asctime)s.%(msecs)03d] %(name)s %(levelname)s: %(message)s', datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_file is not None:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
# file_handler = logging.FileHandler(log_file, 'a')
file_handler = RotatingFileHandler(filename=log_file, maxBytes=10 * 1024 * 1024, backupCount=15, encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.setLevel(log_level)
logger_initialized[name] = True
return logger
if __name__ == "__main__":
# 日志
log_file = './logs/yolov5.log'
logger = get_logger(name='yolov5', log_file=log_file)
4、dockefile 构建镜像
docker build -t wood_detect:test .
# Dockerfile
# Start FROM Nvidia PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
#FROM nvcr.io/nvidia/pytorch:21.05-py3
#FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-runtime
FROM deploy.hello.com/2020-public/yolov5_base:1.0.1
# Install linux packages
#RUN apt update && apt install -y zip htop screen libgl1-mesa-glx
## Create working directory
#RUN mkdir -p /usr/src/app
WORKDIR /usr/src/app
# Copy contents
COPY . /usr/src/app
EXPOSE 8080
# Install python dependencies
#COPY requirements.txt .
# RUN python -m pip install --upgrade pip
#RUN pip uninstall -y nvidia-tensorboard nvidia-tensorboard-plugin-dlprof
#RUN pip install --no-cache -r requirements.txt coremltools onnx gsutil -i https://pypi.douban.com/simple/
RUN pip install --no-cache -r requirements.txt -i https://pypi.douban.com/simple/
# RUN pip install --no-cache -U torch torchvision
## Set environment variables
#ENV HOME=/usr/src/app
#
ENTRYPOINT ["python","main_app.py"]
# --------------------------------------------------- Extras Below ---------------------------------------------------
# Build and Push
# t=ultralytics/yolov5:latest && sudo docker build -t $t . && sudo docker push $t
# for v in {300..303}; do t=ultralytics/coco:v$v && sudo docker build -t $t . && sudo docker push $t; done
# Pull and Run
# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all $t
# Pull and Run with local directory access
# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all -v "$(pwd)"/coco:/usr/src/coco $t
# Kill all
# sudo docker kill $(sudo docker ps -q)
# Kill all image-based
# sudo docker kill $(sudo docker ps -qa --filter ancestor=ultralytics/yolov5:latest)
# Bash into running container
# sudo docker exec -it 5a9b5863d93d bash
# Bash into stopped container
# id=$(sudo docker ps -qa) && sudo docker start $id && sudo docker exec -it $id bash
# Send weights to GCP
# python -c "from utils.general import *; strip_optimizer('runs/train/exp0_*/weights/best.pt', 'tmp.pt')" && gsutil cp tmp.pt gs://*.pt
# Clean up
# docker system prune -a --volumes
5、Docker-compose.yml方法构建镜像并部署
docker-compose up
# Docker-compose.yml
# GPU配置,参考https://docs.docker.com/compose/gpu-support/
version: "3.8"
services:
yolov5:
build:
context: .
image: deploy.deepexi.com/2048-public/yolov5_server:alpha_v1.0
restart: always
container_name: yolov5_server
ports:
- 8080:8080
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: [ '0',]
capabilities: [ gpu ]
6、重写yaml文件
# --hyp,data/hyp.scratch.yaml文件
with open(opt.hyp) as f:
hyp = yaml.safe_load(f)
if 'lr0' in yolo_hype and isinstance(yolo_hype['lr0'], float) and yolo_hype['lr0'] >= 0.0:
hyp['lr0'] = yolo_hype['lr0']
yaml.safe_dump(hyp, open(opt.hyp, mode='w'))