Seaborn本身并不是为了统计分析而生的,seaborn中的回归图主要用于添加视觉指南,以帮助在探索性数据分析EDA中强调存在于数据集的模式。
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
sns.set(style='whitegrid', color_codes=True)
plt.rc(
"figure",
autolayout=True,
figsize=(6, 5),
titlesize=18,
titleweight='bold',
)
plt.rc(
"axes",
labelweight="bold",
labelsize="large",
titleweight="bold",
titlesize=16,
titlepad=10,
)
%config InlineBackend.figure_format = 'retina'
warnings.filterwarnings('ignore')
tips = sns.load_dataset('tips',data_home='data',cache=True)
tips.head()
total_bill | tip | sex | smoker | day | time | size | |
---|---|---|---|---|---|---|---|
0 | 16.99 | 1.01 | Female | No | Sun | Dinner | 2 |
1 | 10.34 | 1.66 | Male | No | Sun | Dinner | 3 |
2 | 21.01 | 3.50 | Male | No | Sun | Dinner | 3 |
3 | 23.68 | 3.31 | Male | No | Sun | Dinner | 2 |
4 | 24.59 | 3.61 | Female | No | Sun | Dinner | 4 |
seaborn中主要通过regplot()
和lmplot()
两个函数来显示线性回归关系,两个函数之间共享核心功能。但是,了解两者的不同之处非常重要,这样就可以快速为特定工作选择正确的工具。
在原始调用中,两个函数都绘制了两个变量x
和y
,然后拟合回归模型y~x
并绘制得到回归线和该回归的95%置信区间:
sns.regplot(x='total_bill', y='tip', data=tips)
sns.lmplot(x='total_bill', y='tip', data=tips)
可以看到,除了图形形状不同,两幅图结果是完全一致的。需要了解的一个主要区别是regplot()
接受多种格式的x
和y
变量,包括简单的numpy数组,pandas的Series对象或者DataFrame对象。而lmplot()
将data作为必须参数,x
和y
变量必须被制定为字符串。这种数据格式被称为“长格式”或“整齐”数据。
当其中一个变量是离散值时,可以拟合线性回归。但是,这种数据集生成的简单散点图通常不是最优的:
sns.lmplot(x='size', y='tip', data=tips)
一种选择是向离散值添加随机噪声(“抖动”),以使这些值分布更清晰。注意抖动仅用于散点图数据,不会影响回归线:
sns.lmplot(x='size', y='tip', data=tips, x_jitter=0.05)
第二种选择是综合每个离散箱中的观测值,以绘制集中趋势的估计值和置信区间;
sns.lmplot(x='size', y='tip', data=tips, x_estimator=np.mean)
上面使用的简单线性模型对于某些类型的数据集并不适用。Anscombe
数据集展示了一些实例,其中简单线性模型回归提供了相同的关系估计,但却有肉眼可见的差异。例如,在第一种情况下,线性回归是一个很好的模型:
anscombe = sns.load_dataset('anscombe', data_home='data', cache=True)
anscombe.head()
dataset | x | y | |
---|---|---|---|
0 | I | 10.0 | 8.04 |
1 | I | 8.0 | 6.95 |
2 | I | 13.0 | 7.58 |
3 | I | 9.0 | 8.81 |
4 | I | 11.0 | 8.33 |
sns.lmplot(x='x', y='y', data=anscombe.query("dataset=='I'"), ci=None, scatter_kws={'s':80})
第二个数据集的线性关系是相同的,但是图表清楚地表明这并不是一个好模型:
sns.lmplot(x='x', y='y', data=anscombe.query("dataset=='II'"), ci=None, scatter_kws={'s':80})
在这些存在高阶关系的情况下,regplot()
和lmplot()
可以拟合多项式回归模型来探索数据集中的简单非线性趋势:
sns.lmplot(x='x', y='y', data=anscombe.query("dataset=='II'"), order=2, ci=None, scatter_kws={'s':80})
在存在异常值的情况下,拟合稳健回归可能会很有用,该回归使用了一种不同的损失函数来降低相对较大的残差的权重
sns.lmplot(x='x', y='y', data=anscombe.query("dataset=='III'"), robust=True, ci=None, scatter_kws={'s':80})
当y
变量是二进制时,在这种情况下的解决方案是拟合逻辑回归:
tips['big_tip'] = (tips.tip / tips.total_bill) > .15
sns.lmplot(x='total_bill', y='big_tip', data=tips, logistic=True, y_jitter=0.03)
一种完全不同的方法是使用lowess smoother
拟合非参数回归。尽管它是计算密集型的,这种方法的假设最少,因此目前置信区间根本没有计算:
sns.lmplot(x='total_bill', y='tip', data=tips, lowess=True)
residplot()
函数可以用作检查简单回归模型是否适合数据集的有效工具。它拟合并删除简单的线性回归,然后绘制每个观察值的残差值。理想情况下,这些值应在y=0
周围随机散步:
sns.residplot(x='x', y='y', data=anscombe.query("dataset=='I'"), scatter_kws={'s':80})
如果残差存在结构形状,则表明简单的线性回归不合适
sns.residplot(x='x', y='y', data=anscombe.query("dataset=='II'"), scatter_kws={'s':80})
上面的图显示了探索一对变量之间关系的许多方法。然而,一个更有趣的问题是“这两个变量之间的关系如何随第三个变量的变化而变化?”这就是regplot()
和lmplot()
最大的区别所在。regplot()
总是表现单一的关系,lmplot()
把regplot()
和FacetGrid
结合,以提供一个简单的界面,显示"facet"图的线性回归,使得可以探索最多三个其他分类变量的交互。
分离关系的最佳方法是在同一轴上绘制两个级别并使用颜色来区分它们:
sns.lmplot(x='total_bill', y='tip', hue='smoker', data=tips)
再添加另一个变量,可以绘制多个facet
,其中每个级别的变量出现在网格的行或列中:
sns.lmplot(x='total_bill', y='tip', hue='smoker', col='time', data=tips)
sns.lmplot(x='total_bill', y='tip', hue='smoker', col='time', row='sex', data=tips)
之前我们注意到regplot()
和lmplot()
生成的默认图看起来相同,但却有不同的大小和形状。这是因为regplot()
是一个“轴级”函数,它绘制在特定的轴上。这意味着我们可以自己制作多面板图形并精确控制回归图的位置。如果没有明确提供轴对象,它只使用“当前活动”轴,这就是默认绘图与大多数其他 matplotlib 函数具有相同大小和形状的原因。要控制大小,就需要先创建一个图形对象。
f, ax = plt.subplots(figsize=(5,6))
sns.regplot(x='total_bill', y='tip', data=tips, ax=ax)
相比之下,lmplot()
图的大小和形状是通过size
和aspect
参数控制,这些参数适用于绘图中的每个facet
,而不是整个图形本身。
sns.lmplot(x='total_bill', y='tip', col='day', data=tips, col_wrap=2, height=3)
sns.lmplot(x='total_bill', y='tip', col='day', data=tips, aspect=.5)
joinplot()
可以通过传递kind="reg"
来显示轴上的线性回归拟合。
sns.jointplot(x='total_bill', y='tip', data=tips, kind='reg')
使用pairplot()
函数与参数kind="reg"
将regplot()
和PairGrid
结合起来,来显示数据集中变量的线性关系。请注意这与lmplot()
的不同之处。在下图中,两个轴在第三变量上的两个级别上没有显示相同的关系。
sns.pairplot(tips, x_vars=['total_bill', 'size'], y_vars=['tip'], height=5, aspect=.8, kind='reg')
和lmplot()
相同,但不同于joingplot()
,额外的分类变量调节是通过hue
参数内置在pairplot()
中的。
sns.pairplot(tips, x_vars=['total_bill', 'size'], y_vars=['tip'], hue='smoker', height=5, aspect=.8, kind='reg')