tensorflow:不要在session中定义运算

最近在做项目时,总是会有程序崩溃的问题,系统也没有任何提示。最后通过监控系统发现是内存溢出造成的。

追查下去,发现一段类似这样的代码,在session中调用tensorflow的api进行运算:

import tensorflow as tf
X = tf.constant([[1,2,3], [3,2,4]], dtype=tf.float32)
W = tf.constant([[1,1],[2,2],[3,3]], dtype=tf.float32)
bias = tf.constant([1, 2], dtype=tf.float32)
y = tf.nn.softmax(tf.matmul(X, W) + bias)

with tf.Session() as sess:

    for i in range(10):
        print(i)
        sess.run(tf.nn.softmax(tf.matmul(X, W) + bias))

    writer = tf.compat.v1.summary.FileWriter("./graph", sess.graph)
    writer.close()

使用tensorboard查看内存泄漏的原因:

tensorflow:不要在session中定义运算_第1张图片

将计算图展开为

tensorflow:不要在session中定义运算_第2张图片

当然,这里只是展开了softmax,其他节点也可以类似展开。

可以看到,在session中定义计算节点,存在一个很大的风险,就是会在计算图中产生新的图节点,如果像我这样使用for循环运算,那么节点数会无限增加,注意不仅仅是softmax节点在增加,其他计算节点也在增加,这样的开销会越来越大,直至程序崩溃。

为了解决这个问题,我们应该使用上面定义的y的等式,在进入session前就已经将计算图定义好,在session中直接调用,而不是重新搭建。

 

你可能感兴趣的:(tensorflow学习)