【Matplotlib】在Jupyter交互页面中绘制折线图对比(自用函数)

0x00 前言

最近数据对比的任务比较常见,比如好些模型的横向对比,
对于 Loss、PRF、Hits 之类的数据,有时需要作log,有时需要去除前面几个值,
还要考虑数据不对齐、记录文件格式不一致等诸多问题,总之主需求是鲁棒性,
稍微写了个画表格的函数,暂时还比较乱,后续有时间再作优化好了,
现在暂时写在这作为记录,便于易于在不同机器上获取使用以及方便后续优化更新。

0x01 用法

# single
draw(y, y_label='Data', title='undefined title', semilog=False)

# multiple
drawx(y_list, label_list=None, title='undefined title', 
          start=0, end=None, x_list=None, x_label=None, semilog=False, axh=0)

0x02 Source Code


import matplotlib.pyplot as plt
from pylab import *  # 支持中文
import warnings
warnings.filterwarnings('ignore')
mpl.rcParams['font.sans-serif'] = ['SimHei']

#plt.plot(x, y, 'ro-')
#plt.plot(x, y1, 'bo-')
#pl.xlim(-1, 11)  # 限定横轴的范围
#pl.ylim(-1, 110)  # 限定纵轴的范围

def focus(filename, label='[Valid_F1_G]', pos=1, delim='\t'):
    # used for logs like '[Valid_F1_G]\t0.667\t0.667\t0.667'
    target = filter(
        lambda x: x.startswith(label),
        [line for line in open(filename, 'r')]
    )
    return map(
        lambda x: x.split(delim)[pos].strip(),
        target
    )


def draw(y, y_label='Data', title='undefined title', semilog=False):
    x = range(1, y.__len__()+1)
    plt.plot(x, y, marker='o', mec='r', mfc='w', label=y_label)
    plt.legend()  # 让图例生效
    #plt.xticks(x, x, rotation=0)
    plt.margins(0.05)
    plt.subplots_adjust(bottom=0.15)
    plt.xlabel(u"Epoch") # X轴标签
    plt.title("Brief Figure for {}".format(title)) #标题
    if semilog: plt.semilogy()
    plt.show()

    
def draw2(y1, y2, 
          y1_label='Data1', y2_label='Data2', 
          semilog=False, title='undefined title', start=0, end=9999, axh=0):
    length = min(end, max(y1.__len__(), y2.__len__()))
    y1_len = min(length, y1.__len__())
    y2_len = min(length, y2.__len__())
    x = range(1, length+1)
    plt.figure(figsize=(15, 5))
    plt.plot(x[start:y1_len], y1[start:y1_len], marker='o', mec='r', mfc='w', label=y1_label)
    plt.plot(x[start:y2_len], y2[start:y2_len], marker='X', mfc='w', ms=8, label=y2_label)
    plt.legend()  # 让图例生效
    # plt.xticks(x1, x1, rotation=0)
    plt.margins(0.05)
    plt.subplots_adjust(bottom=0.15)
    plt.xlabel(u"Epoch") #X轴标签
    plt.title("Brief Figure for {}".format(title)) #标题
    if axh: plt.axhline(axh)
    if semilog: plt.semilogy()
    plt.show()

    
def draw4(y1, y2, y3, y4,
          y1_label, y2_label, y3_label, y4_label,
          title = 'title', start=0, semilog=False, axh=0):
    length = min(y1.__len__(), y2.__len__(), y3.__len__(), y4.__len__())
    x = range(1, length+1)
    plt.figure(figsize=(15, 5))
    plt.plot(x[start:], y1[start:length], marker='o', mec='r', mfc='w', label=y1_label)
    plt.plot(x[start:], y2[start:length], marker='X', mfc='w', ms=8, label=y2_label)
    plt.plot(x[start:], y3[start:length], marker='*', mfc='w', mec='b', label=y3_label)
    plt.plot(x[start:], y4[start:length], marker='.', mfc='w', label=y4_label)
    plt.legend()  # 让图例生效
    plt.margins(0.05)
    plt.subplots_adjust(bottom=0.15)
    plt.xlabel(u"Epoch") #X轴标签
    plt.title("Brief Figure for {}".format(title)) #标题
    if axh: plt.axhline(axh)
    if semilog: plt.semilogy()
    plt.grid(False)
    plt.show()

    
def drawx(y_list, label_list=None, title='undefined title', 
          start=0, end=None, x_list=None, x_label=None, semilog=False, axh=0):
    length = min(end, max(map(len, y_list))) if end else max(map(len, y_list))
    y_length = map(lambda x: min(length, x.__len__()), y_list)
    x_list = range(1, length+1) if x_list is None else x_list[:length]
    plt.figure(figsize=(15, 5))
    if label_list is None:
        label_list = map(
            lambda x: u"Data_{}".format(x), 
            range(y_list.__len__()))
    for idx, label in enumerate(label_list):
        plt.plot(
            x_list[start: y_length[idx]], 
            y_list[idx][start: y_length[idx]], 
            marker='o',
            label=label)
    plt.legend()  # let legends work
    plt.margins(0.05)
    plt.subplots_adjust(bottom=0.15)
    plt.xlabel(u"Epoch" if x_label is None else x_label)
    plt.title(u"Brief Figure for {}".format(title)) #标题
    if axh:  # Add a horizontal line across the axis.
        plt.axhline(axh)
    if semilog: 
        plt.semilogy()
        
    plt.grid(False)
    plt.show()
    ```

你可能感兴趣的:(模板记忆,DIY)