【系列学习】6.1 用tf编写训练模型的程序—— tf工程化项目实战

1.几个概念

【静态图概念】
tensorflow基础知识,使用分为模型构建(构造op)和模型运行(sess.run)。

【动态图】
tf1.11后动态图已经比较常用了,这边不过多介绍了,相比较tf的动态图,应该还是pytorch大家用的比较多。

【Estimators API】
tf提供的API,封装了一些训练模型/测试准确率/生成预测的方法。
Estimator是以tf.layers接口上构建的,框架分为3个主要部分:

  1. 输入函数:主要由tf.data.Dataset构成
  2. 模型函数:tf.layers / tf.metrics 训练/测试/监控参数
  3. 估算器:粘合各部分

该框架还提供了一些封装好的常用模型:LinearRegressor / LinearClassifier / DNNRegressor
/DNNClassifier

该框架对模型训练使用做了高度集成,缺点是开发模型过程中无法精确控制细节。适用于对成熟模型进行训练使用的场景。

【其他接口】

  1. tf.layers接口
    类似于tf.slim的API。tf.layers常用于动态图,TF-slim常用于静态图。在tf2版本中,可直接使用tf.keras接口。
  2. tf.keras接口
    把keras封装进tf的接口。
    keras学习地址
  3. tf.js接口
    学习地址
  4. TFLearn框架
    类似tf的框架

【分配资源】

  1. 为整个程序指定GPU
    通过设置 CUDA_VISIBLE_DEVICES实现。
    可以在命令端 python命令前直接加上 CUDA_VISIBLE_DEVICES=1 (举例)实现,也可以在程序中设置。
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

  2. 为整个程序指定所占的GPU显存
    通过构建tf.ConfigProto实现

  3. 为不同的OP指定GPU
    在代码前使用tf.device语句:
    with tf.device('/cpu:0'):

  4. 使用分布策略
    主要策略有 MirroredStrategy / CollectiveAllReduceStrategy / ParameterServerStrategy 等策略。使用方式比较简单,实例化一个分布策略对象,作为参数传入训练模型。

例如:

distribution = tf.contrib.distribute.MirriredStrategy()
model.compile(loss='mean_squared_error', optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2), distribute=distribution)

在使用多机训练时,需指定网络中的角色关系:

【其他训练技巧】

  1. tfdbg
  2. training_hooks
  3. tf1 移植到 tf2
    ....

介绍完这些概念之后,开始用静态图训练一个具有保存检查点功能的回归模型~

2.训练模型

【如何生成checkpoint文件】

  1. 实例化一个saver对象
    常用参数有:
    var_list: 要保存的变量
    max_to_keep: 最多保留的数量
    keep_checkpoint_every_n_hours: 间隔保存时间
saver = tf.train.Saver(tf.global_variables(), max_to_keep)
  1. 在session里,调用saver对象的save保存checkpoint文件
saver.save(sess, savedir + 'xxxmodel.cpkt', global_step=epoch)

【如何载入checkpoint文件】
使用tf.train.latest_checkpoint 找最近的checkpoint文件,在使用saver.restore载入checkpoint文件。

    kpt = tf.train.latest_checkpoint(save_path)
    if kpt!=None:
      saver.restore(sess, kpt)

在代码中的运用可参考5.1中model.py的load_cpk函数:
begin 0时实例化saver,为其他时(train时实例化model会begin不传调用一次来初始化,sess.run里还会调用一次来加载恢复checkpoint模型)

  def load_cpk(self, global_step, sess, begin=0, saver=None, save_path = None):
    if begin == 0:
      save_path = r'./train_nasnet'
      if not os.path.exists(save_path):
        print("no model path")
      saver = tf.train.Saver(max_to_keep=1)
      return saver, save_path
    else:
      kpt = tf.train.latest_checkpoint(save_path)
      print("load model", kpt)
      startepo = 0
      if kpt!=None:
        saver.restore(sess, kpt)
        ind = kpt.find("-")
        startepo = int(kpt[ind+1:])
        print("global_step=", global_step.eval(),startepo)
      return startepo

。。。未完待续。。。

你可能感兴趣的:(【系列学习】6.1 用tf编写训练模型的程序—— tf工程化项目实战)