python 生成动图


主要利用和matplotlibimageio这两个python库。由于matplotlib无法直接得到所绘图的RGB值,所以每次画完一帧图后,保存下来再读取得到每一帧的RGB值,最后使用imageio将所有的帧连接起来组合成一个动图。这种方法是很多生成动图的方法中较为简单的一种,但是因为每次都要保存和读取图片,所以会增加一定的程序耗时。


下面是使用pytorch写的一个简单的线性回归的例子:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import imageio

torch.manual_seed(0)

num_samples = 100

x_train = torch.linspace(0, 1, num_samples)
y_train = 0.1 * x_train + 0.2 + torch.randn(num_samples)*0.03

w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)

criterion = nn.MSELoss()
optimizer = torch.optim.SGD([w,b], lr=0.01)

images = []
num_epochs = 4000
for epoch in range(num_epochs):
    
    y_pred = w * x_train + b
    loss = criterion(y_pred, y_train)
    
    optimizer.zero_grad()
    loss.backward()

    if epoch % 100 == 99:
        plt.figure()
        plt.ylim(torch.min(y_train).item(), torch.max(y_train).item())
        plt.scatter(x_train.tolist(), y_train.tolist(), marker='.')
        plt.plot(x_train.tolist(), y_pred.tolist(), color='r', linewidth=2)
        plt.title('Epoch [{}/{}], Loss: {:.6f}, \n Weight: {:.6f}, Bias: {:.6f}'
                  .format(epoch+1, num_epochs, loss.item(), w.item(), b.item()))
        plt.savefig('a.png')
        plt.close()
        
        images.append(imageio.imread('a.png'))
        
    optimizer.step()
    
imageio.mimsave('gen.gif', images, duration=0.5)
复制代码

最后生成的动图如下:

转载于:https://juejin.im/post/5bc056ee6fb9a05d330ae0cf

你可能感兴趣的:(python,人工智能,数据结构与算法)