基本配置 | 版本号 |
---|---|
CPU | Intel® Core™ i5-8400 CPU @ 2.80GHz × 6 |
GPU | GeForce RTX 2070 SUPER/PCIe/SSE2 |
OS | Ubuntu18.04 |
openjdk | 1.8.0_242 |
python | 3.6.9 |
bazel | 0.21.0 |
gcc | 4.8.5 |
g++ | 4.8.5 |
hint:
https://github.com/tensorflow/models/tree/v1.12.0
在~/.bashrc中加入配置
export PYTHONPATH=$PYTHONPATH:pwd
:pwd
/slim
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
红框内的模型用于需要对模型进行量化处理的
https://github.com/tensorflow/tensorflow/tree/r1.13
我的另一篇博客写的很详细
https://blog.csdn.net/weixin_43056275/article/details/105124979
我的数据集文件结构和生成tfrecord的代码可供下载: https://download.csdn.net/download/weixin_43056275/12314008
我做的是二分类, 所以我的数据集只有两个类别, bad 和 good.
以下是我制作tfrecord的文件目录:
data_prepare/
pic/
train/
bad/
good/
validation
bad/
good/
src/
tfrecord.py/
data_convert.py
python data_convert.py -t pic/ \
--train-shards 2 \
--validation-shards 2 \
--num-threads 2 \
--dataset-name mydata
参数含义解释:
运行上述命令后,就可以在pic文件夹中找到5个新生成的文件,分别是训练数据 mydata_train_00000-of-00002.tfrecord、mydata_train_00001-of-00002.tfrecord,以及验证数据 mydata_validation_00000-of-00002.tfrecord、mydata_validation_00001-of-00002.tfrecord。另外,还有一个文本文件label.txt,它表示图片的内部标签(数字)到真实类别(字符串)之间的映射顺序。如在tfrecord中的标签为0,那么就对应label.txt 第一行的类别,在tfrecord的标签为1,就对应label.txt中第二行的类别,依此类推。
找到models/research/目录中的slim文件夹, 这就是要用到的TensorFlow Slim的源代码。这里简单介绍TensorFlow Slim的代码结构,见下表。
文件夹或文件名 | 用途 |
---|---|
datasets/ | 定义一些训练时使用的数据集。如果需要训练自己的数据,必须同样在datasets文件夹中进行定义,会在下面对此进行介绍 |
nets/ | 定义了一些常用的网络结构,如AlexNet、VGGl6、VGG19、Inception 系列、ResNet、MobileNet等 |
preprocessing/ | 在模型读入图片前,常常需要对图像做预处理和数据增强。这个文件夹针对不同的网络,分别定义了它们的预处理方法 |
scripts | 包含了一些训练的示例脚本 |
train_ image_classifier.py | 训练模型的入口代码 |
eval_image_classifier.py | 验证模型性能的入口代码 |
download_and _convert data.py | 下载并转换数据集格式的入口代码 |
首先,在datasets/目录下新建一个文件mydata.py,并将flowers.py文件中的内容复制到satellite.py中。接下来,需要修改以下几处内容。
因为我下载的是 mobilenet_v1_0.75_224_quant, 所以 slim/nets/mobilenet_v1.py中有两个地方要修改.
python train_image_classifier.py \
--train_dir=./train_dir \
--dataset_dir=./data_prepare/pic \
--dataset_name=mydata \
--dataset_split_name=train \
--model_name=mobilenet_v1 \
--checkpoint_path=./mobilenet_v1_0.75_224_class/mobilenet_v1_0.75_224_quant.ckpt \
--checkpoint_exclude_scopes=MobilenetV1/Logits,MobilenetV1/AuxLogits \
--max_number_of_steps=100000 \
--train_image_size=224 \
--trainable_scopes=MobilenetV1/Logits,MobilenetV1/AuxLogits \
--quantize_delay=0
python eval_image_classifier.py \
--alsologtostderr \
--checkpoint_path=./train_dir/model.ckpt-100000 \
--dataset_dir=./data_prepare/pic \
--dataset_name=mydata \
--dataset_split_name=validation \
--model_name=mobilenet_v1 \
--eval_image_size=224
python export_inference_graph.py \
--alsologtostderr \
--model_name=mobilenet_v1 \
--dataset_dir=./data_prepare/pic \ #数据集路径
--dataset_name=mydata \ #数据集名字
--image_size=224 \
--output_file=./test.pb \
--quantize=True
如果不设置好数据集路径和名字会导致 Assign requires shapes of both tensors to match. lhs shape= [1001] rhs shape= [2].
如何解决请看, https://blog.csdn.net/weixin_43056275/article/details/105405751
有两种方式
https://lutzroeder.github.io/netron/
可以在线上传模型, 得到输入输出
输入:
输出:
在./tensorflow-r1.13运行
bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
--in_graph=/home/ying/usb/models/models-1.12.0/research/slim/test.pb
在./tensorflow-r1.13运行
bazel build tensorflow/python/tools:freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/home/ying/usb/models/models-1.12.0/research/slim/test.pb \
--input_checkpoint=/home/ying/usb/models/models-1.12.0/research/slim/train_dir/model.ckpt-100000 \
--input_binary=true --output_graph=/home/ying/usb/models/models-1.12.0/research/slim/frozen1.pb \
--output_node_names=MobilenetV1/Predictions/Reshape_1
在./tensorflow-r1.13运行
bazel build tensorflow/lite/toco:toco
bazel-bin/tensorflow/lite/toco/toco --input_file=/home/ying/usb/models/models-1.12.0/research/slim/frozen1.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/home/ying/usb/models/models-1.12.0/research/slim/frozen1.tflite \
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=input \
--output_arrays=MobilenetV1/Predictions/Reshape_1 \
--input_shapes=1,224,224,3
在./tensorflow-r1.13运行
bazel build tensorflow/lite/examples/label_image:label_image
bazel-bin/tensorflow/lite/examples/label_image/label_image
--tflite_model=/home/ying/usb/models/models-1.12.0/research/slim/frozen1.tflite
--input_mean=128
--input_std=128
--labels="/home/ying/usb/models/models-1.12.0/research/slim/data_prepare/pic/label.txt"
--image="/home/ying/usb/deep-learning-for-image-processing-master/deep-learning-for-image-processing-master/pytorch_learning/Test5_resnet/test1/bad15.bmp"
结果:
根据tensorflow版本不同, 参数可能也会有所不同, 请根据情况进行调整.
打开./tensorflow-r1.13/tensorflow/lite/examples/android
将tflite和txt放到./tensorflow-r1.13/tensorflow/lite/examples/android/app/src/main/assets文件下
修改之后, 在运行app时, 屏幕上只会显示最相似的一类. 不修改之前会显示多项分类, 有不同的分数, 根据需求进行修改.
以下为配置文件的更改
apply plugin: 'com.android.application'
android {
compileSdkVersion 28
buildToolsVersion '28.0.3'
defaultConfig {
applicationId "org.tensorflow.lite.demo"
minSdkVersion 21
targetSdkVersion 28
versionCode 1
versionName "1.0"
// Remove this block.
// jackOptions {
// enabled true
// }
}
lintOptions {
abortOnError false
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
}
}
aaptOptions {
noCompress "tflite"
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
repositories {
maven {
url 'https://google.bintray.com/tensorflow'
}
}
// import DownloadModels task
project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets'
project.ext.TMP_DIR = project.buildDir.toString() + '/downloads'
// Download default models; if you wish to use your own models then
// place them in the "assets" directory and comment out this line.
apply from: "download-models.gradle"
dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar'])
implementation 'org.tensorflow:tensorflow-lite:0.1.2-nightly'
implementation 'org.tensorflow:tensorflow-lite:2.0.0'
}
参考:
https://blog.csdn.net/qq_33200967/article/details/82773677#commentBox
https://www.cnblogs.com/Terrypython/p/10858803.html
https://zhuanlan.zhihu.com/p/44437031
https://www.imooc.com/article/details/id/28871