使用tensorflow model maker训练目标检测模型

一、环境配置

1.1 使用conda创建一个新的隔离环境

因为我用的是conda环境,所以又新建了一个专门tensorflow model maker的环境

# 创建环境
conda create -n tf_model_maker python=3.9
# 激活环境
conda activate tf_model_maker
# 退出当前环境
conda deactivate
# 删除环境使用
conda remove -n tf_model_maker --all

1.2 配置tensorflow model maker环境

apt -y install libportaudio2
pip install -q --use-deprecated=legacy-resolver tflite-model-maker
pip install -q pycocotools
pip install -q opencv-python-headless==4.1.2.30
pip uninstall -y tensorflow && pip install -q tensorflow==2.8.0

此处没有使用nightly版本,不知道是有什么bug,使用nightly版本有些库引用出问题了,所以换回非nightly版本

1.3 导包

import numpy as np
import os

from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

执行输出:

/root/anaconda3/envs/env_tflite_model_maker/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

二、数据集整理

我使用的数据格式是coco格式的,已经处理成csv文件了,csv文件格式是:

filename,width,height,class,xmin,ymin,xmax,ymax
00232f5be5eb8a0f2c34a4a63f73d678.jpeg,683,1024,ball,224,756,511,1024
.....

目标csv数据格式:https://cloud.google.com/vision/automl/object-detection/docs/csv-format

set,path,label,xmin,ymin,,,xmax,ymax,,

  • TRAIN或者VAL或者TEST:训练数据、验证数据、测试数据标记

  • 图片文件全路径:此处必须要用全路径

  • label:标记名称

  • 图片中对象的边界框

    • 使用 2 个包含一组 x、y 坐标的顶点(如果这些点是矩形的对角点)(xmin, ymin,,,xmax,ymax,,)
    • 或使用全部 4 个顶点 (xmin,ymin,xmax,ymin,xmax,ymax,xmin,ymax)

    这些坐标必须是 0 到 1 范围内的浮点数,其中 0 表示最小 x 或 y 值,1 表示最大 x 或 y 值。

    例如,(0,0) 表示左上角,(1,1) 表示右下角;整个图片的边界框表示为 (0,0,,,1,1,,) 或 (0,0,1,0,1,1,0,1)

TRAIN,/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpg,ball,0.3279648609077599,0.73828125,,,0.7481698389458272,1.0,,
VAL,/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpg,ball,0.3279648609077599,0.73828125,,,0.7481698389458272,1.0,,
TEST,/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpg,ball,0.3279648609077599,0.73828125,,,0.7481698389458272,1.0,,

数据处理代码:

import codecs
import csv
import cv2
import os

image_path = '/root/xxx/images/'

def makeData(old_file,new_file,key):

    file = open(new_file,'w')
    with file:
        w = csv.writer(file)

        with codecs.open(old_file, encoding='utf-8-sig') as f:
            for row in csv.DictReader(f, skipinitialspace=True):
                width=float(row['width'])
                height=float(row['height'])
                label=row['class']
                xmin=float(row['xmin'])/width
                ymin=float(row['ymin'])/height
                xmax=float(row['xmax'])/width
                ymax=float(row['ymax'])/height
                filename=row['filename']
                print(filename)
                img_path = os.path.join(image_path, filename)
                
                if os.path.exists(img_path) is True:
                    name = filename.replace(".jpeg","").replace(".jpg","")
                    save_path = os.path.join(image_path, name+".jpg")
                    img = cv2.imread(img_path)
                    cv2.imwrite(save_path,img)
                    new_row=[key,save_path,label,xmin,ymin,'','',xmax,ymax,'','',]
                    print(new_row)
                    w.writerow(new_row)

我拿到的图库中有些图片是直接修改的后缀,真实格式和后缀不同,也重新处理了一下,还有些图片不存在了,也过滤了一下

makeData('/root/xxx/train.csv',
         '/root/xxx/new_train.csv',
        'TRAIN')

makeData('/root/xxx/test.csv',
         '/root/xxx/new_test.csv',
        'TEST')

makeData('/root/xxx/test.csv',
         '/root/xxx/new_vaild.csv',
        'VAL')

然后我把new_train.csv、new_test.csv、new_vaild.csv中取了部分数据,手动合并到一个名为data.csv的文件里了

train_data,validation_data,test_data = object_detector.DataLoader.from_csv('/root/xxx/data.csv')

三、准备预训练模型

由于物体检测模型只支持EfficientDet系列的模型,我试过EfficientDet-Lite2发现在手机端的速度不是很理想,高端机差不多需要100ms左右识别出来,最终选择了速度更快的EfficientDet-Lite0

Model architecture Size(MB)* Latency(ms)** Average Precision***
EfficientDet-Lite0 4.4 37 25.69%
EfficientDet-Lite1 5.8 49 30.55%
EfficientDet-Lite2 7.2 69 33.97%
EfficientDet-Lite3 11.4 116 37.70%
EfficientDet-Lite4 19.9 260 41.96%

** Size of the integer quantized models.
** Latency measured on Pixel 4 using 4 threads on CPU.
*** Average Precision is the mAP (mean Average Precision) on the COCO 2017 validation dataset.*

3.1、选择预训练模型

spec = model_spec.get('efficientdet_lite0')

此处在国内的服务器上是会提示超时报错终止,原因就是被墙了,所以要根据提示修改源码成镜像文件路径

3.2、修改源码

# 预训练模型配置文件
vim ~/anaconda3/envs/env_tflite_model_maker/lib/python3.9/site-packages/tensorflow_examples/lite/model_maker/core/task/model_spec/object_detector_spec.py

# 找到efficientdet_lite0_spec配置文件
efficientdet_lite0_spec = functools.partial(
    EfficientDetModelSpec,
    model_name='efficientdet-lite0',
    uri='https://tfhub.dev/tensorflow/efficientdet/lite0/feature-vector/1',
)
# 把uri换一下
efficientdet_lite0_spec = functools.partial(
    EfficientDetModelSpec,
    model_name='efficientdet-lite0',
    uri='https://storage.googleapis.com/tfhub-modules/tensorflow/efficientdet/lite0/feature-vector/1.tar.gz',
)    

关键是替换uri,再重新执行spec = model_spec.get('efficientdet_lite0')

四、训练模型

model = object_detector.create(train_data, model_spec=spec, batch_size=8, train_whole_model=True, validation_data=validation_data)
Epoch 1/50
540/540 [==============================] - 253s 399ms/step - det_loss: 0.6041 - cls_loss: 0.3679 - box_loss: 0.0047 - reg_l2_loss: 0.0637 - loss: 0.6678 - learning_rate: 0.0090 - gradient_norm: 4.1991 - val_det_loss: 1.2947 - val_cls_loss: 0.8470 - val_box_loss: 0.0090 - val_reg_l2_loss: 0.0645 - val_loss: 1.3592
Epoch 2/50
540/540 [==============================] - 214s 397ms/step - det_loss: 0.3937 - cls_loss: 0.2513 - box_loss: 0.0028 - reg_l2_loss: 0.0651 - loss: 0.4588 - learning_rate: 0.0100 - gradient_norm: 3.2312 - val_det_loss: 0.3262 - val_cls_loss: 0.2136 - val_box_loss: 0.0023 - val_reg_l2_loss: 0.0656 - val_loss: 0.3918
Epoch 3/50
540/540 [==============================] - 213s 394ms/step - det_loss: 0.3450 - cls_loss: 0.2250 - box_loss: 0.0024 - reg_l2_loss: 0.0660 - loss: 0.4110 - learning_rate: 0.0099 - gradient_norm: 2.8205 - val_det_loss: 0.2999 - val_cls_loss: 0.2096 - val_box_loss: 0.0018 - val_reg_l2_loss: 0.0664 - val_loss: 0.3663
。。。。。

评估模型

model.evaluate(test_data)

输出:

{'AP': 0.82879966,
 'AP50': 0.9893871,
 'AP75': 0.9637165,
 'APs': 0.50417614,
 'APm': 0.83946806,
 'APl': 0.8315978,
 'ARmax1': 0.7818135,
 'ARmax10': 0.8720247,
 'ARmax100': 0.87727976,
 'ARs': 0.7034483,
 'ARm': 0.89498526,
 'ARl': 0.87662005,
 'AP_/ball': 0.82879966}

五、导出tflite模型

model.export(export_dir='/root/xxx/tf')

会在/root/xxx/tf文件夹下生成model.tflite文件

评估模型:

model.evaluate_tflite('model.tflite', test_data)

输出

{'AP': 0.817586,
 'AP50': 0.98929125,
 'AP75': 0.95808136,
 'APs': 0.4901086,
 'APm': 0.8326331,
 'APl': 0.81800973,
 'ARmax1': 0.77594024,
 'ARmax10': 0.8460072,
 'ARmax100': 0.84688306,
 'ARs': 0.63793105,
 'ARm': 0.86342186,
 'ARl': 0.84720457,
 'AP_/ball': 0.817586}

可以看出导出tflite之后模型的识别度从0.82879966下降到了0.817586,也还算能接受

tflite模型测试:

# Imports
from tflite_support.task import vision
from tflite_support.task import core
from tflite_support.task import processor

# Initialization
base_options = core.BaseOptions(file_name='/root/xxx/tf/model.tflite')
detection_options = processor.DetectionOptions(max_results=2)
options = vision.ObjectDetectorOptions(base_options=base_options, detection_options=detection_options)
detector = vision.ObjectDetector.create_from_options(options)

# Alternatively, you can create an object detector in the following manner:
# detector = vision.ObjectDetector.create_from_file(model_path)

# Run inference
image = vision.TensorImage.create_from_file('/root/xxx/images/00232f5be5eb8a0f2c34a4a63f73d678.jpeg')
detection_result = detector.detect(image)

image = vision.TensorImage.create_from_file('/root/xxx/11.png')
detection_result = detector.detect(image)
print(detection_result)
资料

https://tensorflow.google.cn/lite/models/modify/model_maker/object_detection

你可能感兴趣的:(使用tensorflow model maker训练目标检测模型)