使用TF-hub库微调模型评估年龄
【下载模型】
- 打开 https://tfhub.dev/ 搜索对应模型
- 复制模型链接
- 自动下载:将复制的链接填入代码调用tf-hub库的参数里
- or手动下载:将域名换成
https://storage.googleapis.com/tfhub-modules
并在尾部加上.tar.gz
【直接用tfhub模型预测】
1.加载模型
import tensorflow_hub as hub
module_spec = hub.load_module_spec(url) //url为上述在tfhub上copy的模型链接
module = hub.Module(module_spec)
// 对比之前从slim加载模型的方式:
// 直接使用 pnasnet
from nets.nasnet import pnasnet
//对比本地微调后的模型加载方式:
// MyNASNetModel写在model.py里定义好了
mymode = MyNASNetModel()
mymode.build_model('test', test_dir)
2.获取模型输入图片尺寸
height, width = hub.get_expected_image_size(module_spec)
// VS slim预训练模型方式:
from nets.nasnet import nasnet
image_size = nasnet.build_nasnet_large.default_image_size
3.准备预存变量
sample_images = ['test1.jpg']
input_imgs = tf.placeholder(tf.float32,[None, image_size, image_size, 3])
images = 2 * (input_imgs / 255.0) - 1.0
同样的套路 并无区别
- 构造logits和预测结果的op
logits = module(images)
y = tf.argmax(logits, axis=1)
// VS slim预训练模型logits和预测结果op构造方式:
arg_scope = pnasnet.pnasnet_large_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = pnasnet.build_pnasnet_large(x1, num_classes = 1001, is_training=False)
prob = end_points['Predictions']
y = tf.argmax(prob, axis = 1)
5.跑session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
def preimg(img):
ch = 3
if img.mode == 'RGBA':
ch = 4
imgnp = np.asarray(img.resize((image_size, image_size)),dtype=np.float32).reshape(image_size,image_size,ch)
return imgnp[:,:,:3]
batch_img = [preimg(Image.open(imgfilename)) for imgfilename in sample_images]
org_img = [Image.open(imgfilename) for imgfilename in sample_images]
y, img_norm = sess.run([y,images], feed_dict={input_imgs: batch_img})
【微调tf-Hub中的模型】
这部分提供了tf提供的微调代码:
retrain.py
直接传递对应参数即可获得微调后的模型,以pb文件存在(冻结图文件)。
这个微调部分没有太大变化,感兴趣可以看对应的py代码,下面讲下怎么根据冻结图文件预测。
【用冻结图模型预测】
找到模型中的输入输出节点
可以直接使用 print(placeholder.name) 和 print(final_result.name)打印输入输出节点预测相关代码
// 加载冻结图模型 pd_path 模型所在路径
tf.reset_default_graph()
PATH_TO_CKPT = pd_path + '/output_graph.pd'
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
classification_graph = tf.get_default_graph()
// 定义输入尺寸
height, width = 224, 224
// 跑session
with tf.Session(graph=classification_graph) as sess:
result = classification_graph.get_tensor_by_name('final_result:0')
input_imgs = classification_graph.get_tensor_by_name('Placeholder:0')
y = tf.argmax(result, axis=1)
// 图片处理
def preimg(img):
.......
batchImg = [preimg(Image.open(imgfilename)) for imgfilename in sample_images]
// 跑
yv = sess.run(y, feed_dict={input_imgs: batchImg})
完毕~
【总结】
使用tf.slim进行微调比较灵活,使用tf.hub微调只能微调最后的输出层,不支持整体联调。
除了这两种,还可以基于tf.keras微调等。
微调时,选择预训练模型的策略是:
样本充足首选精度最高的模型,不足选择ResNet。
ResNet在ImgNet上泛化能力是最强的。