matplotlib动态绘制训练进度【训练精度、训练损失、测试精度】

文章目录

  • 一、注意事项
  • 二、实现


一、注意事项

  1. plt.legend() 是展示图例的,就是 ax.plot() 里面的label值,plt.legend() 必须在 ax.plot() 之后调用才生效,并且因为 plt.legend() 每调用一次就会生成一次图例,所以仅在首次绘制时调用即可。
  2. plt.pause(0.001) 不可省略
  3. 绘制时 x、y 数组的长度必须相等且值一一对应,长度可逐步增加,但是不要出现 None 值,否则绘制的线条会出现断点

二、实现

import matplotlib.pyplot as plt
import time
from matplotlib_inline import backend_inline


class TrainVision(object):
    def __init__(self):

        # svg模式
        backend_inline.set_matplotlib_formats('svg')

        # 用于显示正常中文标签
        plt.rcParams['font.sans-serif'] = ['SimHei']

        # 在 1*1 的画布 fig 上创建图纸 ax
        self.fig, self.ax = plt.subplots(1, 1, figsize=(3.5, 2.5))
        # 展示网格线
        self.ax.grid()
        # x轴标签
        self.ax.set_xlabel('epochs')
        # y轴标签
        self.ax.set_ylabel('acc & loss')
        # epoch、训练精度、训练损失、测试精度 的累加数组
        # y是精度或者损失值,x是y对应的epoch
        self.train_acc = {'x': [], 'y': []}
        self.train_loss = {'x': [], 'y': []}
        self.test_acc = {'x': [], 'y': []}
        # 是否加载过图例的标记
        self.is_init_legend = False

    def draw(self, epoch_x, train_acc_y1, train_loss_y2, test_acc_y3):

        # 加入位置信息数组
        if epoch_x:
            if train_acc_y1:	# 训练精度 为 None 时不加入数组
                self.train_acc['x'].append(epoch_x)
                self.train_acc['y'].append(train_acc_y1)
            if train_loss_y2:	# 训练损失 为 None 时不加入数组
                self.train_loss['x'].append(epoch_x)
                self.train_loss['y'].append(train_loss_y2)
            if test_acc_y3:		# 测试精度 为 None 时不加入数组
                self.test_acc['x'].append(epoch_x)
                self.test_acc['y'].append(test_acc_y3)

        # 绘制
        self.ax.plot(self.train_acc['x'], self.train_acc['y'], color='blue', label='train loss')
        self.ax.plot(self.train_loss['x'], self.train_loss['y'], color='red', label='train acc')
        self.ax.plot(self.test_acc['x'], self.test_acc['y'], color='green', label='test acc')

        # 图例仅在首次加载时创建
        if not self.is_init_legend:
            plt.legend()
            self.is_init_legend = True

        plt.draw()
        plt.pause(0.001)


if __name__ == '__main__':

    # 初始化
    vt = TrainVision()

    # 随便写的 训练精度、训练损失 和 测试精度
    x = [a for a in range(50)]
    train_acc_list = [a**2 for a in x]
    train_loss_list = [a**2+a*2 for a in x]
    test_acc_list = [a**2+a*4 for a in x]

    # 模拟训练
    for epochs, (train_acc, train_loss, test_acc) in enumerate(zip(train_acc_list, train_loss_list, test_acc_list)):

        print(epochs)

        # 训练耗时
        time.sleep(0.3)

        # 每个 epoch 绘制 训练精度 和 训练损失
        vt.draw(epochs, train_acc, train_loss, None)

        # 每5个 epoch 绘制 测试精度
        if (epochs+1) % 5 == 0:
            vt.draw(epochs, None, None, test_acc)

跑一下康康

你可能感兴趣的:(Python,NLP,CV,matplotlib,python,深度学习)