facenet加载pretrained_model报错问题解决

facenet加载pretrained_model报错问题解决

问题

facenet github地址:https://github.com/davidsandberg/facenet
使用facenet自己训练模型中断后,继续运行程序:

// 继续训练
python3 src/train_softmax.py --pretrained_models models/20190303-192600/ --其他参数不写了

报错。
注:models/20190303-192600是上次自己训练保存模型的文件夹

原因

Facenet程序调用了tf.train.Saver类,加载预训练模型时使用了Saver.restore(sess, pretrained_model)方法,该方法传入的第二个变量应该为models/20190303-192600/model-20190303-192600.ckpt-6的字符串。
对比models/20190303-192600文件夹下的内容,发现正确输入的参数并不指向预训练模型文件夹下任何一个文件。
models/20190303-192600文件夹下的内容:

zzd@zzd-ubuntu-K80:20190303-192600$ ls
checkpoint
model-20190303-192600.ckpt-4.data-00000-of-00001
model-20190303-192600.ckpt-4.index
model-20190303-192600.ckpt-5.data-00000-of-00001
model-20190303-192600.ckpt-5.index
model-20190303-192600.ckpt-6.data-00000-of-00001
model-20190303-192600.ckpt-6.index
model-20190303-192600.meta

解决办法

方法一:
继续运行程序时输入正确的参数即可:

// 继续训练
python3 src/train_softmax.py --pretrained_models models/20190303-192600/model-20190303-192600.ckpt-6 --其他参数不写了

方法二:
修改:src/train_softmax.py 在第201行后添加一行并修改原202行,修改后为:

200             if pretrained_model:
201                 print('Restoring pretrained model: %s' % pretrained_model)
202                 ckpt = tf.train.get_checkpoint_state(os.path.dirname(pretrained_model))
203                 saver.restore(sess, ckpt.model_checkpoint_path))

该方法能实现继续运行程序时,只需输入预训练模型所在文件夹的路径:

// 继续训练
python3 src/train_softmax.py --pretrained_models models/20190303-192600/ --其他参数不写了

结束。

你可能感兴趣的:(facenet加载pretrained_model报错问题解决)