主要讲述绘制强化学习结果时遇到的seaborn操作。因此,本文主要讲述Lineplot的用法,以及图片的相关设置
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
# 单线绘制
data = pd.DataFrame({ "timestep":np.arange(100), "sr": np.arange(100)/100+0.1*np.random.randn(100)})
sns.lineplot(x="timestep",y="sr",data=data)
plt.show()
# 含有方差的单线绘制
data1 = np.arange(100)/100+0.1*np.random.randn(100)
data2 = np.arange(100)/100+0.1*np.random.randn(100)
data3 = np.arange(100)/100+0.1*np.random.randn(100)
data = pd.DataFrame({ "timestep":np.concatenate([np.arange(100) for i in range(3)],-1),
"sr": np.concatenate([data1,data2,data3])})
sns.lineplot(x="timestep",y="sr",data=data)
plt.show()
也就是说, 对于lineplot的输出,都是一维的向量。 对于多次的结果,我们需要将他们扁平化为一维数据,然后将数据的timesteps标注好。
import seaborn as sns
sns.lineplot(x, # 横轴的标签, RL中一般为 timesteps, episodes, epochs
y, # 纵轴的标签
data, # DataFrame object, 从此处调用数据
style, # 相同style标签的线 会有相同的颜色和线形表示
ci, # 表示是否设置方差 None, "sd" 和 默认的置信区间
hue, # 通过hue 可以设置多条线,并且设置legend
)
更具体的lineplot用法可以参照数据可视化(3)-Seaborn系列 | 折线图lineplot()
因为seaborn.lineplot (也就是relplot(kind=“line”)) 仅支持pandas.DataFrame类型,因此,在对含有方差的线条,以及多条含有方差的线条进行绘制之前,需要将数据配置进DataFrame中。此处参考了使用seaborn绘制强化学习中的图片。
此处重点讲述 如何在一个图中画出多条线来
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
return1 = np.array([(np.arange(100)/100+0.1*np.random.randn(100)).clip(0,1) for i in range(3)])
return2 = np.array([ (1- np.exp(-np.arange(100)/100)+0.1*np.random.randn(100)).clip(0,1) for i in range(3)])
data = [return1,return2]
label = ["algo1","algo2"]
df = []
# ------------------- 重点 ----------------------------------
for i in range(2):
# 此处 melt是将 (3,100)的结果压缩成 (300,)的结果,同时index标签是episodes, value的标签设置为 return
df_elem = pd.DataFrame(data[i]).melt(var_name="episodes",value_name="return")
# 将数据添加一行标签,都设置为 algoi
df_elem["algo"]=label[i]
df.append(df_elem)
# 将 list中的DF 合并成一个df, 之后df 有 episodes, return, algo 三个属性
# episodes用于区分第i episode,表示步长
# return表示值
# algo表示数据的类别,用于分成几条线,也用于区分线的类型
df = pd.concat(df)
# ------------------- 重点 ----------------------------------
# df.head()
sns.lineplot(x="episodes",y="return",data=df,
hue="algo", # 用于将数据分成几条线
style="algo", #用于表示数据的线形和颜色
marker=False, # 可以设置中心点是否显示
)
plt.show()
前两部分保证可以正确的画出想要的曲线,该部分总结如何设置图像细节。
图形风格控制 分为两部分, 一部分是背景控制,另一方面是样式控制。
# 在jupyter notebook中需要在最开始设置 :
sns.set()
# 1. 背景控制
# context = paper, talk, poster, notebook
#设置字体大小,边线
# rc = {"lines.linewidth": 2.5}
sns.set_context(context=None, font_scale=1, rc=None)
# 2. 样式控制
sns.set_style(style)
# style = darkgrid,whitegrid,dark,white,ticks
# 默认样式
sns.axes_style()
{'axes.axisbelow': True,
'axes.edgecolor': 'white',
'axes.facecolor': '#EAEAF2',
'axes.grid': True,
'axes.labelcolor': '.15',
'axes.linewidth': 0.0,
'figure.facecolor': 'white',
'font.family': ['sans-serif'],
'font.sans-serif': ['Arial',
'Liberation Sans',
'Bitstream Vera Sans',
'sans-serif'],
'grid.color': 'white',
'grid.linestyle': '-',
'image.cmap': 'Greys',
'legend.frameon': False,
'legend.numpoints': 1,
'legend.scatterpoints': 1,
'lines.solid_capstyle': 'round',
'text.color': '.15',
'xtick.color': '.15',
'xtick.direction': 'out',
'xtick.major.size': 0.0,
'xtick.minor.size': 0.0,
'ytick.color': '.15',
'ytick.direction': 'out',
'ytick.major.size': 0.0,
'ytick.minor.size': 0.0}
# 设置样式
sns.set_style({"axes.facecolor": ".9"})
sns.despine()
# 4. 颜色设置
# 可以通过 color_palette(RGB) 或者 hls_palette(HLS) 设置颜色
sns.palplot(sns.hls_palette(n_colors=8))
# 具体见下一节
# 设置图片大小
plt.figure(dpi=300, figsize=(6, 4))
# 设置标签
plt.xlabel("x label")
ax.set_ylabel('Y Label',fontsize=15, color='r')
# 设置刻度
plt.xticks( rotation=60) # rotation: degree
ax.tick_params(axis='y',labelsize=8)
plt.title() # 用于设置标题
# legend 设置
legend = ax.legend() # ax 从 lineplot获取
legend.texts[0].set_text("Whatever else") # 设置图例的名字
## 移除图例的标题
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles[1:], labels=labels[1:])
## 移除图例
sns.lineplot(legend=False)
plt.legend(labels=[] , loc="upper right", bbox_to_anchor = (1,1), ncol=1)
# 自己设置图例 , bbox_to_anchor 可以将图例放在图外
# 保存图片
plt.savefig(name, dpi=400, bbox_inches='tight')
参考文章
数据可视化Seaborn从零开始学习教程(二) 颜色调控篇
Seaborn(sns)官方文档学习笔记(第二章 斑驳陆离的调色板)
sns.color_palette() # 接受seaborn, matplotlib中的颜色名称, 也接受RGB,HEX颜色代码
sns.set_palette() # 用于设置调色板
调色板分为 分类色板, 连续色板, 离散色板。
对于没有相关性的数据,我们一般就使用分类色板。
# 当前色板
current_palette = sns.color_palette()
sns.palplot(current_palette)
# 有六种主题
# themes = ['deep', 'muted', 'pastel', 'bright', 'dark', 'colorblind']
sns.color_palette(every_theme) # 都可以查看结果
当需要的颜色超过想要的,或者想要修改亮度,饱和度
sns.palplot(sns.hls_palette(8, l=0.3,s=0.8))
sns.palplot(sns.color_palette("husl",8))
另外,也可以从Color Brewer中调用颜色表 (循环的颜色表)
sns.palplot(sns.color_palette("Paired"))
sns.palplot(sns.color_palette("Set2", 10))
在jupyter notebook中可以进行交互
sns.choose_colorbrewer_palette("qualitative")