近日,在使用Siamese网络实现西储大学轴承数据故障诊断中,测试的过程出现了
1 ValueError: The passed save_path is not a valid checkpoint
的错误。错误是由于在测试的过程中导入checkpoint时,传入的save_path是无效的,或者是说,传入的save_path在给定的路径中没有找到对应的文件。
网上关于该问题的解决方案主要包含两个方面:
- checkpoint路径应该使用相对路径;
- 路径字符不要太长
但均没有从本质上解决遇到的问题,也没有从源头讲明白bug出现的缘由。
Tensorflow会将模型保存生成四个文件,如下图所示。
- 图a的情况是模型保存时,仅传入了地址,而地址中不包含文件的名称。
如第6行代码所示,传入的save_path中只包含要保存checkpoint的路径,未声明保存文件的名称。
在这种情况下,checkpoint_dir可以直接作为路径传入模型恢复save.restore()的函数中。
1 ... 2 TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now()) 3 checkpoint_dir = os.path.join(".\\checkpoint\\%s" % (TIMESTAMP)) 4 save = tf.train.Saver() 5 ... 6 save_path = save.save(sess,save_path=checkpoint_dir+"\\")
- 图b的情况是模型保存时,地址中添加了需要保存文件的名称filename,并且在save声明时,使用了max_to_keep=1的设置,即保存的文件名称中,在XXX.ckpt后包含 "-1" 的名称,其表示当前保存模型的训练代数。
在这种情况下,使用当前的checkpoint_dir作为模型恢复saver.restore()函数中的路径,将会报错。
1 ValueError: The passed save_path is not a valid checkpoint
1 ... 2 TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.datetime.now()) 3 checkpoint_dir = ".\\checkpoint\\%s" % TIMESTAMP 4 filename = "few_shot_learning_fault_diagnosis" 5 checkpoint_dir = os.path.join(checkpoint_dir, filename+".ckpt") 6 save = tf.train.Saver(max_to_keep=1) 7 ... 8 saver.save(sess, diag_obj.checkpoint_dir, global_step=step)
总结: 在编写时,如果使用的是save = tf.train.Saver() 使用了max_to_keep=1的设置,并且在模型训练保存的过程中,是每训练一代保存一次。 此时,checkpoint_dir将不再适用于save.restore(sess, checkpoint_dir)中的checkpoint_dir。因为从图b中可以看到,其包含-1(训练代数的后缀)。如果仍将checkpoint_dir作为模型参数读入的地址传入save.restore()中,将会报
1 ValueError: The passed save_path is not a valid checkpoint
的错误。
【解决方法:】
使用tf.train.latest_checkpoint()函数,将不包含文件名称的路径传入函数中,获取到文件的路径module_file,并将其传入saver.restore()中,便可以解决上述问题。
1 ... 2 module_file = tf.train.latest_checkpoint(diag_obj.save_path) 3 saver.restore(sess, module_file)
module_file获取到的结果如下所示, 其包含训练代数的信息,这也是为什么直接使用原始的checkpoint_dir 会报错的原因。
1 module_dir: few_shot_learning_fault_diagnosis.ckpt-1
该错误的原因是由于每一代保存一次而造成的设置的保存文件名称与实际保存文件名称不一致。