一、环境配置
1.1 使用conda创建一个新的隔离环境
因为我用的是conda环境,所以又新建了一个专门tensorflow model maker的环境
# 创建环境
conda create -n tf_model_maker python=3.9
# 激活环境
conda activate tf_model_maker
# 退出当前环境
conda deactivate
# 删除环境使用
conda remove -n tf_model_maker --all
1.2 配置tensorflow model maker环境
apt -y install libportaudio2
pip install -q --use-deprecated=legacy-resolver tflite-model-maker
pip install -q pycocotools
pip install -q opencv-python-headless==4.1.2.30
pip uninstall -y tensorflow && pip install -q tensorflow==2.8.0
此处没有使用nightly版本,不知道是有什么bug,使用nightly版本有些库引用出问题了,所以换回非nightly版本
1.3 导包
import numpy as np
import os
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)
执行输出:
/root/anaconda3/envs/env_tflite_model_maker/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
二、数据集整理
我使用的数据格式是coco格式的,已经处理成csv文件了,csv文件格式是:
filename,width,height,class,xmin,ymin,xmax,ymax
00232f5be5eb8a0f2c34a4a63f73d678.jpeg,683,1024,ball,224,756,511,1024
.....
目标csv数据格式:https://cloud.google.com/vision/automl/object-detection/docs/csv-format
set,path,label,xmin,ymin,,,xmax,ymax,,
TRAIN或者VAL或者TEST:训练数据、验证数据、测试数据标记
图片文件全路径:此处必须要用全路径
label:标记名称
-
图片中对象的边界框:
- 使用 2 个包含一组 x、y 坐标的顶点(如果这些点是矩形的对角点)(
xmin
,ymin
,,,xmax
,ymax
,,) - 或使用全部 4 个顶点 (
xmin
,ymin
,xmax
,ymin
,xmax
,ymax
,xmin
,ymax
)
这些坐标必须是 0 到 1 范围内的浮点数,其中 0 表示最小 x 或 y 值,1 表示最大 x 或 y 值。
例如,(0,0) 表示左上角,(1,1) 表示右下角;整个图片的边界框表示为 (0,0,,,1,1,,) 或 (0,0,1,0,1,1,0,1)
- 使用 2 个包含一组 x、y 坐标的顶点(如果这些点是矩形的对角点)(
TRAIN,/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpg,ball,0.3279648609077599,0.73828125,,,0.7481698389458272,1.0,,
VAL,/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpg,ball,0.3279648609077599,0.73828125,,,0.7481698389458272,1.0,,
TEST,/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpg,ball,0.3279648609077599,0.73828125,,,0.7481698389458272,1.0,,
数据处理代码:
import codecs
import csv
import cv2
import os
image_path = '/root/xxx/images/'
def makeData(old_file,new_file,key):
file = open(new_file,'w')
with file:
w = csv.writer(file)
with codecs.open(old_file, encoding='utf-8-sig') as f:
for row in csv.DictReader(f, skipinitialspace=True):
width=float(row['width'])
height=float(row['height'])
label=row['class']
xmin=float(row['xmin'])/width
ymin=float(row['ymin'])/height
xmax=float(row['xmax'])/width
ymax=float(row['ymax'])/height
filename=row['filename']
print(filename)
img_path = os.path.join(image_path, filename)
if os.path.exists(img_path) is True:
name = filename.replace(".jpeg","").replace(".jpg","")
save_path = os.path.join(image_path, name+".jpg")
img = cv2.imread(img_path)
cv2.imwrite(save_path,img)
new_row=[key,save_path,label,xmin,ymin,'','',xmax,ymax,'','',]
print(new_row)
w.writerow(new_row)
我拿到的图库中有些图片是直接修改的后缀,真实格式和后缀不同,也重新处理了一下,还有些图片不存在了,也过滤了一下
makeData('/root/xxx/train.csv',
'/root/xxx/new_train.csv',
'TRAIN')
makeData('/root/xxx/test.csv',
'/root/xxx/new_test.csv',
'TEST')
makeData('/root/xxx/test.csv',
'/root/xxx/new_vaild.csv',
'VAL')
然后我把new_train.csv、new_test.csv、new_vaild.csv中取了部分数据,手动合并到一个名为data.csv的文件里了
train_data,validation_data,test_data = object_detector.DataLoader.from_csv('/root/xxx/data.csv')
三、准备预训练模型
由于物体检测模型只支持EfficientDet系列的模型,我试过EfficientDet-Lite2发现在手机端的速度不是很理想,高端机差不多需要100ms左右识别出来,最终选择了速度更快的EfficientDet-Lite0
Model architecture | Size(MB)* | Latency(ms)** | Average Precision*** |
---|---|---|---|
EfficientDet-Lite0 | 4.4 | 37 | 25.69% |
EfficientDet-Lite1 | 5.8 | 49 | 30.55% |
EfficientDet-Lite2 | 7.2 | 69 | 33.97% |
EfficientDet-Lite3 | 11.4 | 116 | 37.70% |
EfficientDet-Lite4 | 19.9 | 260 | 41.96% |
** Size of the integer quantized models.
** Latency measured on Pixel 4 using 4 threads on CPU.
*** Average Precision is the mAP (mean Average Precision) on the COCO 2017 validation dataset.*
3.1、选择预训练模型
spec = model_spec.get('efficientdet_lite0')
此处在国内的服务器上是会提示超时报错终止,原因就是被墙了,所以要根据提示修改源码成镜像文件路径
3.2、修改源码
# 预训练模型配置文件
vim ~/anaconda3/envs/env_tflite_model_maker/lib/python3.9/site-packages/tensorflow_examples/lite/model_maker/core/task/model_spec/object_detector_spec.py
# 找到efficientdet_lite0_spec配置文件
efficientdet_lite0_spec = functools.partial(
EfficientDetModelSpec,
model_name='efficientdet-lite0',
uri='https://tfhub.dev/tensorflow/efficientdet/lite0/feature-vector/1',
)
# 把uri换一下
efficientdet_lite0_spec = functools.partial(
EfficientDetModelSpec,
model_name='efficientdet-lite0',
uri='https://storage.googleapis.com/tfhub-modules/tensorflow/efficientdet/lite0/feature-vector/1.tar.gz',
)
关键是替换uri,再重新执行spec = model_spec.get('efficientdet_lite0')
四、训练模型
model = object_detector.create(train_data, model_spec=spec, batch_size=8, train_whole_model=True, validation_data=validation_data)
Epoch 1/50
540/540 [==============================] - 253s 399ms/step - det_loss: 0.6041 - cls_loss: 0.3679 - box_loss: 0.0047 - reg_l2_loss: 0.0637 - loss: 0.6678 - learning_rate: 0.0090 - gradient_norm: 4.1991 - val_det_loss: 1.2947 - val_cls_loss: 0.8470 - val_box_loss: 0.0090 - val_reg_l2_loss: 0.0645 - val_loss: 1.3592
Epoch 2/50
540/540 [==============================] - 214s 397ms/step - det_loss: 0.3937 - cls_loss: 0.2513 - box_loss: 0.0028 - reg_l2_loss: 0.0651 - loss: 0.4588 - learning_rate: 0.0100 - gradient_norm: 3.2312 - val_det_loss: 0.3262 - val_cls_loss: 0.2136 - val_box_loss: 0.0023 - val_reg_l2_loss: 0.0656 - val_loss: 0.3918
Epoch 3/50
540/540 [==============================] - 213s 394ms/step - det_loss: 0.3450 - cls_loss: 0.2250 - box_loss: 0.0024 - reg_l2_loss: 0.0660 - loss: 0.4110 - learning_rate: 0.0099 - gradient_norm: 2.8205 - val_det_loss: 0.2999 - val_cls_loss: 0.2096 - val_box_loss: 0.0018 - val_reg_l2_loss: 0.0664 - val_loss: 0.3663
。。。。。
评估模型
model.evaluate(test_data)
输出:
{'AP': 0.82879966,
'AP50': 0.9893871,
'AP75': 0.9637165,
'APs': 0.50417614,
'APm': 0.83946806,
'APl': 0.8315978,
'ARmax1': 0.7818135,
'ARmax10': 0.8720247,
'ARmax100': 0.87727976,
'ARs': 0.7034483,
'ARm': 0.89498526,
'ARl': 0.87662005,
'AP_/ball': 0.82879966}
五、导出tflite模型
model.export(export_dir='/root/xxx/tf')
会在/root/xxx/tf文件夹下生成model.tflite文件
评估模型:
model.evaluate_tflite('model.tflite', test_data)
输出
{'AP': 0.817586,
'AP50': 0.98929125,
'AP75': 0.95808136,
'APs': 0.4901086,
'APm': 0.8326331,
'APl': 0.81800973,
'ARmax1': 0.77594024,
'ARmax10': 0.8460072,
'ARmax100': 0.84688306,
'ARs': 0.63793105,
'ARm': 0.86342186,
'ARl': 0.84720457,
'AP_/ball': 0.817586}
可以看出导出tflite之后模型的识别度从0.82879966下降到了0.817586,也还算能接受
tflite模型测试:
# Imports
from tflite_support.task import vision
from tflite_support.task import core
from tflite_support.task import processor
# Initialization
base_options = core.BaseOptions(file_name='/root/xxx/tf/model.tflite')
detection_options = processor.DetectionOptions(max_results=2)
options = vision.ObjectDetectorOptions(base_options=base_options, detection_options=detection_options)
detector = vision.ObjectDetector.create_from_options(options)
# Alternatively, you can create an object detector in the following manner:
# detector = vision.ObjectDetector.create_from_file(model_path)
# Run inference
image = vision.TensorImage.create_from_file('/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpeg')
detection_result = detector.detect(image)
image = vision.TensorImage.create_from_file('/root/xxx/11.png')
detection_result = detector.detect(image)
print(detection_result)
资料
https://tensorflow.google.cn/lite/models/modify/model_maker/object_detection