近期使用PointRend模型来做项目上目标分割部分,整个项目也都完成了,现在需要进行落地,让我有点烦恼,因为一期的项目都是基于tensorflow框架来加载所有的模型,现在使用的是pytorch框架,而且现在的项目中也使用到了tensorflow模型框架,虽然后期可以改成统一使用pytorch框架下训练的模型,但是现在需要呈现出效果,所以还是想可以快速实现。原先项目都是基于django做web服务器落地提供API接口给前端访问,所以我这次也是准备这么做,但是我面临两个问题:
①:django要如何启动pytorch 已经训练好的模型,并且要在启动django时,也要把模型也启动起来,这样做的好处是后期不需要每次都要启动模型,节省访问时间;
②:django能否可以同时加载pytorch和tensorflow框架的模型,因为我做算法研究一直使用tensorflow框架,虽然现在pytorch也挺火热的。
本人在这些方面还是一个菜鸟,所以有不对的地方,希望大神们勿喷。
①:我所有的模型都是放在服务器上,基于centos7系统开发的,所以需要配置PointRend的环境,这部分大家可以去看看官网或者去网上搜一搜一些大神的博客,这个都是有的。配置好这个环境后,需要修剪一下,因为很多部分是不需要的,可以根据PointRend的内容删除不要的部分,这个不细说很简单,如果服务器内存大,也可以不用删除。这里提醒一下,项目和PointRend最好分开放,后期在项目中导入就可以了,挺方便的。
②:pytorch,去看官网吧,没什么说的,这里我提醒一下,官网都是最新的安装,我服务器安装的是cuda9.0所以上面的安装命令不适合我,如果有朋友和我是一样的,那就安装我这个版本,
pip install torch==1.4.0
pip install torchvision==0.5.0
gcc 和g++都是5.4的版本
③:安装django,就是用pip就OK了,并创建一个项目,我创建的是PointDetect,总的目录文件是:
PointDetect/
manage.py
PointDetect/
__init__.py
settings.py
urls.py
asgi.py
wsgi.py
image/
__init__.py
admin.py
apps.py
migrations/
__init__.py
models.py
tests.py
views.py
好了,不废话了,进入正题。
解决方法:
将与模型相关的文件都要拷贝到PointDetect中,最终的目录文件是:
PointDetect/
manage.py
PointDetect/
image/
configs/ #对应pointrend的configs 内容
model/ #存放训练好的模型
point_rend/ #对应pointrend的方法
src_image/ #存放一张空白图片,后面要用
detect_result.py #其他需要调用的方法
load_model.py #加载模型的方法
要加载训练好的PointRend模型,我们需要修改一下加载方式(其实弄懂了很简单,但是本人是菜鸟,没有大神指点,搞了两天才搞出来,哎······),对PointRend进行测试的时候,我们基本上把加载的方式都写好了,下面就是我测试的代码,我也是基于官网的demo来修改,这样我就可以得到我想要的box和mask了。
我的测试代码:
import os
import numpy as np
import cv2
os.getenv('/root/detectron2')#导入环境
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.modeling import build_model
from point_rend import add_pointrend_config
import detectron2.data.transforms as T
import torch
from PIL import Image
def setup_cfg():
cfg = get_cfg()
add_pointrend_config(cfg)
cfg.merge_from_file("/root/PointDetect/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml")
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.5
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = os.path.join('/root/PointDetect/model/', "model_0.005.pth")
cfg.freeze()
return cfg
class LoadModel:
def __init__(self, cfg):
self.cfg = cfg.clone()
self.model = build_model(self.cfg)
self.model.eval()
checkpointer = DetectionCheckpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)
self.aug = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
)
self.input_format = cfg.INPUT.FORMAT
assert self.input_format in ["RGB", "BGR"], self.input_format
def __call__(self, original_image):
with torch.no_grad():
if self.input_format == "RGB":
original_image = original_image[:, :, ::-1]
height, width = original_image.shape[:2]
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = {"image": image, "height": height, "width": width}
predictions = self.model([inputs])[0]
return predictions
print('模型加载完成!')
detect_model=LoadModel(cfg=setup_cfg())
def get_mask_and_box(image):
predictor = detect_model(image)
if "instances" in predictor:
instances = predictor["instances"].to(torch.device("cpu"))
box = instances.pred_boxes if instances.has("pred_boxes") else None
if instances.has("pred_masks"):
mask = np.asarray(instances.pred_masks)
return mask,box
首先来解决我的第一个问题:使用django加载自己训练的模型。
在point_rend目录下新建了一个point_model.py,里面的代码如下:
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 13 15:34:22 2020
@author: ctzn
"""
import os
os.getenv('/root/detectron2')
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.modeling import build_model
import detectron2.data.transforms as T
import torch
class LoadModel:
def __init__(self,cfg):
self.cfg = cfg.clone()
self.model = build_model(self.cfg)
self.model.eval()
checkpointer = DetectionCheckpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)
self.aug = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
)
self.input_format = cfg.INPUT.FORMAT
assert self.input_format in ["RGB", "BGR"], self.input_format
def __call__(self, original_image):
with torch.no_grad():
if self.input_format == "RGB":
original_image = original_image[:, :, ::-1]
height, width = original_image.shape[:2]
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = {"image": image, "height": height, "width": width}
predictions = self.model([inputs])[0]
return predictions
就是我测试load模型的代码,这个也就是随着django启动而加载的代码需要调用的方法,但是需要加一下东西,要不然没有,其实通过平时测试发现一个问题,连续测试多张图像的时候,模型也就初始一次,后期预测都是调用__call__方法,所以如果单纯使用这个代码,那每次运行都是要重新加载模型的,测试的时候使用jupyter都试过了,虽然每次加载也就1秒多一点时间,但这肯定不是我想要的。
从测试过程中发现一次测试多张图像,模型也就初始化一次,那就有方法了,就在django启动的时候,使用一个空的图像访问模型,然后将模型初始化起来不就OK了嘛。
创建load_model.py,这是随着django启动而加载的代码,后面的所有的模型都是放在这里面,我这里使用的是模型池,可以在django启动的时候,将所有的模型都加载起来,代码如下:
import os
os.getenv('/root/detectron2')
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from point_rend import add_pointrend_config,point_model
import tensorflow as tf
import queue
import os
ROOT_DIR='/root/PointDetect/model/'
or_image=read_image('/root/PointDetect/src_image/1.png', format="BGR")#对应的就是一张空白的图像,跟测试图像大小相同
class QueueObject():
def __init__(self, queue, auto_get=False):
self._queue = queue
self.object = self._queue.get() if auto_get else None
def __enter__(self):
if self.object is None:
self.object = self._queue.get()
return self.object
def __exit__(self, Type, value, traceback):
if self.object is not None:
self._queue.put(self.object)
self.object = None
def __del__(self):
if self.object is not None:
self._queue.put(self.object)
self.object = None
def setup_cfg():
cfg = get_cfg()
add_pointrend_config(cfg)
cfg.merge_from_file("/root/PointDetect/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml")
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.5
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = detection_MODEL_PATH
cfg.freeze()
return cfg
def get_detect_model():
print("====================loading detect model==================")
global detect_graph
detect_graph=tf.get_default_graph()
with detect_graph.as_default():
config=setup_cfg()
model_detect=point_model.LoadModel(config)
model_detect(or_image)
return model_detect
print("模型池启用,模型加载:")
# 实例化类
detec_model = queue.Queue()
for i in range(1):
detec_model.put(get_detect_model())
detec_object = QueueObject(detec_model)
with detec_object as obj:
# print(detec_object)
pass
print("模型加载完成。")
将模型放在了一个模型池中,这样后面加载也就方便很多,所有的模型都可以放在这里,而且还可以设置多个模型,我这里就设置一个模型。
其实这里我是借鉴了kears模型加载方式和pytorch训练模型使用tensorboard查看训练loss来修改的,因此我也使用了tensorflow的detect_graph=tf.get_default_graph(),这个是tensorflow框架使用的,因为我以前的tensorflow模型不使用它,就没法预加载起来,也就是不能实现我的第一个问题后面说的将模型加载起来方便后面使用,而我将其使用在pytorch下加载模型,也是可以的,毕竟本人比较菜,也没想到居然能成功,哈哈·····
在detect_result.py导入加载模型的方法,这样每次有图片进行预测的时候,只需要访问这里就可以得到结果了,我只需要mask和box,如果你需要其他的,那就自己去修改了,通过测试demo代码去修改你想要的,这样不会报错。代码如下:
from detect_model import detec_object
def get_mask_and_box(image):
with detec_object as dbj:#一定要加这句,要不然使用不了
predictor = dbj(image)
if "instances" in predictor:
instances = predictor["instances"].to(torch.device("cpu"))
box = instances.pred_boxes if instances.has("pred_boxes") else None
if instances.has("pred_masks"):
mask = np.asarray(instances.pred_masks)
return mask,box
配置好image文件和PointDetect文件下的urls,然后运行python manage.py runserver 0.0.0.0:8000,就可以启动了
这里其实我发现一个很奇葩的事情,我刚开始没有起来一直报错,报错:
raise ImproperlyConfigured(msg.format(name=self.urlconf_name))
django.core.exceptions.ImproperlyConfigured: The included URLconf 'PointDetec.urls' does not appear to have any patterns in it. If you see valid patterns in the file then the issue is probably caused by a circular import.
我还以为是urls写错了,还查了很久,后来单步调试才知道不是,是模型没法加载,但是我通过jupyter去测试,模型都加载起来了,而且我也得到结果了,我一开始也没有怀疑是我代码的问题,我就觉得很奇怪,后来我把加载模型的代码也就是load_model.py导入setting里,居然加载起来了,如下:
"""
Django settings for PointDetect project.
Generated by 'django-admin startproject' using Django 3.0.8.
For more information on this file, see
https://docs.djangoproject.com/en/3.0/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/3.0/ref/settings/
"""
import os
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/3.0/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = '%-5)_$8qi50n(ahlm_j+dea#m50a5!^t-bkq*rni+m!)i^1lep'
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True
ALLOWED_HOSTS='*'
#ALLOWED_HOSTS=['192.***.***.***','127.0.0.1']
# Application definition
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
]
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
# 'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
ROOT_URLCONF = 'PointDetect.urls'
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [BASE_DIR+"/templates",],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]
WSGI_APPLICATION = 'PointDetect.wsgi.application'
import load_model #将加载方法导入到setting
然后我把import load_model这行给注释了,我想重新看看问题是出在哪里,纳尼居然都能加载起来,看不到问题了,我把服务器重启了,也没有问题了,搞笑了····,哎,先不管了,反正都加载起来了,而且也是能用的,那就OK了。
至此到这里基本上就完成了使用django加载pytorch训练好的模型了,当然我这里是根据我自己的项目来做例子写的,这里只是给大家做借鉴作用,因为我在网上查了一遍,没有看到该方面的博客,我看到很多是使用flask来加载的,是挺方便的,但是不是我想要的,因为我原先的项目中有部分模型我现在还是需要的。
这个问题其实是我自己想多了,后来我把keras训练的模型放到load_model.py中,也是没有问题,全部加载起来了,很简单的加载方法,我这里就不贴代码了。
如果有哪里写错了或者有不对的地方,希望各位大神能指正,菜鸟在这里先感谢了。