python自带的scatter函数参数中颜色和大小可以输入列表进行控制,即可以让不同的点有不同的颜色和大小,但是只能是同一种形状。例如例一:
import numpy as np
import matplotlib.pyplot as plt
def plotMatrixPoint(Mat, Label):
"""
:param Mat: 二维点坐标矩阵
:param Label: 点的类别标签
:return:
"""
x = Mat[:, 0]
y = Mat[:, 1]
map_size = {-1: 50, 1: 100}
size = list(map(lambda x: map_size[x], Label))
map_color = {-1: 'r', 1: 'g'}
color = list(map(lambda x: map_color[x], Label))
map_marker = {-1: 'o', 1: 'v'}
markers = list(map(lambda x: map_marker[x], Label))
# 下面一行代码会出错,因为marker参数不支持列表
# plt.scatter(np.array(x), np.array(y), s=size, c=color, marker=markers)
# 下面一行代码为修正过的代码
plt.scatter(np.array(x), np.array(y), s=size, c=color, marker='o') # scatter函数只支持array类型数据
plt.show()
def loadSimpData():
datMat = np.matrix([[1., 2.1], [1.5, 1.6], [1.3, 1.], [1., 1.], [2., 1.]])
classLabels = [1.0, 1.0, -1.0, -1.0, 1.0]
plotMatrixPoint(datMat, classLabels)
return datMat, classLabels
if __name__ == "__main__":
loadSimpData()
下面的代码的mscatter函数 颜色、大小、形状参数都能接受列表,如例二代码:
import numpy as np
import matplotlib.pyplot as plt
def mscatter(x, y, ax=None, m=None, **kw):
import matplotlib.markers as mmarkers
if not ax: ax = plt.gca()
sc = ax.scatter(x, y, **kw)
if (m is not None) and (len(m) == len(x)):
paths = []
for marker in m:
if isinstance(marker, mmarkers.MarkerStyle):
marker_obj = marker
else:
marker_obj = mmarkers.MarkerStyle(marker)
path = marker_obj.get_path().transformed(
marker_obj.get_transform())
paths.append(path)
sc.set_paths(paths)
return sc
def plotMatrixPoint(Mat, Label):
"""
输入二维点矩阵和标签,能够改变不同形状
:param Mat:
:param Label:
:return:
"""
x = Mat[:, 0]
y = Mat[:, 1]
map_size = {-1: 50, 1: 100}
size = list(map(lambda x: map_size[x], Label))
map_color = {-1: 'r', 1: 'g'}
color = list(map(lambda x: map_color[x], Label))
map_marker = {-1: 'o', 1: 's'}
markers = list(map(lambda x: map_marker[x], Label))
mscatter(np.array(x), np.array(y), s=size, c=color, m=markers) # scatter函数只支持array类型数据
plt.show()
def loadSimpData():
datMat = np.matrix([[1., 2.1], [1.5, 1.6], [1.3, 1.], [1., 1.], [2., 1.]])
classLabels = [1.0, 1.0, -1.0, -1.0, 1.0]
plotMatrixPoint(datMat, classLabels)
return datMat, classLabels
if __name__ == "__main__":
loadSimpData()