问题描述:
tensorflow2 运行facenet报错
Node ‘gradients/InceptionResnetV1/Bottleneck/BatchNorm/cond/FusedBatchNorm_1_grad/FusedBatchNormGrad’ has an _output_shapes attribute inconsistent with the GraphDef for output #3: Dimension 0 in both shapes must be equal, but are 0 and 512. Shapes are [0] and [512].
这是我下载了预训练模型,然后在我自己电脑上跑的时候,读模型的时候遇到的。
报错的是import_meta_graph这个方法
saver = tf.compat.v1.train.import_meta_graph(os.path.join(model_exp, meta_file))
这里面的地址是 20180402-114759\model-20180402-114759.meta,是个meta文件。
报错的意思是:
在TensorFlow中,当您尝试使用import_meta_graph加载.meta文件时遇到错误,可能是因为在安装的TensorFlow模块版本大于等于2.2时,存在一些节点输出形状和属性值之间的不一致性,导致了内部TensorFlow错误error_output_shape。
但是知道了意思也没用,只能使用这个方法加载meta格式的模型,我也不知道该怎么跳过这些不匹配的节点,直接加载。
这个问题最根本的原因就是facenet用的是tensorflow1,但是我头铁,就是想用tensorflow2。(此时,能屈能伸赶紧重开tensorflow1环境,重新搞,就不用解决这个问题。)
问题解决:
此时,看一眼这个预训练模型的文件夹,
meta是模型,ckpt是检查点,然后还有个pb格式的文件,这也是模型:
pb是 protocol buffer 的缩写。 TensorFlow训练模型后存成的pb文件,是一种表示模型( 神经网络 )结构的二进制文件
那直接读这个pb试试
这时看一眼这个完整的代码:
报错的是下面的else部分,而上面的if里已经提供了一种读二进制文件的方法,也就是下面这段代码(看到什么什么file知道是读文件,还有“rb”知道是二进制编码)
with gfile.FastGFile(model_exp,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
那就把这个改改,用来读这个pb模型(注意两个saver不用了,注释掉)
#saver = tf.compat.v1.train.import_meta_graph(os.path.join(model_exp, meta_file))
#saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))
with tf.gfile.FastGFile(os.path.join(model_exp, model_exp+".pb"), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
此时注意啊,在graph_def.ParseFromString(f.read())这一行的f.read()很可能报一个错
这个意思是你的路径写错了(position都到68了才发现编码有问题,咋可能,正常编码问题应该position 0就报了),不是编码有问题,毕竟咱们已经改成rb来读二进制了(不读二进制的时候写r)
都改好之后,再运行一遍,模型就读出来了