Node ‘gradients/InceptionResnetV1/Bottleneck/BatchNorm/cond/FusedBatchNorm_1_grad/FusedBatchNormGrad

问题描述:
tensorflow2 运行facenet报错
Node ‘gradients/InceptionResnetV1/Bottleneck/BatchNorm/cond/FusedBatchNorm_1_grad/FusedBatchNormGrad’ has an _output_shapes attribute inconsistent with the GraphDef for output #3: Dimension 0 in both shapes must be equal, but are 0 and 512. Shapes are [0] and [512].
Node ‘gradients/InceptionResnetV1/Bottleneck/BatchNorm/cond/FusedBatchNorm_1_grad/FusedBatchNormGrad_第1张图片
这是我下载了预训练模型,然后在我自己电脑上跑的时候,读模型的时候遇到的。
报错的是import_meta_graph这个方法

saver = tf.compat.v1.train.import_meta_graph(os.path.join(model_exp, meta_file))

这里面的地址是 20180402-114759\model-20180402-114759.meta,是个meta文件。
报错的意思是:
在TensorFlow中,当您尝试使用import_meta_graph加载.meta文件时遇到错误,可能是因为在安装的TensorFlow模块版本大于等于2.2时,存在一些节点输出形状和属性值之间的不一致性,导致了内部TensorFlow错误error_output_shape。

但是知道了意思也没用,只能使用这个方法加载meta格式的模型,我也不知道该怎么跳过这些不匹配的节点,直接加载。
这个问题最根本的原因就是facenet用的是tensorflow1,但是我头铁,就是想用tensorflow2。(此时,能屈能伸赶紧重开tensorflow1环境,重新搞,就不用解决这个问题。)

问题解决:
此时,看一眼这个预训练模型的文件夹,
Node ‘gradients/InceptionResnetV1/Bottleneck/BatchNorm/cond/FusedBatchNorm_1_grad/FusedBatchNormGrad_第2张图片
meta是模型,ckpt是检查点,然后还有个pb格式的文件,这也是模型:
pb是 protocol buffer 的缩写。 TensorFlow训练模型后存成的pb文件,是一种表示模型( 神经网络 )结构的二进制文件

那直接读这个pb试试

这时看一眼这个完整的代码:
Node ‘gradients/InceptionResnetV1/Bottleneck/BatchNorm/cond/FusedBatchNorm_1_grad/FusedBatchNormGrad_第3张图片
报错的是下面的else部分,而上面的if里已经提供了一种读二进制文件的方法,也就是下面这段代码(看到什么什么file知道是读文件,还有“rb”知道是二进制编码)

with gfile.FastGFile(model_exp,'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')

那就把这个改改,用来读这个pb模型(注意两个saver不用了,注释掉)

#saver = tf.compat.v1.train.import_meta_graph(os.path.join(model_exp, meta_file))
#saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))
with tf.gfile.FastGFile(os.path.join(model_exp, model_exp+".pb"), 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')

此时注意啊,在graph_def.ParseFromString(f.read())这一行的f.read()很可能报一个错
在这里插入图片描述
这个意思是你的路径写错了(position都到68了才发现编码有问题,咋可能,正常编码问题应该position 0就报了),不是编码有问题,毕竟咱们已经改成rb来读二进制了(不读二进制的时候写r)

都改好之后,再运行一遍,模型就读出来了

你可能感兴趣的:(python,tensorflow,tensorflow2)