一种用向量化的方式实现 L2 Distance 的数学技巧

背景描述

在 CS231n 的作业题中有一个需要实现 L2 distance 的题目,题目要求不能用循环语句。
已知待分类的数据集合矩阵 X ,训练用数据集合矩阵 X_train。
待分类数据集中任何一个数据点到训练用数据集合中的任一点之间的距离用矩阵 dists[i,j] 来表示。

理解这个问题有这样要点:

  • 在计算的过程中借助广播的性质从向量运算一口气生成最后的矩阵,而不是手动的生成矩阵。
  • 利用这样一个数学等价变化 ( x − y ) 2 = x 2 − 2 x y + y 2 (x - y)^2 = x^2 - 2xy + y^2 (xy)2=x22xy+y2
首先这个问题最好还是要形象一下。

比如说我设待分类的矩阵(也就是 test 矩阵)为 X ,它有两个样本, 3 个属性。也就是
X = ( x 1 ( 1 ) x 2 ( 1 ) x 3 ( 1 ) x 1 ( 2 ) x 2 ( 2 ) x 3 ( 2 ) ) X = \begin{pmatrix} x^{(1)}_1 & x^{(1)}_2 & x^{(1)}_3 \\ x^{(2)}_1 & x^{(2)}_2 & x^{(2)}_3 \end{pmatrix} X=(x1(1)x1(2)x2(1)x2(2)x3(1)x3(2))
我设测试机矩阵(也就是 X_train)为 Y ,它有三个样本,3个属性。也就是
Y = ( y 1 ( 1 ) y 2 ( 1 ) y 3 ( 1 ) y 1 ( 2 ) y 2 ( 2 ) y 3 ( 2 ) y 1 ( 3 ) y 2 ( 3 ) y 3 ( 3 ) ) Y = \begin{pmatrix} y^{(1)}_1 & y^{(1)}_2 & y^{(1)}_3 \\ y^{(2)}_1 & y^{(2)}_2 & y^{(2)}_3 \\ y^{(3)}_1 & y^{(3)}_2 & y^{(3)}_3 \end{pmatrix} Y=y1(1)y1(2)y1(3)y2(1)y2(2)y2(3)y3(1)y3(2)y3(3)
我们最后要求的这个矩阵dist[i,j],第 i,j 个位置对应于待分类矩阵中的第 i 个元素到测试矩阵第 j 个元素的L2距离。这是一个 2 X 3 的矩阵(num_test X num_train)。

我们进一步把 dist 矩阵形象化一下,并拿出其中的一个元素,试着将其展开,推导一下看看能不能看出点什么东西出来。

d i s t = ( L ( x ( 1 ) , y ( 1 ) ) L ( x ( 1 ) , y ( 2 ) ) L ( x ( 1 ) , y ( 3 ) ) L ( x ( 2 ) , y ( 1 ) ) L ( x ( 2 ) , y ( 2 ) ) L ( x ( 2 ) , y ( 3 ) ) ) dist = \begin{pmatrix} L(x^{(1)},y^{(1)}) & L(x^{(1)},y^{(2)}) & L(x^{(1)},y^{(3)}) \\ L(x^{(2)},y^{(1)}) & L(x^{(2)},y^{(2)}) & L(x^{(2)},y^{(3)}) \end{pmatrix} dist=(L(x(1),y(1))L(x(2),y(1))L(x(1),y(2))L(x(2),y(2))L(x(1),y(3))L(x(2),y(3)))
拿出 L ( x ( 1 ) , y ( 1 ) ) L(x^{(1)},y^{(1)}) L(x(1),y(1)) 这一项来,这表示待分类数据集中的 x ( 1 ) x^{(1)} x(1)这个样本到测试数据集中的 y ( 1 ) y^{(1)} y(1)这个点之间的距离。 x ( 1 ) x^{(1)} x(1) y ( 1 ) y^{(1)} y(1) 这两点都是向量,每个向量都有 3 个分量。我们把 L ( x ( 1 ) , y ( 1 ) ) L(x^{(1)},y^{(1)}) L(x(1),y(1)) 这项也展开来。
L ( x ( 1 ) , y ( 1 ) ) = ( x 1 ( 1 ) − y 1 ( 1 ) ) 2 + ( x 2 ( 1 ) − y 2 ( 1 ) ) 2 + ( x 3 ( 1 ) − y 3 ( 1 ) ) 2 = [ ( x 1 ( 1 ) ) 2 + ( x 2 ( 1 ) ) 2 + ( x 3 ( 1 ) ) 2 ] + [ ( y 1 ( 1 ) ) 2 + ( y 2 ( 1 ) ) 2 + ( y 3 ( 1 ) ) 2 ] − 2 x 1 ( 1 ) y 1 ( 1 ) − 2 x 2 ( 1 ) y 2 ( 1 ) − 2 x 3 ( 1 ) y 3 ( 1 ) L(x^{(1)},y^{(1)}) = (x^{(1)}_1 - y^{(1)}_1)^2 + (x^{(1)}_2 - y^{(1)}_2)^2 + (x^{(1)}_3 - y^{(1)}_3)^2 = [(x^{(1)}_1)^2 + (x^{(1)}_2)^2 + (x^{(1)}_3)^2] +[ (y^{(1)}_1)^2 + (y^{(1)}_2)^2 + (y^{(1)}_3)^2 ] - 2 x^{(1)}_1 y^{(1)}_1 - 2 x^{(1)}_2 y^{(1)}_2 - 2 x^{(1)}_3 y^{(1)}_3 L(x(1),y(1))=(x1(1)y1(1))2+(x2(1)y2(1))2+(x3(1)y3(1))2=[(x1(1))2+(x2(1))2+(x3(1))2]+[(y1(1))2+(y2(1))2+(y3(1))2]2x1(1)y1(1)2x2(1)y2(1)2x3(1)y3(1)
我们尝试着把dist中的其它项也都展开,为了书写方便把 ( x 1 ( 1 ) ) 2 + ( x 2 ( 1 ) ) 2 + ( x 3 ( 1 ) ) 2 (x^{(1)}_1)^2 + (x^{(1)}_2)^2 + (x^{(1)}_3)^2 (x1(1))2+(x2(1))2+(x3(1))2 记为 s u m ( ( x ( 1 ) ) 2 ) sum( (x^{(1)})^2) sum((x(1))2) ( y 1 ( 1 ) ) 2 + ( y 2 ( 1 ) ) 2 + ( y 3 ( 1 ) ) 2 (y^{(1)}_1)^2 + (y^{(1)}_2)^2 + (y^{(1)}_3)^2 (y1(1))2+(y2(1))2+(y3(1))2 记为 s u m ( ( y ( 1 ) ) 2 ) sum( (y^{(1)}) ^2) sum((y(1))2) x 1 ( 1 ) y 1 ( 1 ) + x 2 ( 1 ) y 2 ( 1 ) + x 3 ( 1 ) y 3 ( 1 ) x^{(1)}_1 y^{(1)}_1 + x^{(1)}_2 y^{(1)}_2 + x^{(1)}_3 y^{(1)}_3 x1(1)y1(1)+x2(1)y2(1)+x3(1)y3(1) 记为 x ( 1 ) ⋅ y ( 1 ) x^{(1)} \cdot y^{(1)} x(1)y(1) 依次类推。于是
d i s t = ( s u m ( ( x ( 1 ) ) 2 ) + s u m ( ( y ( 1 ) ) 2 ) − 2 x ( 1 ) ⋅ y ( 1 ) s u m ( ( x ( 1 ) ) 2 ) + s u m ( ( y ( 2 ) ) 2 ) − 2 x ( 1 ) ⋅ y ( 2 ) s u m ( ( x ( 1 ) ) 2 ) + s u m ( ( y ( 3 ) ) 2 ) − 2 x ( 1 ) ⋅ y ( 3 ) s u m ( ( x ( 2 ) ) 2 ) + s u m ( ( y ( 1 ) ) 2 ) − 2 x ( 2 ) ⋅ y ( 1 ) s u m ( ( x ( 2 ) ) 2 ) + s u m ( ( y ( 2 ) ) 2 ) − 2 x ( 2 ) ⋅ y ( 2 ) s u m ( ( x ( 2 ) ) 2 ) + s u m ( ( y ( 3 ) ) 2 ) − 2 x ( 2 ) ⋅ y ( 3 ) ) dist = \begin{pmatrix} sum( (x^{(1)})^2) + sum( (y^{(1)}) ^2) - 2 x^{(1)} \cdot y^{(1)} & sum( (x^{(1)})^2) + sum( (y^{(2)}) ^2) - 2 x^{(1)} \cdot y^{(2)} & sum( (x^{(1)})^2) + sum( (y^{(3)}) ^2) - 2 x^{(1)} \cdot y^{(3)} \\ sum( (x^{(2)})^2) + sum( (y^{(1)}) ^2) - 2 x^{(2)} \cdot y^{(1)} & sum( (x^{(2)})^2) + sum( (y^{(2)}) ^2) - 2 x^{(2)} \cdot y^{(2)} & sum( (x^{(2)})^2) + sum( (y^{(3)}) ^2) - 2 x^{(2)} \cdot y^{(3)} \end{pmatrix} dist=(sum((x(1))2)+sum((y(1))2)2x(1)y(1)sum((x(2))2)+sum((y(1))2)2x(2)y(1)sum((x(1))2)+sum((y(2))2)2x(1)y(2)sum((x(2))2)+sum((y(2))2)2x(2)y(2)sum((x(1))2)+sum((y(3))2)2x(1)y(3)sum((x(2))2)+sum((y(3))2)2x(2)y(3))
进一步,dist 可以拆成 3 个矩阵的和
d i s t = ( s u m ( ( x ( 1 ) ) 2 ) s u m ( ( x ( 1 ) ) 2 ) s u m ( ( x ( 1 ) ) 2 ) s u m ( ( x ( 2 ) ) 2 ) s u m ( ( x ( 2 ) ) 2 ) s u m ( ( x ( 2 ) ) 2 ) ) + ( s u m ( ( y ( 1 ) ) 2 ) s u m ( ( y ( 2 ) ) 2 ) s u m ( ( y ( 3 ) ) 2 ) s u m ( ( y ( 1 ) ) 2 ) s u m ( ( y ( 2 ) ) 2 ) s u m ( ( y ( 3 ) ) 2 ) ) − 2 ( x ( 1 ) ⋅ y ( 1 ) x ( 1 ) ⋅ y ( 2 ) x ( 1 ) ⋅ y ( 3 ) x ( 2 ) ⋅ y ( 1 ) x ( 2 ) ⋅ y ( 2 ) x ( 2 ) ⋅ y ( 3 ) ) dist = \begin{pmatrix} sum( (x^{(1)})^2) & sum( (x^{(1)})^2) & sum( (x^{(1)})^2) \\ sum( (x^{(2)})^2) & sum( (x^{(2)})^2) & sum( (x^{(2)})^2) \end{pmatrix} + \begin{pmatrix} sum( (y^{(1)}) ^2) & sum( (y^{(2)}) ^2) & sum( (y^{(3)}) ^2) \\ sum( (y^{(1)}) ^2) & sum( (y^{(2)}) ^2) & sum( (y^{(3)}) ^2) \end{pmatrix} - 2 \begin{pmatrix} x^{(1)} \cdot y^{(1)} & x^{(1)} \cdot y^{(2)} & x^{(1)} \cdot y^{(3)} \\ x^{(2)} \cdot y^{(1)} & x^{(2)} \cdot y^{(2)} & x^{(2)} \cdot y^{(3)} \end{pmatrix} dist=(sum((x(1))2)sum((x(2))2)sum((x(1))2)sum((x(2))2)sum((x(1))2)sum((x(2))2))+(sum((y(1))2)sum((y(1))2)sum((y(2))2)sum((y(2))2)sum((y(3))2)sum((y(3))2))2(x(1)y(1)x(2)y(1)x(1)y(2)x(2)y(2)x(1)y(3)x(2)y(3))

接下来,我们看看这三个矩阵上能不能看出一些什么规律,进一步接近答案

注意看
( s u m ( ( x ( 1 ) ) 2 ) s u m ( ( x ( 1 ) ) 2 ) s u m ( ( x ( 1 ) ) 2 ) s u m ( ( x ( 2 ) ) 2 ) s u m ( ( x ( 2 ) ) 2 ) s u m ( ( x ( 2 ) ) 2 ) ) \begin{pmatrix} sum( (x^{(1)})^2) & sum( (x^{(1)})^2) & sum( (x^{(1)})^2) \\ sum( (x^{(2)})^2) & sum( (x^{(2)})^2) & sum( (x^{(2)})^2) \end{pmatrix} (sum((x(1))2)sum((x(2))2)sum((x(1))2)sum((x(2))2)sum((x(1))2)sum((x(2))2))这个矩阵可以看成 ( s u m ( ( x ( 1 ) ) 2 ) s u m ( ( x ( 2 ) ) 2 ) ) \begin{pmatrix} sum( (x^{(1)})^2) \\ sum( (x^{(2)})^2) \end{pmatrix} (sum((x(1))2)sum((x(2))2)) 这个列向量在列方向上的复制吧,这可以利用广播得到。
( s u m ( ( y ( 1 ) ) 2 ) s u m ( ( y ( 2 ) ) 2 ) s u m ( ( y ( 3 ) ) 2 ) s u m ( ( y ( 1 ) ) 2 ) s u m ( ( y ( 2 ) ) 2 ) s u m ( ( y ( 3 ) ) 2 ) ) \begin{pmatrix} sum( (y^{(1)}) ^2) & sum( (y^{(2)}) ^2) & sum( (y^{(3)}) ^2) \\ sum( (y^{(1)}) ^2) & sum( (y^{(2)}) ^2) & sum( (y^{(3)}) ^2) \end{pmatrix} (sum((y(1))2)sum((y(1))2)sum((y(2))2)sum((y(2))2)sum((y(3))2)sum((y(3))2))这个矩阵可以看成 ( s u m ( ( y ( 1 ) ) 2 ) s u m ( ( y ( 2 ) ) 2 ) s u m ( ( y ( 3 ) ) 2 ) ) \begin{pmatrix} sum( (y^{(1)}) ^2) & sum( (y^{(2)}) ^2) & sum( (y^{(3)}) ^2) \end{pmatrix} (sum((y(1))2)sum((y(2))2)sum((y(3))2)) 这个行向量在行方向上的复制吧。
( x ( 1 ) ⋅ y ( 1 ) x ( 1 ) ⋅ y ( 2 ) x ( 1 ) ⋅ y ( 3 ) x ( 2 ) ⋅ y ( 1 ) x ( 2 ) ⋅ y ( 2 ) x ( 2 ) ⋅ y ( 3 ) ) \begin{pmatrix} x^{(1)} \cdot y^{(1)} & x^{(1)} \cdot y^{(2)} & x^{(1)} \cdot y^{(3)} \\ x^{(2)} \cdot y^{(1)} & x^{(2)} \cdot y^{(2)} & x^{(2)} \cdot y^{(3)} \end{pmatrix} (x(1)y(1)x(2)y(1)x(1)y(2)x(2)y(2)x(1)y(3)x(2)y(3)) 可以看成 矩阵 X 和矩阵 Y 的转置的点积 X ⋅ Y T X \cdot Y^T XYT

转成代码 X那一项

( s u m ( ( x ( 1 ) ) 2 ) s u m ( ( x ( 2 ) ) 2 ) ) \begin{pmatrix} sum( (x^{(1)})^2) \\ sum( (x^{(2)})^2) \end{pmatrix} (sum((x(1))2)sum((x(2))2)) 这个向量可以由

np.sum(X **2,axis = 1) 

得到。
注意,虽然 axis = 1 是加到同一列上去,但是结果还是一个行向量,numpy 默认的向量是行向量
比如

>>> b
array([[-4, -3, -2],
       [-1, 0, 1],
       [ 2, 3, 4]])
>>> c = np.sum(b,axis=1)
>>> c
array([-9, 0, 9])

c 是一个行向量。你还要手动的把它转为列向量

>>> d = c[:,np.newaxis]
>>> d
array([[-9],
       [ 0],
       [ 9]])

所以,对于待分类的矩阵来说,这一项为

 np.sum(X\**2,axis = 1)[:,np.newaxis] 

或者

np.sum(X**2,axis = 1).reshape(num_test,1)
转成代码 Y那一项

( s u m ( ( y ( 1 ) ) 2 ) s u m ( ( y ( 2 ) ) 2 ) s u m ( ( y ( 3 ) ) 2 ) s u m ( ( y ( 1 ) ) 2 ) s u m ( ( y ( 2 ) ) 2 ) s u m ( ( y ( 3 ) ) 2 ) ) \begin{pmatrix} sum( (y^{(1)}) ^2) & sum( (y^{(2)}) ^2) & sum( (y^{(3)}) ^2) \\ sum( (y^{(1)}) ^2) & sum( (y^{(2)}) ^2) & sum( (y^{(3)}) ^2) \end{pmatrix} (sum((y(1))2)sum((y(1))2)sum((y(2))2)sum((y(2))2)sum((y(3))2)sum((y(3))2))这一项依然可以用sum axis = 1 的方式,而且就是行向量,不用变了。注意,这里Y是 self.X_train

np.sum(self.X_train ** 2 , axis =1)
转成代码 XY那一项

( x ( 1 ) ⋅ y ( 1 ) x ( 1 ) ⋅ y ( 2 ) x ( 1 ) ⋅ y ( 3 ) x ( 2 ) ⋅ y ( 1 ) x ( 2 ) ⋅ y ( 2 ) x ( 2 ) ⋅ y ( 3 ) ) \begin{pmatrix} x^{(1)} \cdot y^{(1)} & x^{(1)} \cdot y^{(2)} & x^{(1)} \cdot y^{(3)} \\ x^{(2)} \cdot y^{(1)} & x^{(2)} \cdot y^{(2)} & x^{(2)} \cdot y^{(3)} \end{pmatrix} (x(1)y(1)x(2)y(1)x(1)y(2)x(2)y(2)x(1)y(3)x(2)y(3)) 这一项用数学表达式表达就是 X ⋅ Y T X \cdot Y^T XYT,转成代码为

np.dot(X,self.X_train.T)
最后的代码为

别忘了,最后开方

dists = np.sqrt( np.sum(X ** 2,axis = 1)[:,np.newaxis]  +  np.sum(self.X_train ** 2 , axis =1) - 2 * np.dot(X,self.X_train.T) )
这部分的参考

这部分主要是看Burton这哥们儿的答案。
它给了一个特别好的连接 https://medium.com/dataholiks-distillery/l2-distance-matrix-vectorization-trick-26aa3247ac6c
但是这个连接老实说 写得不太好。跳跃太大。因为 ( x − y ) 2 = x 2 − 2 x y + y 2 (x - y)^2 = x^2 - 2xy + y^2 (xy)2=x22xy+y2 这个公式严格意义上是元素层面的加减,不能直接推广到矩阵形式使用。必须要按照我的这种方式推导一遍才行。虽说最后向量形式和这个公式形式上一样。但是一开始时我是不能保证在元素层面上的公式就能直接在向量形式上应用。

你可能感兴趣的:(CS231n)