最近数据对比的任务比较常见,比如好些模型的横向对比,
对于 Loss、PRF、Hits 之类的数据,有时需要作log,有时需要去除前面几个值,
还要考虑数据不对齐、记录文件格式不一致等诸多问题,总之主需求是鲁棒性,
稍微写了个画表格的函数,暂时还比较乱,后续有时间再作优化好了,
现在暂时写在这作为记录,便于易于在不同机器上获取使用以及方便后续优化更新。
# single
draw(y, y_label='Data', title='undefined title', semilog=False)
# multiple
drawx(y_list, label_list=None, title='undefined title',
start=0, end=None, x_list=None, x_label=None, semilog=False, axh=0)
import matplotlib.pyplot as plt
from pylab import * # 支持中文
import warnings
warnings.filterwarnings('ignore')
mpl.rcParams['font.sans-serif'] = ['SimHei']
#plt.plot(x, y, 'ro-')
#plt.plot(x, y1, 'bo-')
#pl.xlim(-1, 11) # 限定横轴的范围
#pl.ylim(-1, 110) # 限定纵轴的范围
def focus(filename, label='[Valid_F1_G]', pos=1, delim='\t'):
# used for logs like '[Valid_F1_G]\t0.667\t0.667\t0.667'
target = filter(
lambda x: x.startswith(label),
[line for line in open(filename, 'r')]
)
return map(
lambda x: x.split(delim)[pos].strip(),
target
)
def draw(y, y_label='Data', title='undefined title', semilog=False):
x = range(1, y.__len__()+1)
plt.plot(x, y, marker='o', mec='r', mfc='w', label=y_label)
plt.legend() # 让图例生效
#plt.xticks(x, x, rotation=0)
plt.margins(0.05)
plt.subplots_adjust(bottom=0.15)
plt.xlabel(u"Epoch") # X轴标签
plt.title("Brief Figure for {}".format(title)) #标题
if semilog: plt.semilogy()
plt.show()
def draw2(y1, y2,
y1_label='Data1', y2_label='Data2',
semilog=False, title='undefined title', start=0, end=9999, axh=0):
length = min(end, max(y1.__len__(), y2.__len__()))
y1_len = min(length, y1.__len__())
y2_len = min(length, y2.__len__())
x = range(1, length+1)
plt.figure(figsize=(15, 5))
plt.plot(x[start:y1_len], y1[start:y1_len], marker='o', mec='r', mfc='w', label=y1_label)
plt.plot(x[start:y2_len], y2[start:y2_len], marker='X', mfc='w', ms=8, label=y2_label)
plt.legend() # 让图例生效
# plt.xticks(x1, x1, rotation=0)
plt.margins(0.05)
plt.subplots_adjust(bottom=0.15)
plt.xlabel(u"Epoch") #X轴标签
plt.title("Brief Figure for {}".format(title)) #标题
if axh: plt.axhline(axh)
if semilog: plt.semilogy()
plt.show()
def draw4(y1, y2, y3, y4,
y1_label, y2_label, y3_label, y4_label,
title = 'title', start=0, semilog=False, axh=0):
length = min(y1.__len__(), y2.__len__(), y3.__len__(), y4.__len__())
x = range(1, length+1)
plt.figure(figsize=(15, 5))
plt.plot(x[start:], y1[start:length], marker='o', mec='r', mfc='w', label=y1_label)
plt.plot(x[start:], y2[start:length], marker='X', mfc='w', ms=8, label=y2_label)
plt.plot(x[start:], y3[start:length], marker='*', mfc='w', mec='b', label=y3_label)
plt.plot(x[start:], y4[start:length], marker='.', mfc='w', label=y4_label)
plt.legend() # 让图例生效
plt.margins(0.05)
plt.subplots_adjust(bottom=0.15)
plt.xlabel(u"Epoch") #X轴标签
plt.title("Brief Figure for {}".format(title)) #标题
if axh: plt.axhline(axh)
if semilog: plt.semilogy()
plt.grid(False)
plt.show()
def drawx(y_list, label_list=None, title='undefined title',
start=0, end=None, x_list=None, x_label=None, semilog=False, axh=0):
length = min(end, max(map(len, y_list))) if end else max(map(len, y_list))
y_length = map(lambda x: min(length, x.__len__()), y_list)
x_list = range(1, length+1) if x_list is None else x_list[:length]
plt.figure(figsize=(15, 5))
if label_list is None:
label_list = map(
lambda x: u"Data_{}".format(x),
range(y_list.__len__()))
for idx, label in enumerate(label_list):
plt.plot(
x_list[start: y_length[idx]],
y_list[idx][start: y_length[idx]],
marker='o',
label=label)
plt.legend() # let legends work
plt.margins(0.05)
plt.subplots_adjust(bottom=0.15)
plt.xlabel(u"Epoch" if x_label is None else x_label)
plt.title(u"Brief Figure for {}".format(title)) #标题
if axh: # Add a horizontal line across the axis.
plt.axhline(axh)
if semilog:
plt.semilogy()
plt.grid(False)
plt.show()
```