Matplotlib是Python数据可视化工具包,IPython为Matplotlib专门提供了特殊的交互模式。如果要在IPython控制台使用Matplotlib,可以使用ipython–matplotlib命令来启动IPython控制台程序;如果要在IPython notebook里使用Matplotlib,则在notebook的开始位置插入%matplotlib inline魔术命令即可。IPython的Matplotlib模式有两个优点,一是提供了非阻塞的画图操作,而是不需要显示地调用show()方法来显示画出来的图片。
Matplotlib下的pyplot子包提供了面向对象的画图程序接口,几乎所有的画图函数都与MATLAB类似,连参数都类似。在实际开发工作中,有时候甚至可以访问MATLAB的官方文档 cn.mathworks.com/help/matlab 来查询画图的接口和参数,这些参数可以直接在pyplot下的画图函数里使用。使用pyplot的习惯性写法是:
from matplotlib import pyplot as plt
在机器学习领域中,我们经常需要把数据可视化,以便观察数据的模式。此外,在对算法性能进行评估时,也需要把模型相关的数据可视化,才能观察出模型里需要改进的地方。例如,我们把算法的准确度和训练数据集大小的变化曲线画出来,可以清晰地看出训练数据集大小与算法准确度的关系。这就是我们需要学习Matplotlib的原因。
通常使用IPython notebook的Matplotlib模式来画图,这样画出来的图片会直接显示在网页上。要记得在notebook的最上面写上魔术命令%matplotlib inline。
使用Matplotlib的默认样式在一个坐标轴上画出正弦和余弦曲线。
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
x=np.linspace(-np.pi,np.pi,200)
C,S=np.cos(x),np.sin(x)
plt.plot(x,C)
plt.plot(x,S)
plt.show()
接着,通过修改Matplotlib的默认样式,画出我们需要的样式图片。
(1)把正弦曲线的线条加粗,并且定制合适的颜色
# 画出余弦曲线,并设置线条颜色、宽度、样式
cos,=plt.plot(x,C,color="blue",linewidth=2.0,linestyle="-")
# 画出正弦曲线,并设置线条样色,宽度,样式
sin,=plt.plot(x,S,color="red",linewidth=2.0,linestyle="-")
(2)设置坐标轴的长度
# 设置坐标轴的长度
plt.xlim(x.min()*1.1,x.max()*1.1)
plt.ylim(C.min()*1.1,C.max()*1.1)
(3)重新设置坐标轴的刻度,X轴的刻度使用自定义的标签,标签的文本使用了LaTeX来显示圆周率符号π。
# 设置坐标轴的刻度和标签
plt.xticks((-np.pi,-np.pi/2,np.pi/2,np.pi),
(r'$-\pi$',r'$-\pi/2$',r'$+\pi/2$',r'$+\pi$'))
plt.yticks([-1,-0.5,0,0.5,1])
(4)把左侧图片中的4个方向的坐标轴改为两个方向的交叉坐标轴。方法是通过设置颜色为透明色,把上方和右侧的坐标边线隐藏起来。然后移动左侧和下方的坐标轴边线到原点(0,0)的位置。
# 坐标轴总共有4个连线,我们通过设置透明色隐藏上方的右方的边线
# 通过set_position()移动左侧和下侧的边线
# 通过set_ticks_position()设置坐标轴的刻度线的显示位置
ax=plt.gca() # gca代表当前坐标轴,即'get current axis'
ax.spines['right'].set_color('none') # 隐藏坐标轴
ax.spines['top'].set_color('none')
ax.xaxis.set_ticks_position('bottom') # 设置刻度线的显示位置
ax.spines['bottom'].set_position('center') # 设置下方坐标轴的位置
ax.yaxis.set_ticks_position('left')
ax.spines['left'].set_position('center')
(5)在图片左上角添加一个图例,用来标识图片中的正弦曲线和余弦曲线。
# 在左上角添加图例
plt.legend([cos,sin],["cos(x)", "sin(x)"],loc='upper left')
(6)在图片中标识处cos(2π/3)=-1/2,不但把这个公式画到图片上,还要再余弦曲线上标识出这个点,同时用虚线画出这个点对应的X轴的坐标。
t=2*np.pi/3
# 画出cos(t)所在的点在X轴上的位置,使用虚线画出(t,0)-->(t,cos(t))线段
plt.plot([t,t],[0,np.cos(t)],color='blue',linewidth=1.5,linestyle='--')
# 画出标识的坐标点,在(t,cos(t))处画一个大小为50的蓝色点
plt.scatter([t,],[np.cos(t),],50,color='blue')
# 画出标识点的值,cos(t)
plt.annotate(r'$cos(\frac{2\pi}{3})=-\frac{1}{2}$',
xy=(t,np.cos(t)),
xycoords='data',
xytext=(-90,-50),
textcoords='offset points',
fontsize=16,
arrowprops=dict(arrowstyle="->",connectionstyle="arc3,rad=.2"),
)
其中,plt.annotate()函数的功能是在图片上画出标识文本,其文本内容也是使用LaTeX公式书写,这个函数参数众多,具体可以参阅官方的API说明文档。
使用相同的方法也可以在正弦曲线上标识出一个点。
plt.plot([t,t],[0,np.sin(t)],color='red',linewidth=1.5,linestyle='--')
plt.scatter([t,],[np.sin(t),],50,color='red')
plt.annotate(r'$sin(\frac{2\pi}{3})=-\frac{\sqrt{3}}{2}$',
xy=(t,np.sin(t)),
xycoords='data',
xytext=(20,20),
textcoords='offset points',
fontsize=16,
arrowprops=dict(arrowstyle="->",connectionstyle="arc3,rad=.2"),
)
(7)定制坐标轴上的刻度标签的字体,同时为了避免正弦曲线覆盖掉刻度标识,在刻度标签上添加一个半透明的方框作为背景。
# 设置坐标刻度的字体大小,添加半透明背景
for label in ax.get_xticklabels()+ax.get_yticklabels():
label.set_fontsize(16)
label.set_bbox(dict(facecolor='white',edgecolor='None',alpha=0.65))
plt.show()
这样就完成了一个Matplotlib样式配置的过程,把默认的样式修改成我们需要的样式。
在Matplotlib里,一个图形figure是指图片的全部可视区域,可以使用plt.figure()来创建。在一个图形里,可以包含多个子图subplot,可以使用plt.subplot()来创建子图。子图按照网格形状排列显示在图形里,可以在每个子图上单独作画。坐标轴Axes和子图类似,唯一不同的是,坐标轴可以在图形上任意摆放,而不需要按照网格排列,这样显示起来更灵活,可以使用plt.axes()来创建坐标轴。
当我们使用默认配置进行作画时,Matplotlib调用plt.gca()函数来获取当前的坐标轴,并在当前坐标轴上作画。plt.gca()函数调用plt.gcf()函数来获取当前当前图形对象,如果当前不存在图形对象,则会调用plt.figure()函数创建一个图形对象。
plt.figure()函数有以下几个常用的参数:
下面的代码创建了两个图形,一个是’sin’,并且把正弦曲线画在这个图形上,另一个是’cos’,并把余弦曲线画在这个图形上。接着切换到之前创建的’sin’图形上,把余弦曲线也画在这个图形上。
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
X=np.linspace(-np.pi,np.pi,200,endpoint=True)
C,S=np.cos(X),np.sin(X)
plt.figure(num='sin',figsize=(16,4))
plt.plot(X,S)
plt.figure(num='cos',figsize=(16,4))
plt.plot(X,C)
plt.figure(num='sin')
plt.plot(X,C)
print(plt.figure(num='sin').number)
print(plt.figure(num='cos').number)
不同的图形可以单独保存为一个图片文件,但是子图是指一个图形里分成几个区域,在不同的区域里单独作画,所有的子图最终都保存在一个文件里。plt.subplot()函数的关键参数是一个包含3个元素的元组,分别代表子图的行,列及当前激活的子图序号。比如plt.subplot(2,2,1)表示把图形对象分成两行两列,激活第一个子图来作画。我们看一个网格状的子图的例子。
%matplotlib inline
from matplotlib import pyplot as plt
plt.figure(figsize=(18,4))
plt.subplot(2,2,1)
plt.xticks(())
plt.yticks(())
plt.text(0.5,0.5,'subplot(2,2,1)',ha='center',va='center',size=20,alpha=0.5)
plt.subplot(2,2,2)
plt.xticks(())
plt.yticks(())
plt.text(0.5,0.5,'subplot(2,2,2)',ha='center',va='center',size=20,alpha=0.5)
plt.subplot(2,2,3)
plt.xticks(())
plt.yticks(())
plt.text(0.5,0.5,'subplot(2,2,3)',ha='center',va='center',size=20,alpha=0.5)
plt.subplot(2,2,4)
plt.xticks(())
plt.yticks(())
plt.text(0.5,0.5,'subplot(2,2,1)',ha='center',va='center',size=20,alpha=0.5)
plt.tight_layout()
plt.show()
更复杂的子图布局,可以使用gridspec来实现,其优点是可以指定某个子图横跨多个列或者多个行。
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
plt.figure(figsize=(18,4))
G=gridspec.GridSpec(3,3)
axes_1=plt.subplot(G[0,:])
plt.xticks(())
plt.yticks(())
plt.text(0.5,0.5,'Axes 1',ha='center',va='center',size=24,alpha=0.5)
axes_2=plt.subplot(G[1:,0])
plt.xticks(())
plt.yticks(())
plt.text(0.5,0.5,'Axes 2',ha='center',va='center',size=24,alpha=0.5)
axes_3=plt.subplot(G[1:,-1])
plt.xticks(())
plt.yticks(())
plt.text(0.5,0.5,'Axes 3',ha='center',va='center',size=24,alpha=0.5)
axes_4=plt.subplot(G[1,-2])
plt.xticks(())
plt.yticks(())
plt.text(0.5,0.5,'Axes 4',ha='center',va='center',size=24,alpha=0.5)
axes_5=plt.subplot(G[-1,-2])
plt.xticks(())
plt.yticks(())
plt.text(0.5,0.5,'Axes 5',ha='center',va='center',size=24,alpha=0.5)
plt.tight_layout()
plt.show()
坐标轴使用plt.axes()来创建,它用一个矩形来给坐标轴定位,矩形使用[left,bottom,width,height]来表达。其数据为图形对象对应坐标轴长度的百分比。
%matplotlib inline
from matplotlib import pyplot as plt
plt.figure(figsize=(18,4))
plt.axes([0.1,0.1,0.8,0.8])
plt.xticks(())
plt.yticks(())
plt.text(0.2,0.5,'axes[0.1,0.1,0.8,0.8]',ha='center',va='center',size=20,alpha=0.5)
plt.axes([0.5,0.5,0.3,0.3])
plt.xticks(())
plt.yticks(())
plt.text(0.5,0.5,'axes[0.5,0.5,0.3,0.3]',ha='center',va='center',size=20,alpha=0.5)
plt.show()
一个优美而恰当的坐标轴刻度对理解数据异常重要,Matplotlib内置提供了以下几个坐标轴刻度。
除了内置标签外,我们也可以继承matplotlib.ticker.Locater类来实现自定义样式的刻度标签。
通过下面的代码把内置坐标刻度全部画出来,可以直观地观察到内置坐标刻度的样式。
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
def tickline():
plt.xlim(0,10)
plt.ylim(-1,1)
plt.yticks([])
ax = plt.gca()
ax.spines['right'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['top'].set_color('none')
ax.xaxis.set_ticks_position('bottom')
ax.spines['bottom'].set_position(('data',0))
ax.yaxis.set_ticks_position('none')
ax.xaxis.set_major_locator(plt.MultipleLocator(0.1))
for label in ax.get_xticklabels()+ax.get_yticklabels():
label.set_fontsize(16)
ax.plot(np.arange(11),np.zeros(11))
return ax
locators = [
'plt.NullLocator()',
'plt.MultipleLocator(base=1.0)',
'plt.FixedLocator(locs=[0,2,8,9,10])',
'plt.IndexLocator(base=3,offset=1)',
'plt.LinearLocator(numticks=5)',
'plt.LogLocator(base=2,subs=[1.0])',
'plt.MaxNLocator(nbins=3,steps=[1,3,5,7,9,10])',
'plt.AutoLocator()'
]
n_locators = len(locators)
size = 1024, 60*n_locators
dpi = 72.0
figsize = size[0]/float(dpi), size[1]/float(dpi)
fig = plt.figure(figsize=figsize,dpi=dpi)
fig.patch.set_alpha(0)
for i,locator in enumerate(locators):
plt.subplot(n_locators,1,i+1)
ax = tickline()
ax.xaxis.set_major_locator(eval(locator))
plt.text(5,0.3,locator[3:],ha='center',size=16)
plt.subplots_adjust(bottom=0.01,top=0.99,left=0.01,right=0.99)
plt.show()
本节通过一系列的例子,来演示Matplotlib的画图操作
例:画散点图(使用plt.scatter()函数)
n = 1024
X = np.random.normal(0,1,n)
Y = np.random.normal(0,1,n)
T = np.arctan2(Y,X)
plt.scatter(X,Y,s=75,c=T,alpha=0.5)
plt.xlim(-1.5,1.5)
plt.xticks(())
plt.ylim(-1.5,1.5)
plt.yticks(())
例:图形填充(使用plt.fill_between()函数)
n = 256
X = np.linspace(-np.pi,np.pi,n,endpoint=True)
Y = np.sin(2*X)
plt.plot(X,Y+1,color='blue',alpha=1.00)
plt.fill_between(X,1,Y+1,color='blue',alpha=0.25)
plt.plot(X,Y-1,color='blue',alpha=1.00)
plt.fill_between(X,-1,Y-1,(Y-1)>-1,color='blue',alpha=0.25)
plt.fill_between(X,-1,Y-1,(Y-1)<-1,color='red',alpha=0.25)
plt.xlim(-np.pi,np.pi)
plt.xticks(())
plt.ylim(-2.5,2.5)
plt.yticks(())
n = 12
X = np.arange(n)
Y1 = (1-X/float(n))*np.random.uniform(0.5,1.0,n)
Y2 = (1-X/float(n))*np.random.uniform(0.5,1.0,n)
plt.bar(X,+Y1,facecolor='#9999ff',edgecolor='white')
plt.bar(X,-Y2,facecolor='#ff9999',edgecolor='white')
for x,y in zip(X,Y1):
plt.text(x+0.4,y+0.05,'%.2f'%y,ha='center',va='center')
for x,y in zip(X,Y2):
plt.text(x+0.4,-y-0.05,'%.2f'%y,ha='center',va='center')
plt.xlim(-0.5,n)
plt.xticks(())
plt.ylim(-1.25,1.25)
plt.yticks(())
def f(x,y):
return (1-x/2+x**5+y**3)*np.exp(-x**2,-y**2)
n = 256
x = np.linspace(-3,3,n)
y = np.linspace(-3,3,n)
X,Y = np.meshgrid(x,y)
plt.contourf(X,Y,f(X,Y),8,alpha=0.75,cmap=plt.cm.hot)
C = plt.contourf(X,Y,f(X,Y),8,color='black',linewidth=0.5)
plt.clabel(C,inline=1,fontsize=10)
plt.xticks(())
plt.yticks(())
例:画热成像图(使用plt.imshow()和plt.colorbar()函数)
def f(x,y):
return (1-x/2+x**5+y**3)*np.exp(-x**2,-y**2)
n = 10
x = np.linspace(-3,3,4*n)
y = np.linspace(-3,3,3*n)
X,Y = np.meshgrid(x,y)
plt.imshow(f(X,Y),cmap='hot',origin='low')
plt.colorbar(shrink=0.83)
plt.xticks(())
plt.yticks(())
n = 20
Z = np.ones(n)
Z[-1] *= 2
plt.pie(Z,explode=Z*0.05,colors=['%f'%(i/float(n)) for i in range(n)])
plt.axis('equal')
plt.xticks(())
plt.yticks(())
ax = plt.gca()
ax.set_xlim(0,4)
ax.set_ylim(0,3)
ax.xaxis.set_major_locator(plt.MultipleLocator(1.0))
ax.xaxis.set_minor_locator(plt.MultipleLocator(0.1))
ax.yaxis.set_major_locator(plt.MultipleLocator(1.0))
ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))
ax.grid(which='major',axis='x',linewidth=0.75,linestyle='-',color='0.75')
ax.grid(which='minor',axis='x',linewidth=0.25,linestyle='-',color='0.75')
ax.grid(which='major',axis='y',linewidth=0.75,linestyle='-',color='0.75')
ax.grid(which='minor',axis='y',linewidth=0.25,linestyle='-',color='0.75')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax = plt.subplot(1,1,1, projection='polar')
N = 20
theta = np.arange(0.0,2*np.pi,2*np.pi/N)
radii = 10 * np.random.rand(N)
width = np.pi / 4 * np.random.rand(N)
bars = plt.bar(theta,radii,width=width,bottom=0.0)
for r,bar in zip(radii,bars):
bar.set_facecolor(plt.cm.jet(r/10.0))
bar.set_alpha(0.5)
ax.set_xticklabels([])
ax.set_yticklabels([])
本节只简要介绍了Matplotlib的入门知识,关于Matplotlib的更多细节,可以参考官方网站 matplotlib.org 上的介绍。