如下
import numpy as np
import matplotlib
import matplotlib.ticker as ticker
from matplotlib.colors import ListedColormap, BoundaryNorm
#随机生成mask,像素值为0、1、2
mask=np.random.randint(0,3,[255,255])
#mask中像素的种类数
unique_values = np.unique(mask).size
#选择一种cmap
cmap = matplotlib.colormaps.get_cmap('cool')
#选取n种颜色
cmap = ListedColormap(cmap(np.linspace(0, 255, unique_values).astype(np.uint8)))
# 创建边界规范。bounds 指定了颜色分界线在哪里
bounds = np.array(list((np.unique(mask)))+[unique_values]) #需要+1,否则少一种颜色
norm = BoundaryNorm(bounds, cmap.N)
# 绘制颜色条
cb = plt.colorbar(
plt.cm.ScalarMappable(cmap=cmap, norm=norm),
ticks=bounds+0.5, #ticks指定了刻度标在哪里,+0.5是为了使刻度标在颜色中间
boundaries=bounds,
orientation='vertical',
format='%d'
)
plt.imshow(mask, cmap=cmap, norm=norm)
plt.show()