如何在TensorFlow2.X中使用自定义训练循环的情况下在TensorBoard中绘制网络结构图(计算图)

遇到的问题

很多小伙伴在使用TensorFlow2.x的时候会进行自定义的循环,也就是自己采用for循环来逐个Epoch循环;同时又想将此时的网络图绘制在TensorBoard中。这个时候问题就出现了:TensorBoard在2.0以后的版本中的的网络图是默认在model.fit之中自动绘制的;

# 使用fit函数的时候会自动绘制网络计算图
model.fit(trrain_dataset, epoch=10, ......)

倘若想要自定义训练循环则又需要手动绘制网络图。

# 自定义寻来你循环的时候,TensorFlow不会帮助我们绘制网络计算图
for epooch in range(1, EPOCHS):
	SDG...
	LOSS...
	Record...

而网络上关于TensorFlow2.x绘制网络图的说明是少之又少,于是我决定写这篇博客来帮助大家来实现网络图的绘制。

如何在TensorFlow2.x中自己绘制网络图

直接给大家展示代码

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras.datasets import mnist
from tensorflow.python.ops import summary_ops_v2  # 需要引入这个模块

logs_dir='你的自定义的日志目录'

# 你创建的模型
class ClassModel(tf.keras.Model):
    def __init__(self, ...):
        super(ClassModel, self).__init__()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(self.num_classes, activation='softmax')
        ... # 其他操作
        
    @tf.function # 需要使用tf.function
    def call(self, inputs):
        inputs = self.d1(inputs)
        output = self.d2(inputs)
        return output

# inputs可以是符合你输入数据形状的输入数据
inputs=training_dataset  
model=ClassModel()

# 开始创建网络计算图
graph_writer = tf.summary.create_file_writer(logdir=logs_dir)
with graph_writer.as_default():
    graph=model.call.get_concrete_function(inputs).graph
    summary_ops_v2.graph(graph.as_graph_def())
graph_writer.close()

通过这个流程,就可以构建出你的网络模型图了。
在这个过程中,有几点注意事项

  1. from tensorflow.python.ops import summary_ops_v2 需要引入这个模块
  2. 自定义模型中的call需要使用tf.function注解标注
  3. inputs可以为任何符合网络输入形状的数据,比如我的网络输入为(None, 32, 32, 3),那么我就可以令inputs=tf.ones((64, 32, 32, 3)),也就是说可以使用该数据跑通这个模型即可
  4. 使用tf.summary的FileWriter来进行绘制

绘制结果可以在TensorBoard的URL之中查看:
如何在TensorFlow2.X中使用自定义训练循环的情况下在TensorBoard中绘制网络结构图(计算图)_第1张图片

总结

其实这也是笔者找了很多文档都没发现,然后自己研究出来的方法。希望可以帮到大家。如果大家有任何问题,可以添加笔者QQ进行讨论:1574143668.
请大家在学习与工作的过程中不要忘记互联网创立的初衷——分享。

你可能感兴趣的:(机器学习,Python,TensorFlow)