TensorFlow教程(二)

一、参考资料

Tensorflow 2.x入门教程
简单粗暴 TensorFlow 2 | A Concise Handbook of TensorFlow 2
tensorflow python deploy
TensorFlow Basics: Tensor, Shape, Type, Sessions & Operators

二、tensorflow版本对齐

tensorflow版本对应关系
RTX3060无法运行tensorflow1.x
RTX3060深度学习tensorflow环境配置之踩坑记录

2.1 重要说明

  1. tensorflow 1.x版本只能在CUDA 10.0及以前版本上运行;
  2. GeForce RTX 30系列显卡目前支持CUDA 11.1及以上版本,TensorFlow 2.4及更高版本才支持 CUDA 11;
  3. RTX 3060只能装CUDA 11以上版本,对应只能装Tensorflow2.4以上版本,RTX 3060 无法运行tensorflow 1.x版本;

2.2 tf1与tf2推理对比

本以为tf2相比于tf2改动较大,但实际上并没有那么复杂。替换一下tensorflow的导包方式即可完美解决tf2与tf1的兼容问题。

2.2.1 方法一

# 替换包导入方式
import tensorflow as tf
替换为
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

2.2.2 方法二

TensorFlow 1.X 手动开启 tf.enable_eager_execution(),手动关闭 tf.compat.v1.disable_eager_execution(),TensorFlow 2.X 默认已开启。

# 替换包导入方式
import tensorflow as tf
替换为
import tensorflow as tf
tf.compat.v1.disable_eager_execution()

# 替换其他API
x = tf.placeholder(tf.float32, shape=(1024, 1024))
替换为
x = tf.compat.v1.placeholder(tf.float32, shape=(1024, 1024))

with tf.Session() as sess:
替换为
with tf.compat.v1.Session() as sess:

2.3 compat模块

compat是TF2.X 专门为兼容TF1.X配置的模块。TF2.X默认采用动态计算图,推荐使用tf.keras高级API,用Keras写模型比PyTorch还要精简。但是TF天生就比tf.keras拥有更多的底层配置。可以使用tf.function等底层的API构建模型,能进行各方面的定制化;也可以使用tf.keras像搭积木一样搭建模型,开发者不用了解底层的架构如何搭建,只需要关注整体的设计流程即可。

TF2.X官方教程以Keras API为主,同一个功能可以由不同的API实现,但是不同API进行组合就会出现问题。也就是说,混淆了tf.keras和底层API。

三、相关介绍

3.1 tensorflow例程

TensorFlow-Examples

3.2 tensorflow构建模型

【工具】Tensorflow2.x(一)建立模型的三种方式

3.3 TensorFlow低级API

TensorFlow2.0教程-使用低级api训练(非tf.keras)

3.4 图片预处理

3.4.1 结合PIL和Opencv

# 读取图片
pil_image = Image.open(item_imgpath)
# 简单图像处理
image_np = np.array(pil_image)  # pil转np
input_tensor = tf.convert_to_tensor(image_np)  # np转tensor
input_tensor = input_tensor[tf.newaxis, ...]  # 扩展维度,三维转四维
image_np = input_tensor.eval()  # tensor转np

3.4.2 结合Tensorflow

# 图片path
image_path = './1.jpg'

# 定义输入格式
img = tf.io.read_file(image_path)
# 图片原始shape,(450, 600, 3)
img = tf.image.decode_jpeg(img, channels=3)  
# 图片resize
img = tf.image.resize(img, (299, 299))
# 添加维度,3维变4维,(1, 450, 600, 3)
img = tf.expand_dims(img, axis=0) 
# 类型转换
img = tf.cast(img, dtype=tf.uint8)

四、TensorBoard教程

4.1 生成log日志

def print_pb_info_v1(pb_path):
    """
    打印pb模型可视化结构
    :param pb_path: pb模型文件的路径
    :return:
    """
    tf.reset_default_graph()  # 重置计算图
    output_graph_path = pb_path
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        output_graph_def = tf.GraphDef()
        # 获得默认的图
        graph = tf.get_default_graph()
        with open(output_graph_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(output_graph_def, name="")
            # 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型
            _ = tf.summary.FileWriter('log_pb/', graph)

4.2 生成 TensorBoard

tensorboard --logdir="/PATH/TO/log_pb"

4.3 查看 TensorBoard

http://localhost:6006/

五、tensorflow各种模型格式

Tensorflow模型保存方式大汇总
辨析tensorflow模型存储形式:checkpoint、graphdef、savedmodel、frezzemodel
tensorflow模型保存(三)——tensorflow1.x版本的savedmodel格式的模型保存与加载

5.1 GraphDef

GraphDef是Tensorflow中序列化的图结构。计算图被保存为Protobuf格式(pb)。pb可以只保存图的结构,也可以保存结构加权重。

5.2 SignatureDef

定义图结构输入输出的节点名称和属性,一般存储于.index文件中。

5.3 tf.saved_model

将动态图保存成权重(./variables)、计算图(keras_metadata.pb)、权重和计算图(saved_model.pb)三种文件。

5.4 freeze_graph

[深度学习] TensorFlow中模型的freeze_graph

该函数将图和权重以常量的形式保存在一张静态图中(pb)。

六、可能出现的问题

6.1 tf版本过低,读取pb文件失败

关于tensorflow的报错NodeDef mentions attr ‘xxx’ not in Op的解决方案和产生原因

Traceback (most recent call last):
  File "F:\360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\framework\importer.py", line 501, in _import_graph_def_internal
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: NodeDef mentions attr 'exponential_avg_factor' not in Op<name=FusedBatchNormV3; signature=x:T, scale:U, offset:U, mean:U, variance:U -> y:T, batch_mean:U, batch_variance:U, reserve_space_1:U, reserve_space_2:U, reserve_space_3:U; attr=T:type,allowed=[DT_HALF, DT_BFLOAT16, DT_FLOAT]; attr=U:type,allowed=[DT_FLOAT]; attr=epsilon:float,default=0.0001; attr=data_format:string,default="NHWC",allowed=["NHWC", "NCHW"]; attr=is_training:bool,default=true>; NodeDef: {{node sequential/custom_layer/batch_normalization_56/FusedBatchNormV3}}. (Check whether your GraphDef-interpreting binary is up to date with your GraphDef-generating binary.).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "F:/MyDocumentes/PythonProjects/202202/20220228/tensorflow_demo.py", line 132, in <module>
    get_ops_name(pb_path=r'C:\Users\Seeking\Desktop\model.pb')
  File "F:/MyDocumentes/PythonProjects/202202/20220228/tensorflow_demo.py", line 117, in get_ops_name
    tf.import_graph_def(graph_def, name='')
  File "F:\360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "F:\360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\framework\importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "F:\360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\framework\importer.py", line 505, in _import_graph_def_internal
    raise ValueError(str(e))
ValueError: NodeDef mentions attr 'exponential_avg_factor' not in Op<name=FusedBatchNormV3; signature=x:T, scale:U, offset:U, mean:U, variance:U -> y:T, batch_mean:U, batch_variance:U, reserve_space_1:U, reserve_space_2:U, reserve_space_3:U; attr=T:type,allowed=[DT_HALF, DT_BFLOAT16, DT_FLOAT]; attr=U:type,allowed=[DT_FLOAT]; attr=epsilon:float,default=0.0001; attr=data_format:string,default="NHWC",allowed=["NHWC", "NCHW"]; attr=is_training:bool,default=true>; NodeDef: {{node sequential/custom_layer/batch_normalization_56/FusedBatchNormV3}}. (Check whether your GraphDef-interpreting binary is up to date with your GraphDef-generating binary.).
错误原因:
tensorflow的版本为1.15,但是训练保存的pb文件是2.x版本,即低版本的tensorflow无法读取高版本训练生成的pb文件。

方法一:
升级tensorflow版本。

方法二(推荐):
见上文章节【tensorflow版本】。

6.2 输入节点和输出节点不一致

tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'import/istraining' with dtype bool
	 [[node import/istraining (defined at \360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\framework\ops.py:1748) ]]

Original stack trace for 'import/istraining':
  File "/MyDocumentes/PythonProjects/202202/20220228/site_infer.py", line 84, in <module>
    return_tensors = read_pb_return_tensors(graph, pb_file, return_elements)
  File "/MyDocumentes/PythonProjects/202202/20220228/site_infer.py", line 70, in read_pb_return_tensors
    return_elements = tf.import_graph_def(frozen_graph_def, return_elements=return_elements)  # 获取输入和输出的节点
  File "\360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "\360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\framework\importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "\360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\framework\importer.py", line 517, in _import_graph_def_internal
    _ProcessNewOps(graph)
  File "\360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\framework\importer.py", line 243, in _ProcessNewOps
    for new_op in graph._add_new_tf_operations(compute_devices=False):  # pylint: disable=protected-access
  File "\360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\framework\ops.py", line 3561, in _add_new_tf_operations
    for c_op in c_api_util.new_tf_operations(self)
  File "\360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\framework\ops.py", line 3561, in <listcomp>
    for c_op in c_api_util.new_tf_operations(self)
  File "\360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\framework\ops.py", line 3451, in _create_op_from_tf_operation
    ret = Operation(c_op, self)
  File "\360Downloads\Anaconda3\envs\tf15\lib\site-packages\tensorflow_core\python\framework\ops.py", line 1748, in __init__
    self._traceback = tf_stack.extract_stack()
错误原因:
模型的输入输出节点的数量不一致,
比如,输入输出节点为:["inputs:0", "istraining:0", "outputs:0"]
【错误】推理:output = sess.run([return_tensors[1]], feed_dict={return_tensors[0]: image_data})
【正确】推理:output = sess.run([return_tensors[2]], feed_dict={return_tensors[0]: image_data, return_tensors[1]: False})

使用技巧:
打印graph.get_operations的操作节点,查看输入输出节点的数量和名称。

"""print operations"""
for op in tf.Graph().get_operations():
	print(op.name)

6.3 tf版本问题

File "F:/MyDocumentes/PythonProjects/202202/20220228/tensorflow_demo.py", line 42, in soft_demo
    ndarray_data = sess.run(sm)
  File "F:\360Downloads\Anaconda3\envs\tf22py37\lib\site-packages\tensorflow\python\client\session.py", line 971, in run
    run_metadata_ptr)
  File "F:\360Downloads\Anaconda3\envs\tf22py37\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run
    raise RuntimeError('The Session graph is empty. Add operations to the '
RuntimeError: The Session graph is empty. Add operations to the graph before calling run().
错误原因:
tensorflow版本不同导致的,tf2无法兼容版本tf1。

解决办法:
import tensorflow as tf
替换为
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()

或者
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

6.4 读取pb计算图失败

Use tf.gfile.GFile.
Traceback (most recent call last):
  File "F:/MyDocumentes/PythonProjects/utils_tools/tensorflow_utils.py", line 139, in <module>
    get_ops_name(pb_path=r'G:\ModelZoo\yolov3模型\yolov3\yolov3.pb')
  File "F:/MyDocumentes/PythonProjects/utils_tools/tensorflow_utils.py", line 38, in get_ops_name
    graph_def.ParseFromString(f.read())
google.protobuf.message.DecodeError: Error parsing message with type 'tensorflow.GraphDef'
错误原因:
SavedModel格式
/home/jitao/raccoon_dataset-master/train/saved_model/saved_model.pb 

TensorFlow Graph格式
/home/jitao/raccoon_dataset-master/train/frozen_inference_graph.pb

读取计算图的方式,应该选用 TensorFlow Graph格式,我错误的选择了SavedModel格式,导致了错误发生。 

解决办法:
修改为第二个路径之后,问题解决。

你可能感兴趣的:(深度学习,tensorflow)