Tensorflow–tf.assign()详解

Tensorflow–tf.assign()详解

转自:http://www.soaringroad.com/?p=194

tf.assgin()函数很具有被误解的潜质,如果不是非常透彻理解tensorflow graph 和 op 的概念的话,一不小心就会计算错误……. 先来看下源代码:

def assign(ref, value, validate_shape=None, use_locking=None, name=None):
  """Update 'ref' by assigning 'value' to it.

  This operation outputs a Tensor that holds the new value of 'ref' after
    the value has been assigned. This makes it easier to chain operations
    that need to use the reset value.

  Args:
    ref: A mutable `Tensor`.
      Should be from a `Variable` node. May be uninitialized.
    value: A `Tensor`. Must have the same type as `ref`.
      The value to be assigned to the variable.
    validate_shape: An optional `bool`. Defaults to `True`.
      If true, the operation will validate that the shape
      of 'value' matches the shape of the Tensor being assigned to.  If false,
      'ref' will take on the shape of 'value'.
    use_locking: An optional `bool`. Defaults to `True`.
      If True, the assignment will be protected by a lock;
      otherwise the behavior is undefined, but may exhibit less contention.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` that will hold the new value of 'ref' after
      the assignment has completed.
  """
  if ref.dtype._is_ref_dtype:
    return gen_state_ops.assign(
        ref, value, use_locking=use_locking, name=name,
        validate_shape=validate_shape)
  return ref.assign(value)

A `Tensor` that will hold the new value of ‘ref’ after the assignment has completed. 只有当assign()被执行了才会返回新值 下面两个例子看一下就明白了:

# --*== UTF-8 --*--
import tensorflow as tf


def test_1():
    a = tf.Variable([10, 20])
    b = tf.assign(a, [20, 30])
    c = a + [10, 20]
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print("test_1 run a : ",sess.run(a)) # => [10 20] 
        print("test_1 run c : ",sess.run(c)) # => [10 20]+[10 20] = [20 40] 因为b没有被run所以a还是[10 20]
        print("test_1 run b : ",sess.run(b)) # => ref:a = [20 30] 运行b,对a进行assign
        print("test_1 run a again : ",sess.run(a)) # => [20 30] 因为b被run过了,所以a为[20 30]
        print("test_1 run c again : ",sess.run(c)) # => [20 30] + [10 20] = [30 50] 因为b被run过了,所以a为[20,30], 那么c就是[30 50]


def test_2():
    a = tf.Variable([10, 20])
    b = tf.assign(a, [20, 30])
    c = b + [10, 20]
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(a)) # => [10 20] 
        print(sess.run(c)) # => [30 50] 运行c的时候,由于c中含有b,所以b也被运行了
        print(sess.run(a)) # => [20 30]


def main():
    test_1()
    test_2()


if __name__ == '__main__()':
    main()

如果把上面两个test弄明白了,那就真的理解了assign的操作了

总结

assign未被执行,ref值不更新

assign_add 、assign_sub 也是一样的

assign_add(加后分配值给x,如x=x+1/x-=1)

assign_sub(减后分配值给x,x=x-1/x-=1)

你可能感兴趣的:(tensorflow)