参考链接:https://blog.csdn.net/sinat_36256646/article/details/81002809
https://zhuanlan.zhihu.com/p/31323002
pyplot里面自带interactive mode可以用来画动态图。在这中模式下,pyplot.plot()可以立马在画布上画出图像,而不需要pyplot.show()函数。同时,它还有清除画布上图像的方法pyplot.clf。这样我们就可以通过“画图--清除--画图”这个循环的过程,来体现动态的过程。特别要注意的是,在每一次循环后要加上pyplot.pause()函数,否则最后出来的只有一张图。大家可以参考第一个链接。
下面我们通过一个tensorflow中的线性回归?来看看:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
## parameters
learning_rate = 0.8
training_epochs = 2000
display_step = 50
## create data
train_X = np.random.rand(100).astype(np.float32)
train_X.sort()
train_Y = (train_X-0.5)**2 + 0.3
n_samples = train_X.shape[0]
X = tf.placeholder("float")
Y = tf.placeholder("float")
w = tf.Variable(np.random.randn(), name='weight')
b = tf.Variable(np.random.randn(), name='bias')
pred = tf.add(tf.multiply((X-0.5)**2,w), b)
cost = tf.reduce_sum(tf.pow(pred-Y, 2))/(2*n_samples)
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
plt.figure(num=3)
for epoch in range(training_epochs):
for (x, y) in zip(train_X, train_Y):
sess.run(optimizer, feed_dict={X:x, Y:y})
if (epoch+1) % display_step == 0:
# plt.figure(num=3)
plt.ion()
plt.cla()
c = sess.run(cost, feed_dict={X:train_X, Y:train_Y})
print("Epoch:", '%04d'%(epoch+1), "cost=", "{:.9f}".format(c),\
"W=", sess.run(w), 'b=', sess.run(b))
### draw the picture
plt.plot(train_X, train_Y, 'ro', label='Orginal data')
plt.plot(train_X, sess.run(w)*(train_X-0.5)**2+sess.run(b), label='Fitted line')
plt.legend()
plt.pause(0.05)
plt.ioff()
# plt.show()
在命令行的模式下运行这段代码就可以出现动态拟合的图像。但是,这个在jupyter notebool中运行出来的结果是很多张图片,并不是在一张图片中。这个时候就需要导入display模块,并添加%matplotlib inline。
下面是修改了以后的代码,添加的部分都备注了新添加:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
## 新添加
from IPython import display
% matplotlib inline
## parameters
learning_rate = 0.8
training_epochs = 2000
display_step = 50
## create data
train_X = np.random.rand(100).astype(np.float32)
train_X.sort()
train_Y = (train_X-0.5)**2 + 0.3
n_samples = train_X.shape[0]
X = tf.placeholder("float")
Y = tf.placeholder("float")
w = tf.Variable(np.random.randn(), name='weight')
b = tf.Variable(np.random.randn(), name='bias')
pred = tf.add(tf.multiply((X-0.5)**2,w), b)
cost = tf.reduce_sum(tf.pow(pred-Y, 2))/(2*n_samples)
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
plt.figure(num=3)
for epoch in range(training_epochs):
for (x, y) in zip(train_X, train_Y):
sess.run(optimizer, feed_dict={X:x, Y:y})
if (epoch+1) % display_step == 0:
# plt.figure(num=3)
plt.ion()
plt.cla()
c = sess.run(cost, feed_dict={X:train_X, Y:train_Y})
print("Epoch:", '%04d'%(epoch+1), "cost=", "{:.9f}".format(c),\
"W=", sess.run(w), 'b=', sess.run(b))
### draw the picture
plt.plot(train_X, train_Y, 'ro', label='Orginal data')
plt.plot(train_X, sess.run(w)*(train_X-0.5)**2+sess.run(b), label='Fitted line')
plt.legend()
plt.pause(0.05)
#### 新添加
display.clear_output(wait=True)
plt.ioff()
# plt.show()
最后运行的结果就是下面蓝色的线一直在变化来拟合红色的点。等搞清楚如何生成gif图,再上传动态图片。
2018.12.31 最后☝️个,NICE~