Matplotlib设置网格线之major和minor

Matplotlib设置网格线之major和minor

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
data_dict = {
    0: 7.088511943817139,
    1: 1.941696047782898,
    2: 4.185360431671143,
    3: 3.784451723098755,
    4: 5.676402568817139,
    5: 6.908511161804199,
    6: 0.0035664942115545273,
    7: 0.00400047842413187,
    8: 0.00456645293161273,
    9: 0.003707129508256912,
    10: 20.32890510559082,
    11: 0.0040948097594082355
}
encoder_dict = {k: data_dict[k] for k in range(6)}
decoder_dict = {k: data_dict[k] for k in range(6, 12)}
data1 = np.array(list(encoder_dict.values())).reshape(1, -1)
data2 = np.array(list(decoder_dict.values())).reshape(1, -1)

# Create a figure and axis objects
fig, (ax1, ax2) = plt.subplots(2, 1, sharey=True)

# Plot the first heatmap
im1 = ax1.imshow(data1, cmap='Oranges', interpolation='nearest')
# im1.set_gid(True)
ax1.set_yticks([])
ax1.set_xticks([0, 1, 2,3,4,5])
ax1.xaxis.set_minor_locator(MultipleLocator(0.5))
ax1.set_ylabel('Encoder')
ax1.set_xlabel('Layer')

# Plot the second heatmap
im2 = ax2.imshow(data2, cmap='Oranges', interpolation='nearest')
# im2.set_gid(True)
ax2.set_yticks([])
ax2.set_xticks([0, 1, 2,3,4,5])
ax2.xaxis.set_minor_locator(MultipleLocator(0.5))
ax2.set_xlabel('Layer')
ax2.set_ylabel('Decoder')

# Add gridlines to both subplots
for ax in [ax1, ax2]:
    ax.grid(True, which='minor', axis='both', linestyle='-', color='gray', alpha=0.5)
# 添加颜色条,颜色条的位置位于底部横着放
cbar = fig.colorbar(im2, ax=[ax1, ax2], orientation='horizontal', pad=0.2)

# 显示图形

plt.show()
# 保存图形为pdf
fig.savefig('norm_visualization.pdf', format='pdf', dpi=300)

你可能感兴趣的:(python,matplotlib)