import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
class ScatterPlot:
def __init__(self, x, y, title=None, xlabel=None, ylabel=None):
self.x = x
self.y = y
self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
def draw(self):
plt.scatter(self.x, self.y)
self._set_labels()
plt.show()
def _set_labels(self):
if self.xlabel:
plt.xlabel(self.xlabel)
if self.ylabel:
plt.ylabel(self.ylabel)
if self.title:
plt.title(self.title)
def save(self, path, format='png'):
plt.savefig(path, format=format)
class LinePlot:
def __init__(self, x, y, title=None, xlabel=None, ylabel=None):
self.x = x
self.y = y
self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
def draw(self):
plt.plot(self.x, self.y)
self._set_labels()
plt.show()
def _set_labels(self):
if self.xlabel:
plt.xlabel(self.xlabel)
if self.ylabel:
plt.ylabel(self.ylabel)
if self.title:
plt.title(self.title)
def save(self, path, format='png'):
plt.savefig(path, format=format)
class Heatmap:
def __init__(self, data, title=None, xlabel=None, ylabel=None):
self.data = data
self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
def draw(self):
plt.imshow(self.data, cmap='hot', interpolation='nearest')
self._set_labels()
plt.show()
def _set_labels(self):
if self.xlabel:
plt.xlabel(self.xlabel)
if self.ylabel:
plt.ylabel(self.ylabel)
if self.title:
plt.title(self.title)
def save(self, path, format='png'):
plt.savefig(path, format=format)
class SurfacePlot:
def __init__(self, x, y, z, title=None, xlabel=None, ylabel=None, zlabel=None):
self.x = x
self.y = y
self.z = z
self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
self.zlabel = zlabel
def draw(self):
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(self.x, self.y, self.z)
self._set_labels(ax)
plt.show()
def _set_labels(self, ax):
if self.xlabel:
ax.set_xlabel(self.xlabel)
if self.ylabel:
ax.set_ylabel(self.ylabel)
if self.zlabel:
ax.set_zlabel(self.zlabel)
if self.title:
plt.title(self.title)
def save(self, path, format='png'):
plt.savefig(path, format=format)
# 示例用法:
# 散点图
x = np.random.rand(50)
y = np.random.rand(50)
scatter = ScatterPlot(x, y, title='散点图', xlabel='X轴', ylabel='Y轴')
scatter.draw()
scatter.save('scatter_plot.png')
# 折线图
x = np.linspace(0, 2 * np.pi, 100)
y = np.sin(x)
line = LinePlot(x, y, title='正弦波', xlabel='X轴', ylabel='Y轴')
line.draw()
line.save('line_plot.png')
# 热力图
data = np.random.rand(10, 10)
heatmap = Heatmap(data, title='热力图')
heatmap.draw()
heatmap.save('heatmap.png')
# 3D面图
x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
x, y = np.meshgrid(x, y)
z = np.sin(np.sqrt(x**2 + y**2))
surface = SurfacePlot(x, y, z, title='3D面图', xlabel='X轴', ylabel='Y轴', zlabel='Z轴')
surface.draw()
surface.save('surface_plot.png')
这个代码将不同类型的图表分别封装到了具体的类中,使得代码更易理解和维护。每个类都包含了绘制、设置标签和保存的功能。