为了增强可解释性,我们有时会选择使用matplotlib,这是Python中流行的绘图库,要配置matplotlib生成图形的属性,我们需要定义几个函数。
首先第一个是use_svg_display
函数,我们使用这个函数指定matplotlib输出svg图表以获得更清晰的图像。
def use_svg_display(self):
backend_inline.set_matplotlib_formats('svg')
我们定义set_figsize
函数来设置图表大小:
def set_figsize(figsize=(3.5, 2.5)): #@save
"""设置matplotlib的图表大小"""
use_svg_display()
plt.rcParams['figure.figsize'] = figsize
下面的set_axes
函数用于设置由matplotlib生成图表的轴的属性:
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
"""设置matplotlib的轴"""
axes.set_xlabel(xlabel)
axes.set_ylabel(ylabel)
axes.set_xscale(xscale)
axes.set_yscale(yscale)
axes.set_xlim(xlim)
axes.set_ylim(ylim)
if legend:
axes.legend(legend)
axes.grid()
通过这三个用于图形配置的函数,我们定义了plot函数来简洁地绘制多条曲线。
def plot(X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):
"""绘制数据点"""
if legend is None:
legend = []
set_figsize(figsize)
axes = axes if axes else d2l.plt.gca()
# 如果X有一个轴,输出True
def has_one_axis(X):
return (hasattr(X, "ndim") and X.ndim == 1 or isinstance(X, list)
and not hasattr(X[0], "__len__"))
if has_one_axis(X):
X = [X]
if Y is None:
X, Y = [[]] * len(X), X
elif has_one_axis(Y):
Y = [Y]
if len(X) != len(Y):
X = X * len(Y)
axes.cla()
for x, y, fmt in zip(X, Y, fmts):
if len(x):
axes.plot(x, y, fmt)
else:
axes.plot(y, fmt)
set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
到此,我们定义完了新的matplotlib。我们把其归结为一个类:
import math
import time
import numpy as np
import torch
from d2l import torch as d2l
import matplotlib.pyplot as plt
from matplotlib_inline import backend_inline
class Plot:
def __init__(self):
return None
# 使用svg格式显示绘图
def use_svg_display(self):
backend_inline.set_matplotlib_formats('svg')
# 设置matplotlib的图表大小
def set_figsize(self, figsize=(12,8)):
self.use_svg_display()
plt.rcParams['figure.figsize'] = figsize
# 设置matplotlib的轴
def set_axes(self, axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
axes.set_xlabel(xlabel)
axes.set_ylabel(ylabel)
axes.set_xlim(xlim)
axes.set_ylim(ylim)
axes.set_xscale(xscale)
axes.set_yscale(yscale)
if legend:
axes.legend(legend)
axes.grid()
# 定义plot函数来实现绘图功能
def plot(self, X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale='linear', yscale='linear',
fmts=('-','m--','g-','r:'), figsize=(12,8), axes=None):
# 绘制数据点
if legend is None:
legend = []
self.set_figsize(figsize)
axes = axes if axes else d2l.plt.gca()
# 如果X有一个轴,输出为True
def has_one_axis(X):
return (hasattr(X, 'ndim') and X.ndim == 1 or isinstance(X, list) and not hasattr(X[0], '__len__'))
if has_one_axis(X):
X=[X]
if Y is None:
X,Y=[[]]*len(X),X
elif has_one_axis(Y):
Y=[Y]
if len(X) != len(Y):
X = X*len(Y)
axes.cla()
for x,y,fmt in zip(X, Y, fmts):
if len(x):
axes.plot(x, y, fmt)
else:
axes.plot(y, fmt)
self.set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
我们举一个例子:
if __name__ == '__main__':
def demo_1():
# 计算正态分布
def normal(x, mu, sigma):
p = 1/math.sqrt(2*math.pi*sigma**2)
return p*np.exp(-0.5/sigma**2*(x-mu)**2)
x = np.arange(-7, 7, 0.01)
params = [(0,1), (0,2), (3,1)]
estimator = Plot()
estimator.plot(x, [normal(x,mu,sigma) for mu,sigma in params], xlabel='x', ylabel='p(x)', figsize=(12, 8), legend=[f'mean {mu},std {sigma}'for mu,sigma in params])
plt.show()
demo_1()