python绘制圆柱图

import matplotlib.pyplot as plt

# 实验数据
networks = ['A', 'B', 'C','D','E','F','G','H','I','J','K']
mIoU = [1, 2, 3, 4,5,6,7,8,9,10,11]
params = [1, 2, 3, 4,5,6,7,8,9,10,11]
flops = [1, 2, 3, 4,5,6,7,8,9,10,11]
time = [1, 2, 3, 4,5,6,7,8,9,10,11]

plt.style.use('seaborn')
# 创建子图
fig, axs = plt.subplots(2, 2, figsize=(12, 8))

# 设置SCI Nature风格的颜色
sci_nature_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

# 绘制mIoU对比图
bars = axs[0, 0].bar(networks, mIoU, color=sci_nature_colors)
axs[0, 0].set_title('mIoU Comparison')
axs[0, 0].set_ylabel('mIoU (%)')
axs[0, 0].tick_params(axis='x', rotation=45)

# 绘制参数数量对比图
bars = axs[0, 1].bar(networks, params, color=sci_nature_colors)
axs[0, 1].set_title('Parameter Comparison')
axs[0, 1].set_ylabel('Params (M)')
axs[0, 1].tick_params(axis='x', rotation=45)

# 绘制FLOPs对比图
bars = axs[1, 0].bar(networks, flops, color=sci_nature_colors)
axs[1, 0].set_title('FLOPs Comparison')
axs[1, 0].set_ylabel('FLOPs (G)')
axs[1, 0].tick_params(axis='x', rotation=45)

# 绘制时间对比图
bars = axs[1, 1].bar(networks, time, color=sci_nature_colors)
axs[1, 1].set_title('Time Comparison')
axs[1, 1].set_ylabel('Time (s)')
axs[1, 1].tick_params(axis='x', rotation=45)

# 添加具体数值标签到每个子图
for ax in axs.flat:
    for bar in ax.patches:
        height = bar.get_height()
        ax.annotate('%.3f' % height,
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')

# 调整子图之间的间距
plt.tight_layout()

# 添加背景颜色
fig.patch.set_facecolor('#f2f2f2')

# 自定义字体
font = {'family': 'sans-serif', 'weight': 'normal', 'size': 12}
plt.rc('font', **font)

# 显示网格线
for ax in axs.flat:
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    ax.grid(visible=True, which='major', linestyle='-')
    ax.grid(visible=True, which='minor', linestyle='--', alpha=0.5)
    ax.minorticks_on()

# 保存图像文件
plt.savefig('network_comparison.png', dpi=300, bbox_inches='tight')

# 显示图表
plt.show()

你可能感兴趣的:(工具代码,python,matplotlib)