python绘制基础图像

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

class ScatterPlot:

    def __init__(self, x, y, title=None, xlabel=None, ylabel=None):
        self.x = x
        self.y = y
        self.title = title
        self.xlabel = xlabel
        self.ylabel = ylabel

    def draw(self):
        plt.scatter(self.x, self.y)
        self._set_labels()
        plt.show()

    def _set_labels(self):
        if self.xlabel:
            plt.xlabel(self.xlabel)
        if self.ylabel:
            plt.ylabel(self.ylabel)
        if self.title:
            plt.title(self.title)

    def save(self, path, format='png'):
        plt.savefig(path, format=format)


class LinePlot:

    def __init__(self, x, y, title=None, xlabel=None, ylabel=None):
        self.x = x
        self.y = y
        self.title = title
        self.xlabel = xlabel
        self.ylabel = ylabel

    def draw(self):
        plt.plot(self.x, self.y)
        self._set_labels()
        plt.show()

    def _set_labels(self):
        if self.xlabel:
            plt.xlabel(self.xlabel)
        if self.ylabel:
            plt.ylabel(self.ylabel)
        if self.title:
            plt.title(self.title)

    def save(self, path, format='png'):
        plt.savefig(path, format=format)


class Heatmap:

    def __init__(self, data, title=None, xlabel=None, ylabel=None):
        self.data = data
        self.title = title
        self.xlabel = xlabel
        self.ylabel = ylabel

    def draw(self):
        plt.imshow(self.data, cmap='hot', interpolation='nearest')
        self._set_labels()
        plt.show()

    def _set_labels(self):
        if self.xlabel:
            plt.xlabel(self.xlabel)
        if self.ylabel:
            plt.ylabel(self.ylabel)
        if self.title:
            plt.title(self.title)

    def save(self, path, format='png'):
        plt.savefig(path, format=format)


class SurfacePlot:

    def __init__(self, x, y, z, title=None, xlabel=None, ylabel=None, zlabel=None):
        self.x = x
        self.y = y
        self.z = z
        self.title = title
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.zlabel = zlabel

    def draw(self):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.plot_surface(self.x, self.y, self.z)
        self._set_labels(ax)
        plt.show()

    def _set_labels(self, ax):
        if self.xlabel:
            ax.set_xlabel(self.xlabel)
        if self.ylabel:
            ax.set_ylabel(self.ylabel)
        if self.zlabel:
            ax.set_zlabel(self.zlabel)
        if self.title:
            plt.title(self.title)

    def save(self, path, format='png'):
        plt.savefig(path, format=format)

# 示例用法:

# 散点图
x = np.random.rand(50)
y = np.random.rand(50)
scatter = ScatterPlot(x, y, title='散点图', xlabel='X轴', ylabel='Y轴')
scatter.draw()
scatter.save('scatter_plot.png')

# 折线图
x = np.linspace(0, 2 * np.pi, 100)
y = np.sin(x)
line = LinePlot(x, y, title='正弦波', xlabel='X轴', ylabel='Y轴')
line.draw()
line.save('line_plot.png')

# 热力图
data = np.random.rand(10, 10)
heatmap = Heatmap(data, title='热力图')
heatmap.draw()
heatmap.save('heatmap.png')

# 3D面图
x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
x, y = np.meshgrid(x, y)
z = np.sin(np.sqrt(x**2 + y**2))
surface = SurfacePlot(x, y, z, title='3D面图', xlabel='X轴', ylabel='Y轴', zlabel='Z轴')
surface.draw()
surface.save('surface_plot.png')

这个代码将不同类型的图表分别封装到了具体的类中,使得代码更易理解和维护。每个类都包含了绘制、设置标签和保存的功能。

你可能感兴趣的:(python,python,开发语言)