如何修改以下代码以添加图例?
# Code source: Gae"l Varoquaux
# License: BSD 3 clause
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import decomposition
from sklearn import datasets
np.random.seed(5)
centers = [[1, 1], [-1, -1], [1, -1]]
iris = datasets.load_iris()
X = iris.data#the floating point values
y = iris.target#unsigned integers specifying group
fig = plt.figure(1, figsize=(4, 3))
plt.clf()
ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)
plt.cla()
pca = decomposition.PCA(n_components=3)
pca.fit(X)
X = pca.transform(X)
for name, label in [('Setosa', 0), ('Versicolour', 1), ('Virginica', 2)]:
ax.text3D(X[y == label, 0].mean(),
X[y == label, 1].mean() + 1.5,
X[y == label, 2].mean(), name,
horizontalalignment='center',
bbox=dict(alpha=.5, edgecolor='w', facecolor='w'))
# Reorder the labels to have colors matching the cluster results
y = np.choose(y, [1, 2, 0]).astype(np.float)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y, cmap=plt.cm.spectral,
edgecolor='k')
ax.w_xaxis.set_ticklabels([])
ax.w_yaxis.set_ticklabels([])
ax.w_zaxis.set_ticklabels([])
plt.show()
解决方法:
另一个答案中存在一些问题,OP和回答者似乎都不清楚;因此,这不是一个完整的答案,而是现有答案的附录.
> 2.2版中的matplotlib已删除光谱色图,
使用Spectral或nipy_spectral或任何other valid colormap.
> matplotlib中的任何色彩映射范围都是0到1.如果使用该范围之外的任何值调用它,
它会给你最出色的颜色.要从色图中获取颜色,您需要对值进行标准化.
这是通过Normalize实例完成的.在这种情况下,这是分散的内部.
因此,使用sc = ax.scatter(…)然后使用sc.cmap(sc.norm(value))根据散点图中使用的相同映射获取值.
因此代码应该使用
[sc.cmap(sc.norm(i)) for i in [1, 2, 0]]
>传说不在图中.这个数字是4 x 3英寸(figsize =(4,3)).
轴占据宽度的95%(rect = [0,0,.95,1]).
对图例的调用将图例的右中心点置于轴宽度的1.7倍= 4 * 0.95 * 1.7 = 6.46英寸. (bbox_to_anchor =(1.7,0.5)).
来自我方的替代建议:使图形更大(figsize =(5.5,3)),使图例适合,使轴只占图形宽度的70%,这样你就有30%的图形.将图例的左侧放置在靠近轴边界的位置(bbox_to_anchor =(1.0,.5)).
你仍然可以看到包括jupyter笔记本中的图例在内的完整图形的原因是jupyter会将所有内容保存在画布中,即使它重叠并因此放大了图形.
总的来说,代码可能看起来像
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np; np.random.seed(5)
from sklearn import decomposition, datasets
centers = [[1, 1], [-1, -1], [1, -1]]
iris = datasets.load_iris()
X = iris.data #the floating point values
y = iris.target #unsigned integers specifying group
fig = plt.figure(figsize=(5.5, 3))
ax = Axes3D(fig, rect=[0, 0, .7, 1], elev=48, azim=134)
pca = decomposition.PCA(n_components=3)
pca.fit(X)
X = pca.transform(X)
labelTups = [('Setosa', 0), ('Versicolour', 1), ('Virginica', 2)]
for name, label in labelTups:
ax.text3D(X[y == label, 0].mean(),
X[y == label, 1].mean() + 1.5,
X[y == label, 2].mean(), name,
horizontalalignment='center',
bbox=dict(alpha=.5, edgecolor='w', facecolor='w'))
# Reorder the labels to have colors matching the cluster results
y = np.choose(y, [1, 2, 0]).astype(np.float)
sc = ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y, cmap="Spectral", edgecolor='k')
ax.w_xaxis.set_ticklabels([])
ax.w_yaxis.set_ticklabels([])
ax.w_zaxis.set_ticklabels([])
colors = [sc.cmap(sc.norm(i)) for i in [1, 2, 0]]
custom_lines = [plt.Line2D([],[], ls="", marker='.',
mec='k', mfc=c, mew=.1, ms=20) for c in colors]
ax.legend(custom_lines, [lt[0] for lt in labelTups],
loc='center left', bbox_to_anchor=(1.0, .5))
plt.show()
并生产
标签:python,scikit-learn,matplotlib
来源: https://codeday.me/bug/20190828/1747897.html