计算复杂度

提示:计算复杂度的简单理解(第一次写博客)

计算复杂度

  • 计算复杂度

计算复杂度

我们以Vicinity Vision Transformer论文中的图为例。
计算复杂度_第1张图片图注:标准自注意力(左)和线性化自注意力(右)的图示。 N N N表示输入图像的 p a t c h patch patch数, d d d是特征维度。使 N ≫ d N\gg d Nd,线性化自注意力的计算复杂度相对于输入长度线性增长,而标准自注意力的计算复杂度是二次的。

从输入到输出可以这样计算:
( N × d ) × ( d × N ) = N × N × ( d × N ) × ( N × N ) = d × N (N\times d)\times (d\times N)=N\times N\times (d\times N)\times (N\times N)=d\times N (N×d)×(d×N)=N×N×(d×N)×(N×N)=d×N
( d × N ) × ( N × d ) = d × d × ( d × d ) × ( d × N ) = d × N (d\times N)\times (N\times d)=d\times d\times (d\times d)\times (d\times N)=d\times N (d×N)×(N×d)=d×d×(d×d)×(d×N)=d×N

关于计算复杂度:其实可以认为是乘法次数。我们给出最直观的解释。

假设有两个矩阵做乘法,如下:
[ 1 2 3 4 5 6 ] × [ 1 2 3 4 5 6 ] = [ 1 2 3 4 5 6 7 8 9 ] \left[\begin{matrix}1&2\\3&4\\5&6\\\end{matrix}\right]\times\left[\begin{matrix}1&2&3\\4&5&6\\\end{matrix}\right]=\left[\begin{matrix}1&2&3\\4&5&6\\7&8&9\\\end{matrix}\right] 135246×[142536]=147258369,其中行数为 N N N,列数为 d d d

( 3 × 2 ) × ( 2 × 3 ) = ( 3 × 3 ) × ( N × d ) × ( d × N ) = ( N × N ) (3\times 2)\times (2\times 3)=(3\times 3)\times (N\times d)\times (d\times N)=(N\times N) (3×2)×(2×3)=(3×3)×(N×d)×(d×N)=(N×N)

3 × 3 3\times 3 3×3矩阵第一个元素涉及的乘法次数: 1 × 1 + 2 × 4 = 9 1\times 1+2\times 4=9 1×1+2×4=9 共2次乘法;其它元素是一样的。最后可以得到 2 × 9 = 2 × 3 × 3 = d × N × N = N 2 d 2\times 9=2\times 3\times 3=d\times N\times N=N^{2}d 2×9=2×3×3=d×N×N=N2d.

假设又有两个矩阵做乘法,如下:
[ 1 2 3 4 5 6 ] × [ 1 2 3 4 5 6 ] = [ 1 2 3 4 ] \left[\begin{matrix}1&2&3\\4&5&6\\\end{matrix}\right]\times\left[\begin{matrix}1&2\\3&4\\5&6\\\end{matrix}\right]=\left[\begin{matrix}1&2\\3&4\\\end{matrix}\right] [142536]×135246=[1324],其中行数为 d d d,列数为 N N N

( 2 × 3 ) × ( 3 × 2 ) = ( 2 × 2 ) × ( d × N ) × ( N × d ) = ( d × d ) (2\times 3)\times (3\times 2)=(2\times 2)\times (d\times N)\times (N\times d)=(d\times d) (2×3)×(3×2)=(2×2)×(d×N)×(N×d)=(d×d)

2 × 2 2\times 2 2×2矩阵第一个元素涉及的乘法次数: 1 × 1 + 2 × 3 + 2 × 5 = 17 1\times 1+2\times 3+2\times 5=17 1×1+2×3+2×5=17 共3次乘法;其它元素是一样的。最后可以得到 3 × 4 = 3 × 2 × 2 = N × d × d = N d 2 3\times 4=3\times 2\times 2=N\times d\times d=Nd^2 3×4=3×2×2=N×d×d=Nd2 .

为什么会有这种情况呢?以第二个例子为例,可以观察到,所得结果的一个元素的乘法数量和消失的维度大小有关,也就是列数 N N N,或者说,列数 N N N就是所得结果一个元素的乘法次数。那么多少个元素呢?元素个数就要看你是如何进行的乘法操作,其实就是矩阵大小。比如 ( 2 × 3 ) × ( 3 × 2 ) = ( 2 × 2 ) × ( d × N ) × ( N × d ) = ( d × d ) (2\times 3)\times (3\times 2)=(2\times 2)\times (d\times N)\times (N\times d)=(d\times d) (2×3)×(3×2)=(2×2)×(d×N)×(N×d)=(d×d),那么就是 d 2 d^2 d2个元素,最后乘法次数就是 N d 2 Nd^2 Nd2

乘法次数=消失的维度 × 所得矩阵大小

那么计算复杂度呢?我们不要去管 O ( ∙ ) O(\bullet) O()具体代表什么,这不重要。
以第一个图为例,乘法次数1: ( N × d ) × ( d × N ) = N 2 d (N\times d)\times (d\times N)=N^{2}d (N×d)×(d×N)=N2d;乘法次数 2 2 2 ( N × d ) × ( d × N ) = N 2 d (N\times d)\times (d\times N)=N^{2}d (N×d)×(d×N)=N2d O ( N 2 d + N 2 d ) = O ( N 2 ) O(N^{2}d+N^{2}d)=O(N^2) O(N2d+N2d)=O(N2)。因为 N ≫ d N\gg d Nd,所以 d d d(还有常数 2 2 2)被省略了,即 O ( N 2 ) O(N^2) O(N2)
以第二个图为例,乘法次数1: ( d × N ) × ( N × d ) = N d 2 (d\times N)\times (N\times d)=Nd^2 (d×N)×(N×d)=Nd2;乘法次数2: ( d × d ) × ( d × N ) = N d 2 (d\times d)\times (d\times N)=Nd^2 (d×d)×(d×N)=Nd2 O ( N d 2 + N d 2 ) = O ( N ) O(Nd^2+Nd^2)=O(N) O(Nd2+Nd2)=O(N)。因为 N ≫ d N\gg d Nd,所以 d d d(还有常数2)被省略了,即 O ( N ) O(N) O(N)

事实告诉我们,我们两个的结果一样,但是我们可以通过控制中间过程减少计算复杂度。

你可能感兴趣的:(深度学习,机器学习,人工智能)