Tensorflow2.0中function(是1.0版本的Graph的推荐替代)的相关知识介绍

在 Tensorflow无人车使用移动端的SSD(单发多框检测)来识别物体及Graph的认识 中我们对Graph这个计算图有了一定的了解,也知道了它具备的优点:性能做了提升,可以并行处理以及由于它是一种数据结构,可以在非Python环境中进行交互。

我们先来看下自己的tensorflow的版本: 

print(tf.__version__) # 2.11.0 

目前基本上都是2.0以上,不过这个Session的用法在tensorflow2.0版本之后就没有了,所以大家在上一篇文章看到的是我使用的兼容1.0版本的用法:tf.compat.v1.Session(graph=g1)
如果是直接去调用的话:tf.compat.v1.Session(graph=g1) 就会报下面这样的错误:

AttributeError: module 'tensorflow' has no attribute 'Session'

于是到了2.0版本之后,我们使用function来代替Graph! 

1、tf.function声明

 我们具体来看下,2.0及以上版本中的tf.function是如何使用的。

import tensorflow as tf

# 常规函数
def f(x, w, b):
    y = tf.matmul(x, w) # 矩阵乘法就是dot点积运算
    return y + b

tf_f = tf.function(f)

print(type(f),type(tf_f))
# 
# 

 可以看到这里的tf.function和平时定义的def的这个函数类型是不一样,def的类型就是function,而tf.function(函数名)得到的类型是
tensorflow.python.eager.polymorphic_function.polymorphic_function.Function

eager:渴望的,急切的,这里就是一种即时模型的意思,polymorphic:意思来看是,多形态的,多态的。 

函数定义好了之后,我们来看下是如何调用并计算的

c1 = tf.constant([[1.0, 2.0]])
c2 = tf.constant([[2.0], [3.0]])
c3 = tf.constant(4.0)

f_value = f(c1, c2, c3).numpy()
tf_value = tf_f(c1, c2, c3).numpy()
print(f_value,tf_value)#[[12.]] [[12.]]

得到的结果是一样的,那我们引入这个tf.function的作用是什么呢?接着往下看

2、@tf.function装饰器

上面我们可以看到 tf.function 的类型,虽然也是函数,但跟常规函数还是有很大区别,因为我们的目的是能够代替Graph,而使用计算图的目的又是为了性能的提升,所以应该知道这个函数所要表达的意思了吧,在这里我们可以使用 @tf.function 装饰器,就可以让这种即时执行模式的控制流转换成计算图的方式了。

其中matmulMatrixMultiple的缩写,矩阵乘法的意思,也就是在numpy中的dot点积运算的用法(行乘以列的和)
实际上,这个tf.function可能封装多个tf.graph,所以这两种不同的函数表达,在性能和部署上存在很大的不同。

import tensorflow as tf

def inner_function(x, w, b):
    x = tf.matmul(x, w)
    return x + b

# 使用装饰器来定义函数
@tf.function
def outer_function(x):
    w = tf.constant([[2.0], [3.0]])
    b = tf.constant(4.0)
    return inner_function(x, w, b)

# 创建一个Graph计算图,里面包含inner_function和outer_function
r1 = outer_function(tf.constant([[1.0, 2.0]])).numpy()
r2 = outer_function(tf.constant([[1.0, 2.0],[3.0, 4.0],[5.0, 6.0]])).numpy()
print(r1)
print(r2)

'''
[[12.]]

[[12.]
 [22.]
 [32.]]
'''

这里使用了一个@tf.function装饰器来声明这个函数为多态函数,我们来打印看下它的具体特征:

print(outer_function.pretty_printed_concrete_signatures())
'''
outer_function(x)
  Args:
    x: float32 Tensor, shape=(1, 2)
  Returns:
    float32 Tensor, shape=(1, 1)

outer_function(x)
  Args:
    x: float32 Tensor, shape=(3, 2)
  Returns:
    float32 Tensor, shape=(3, 1)
'''

我使用了两种形状的输入,这里也对应出现两种形式的计算图。这种多态的作用是可以用来提升性能,因为可以判断输入的类型(以及形状),如果是一样的形状,同类型的就不需要新建计算图,我们接着来看下

r3 = outer_function(tf.constant([[11.0, 22.0],[3.0, 4.0],[5.0, 6.0]])).numpy()
print(outer_function.pretty_printed_concrete_signatures())

'''
outer_function(x)
  Args:
    x: float32 Tensor, shape=(1, 2)
  Returns:
    float32 Tensor, shape=(1, 1)

outer_function(x)
  Args:
    x: float32 Tensor, shape=(3, 2)
  Returns:
    float32 Tensor, shape=(3, 1)
'''

可以看到结果是一样的,对于这样的输入,因为r3跟r2的类型形状是一样的,所以r3可以使用r2的,那么从另一角度可以理解为缓存,当数据类型或形状不一致的时候才会创建新的计算图。

r4 = outer_function([[11.0, 22.0],[3.0, 4.0],[5.0, 6.0]]).numpy()
print(outer_function.pretty_printed_concrete_signatures())

'''
outer_function(x)
  Args:
    x: float32 Tensor, shape=(1, 2)
  Returns:
    float32 Tensor, shape=(1, 1)

outer_function(x)
  Args:
    x: float32 Tensor, shape=(3, 2)
  Returns:
    float32 Tensor, shape=(3, 1)

outer_function(x=[[11.0, 22.0], [3.0, 4.0], [5.0, 6.0]])
  Returns:
    float32 Tensor, shape=(3, 1)
'''

这里的r4虽然输出是跟r3一样,不过这里的输入类型不一致,所以还是会新建一个。

3、tf.autograph

现在我们又回到最开始的tf.function,它的本质其实是对原函数做了转换,函数体做了新的变化处理。依然是上面的示例,我们查看下它的本质:

import tensorflow as tf

# 常规函数
def f(x, w, b):
    y = tf.matmul(x, w) # 矩阵乘法就是dot点积运算
    return y + b

tf_f = tf.function(f)
w = tf.constant([[2.0], [3.0]])
b = tf.constant(4.0)
print(tf_f(tf.constant([[1.0, 2.0]]),w,b).numpy()) # [[12.]]
print(tf.autograph.to_code(f))
'''
def tf__f(x, w, b):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()
        y = ag__.converted_call(ag__.ld(tf).matmul, (ag__.ld(x), ag__.ld(w)), None, fscope)
        try:
            do_return = True
            retval_ = (ag__.ld(y) + ag__.ld(b))
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)
'''

可以看到这里是将原函数 转成 tf__f 函数,其函数体是做了另外的处理,里面结构也是很类似的。我们打印 tf__f 这个graph计算图的详情看下:

print(tf_f.get_concrete_function(tf.constant([[1.0, 2.0]]),w,b).graph.as_graph_def())

'''
node {
  name: "x"
  op: "Placeholder"
  attr {
    key: "_user_specified_name"
    value {
      s: "x"
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: 1
        }
        dim {
          size: 2
        }
      }
    }
  }
}
node {
  name: "w"
  op: "Placeholder"
  attr {
    key: "_user_specified_name"
    value {
      s: "w"
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: 2
        }
        dim {
          size: 1
        }
      }
    }
  }
}
node {
  name: "b"
  op: "Placeholder"
  attr {
    key: "_user_specified_name"
    value {
      s: "b"
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "MatMul"
  op: "MatMul"
  input: "x"
  input: "w"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "transpose_a"
    value {
      b: false
    }
  }
  attr {
    key: "transpose_b"
    value {
      b: false
    }
  }
}
node {
  name: "add"
  op: "AddV2"
  input: "MatMul"
  input: "b"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Identity"
  op: "Identity"
  input: "add"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
versions {
  producer: 1286
}
'''

可以看到每个节点里面是名称、属性、类型、操作等详细的说明,这也印证了最开始说的这个Graph计算图是一种数据结构,数据类型是

4、追踪

graph里的Tracing也是其一个特性,这里的print不在追踪范围内,所以虽然调用了三次,结果只输出一次!

import tensorflow as tf
#tf.config.run_functions_eagerly(False)
@tf.function
def mse(y_true, y_pred):
    print("计算均方误差")
    tf.print("均方误差")
    sq_diff = tf.pow(y_true - y_pred, 2)
    return tf.reduce_mean(sq_diff)
  
tf.config.run_functions_eagerly(False)

y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)

print(y_true,y_pred)
mse_r = mse(y_true, y_pred)
mse_r = mse(y_true, y_pred)
mse_r = mse(y_true, y_pred)

print(mse_r.numpy())
'''
tf.Tensor([8 0 6 3 9], shape=(5,), dtype=int32) tf.Tensor([9 5 9 3 9], shape=(5,), dtype=int32)
计算均方误差
均方误差
均方误差
均方误差
7
'''

其中的tf.print是可以追踪的,所以每次的调用都会输出。

5、非严格执行

tf.graph计算图只关心需要的操作,其余的不会关心,就算是错误的情况也不处理。

def t(x):
    tf.gather(x, [3])
    return x

try:
    print(t(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:
    print(f'{type(e).__name__}: {e}')

gather用法: 

help(tf.gather)
gather_v2(params, indices, validate_indices=None, axis=None, batch_dims=0, name=None)


params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
params[3].numpy() # b'p3'
indices = [2, 0, 2, 5]
tf.gather(params, indices).numpy() #array([b'p2', b'p0', b'p2', b'p5'], dtype=object)

这里的tf.gather调用,我们可以很明显知道,在输入是[0]的情况,[3]索引的数据是不存在的,所以会报错:

InvalidArgumentError: {{function_node __wrapped__GatherV2_device_/job:localhost/replica:0/task:0/device:CPU:0}} indices[0] = 3 is not in [0, 1) [Op:GatherV2]

而当我们使用@tf.function装饰器来装饰这个函数的时候,会发现即便有错误也不会执行。 

@tf.function
def t(x):
    tf.gather(x, [3])
    return x

print(t(tf.constant([0.0])))#tf.Tensor([0.], shape=(1,), dtype=float32)

这也再次说明了计算图只关心流程图,里面的具体计算不会去验证。其中tf.gather的用法就是在指定维度抽取数据,这个用法在很多情况使用起来特别有用,我们再来看一个示例:

import tensorflow as tf
a = tf.range(8)
a = tf.reshape(a, [4,2])
print(a)
print(tf.gather(a, [3,1,0], axis=0))
'''
tf.Tensor(
[[0 1]
 [2 3]
 [4 5]
 [6 7]], shape=(4, 2), dtype=int32)
tf.Tensor(
[[6 7]
 [2 3]
 [0 1]], shape=(3, 2), dtype=int32)
'''

你可能感兴趣的:(Python,tf.function,function装饰器,tf.autograph,as_graph_def,tf.gather)