为了效率!!!TensorFLow2.0静态图与eager模式(动态图)

总结自此博客

为了效率!!!eager与静态图转换的那些坑

eager模式: 就是动态图
调用方法: tensorflow2.0默认eager模式
优点 Python写法,方便调试
缺点 比自动图速度慢
静 态 图 模 式 :
调用方法: 需要在函数上方加上@tf.function装饰
优点 tf写法,速度快
缺点 不如python写法直观
转换方式 在头上加@tf.function

第一个大坑

#调用静态图的例子
@tf.function
def f():
    a = tf.constant([[10,10],[11.,1.]])
    x = tf.constant([[1.,0.],[0.,1.]])
    b = tf.Variable(12.)
    y = tf.matmul(a, x) + b
    print("PRINT: ", y)
    tf.print("TF-PRINT: ", y)
    return y

f()
#此时执行会报错

这为啥报错?
当然可能是因为你写错了
答:为啥?因为tf.Variable在计算图中是一个持续存在的节点,不受作用域的影响,函数结束没销毁啊
解决的一般的方法是把tf.Variable当做类属性来调用

#此时就可以行的通了
class F():
    def __init__(self):
        self._b = None

    @tf.function
    def __call__(self):
        a = tf.constant([[10, 10], [11., 1.]])
        x = tf.constant([[1., 0.], [0., 1.]])
        if self._b is None:
            self._b = tf.Variable(12.)
        y = tf.matmul(a, x) + self._b
        print("PRINT: ", y)
        tf.print("TF-PRINT: ", y)
        return y

f = F()
f()


#或者
@tf.function
def f(b):
    a = tf.constant([[10,10],[11.,1.]])
    x = tf.constant([[1.,0.],[0.,1.]])
    y = tf.matmul(a, x) + b
    print("PRINT: ", y)
    tf.print("TF-PRINT: ", y)
    return y

b = tf.Variable(12.)
f(b)
其余大坑
坑1 静态图尽量使用tf.Tensor做参数,tensorflow会根据Python原生数据类型的值不同(这里不受类型影响,浮点数1.0等于int型1),而重复创建图,导致适得其反,速度反而变慢

比如

@tf.function
def g(x):
  return x

start = time.time()
for i in tf.range(1000):
  g(i)
end = time.time()

print("tf.Tensor time elapsed: ", (end-start))

start = time.time()
for i in range(1000):
  g(i)
end = time.time()

print("Native type time elapsed: ", (end-start))


#结果为
tf.Tensor time elapsed:  0.41594886779785156
Native type time elapsed:  5.189513444900513
#静态图模式速度反而变慢了,所以注意,静态图使用tf.range,而不是python原生的range
其余大坑
坑2 静态图tf.Tensor比值比的是tensor的hash值,而不是原本的值
def outPy(x,y):
    if x<y:
        print(1)
@tf.function
def outTf(x,y):
    if x<y:
        print(1)

为了效率!!!TensorFLow2.0静态图与eager模式(动态图)_第1张图片

所以,到底应该如何写函数啊?
答:当然是使用tf自带的数学方法

tf自带数学比大小函数
大于 tf.math.greater(a, b)
等于 tf.math.equal(a, b)
小于 tf.math.less(a, b)

你可能感兴趣的:(tensorflow)