【系列学习】5.2 通过tfHub加载模型—— tf工程化项目实战

使用TF-hub库微调模型评估年龄

【下载模型】

  1. 打开 https://tfhub.dev/ 搜索对应模型
  2. 复制模型链接
  3. 自动下载:将复制的链接填入代码调用tf-hub库的参数里
  4. or手动下载:将域名换成
https://storage.googleapis.com/tfhub-modules 

并在尾部加上.tar.gz

【直接用tfhub模型预测】
1.加载模型

import tensorflow_hub as hub

module_spec = hub.load_module_spec(url) //url为上述在tfhub上copy的模型链接
module = hub.Module(module_spec)


// 对比之前从slim加载模型的方式:
// 直接使用 pnasnet
from nets.nasnet import pnasnet

//对比本地微调后的模型加载方式:
// MyNASNetModel写在model.py里定义好了
mymode = MyNASNetModel()
mymode.build_model('test', test_dir)

2.获取模型输入图片尺寸

height, width = hub.get_expected_image_size(module_spec)

// VS slim预训练模型方式:
from nets.nasnet import nasnet
image_size = nasnet.build_nasnet_large.default_image_size

3.准备预存变量

sample_images = ['test1.jpg']
input_imgs = tf.placeholder(tf.float32,[None, image_size, image_size, 3])
images = 2 * (input_imgs / 255.0) - 1.0

同样的套路 并无区别

  1. 构造logits和预测结果的op
logits = module(images)
y = tf.argmax(logits, axis=1)

// VS slim预训练模型logits和预测结果op构造方式:
arg_scope = pnasnet.pnasnet_large_arg_scope()
with slim.arg_scope(arg_scope):
    logits, end_points = pnasnet.build_pnasnet_large(x1, num_classes = 1001, is_training=False)
    prob = end_points['Predictions']
    y = tf.argmax(prob, axis = 1)

5.跑session

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  sess.run(tf.tables_initializer())
 def preimg(img):
      ch = 3
      if img.mode == 'RGBA':
          ch = 4
      imgnp = np.asarray(img.resize((image_size, image_size)),dtype=np.float32).reshape(image_size,image_size,ch)
      return imgnp[:,:,:3]
   
  batch_img = [preimg(Image.open(imgfilename)) for imgfilename in sample_images]
  org_img = [Image.open(imgfilename) for imgfilename in sample_images]
  y, img_norm = sess.run([y,images], feed_dict={input_imgs: batch_img})

【微调tf-Hub中的模型】
这部分提供了tf提供的微调代码:
retrain.py

直接传递对应参数即可获得微调后的模型,以pb文件存在(冻结图文件)。
这个微调部分没有太大变化,感兴趣可以看对应的py代码,下面讲下怎么根据冻结图文件预测。

【用冻结图模型预测】

  1. 找到模型中的输入输出节点
    可以直接使用 print(placeholder.name) 和 print(final_result.name)打印输入输出节点

  2. 预测相关代码


// 加载冻结图模型 pd_path 模型所在路径
tf.reset_default_graph()
PATH_TO_CKPT = pd_path + '/output_graph.pd'
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
  serialized_graph = fid.read()
  od_graph_def.ParseFromString(serialized_graph)
  tf.import_graph_def(od_graph_def, name='')

classification_graph = tf.get_default_graph()

// 定义输入尺寸
height, width = 224, 224

// 跑session
with tf.Session(graph=classification_graph) as sess:
  result = classification_graph.get_tensor_by_name('final_result:0')
  input_imgs = classification_graph.get_tensor_by_name('Placeholder:0')
  y = tf.argmax(result, axis=1)

  // 图片处理
  def preimg(img):
    .......
  batchImg = [preimg(Image.open(imgfilename)) for imgfilename in sample_images]
  
  // 跑
  yv = sess.run(y, feed_dict={input_imgs: batchImg})

完毕~

【总结】
使用tf.slim进行微调比较灵活,使用tf.hub微调只能微调最后的输出层,不支持整体联调。
除了这两种,还可以基于tf.keras微调等。
微调时,选择预训练模型的策略是:
样本充足首选精度最高的模型,不足选择ResNet。
ResNet在ImgNet上泛化能力是最强的。

你可能感兴趣的:(【系列学习】5.2 通过tfHub加载模型—— tf工程化项目实战)