https://blog.csdn.net/u013733326/article/details/79847918
np.meshgrid 生成网格点坐标
二维坐标系中,X轴可以取三个值1,2,3, Y轴可以取三个值4,5,
(1,4)(2,4)(3,4)
(1,5)(2,5)(3,5)
返回list,有两个元素,第一个元素是X轴的取值,第二个元素是Y轴的取值
返回结果: [array([ [1,2,3] [1,2,3] ]), array([ [4,4,4] [5,5,5] ])]
xx, yy= np.meshgrid(a,b)
"""
xx [1 2 3 1 2 3]
yy [4 4 4 5 5 5]
z [[1 4]
[2 4]
[3 4]
[1 5]
[2 5]
[3 5]]
"""
def plot_decision_boundary(model,x,y):
x_min, x_max = x[0,:].min() - 1, x[0,:].max() +1 # 第一个数据中最大最小值
y_min, y_max = x[1,:].min() - 1, x[1,:].max() +1 # 第二个数据中最大最小值
h = 0.01
xx,yy = np.meshgrid(np.arange(x_min,x_max,h),np.arange(y_min,y_max,h))
z = model(np.c_[xx.ravel(),yy.raverl()])
z = z.reshape(xx.shape)
plt.contourf(xx, yy, z,cmap=plt.cm.Spectral)
plt.ylabel('x2')
plt.xlabel('x1')
plt.scatter(x[0, :], x[1, :], c=np.squeeze(y), cmap=plt.cm.Spectral)
plt.show()
https://blog.csdn.net/lens___/article/details/83960810
绘制等高线的,contour和contourf都是画三维等高线图的,不同点在于contour() 是绘制轮廓线,
contourf()会填充轮廓。除非另有说明,否则两个版本的函数是相同的。
参数: X,Y:类似数组,可选为Z中的坐标值
当 X,Y,Z 都是 2 维数组时,它们的形状必须相同。如果都是 1 维数组时,
len(X)是 Z 的列数,而 len(Y) 是 Z 中的行数。(例如,经由创建numpy.meshgrid())
Z:类似矩阵绘制轮廓的高度值
levels:int或类似数组,可选确定轮廓线/区域的数量和位置
其他参数: aalpha:float ,可选
alpha混合值,介于0(透明)和1(不透明)之间。
cmap:str或colormap ,可选
Colormap用于将数据值(浮点数)从间隔转换为相应Colormap表示的RGBA颜 色。用于将数据缩放到间隔中看 。
x = np.array([1,2,3])
y = np.array([4,5])
xx,yy = np.meshgrid(x,y)
z = xx**3 + yy**3
# z.reshape(xx.shape)
print('xx',xx.ravel())
print('yy',yy.ravel())
print('z',z)
plt.contourf(xx, yy, z)
plt.show()
绘制预测和实际不同的图像
def print_mislabeled_images(classes, X, Y, p):
"""
绘制预测和实际不同的图像
:param classes: 种类
:param X: 数据集
:param Y: 标签
:param p: 预测
:return:
"""
a = p + Y # 预测 和 实际 概率 累加 ,其中预测错的和为1
print("a.shape", a.shape)
mislabeled_indices = np.asarray(np.where(a == 1))
print("mislabeled_indices", mislabeled_indices)
# mislabeled_indices
# [[ 0 0 0 0 0 0 0 0 0 0 0]
# [ 5 6 13 19 28 29 34 44 45 46 48]]
# 原矩阵a是个一维矩阵(行数只有一行)
# 第一行的[ 0 0 0 0 0 0 0 0 0 0 0]代表的是对应点的行下标
# 第二行的[ 5 6 13 19 28 29 34 44 45 46 48]代表的是列下标
print("mislabeled_indices.shape", mislabeled_indices.shape)
plt.rcParams['figure.figsize'] = (40.0, 40.0) # set default size of plots
num_images = len(mislabeled_indices[0])
# print("mislabeled_indices.shape[0]")
for i in range(num_images):
index = mislabeled_indices[1][i]
print('index', index)
plt.subplot(2, num_images, i + 1)
plt.imshow(X[:, index].reshape(64, 64, 3), interpolation='nearest')
plt.axis('off')
plt.title('Prediction:' + classes[int(p[0, index])].decode('utf-8') + '\n Class' + classes[Y[0, index]].decode(
'utf-8'))
plt.show()