【详解】CS231n assignment1KNN中不使用循环计算距离:从原理到程序

本文主要讲述不使用循环结构来计算两个矩阵的欧氏距离, 设训练集矩阵为train,size为num_train * num_features,设验证集矩阵为validate,size为num_test,num_features。
因此我们计算每一个验证集样本到训练集样本的距离,就是将训练集矩阵train的某一行拿出来与验证集矩阵validate的某一行计算欧式距离。
这在两层循环中就是这么做的,相比大家都明白。但是不使用循环可以有一点难受。本文就是从算法原理上到程序本身上去解释怎么做的。

首先,我不得不承认,我没有认识到不使用循环和使用循环中时间的对比情况,先看一个我实际运行的时间结果:可以看见不使用循环要快非常多,只与为什么两层循环比一层循环还快我不是那么明白。
【详解】CS231n assignment1KNN中不使用循环计算距离:从原理到程序_第1张图片
本来我还没觉得有不使用循环的必要,做完这个实验,我才开始认真考虑到时间成本,并且花了一些时间在不使用循环来求距离上。

以下正文

第一步:完全平方公式

我们明白,所谓的欧氏距离就是先求差,再求平方和,以及求二次方根;我们假设求两个向量的欧式距离:
在这里插入图片描述
不是一般性,我假设x=x1, x2, x3, y=y1, y2, y3,因此,欧氏距离也就是
在这里插入图片描述
我们可以做一个变换,根据
在这里插入图片描述
我们可以知道,当x=x1, x2, x3, y=y1, y2, y3时,有
在这里插入图片描述

第二步:维度验证

前面,我们假设训练集矩阵为train,size为num_train * num_features,设验证集矩阵为validate,size为num_test,num_features。
那么,假如我们计算好了之后,距离矩阵应该是怎样的维度呢?
我们假设让它这样排列:第i行第j列的距离表示验证集的第i行向量和训练集的第j行向量的距离。
因此,这个距离矩阵dist应该是num_validate, num_train。
在计算过程中,我们始终要注意维度是否合理

第三步:计算方法

根据上面解释,我们将矩阵距离换成多项来运算,也就是
在这里插入图片描述
而我们知道,电脑是擅长矩阵运算的。
我们将训练矩阵train看做a, 将验证集validate看做b,我们就是要求(a-b)^2
但是这里都是矩阵,数据是二维的。

比较清楚的,如果我们想要将两个数据的两行进行处理,一个方法就是将其中一个矩阵转置,这样就变成了一行与一列进行运算,这样非常适合矩阵运算。

第四步:程序实现

特别注意这里的数据维度

def compute_distances_no_loops(self, X):
    """
    Compute the distance between each test point in X and each training point
    in self.X_train using no explicit loops.

    Input / Output: Same as compute_distances_two_loops
    """
	num_test = X.shape[0]
	num_train = self.X_train.shape[0]
	dists = np.zeros((num_test, num_train)) 
	
	ab = np.dot(X, self.X_train.T)  # num_test * num_train
	a2 = np.sum(np.square(X), axis=1).reshape(-1, 1)   # num_test * 1
	b2 = np.sum(np.square(self.X_train.T), axis=0).reshape(1, -1)  # 1 * num_train
	dists = -2 * ab + a2 + b2 # 不同维度计算会自动 broadcast
	dists = np.sqrt(dists)

	return dists

程序中,我就是先将X矩阵(也就是验证集)进行转置,使其和目标矩阵(距离矩阵)的行数相同。
转置之后,计算ab就变成矩阵乘法;
对于最后的距离矩阵来说,每一行都是验证集的对应行与训练集的距离,所以a^2是相同的;
与此相同,距离矩阵的每一列都是训练集的对应行与验证集的距离,所以都加上b^2
最后的加法是可以进行broadcast。会自动的将a2 和 b2 分别加到每一行和每一列

后记

感觉还是没能说清楚。
在此提醒大家不要着急,尤其是新学者,可能多遇到几次就好了。
祝各位最后都能顺利理解。这里只能是抛砖引玉了。

你可能感兴趣的:(opencv,数学理论,numpy,python)