转自:http://www.soaringroad.com/?p=194
tf.assgin()函数很具有被误解的潜质,如果不是非常透彻理解tensorflow graph 和 op 的概念的话,一不小心就会计算错误……. 先来看下源代码:
def assign(ref, value, validate_shape=None, use_locking=None, name=None):
"""Update 'ref' by assigning 'value' to it.
This operation outputs a Tensor that holds the new value of 'ref' after
the value has been assigned. This makes it easier to chain operations
that need to use the reset value.
Args:
ref: A mutable `Tensor`.
Should be from a `Variable` node. May be uninitialized.
value: A `Tensor`. Must have the same type as `ref`.
The value to be assigned to the variable.
validate_shape: An optional `bool`. Defaults to `True`.
If true, the operation will validate that the shape
of 'value' matches the shape of the Tensor being assigned to. If false,
'ref' will take on the shape of 'value'.
use_locking: An optional `bool`. Defaults to `True`.
If True, the assignment will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
name: A name for the operation (optional).
Returns:
A `Tensor` that will hold the new value of 'ref' after
the assignment has completed.
"""
if ref.dtype._is_ref_dtype:
return gen_state_ops.assign(
ref, value, use_locking=use_locking, name=name,
validate_shape=validate_shape)
return ref.assign(value)
A `Tensor` that will hold the new value of ‘ref’ after the assignment has completed. 只有当assign()被执行了才会返回新值 下面两个例子看一下就明白了:
# --*== UTF-8 --*--
import tensorflow as tf
def test_1():
a = tf.Variable([10, 20])
b = tf.assign(a, [20, 30])
c = a + [10, 20]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("test_1 run a : ",sess.run(a)) # => [10 20]
print("test_1 run c : ",sess.run(c)) # => [10 20]+[10 20] = [20 40] 因为b没有被run所以a还是[10 20]
print("test_1 run b : ",sess.run(b)) # => ref:a = [20 30] 运行b,对a进行assign
print("test_1 run a again : ",sess.run(a)) # => [20 30] 因为b被run过了,所以a为[20 30]
print("test_1 run c again : ",sess.run(c)) # => [20 30] + [10 20] = [30 50] 因为b被run过了,所以a为[20,30], 那么c就是[30 50]
def test_2():
a = tf.Variable([10, 20])
b = tf.assign(a, [20, 30])
c = b + [10, 20]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(a)) # => [10 20]
print(sess.run(c)) # => [30 50] 运行c的时候,由于c中含有b,所以b也被运行了
print(sess.run(a)) # => [20 30]
def main():
test_1()
test_2()
if __name__ == '__main__()':
main()
如果把上面两个test弄明白了,那就真的理解了assign的操作了
assign未被执行,ref值不更新
assign_add 、assign_sub 也是一样的
assign_add(加后分配值给x,如x=x+1/x-=1)
assign_sub(减后分配值给x,x=x-1/x-=1)