博主在根据官网配置图像分类迁移学习时,由于没有设置,程序执行如下语句时
model = image_classifier.create(train_data)
会因为模型下载超时而报错:urllib.error.URLError:
博主debug看了下,在/home/sxhlvye/anaconda3/envs/testTF/lib/python3.6/site-packages/tensorflow_examples/lite/model_maker/core/task/model_spec/_init_.py文件里(结合自己的路径)看到了预先设定好的模型配置
切换到/home/sxhlvye/anaconda3/envs/testTF/lib/python3.6/site-packages/tensorflow_examples/lite/model_maker/core/task/model_spec/image_spec.py文件,可以看到每个模型的下载路径
在不指定模型路径情况相下,系统默认使用的是efficientnet_lite0模型,对应路径是
https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2
博主想直接到tensorflow hub网站上去下载
TensorFlow Hubhttps://tensorflow.google.cn/hub/但点击'查看模型'没有反应
可看到图中示例可以直接通过tensorflow_hub.KerasLayer函数通过路径来加载模型,其实再深入debug,你会发现上面的mage_classifier.create()里面其实也调用了KerasLayer函数
为了网页浏览模型,可以访问如下网址:
TensorFlow Hubhttps://hub.tensorflow.google.cn如下页面中可以根据条件去筛选
模型下载不了解决方法
输入上面的网址https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2,发现没有反应,可以对链接进行如下的更改即可。
(1)https://tfhub.dev修改为https://storage.googleapis.com/tfhub-modules
(2) 2修改为2.tar.gz
修改后的访问网址应该为:https://storage.googleapis.com/tfhub-modules/tensorflow/efficientnet/lite0/feature-vector/2.tar.gz
代码修改为如下:
import os
import numpy as np
import tensorflow as tf
assert tf.__version__.startswith('2')
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as plt
from tensorflow_examples.lite.model_maker.core.task import model_spec as ms
import tensorflow_hub as hub
image_path = tf.keras.utils.get_file(
'flower_photos.tgz',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
data = DataLoader.from_folder(image_path)
train_data, test_data = data.split(0.9)
print(train_data.size)
print(test_data.size)
#inception_v3_spec = image_classifier.ModelSpec(uri='/home/sxhlvye/efficientnet_lite0_feature-vector_2')
inception_v3_spec = image_classifier.ModelSpec(uri='https://storage.googleapis.com/tfhub-modules/tensorflow/efficientnet/lite0/feature-vector/2.tar.gz')
inception_v3_spec.input_image_shape = [240, 240]
model = image_classifier.create(train_data, model_spec=inception_v3_spec)
print("ok")
运行部分结果如下:
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________
None
/home/sxhlvye/anaconda3/envs/testTF/lib/python3.6/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
"The `lr` argument is deprecated, use `learning_rate` instead.")
Epoch 1/5
2022-04-17 20:08:02.362645: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8100
103/103 [==============================] - 9s 49ms/step - loss: 0.8722 - accuracy: 0.7630
Epoch 2/5
103/103 [==============================] - 5s 50ms/step - loss: 0.6609 - accuracy: 0.8941
Epoch 3/5
103/103 [==============================] - 5s 50ms/step - loss: 0.6217 - accuracy: 0.9181
Epoch 4/5
103/103 [==============================] - 5s 51ms/step - loss: 0.6085 - accuracy: 0.9190
Epoch 5/5
103/103 [==============================] - 5s 52ms/step - loss: 0.5915 - accuracy: 0.9354
ok
上面图片会默认下载到如下路径(结合自己的博客)
训练自己数据集用于分类的时候,就可以借鉴此目录结构。
从 TF Hub 缓存下载的模型 | TensorFlow Hub
可以看到下载的模型保存的位置
博主把上面文件夹中的内容拷贝别的一个位置
代码中路径设定为本地地址(结合自己的路径),程序可以正常运行
inception_v3_spec = image_classifier.ModelSpec(uri='/home/sxhlvye/efficientnet_lite0_feature-vector_2')