如何快速看出矩阵乘法的时间复杂度

以 Attention Score 的计算为例
A t t n ( K , Q , V ) = S o f t m a x ( Q ⋅ K T / d ) ⋅ V Attn(K,Q,V) = Softmax(Q\cdot K^T/\sqrt{d})\cdot V Attn(K,Q,V)=Softmax(QKT/d )V
咱姑且把 Softmax 和 Softmax里面的除以 d \sqrt{d} d 去掉(其运算时间复杂度小),表示为
A t t n ( K , Q , V ) = Q ⋅ K T ⋅ V Attn(K,Q,V) = Q\cdot K^T\cdot V Attn(K,Q,V)=QKTV
其中, Q , K , V ∈ R N × d Q,K,V \in \mathbb{R}^{N\times d} Q,K,VRN×d N N N 是token的数量, d d d 是每个token的维度,一般认为 N N N> d d d

Q ⋅ K T Q\cdot K^T QKT 从矩阵乘法上看维度变换是 N × d × d × N N\times d \times d \times N N×d×d×N,得到的矩阵维度是 N × N N\times N N×N,即得到的矩阵有 N 2 N^2 N2 个元素,每个元素需要经过d个元素相乘再相加得到(加权求和),所以 Q ⋅ K T Q\cdot K^T QKT 计算的时间复杂度为 O ( N 2 d ) O(N^2d) O(N2d)

  • 总结一个快速得出结论的方法

如果你不想鸟我上面写的,你只需要按照这个规则来看

比如两个矩阵 M ⋅ N , M ∈ R m × n , M ∈ R n × k M\cdot N, M\in\mathbb{R}^{m\times n}, M\in\mathbb{R}^{n\times k} MN,MRm×n,MRn×k

按照维度表示为 m × n × n × k m\times n \times n \times k m×n×n×k只需要把中间的两个 n n n 删掉一个即可表示时间复杂度,为 O ( m × n × k ) O(m\times n\times k) O(m×n×k)

你可能感兴趣的:(深度学习,矩阵,算法,线性代数)