【自定义函数】配置matplotlib生成图形的属性(含源代码)

为了增强可解释性,我们有时会选择使用matplotlib,这是Python中流行的绘图库,要配置matplotlib生成图形的属性,我们需要定义几个函数。

首先第一个是use_svg_display函数,我们使用这个函数指定matplotlib输出svg图表以获得更清晰的图像。

def use_svg_display(self):
    backend_inline.set_matplotlib_formats('svg')

我们定义set_figsize函数来设置图表大小:

def set_figsize(figsize=(3.5, 2.5)):  #@save
    """设置matplotlib的图表大小"""
    use_svg_display()
    plt.rcParams['figure.figsize'] = figsize

下面的set_axes函数用于设置由matplotlib生成图表的轴的属性:

def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
    """设置matplotlib的轴"""
    axes.set_xlabel(xlabel)
    axes.set_ylabel(ylabel)
    axes.set_xscale(xscale)
    axes.set_yscale(yscale)
    axes.set_xlim(xlim)
    axes.set_ylim(ylim)
    if legend:
        axes.legend(legend)
    axes.grid()

通过这三个用于图形配置的函数,我们定义了plot函数来简洁地绘制多条曲线。

def plot(X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None,
         ylim=None, xscale='linear', yscale='linear',
         fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):
    """绘制数据点"""
    if legend is None:
        legend = []

    set_figsize(figsize)
    axes = axes if axes else d2l.plt.gca()

    # 如果X有一个轴,输出True
    def has_one_axis(X):
        return (hasattr(X, "ndim") and X.ndim == 1 or isinstance(X, list)
                and not hasattr(X[0], "__len__"))

    if has_one_axis(X):
        X = [X]
    if Y is None:
        X, Y = [[]] * len(X), X
    elif has_one_axis(Y):
        Y = [Y]
    if len(X) != len(Y):
        X = X * len(Y)
    axes.cla()
    for x, y, fmt in zip(X, Y, fmts):
        if len(x):
            axes.plot(x, y, fmt)
        else:
            axes.plot(y, fmt)
    set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)

到此,我们定义完了新的matplotlib。我们把其归结为一个类:

import math
import time
import numpy as np
import torch
from d2l import torch as d2l
import matplotlib.pyplot as plt
from matplotlib_inline import backend_inline

class Plot:
    def __init__(self):
        return None
    
    # 使用svg格式显示绘图
    def use_svg_display(self):
        backend_inline.set_matplotlib_formats('svg')

    # 设置matplotlib的图表大小
    def set_figsize(self, figsize=(12,8)):
        self.use_svg_display()
        plt.rcParams['figure.figsize'] = figsize

    # 设置matplotlib的轴
    def set_axes(self, axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
        axes.set_xlabel(xlabel)
        axes.set_ylabel(ylabel)
        axes.set_xlim(xlim)
        axes.set_ylim(ylim)
        axes.set_xscale(xscale)
        axes.set_yscale(yscale)
        if legend:
            axes.legend(legend)
        axes.grid()

    # 定义plot函数来实现绘图功能
    def plot(self, X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale='linear', yscale='linear',
             fmts=('-','m--','g-','r:'), figsize=(12,8), axes=None):
        # 绘制数据点
        if legend is None:
            legend = []
        self.set_figsize(figsize)
        axes = axes if axes else d2l.plt.gca()

        # 如果X有一个轴,输出为True
        def has_one_axis(X):
            return (hasattr(X, 'ndim') and X.ndim == 1 or isinstance(X, list) and not hasattr(X[0], '__len__'))
        if has_one_axis(X):
            X=[X]
        if Y is None:
            X,Y=[[]]*len(X),X
        elif has_one_axis(Y):
            Y=[Y]
        if len(X) != len(Y):
            X = X*len(Y)
        axes.cla()
        for x,y,fmt in zip(X, Y, fmts):
            if len(x):
                axes.plot(x, y, fmt)
            else:
                axes.plot(y, fmt)
        self.set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)

我们举一个例子:

if __name__ == '__main__':
    def demo_1():
        # 计算正态分布
        def normal(x, mu, sigma):
            p = 1/math.sqrt(2*math.pi*sigma**2)
            return p*np.exp(-0.5/sigma**2*(x-mu)**2)
        x = np.arange(-7, 7, 0.01)
        params = [(0,1), (0,2), (3,1)]
        estimator = Plot()
        estimator.plot(x, [normal(x,mu,sigma) for mu,sigma in params], xlabel='x', ylabel='p(x)', figsize=(12, 8), legend=[f'mean {mu},std {sigma}'for mu,sigma in params])
        plt.show()
    demo_1()

结果展示为:
【自定义函数】配置matplotlib生成图形的属性(含源代码)_第1张图片

你可能感兴趣的:(深度学习笔记,matplotlib,python,深度学习)