Matplotlib设置网格线之major和minor
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
data_dict = {
0: 7.088511943817139,
1: 1.941696047782898,
2: 4.185360431671143,
3: 3.784451723098755,
4: 5.676402568817139,
5: 6.908511161804199,
6: 0.0035664942115545273,
7: 0.00400047842413187,
8: 0.00456645293161273,
9: 0.003707129508256912,
10: 20.32890510559082,
11: 0.0040948097594082355
}
encoder_dict = {k: data_dict[k] for k in range(6)}
decoder_dict = {k: data_dict[k] for k in range(6, 12)}
data1 = np.array(list(encoder_dict.values())).reshape(1, -1)
data2 = np.array(list(decoder_dict.values())).reshape(1, -1)
fig, (ax1, ax2) = plt.subplots(2, 1, sharey=True)
im1 = ax1.imshow(data1, cmap='Oranges', interpolation='nearest')
ax1.set_yticks([])
ax1.set_xticks([0, 1, 2,3,4,5])
ax1.xaxis.set_minor_locator(MultipleLocator(0.5))
ax1.set_ylabel('Encoder')
ax1.set_xlabel('Layer')
im2 = ax2.imshow(data2, cmap='Oranges', interpolation='nearest')
ax2.set_yticks([])
ax2.set_xticks([0, 1, 2,3,4,5])
ax2.xaxis.set_minor_locator(MultipleLocator(0.5))
ax2.set_xlabel('Layer')
ax2.set_ylabel('Decoder')
for ax in [ax1, ax2]:
ax.grid(True, which='minor', axis='both', linestyle='-', color='gray', alpha=0.5)
cbar = fig.colorbar(im2, ax=[ax1, ax2], orientation='horizontal', pad=0.2)
plt.show()
fig.savefig('norm_visualization.pdf', format='pdf', dpi=300)