【Tensorflow2.1】

0 涉及全部代码和tf版本(2.1.0)

import tensorflow as tf
print(tf.__version__)

def fun(x1, x2):
    return 8 * x1 + 6 * x2

lr = 0.01
x1 = tf.Variable(1.0)
x2 = tf.Variable(1.0)
op = tf.keras.optimizers.SGD(0.01)

with tf.GradientTape() as t:
    t.watch([x1, x2])
    y = fun(x1, x2)

grads = t.gradient(y, [x1, x2])
op.apply_gradients([(grads[i],x) for i, x in enumerate([x1, x2])])
print([x1, x2])

1 变量求导数

import tensorflow as tf
print(tf.__version__)

def fun(x1, x2):
    return 8 * x1 + 6 * x2

x1 = tf.Variable(1.0)
x2 = tf.Variable(1.0)

with tf.GradientTape() as t:
    y = fun(x1, x2)

print(t.gradient(y,x1))
#这种情况会报错,gradient这个对象在on non-persistent tapes情况下只能被使用一次
print(t.gradient(y,x2)) 

【Tensorflow2.1】_第1张图片
修改代码后

import tensorflow as tf
print(tf.__version__)

def fun(x1, x2):
    return 8 * x1 + 6 * x2

x1 = tf.Variable(1.0)
x2 = tf.Variable(1.0)

with tf.GradientTape(persistent = True) as t:
    y = fun(x1, x2)

print(t.gradient(y,x1))
print(t.gradient(y,x2))

可得
在这里插入图片描述
【注意】使用了persistent = True后需要自己手动释放资源

del t

或者使用list形式对多个变量求导

import tensorflow as tf
print(tf.__version__)

def fun(x1, x2):
    return 8 * x1 + 6 * x2

x1 = tf.Variable(1.0)
x2 = tf.Variable(1.0)

with tf.GradientTape() as t:
    y = fun(x1, x2)

print(t.gradient(y,[x1,x2]))

得到的结果也是一个list
在这里插入图片描述

2 利用求到的导数更新参数

  • 使用assign函数赋值
import tensorflow as tf
print(tf.__version__)

def fun(x1, x2):
    return 8 * x1 + 6 * x2

lr = 0.01
x1 = tf.Variable(1.0)
x2 = tf.Variable(1.0)

with tf.GradientTape() as t:
    t.watch([x1, x2])
    y = fun(x1, x2)

grads = t.gradient(y, [x1, x2])
x1.assign_sub(lr * grads[0])
print(x1)

在这里插入图片描述

  • 使用keras优化器自动赋值
import tensorflow as tf
print(tf.__version__)

def fun(x1, x2):
    return 8 * x1 + 6 * x2

lr = 0.01
x1 = tf.Variable(1.0)
x2 = tf.Variable(1.0)
op = tf.keras.optimizers.SGD(0.01)

with tf.GradientTape() as t:
    t.watch([x1, x2])
    y = fun(x1, x2)

grads = t.gradient(y, [x1, x2])
op.apply_gradients([(grads[i],x) for i, x in enumerate([x1, x2])])
print([x1, x2])

在这里插入图片描述

你可能感兴趣的:(深度学习,tensorflow)