矩阵乘法的定义:
若 A = ( a i j ) A=(a_{ij}) A=(aij) 和 B = ( b i j ) B=(b_{ij}) B=(bij) 是 n × n n \times n n×n 的方阵, 则对 i , j = 1 , 2 , . . . , n i, j = 1, 2, ..., n i,j=1,2,...,n,定义乘积 C = A ⋅ B C=A\cdot B C=A⋅B 中的元素 c i j c_{ij} cij 为:
c i j = ∑ k = 1 n a i k ⋅ b k j c_{ij} = \sum_{k=1}^{n} a_{ik} \cdot b_{kj} cij=k=1∑naik⋅bkj
按照矩阵乘法的定义不难写出算法1
// 算法1:定义法
pub fn square_matrix_multiply<T: Default + Clone + AddAssign + Mul<Output = T>>(
a: &Vec<Vec<T>>,
b: &Vec<Vec<T>>,
) -> Vec<Vec<T>>
where
T: Mul<T>,
{
let mut c = vec![vec![T::default(); a.len()]; a.len()];
if a.len() == a[0].len() && b.len() == a[0].len() && a.len() == b.len() {
let n = a.len();
for i in 0..n {
for j in 0..n {
c[i][j] = T::default();
for k in 0..a.len() {
c[i][j] += a[i][k].clone() * b[k][j].clone();
}
}
}
}
return c;
}
由于三重 for
循环的每一重都恰好执行 n
步,因此该算法的时间复杂度为 O ( n 3 ) O(n^3) O(n3)
为此我们尝试改进
为简单起见,假定三个矩阵均为 n × n n \times n n×n 矩阵, n n n 是 2 2 2 的幂,矩阵 A , B , C A, B, C A,B,C 均分解为 4 4 4 个 n / 2 × n / 2 n / 2 \times n / 2 n/2×n/2 的子矩阵,则计算公式等价于
C 11 = A 11 ⋅ B 11 + A 12 ⋅ B 21 C_{11} = A_{11} \cdot B_{11} + A_{12} \cdot B_{21} C11=A11⋅B11+A12⋅B21
C 12 = A 11 ⋅ B 12 + A 12 ⋅ B 22 C_{12} = A_{11} \cdot B_{12} + A_{12} \cdot B_{22} C12=A11⋅B12+A12⋅B22
C 21 = A 21 ⋅ B 11 + A 22 ⋅ B 21 C_{21} = A_{21} \cdot B_{11} + A_{22} \cdot B_{21} C21=A21⋅B11+A22⋅B21
C 22 = A 21 ⋅ B 12 + A 22 ⋅ B 22 C_{22} = A_{21} \cdot B_{12} + A_{22} \cdot B_{22} C22=A21⋅B12+A22⋅B22
我们可以设计一个直接的递归分治算法2
// 算法2:直接分治法
pub fn square_matrix_multiply_recursive<
T: Default + Clone + AddAssign + Mul<Output = T> + Add<Output = T> + Debug,
>(
a: &Vec<Vec<T>>,
b: &Vec<Vec<T>>,
) -> Vec<Vec<T>>
where
T: Mul<T> + Add<T>,
{
let mut n = a.len();
let mut c = vec![vec![T::default(); n]; n];
if n == 1 {
c[0][0] = a[0][0].clone() * b[0][0].clone();
} else {
n /= 2;
let mut a11 = vec![vec![T::default(); n]; n];
let mut a12 = vec![vec![T::default(); n]; n];
let mut a21 = vec![vec![T::default(); n]; n];
let mut a22 = vec![vec![T::default(); n]; n];
let mut b11 = vec![vec![T::default(); n]; n];
let mut b12 = vec![vec![T::default(); n]; n];
let mut b21 = vec![vec![T::default(); n]; n];
let mut b22 = vec![vec![T::default(); n]; n];
partition_four(&a, &mut a11, &mut a12, &mut a21, &mut a22);
partition_four(&b, &mut b11, &mut b12, &mut b21, &mut b22);
let c11 = square_matrix_add(
&square_matrix_multiply_recursive(&a11, &b11),
&square_matrix_multiply_recursive(&a12, &b21),
);
let c12 = square_matrix_add(
&square_matrix_multiply_recursive(&a11, &b12),
&square_matrix_multiply_recursive(&a12, &b22),
);
let c21 = square_matrix_add(
&square_matrix_multiply_recursive(&a21, &b11),
&square_matrix_multiply_recursive(&a22, &b21),
);
let c22 = square_matrix_add(
&square_matrix_multiply_recursive(&a21, &b12),
&square_matrix_multiply_recursive(&a22, &b22),
);
merge_four(&mut c, &c11, &c12, &c21, &c22);
}
c
}
其中 partition_four()
和 merge_four()
是将一个矩阵拆分成4个子矩阵以及将4个子矩阵合成一个矩阵的函数。
// partition_four()
pub fn partition_four<T: Clone>(
a: &Vec<Vec<T>>,
a11: &mut Vec<Vec<T>>,
a12: &mut Vec<Vec<T>>,
a21: &mut Vec<Vec<T>>,
a22: &mut Vec<Vec<T>>,
) {
let n = a.len();
for i in 0..(n / 2) {
for j in 0..(n / 2) {
a11[i][j] = a[i][j].clone();
}
}
for i in 0..(n / 2) {
for j in (n / 2)..n {
a12[i][j - n / 2] = a[i][j].clone();
}
}
for i in (n / 2)..n {
for j in 0..(n / 2) {
a21[i - n / 2][j] = a[i][j].clone();
}
}
for i in (n / 2)..n {
for j in (n / 2)..n {
a22[i - n / 2][j - n / 2] = a[i][j].clone();
}
}
}
// merge_four()
pub fn merge_four<T: Clone>(
a: &mut Vec<Vec<T>>,
a11: &Vec<Vec<T>>,
a12: &Vec<Vec<T>>,
a21: &Vec<Vec<T>>,
a22: &Vec<Vec<T>>,
) {
let n = a.len();
for i in 0..(n / 2) {
for j in 0..(n / 2) {
a[i][j] = a11[i][j].clone();
}
}
for i in 0..(n / 2) {
for j in (n / 2)..n {
a[i][j] = a12[i][j - n / 2].clone();
}
}
for i in (n / 2)..n {
for j in 0..(n / 2) {
a[i][j] = a21[i - n / 2][j].clone();
}
}
for i in (n / 2)..n {
for j in (n / 2)..n {
a[i][j] = a22[i - n / 2][j - n / 2].clone();
}
}
}
其实这部分的分解合并不是必要的,完全可以通过原矩阵的一组行下标和一组列下标来指明一个子矩阵,这里为了方便编码就不做进一步处理了。
对 n = 1 n = 1 n=1 的基本情况,只需进行一次标量乘法,时间复杂度为 O ( 1 ) O(1) O(1)。
当 n > 1 n > 1 n>1 时是递归情况,8次递归调用总时间为 8 T ( n / 2 ) 8T(n/2) 8T(n/2) ,矩阵加法总时间为 Θ ( n 2 ) \Theta(n^2) Θ(n2),因此递归情况的总时间为分解时间、递归调用时间以及矩阵加法时间之和:
T ( n ) = Θ ( 1 ) + 8 T ( n / 2 ) + Θ ( n 2 ) = 8 T ( n / 2 ) + Θ ( n 2 ) T(n) = \Theta(1) + 8T(n / 2) + \Theta(n^2) = 8T(n / 2) + \Theta(n^2) T(n)=Θ(1)+8T(n/2)+Θ(n2)=8T(n/2)+Θ(n2)
如果通过复制元素来实现矩阵分解,正如我们所做的那样,额外开销为 Θ ( n 2 ) \Theta(n^2) Θ(n2),递归式不会发生改变,总运行时间会提高常数倍。
该算法的时间复杂度仍为 O ( n 3 ) O(n^3) O(n3)
Strassen 算法的核心思想是令递归树稍微不那么茂盛一点,只递归进行7次而不是8次。
Strassen 算法的步骤:
按公式将输入矩阵和输出矩阵分解为4个子矩阵,与算法2相同。
创建 10 10 10 个 n / 2 × n / 2 n / 2 \times n /2 n/2×n/2 的矩阵 S 0 , S 1 , . . . , S 9 S_0, S_1, ..., S_9 S0,S1,...,S9,每个矩阵保存步骤1中创建的两个子矩阵的和或差。
用步骤1中创建的子矩阵和步骤2中创建的10个矩阵,递归计算7个矩阵积 P 1 , P 2 , . . . , P 7 P_1, P_2, ..., P_7 P1,P2,...,P7,每个矩都是 n / 2 × n / 2 n / 2 \times n /2 n/2×n/2。
通过 P i P_i Pi 矩阵的不同组合进行加减运算,计算出矩阵 C 的子矩阵。
步骤2:
S 0 = B 12 − B 22 S_0 = B_{12} - B_{22} S0=B12−B22
S 1 = A 11 + A 12 S_1 = A_{11} + A_{12} S1=A11+A12
S 2 = A 21 + A 22 S_2 = A_{21} + A_{22} S2=A21+A22
S 3 = B 21 − B 11 S_3 = B_{21} - B_{11} S3=B21−B11
S 4 = A 11 + A 22 S_4 = A_{11} + A_{22} S4=A11+A22
S 5 = B 11 + B 22 S_5 = B_{11} + B_{22} S5=B11+B22
S 6 = A 12 − A 22 S_6 = A_{12} - A_{22} S6=A12−A22
S 7 = B 21 + B 22 S_7 = B_{21} + B_{22} S7=B21+B22
S 8 = A 11 − A 21 S_8 = A_{11} - A_{21} S8=A11−A21
S 9 = B 11 + B 12 S_9 = B_{11} + B_{12} S9=B11+B12
步骤3:
P 1 = A 11 ⋅ S 0 P_1 = A_{11} \cdot S_0 P1=A11⋅S0
P 2 = S 1 ⋅ B 22 P_2 = S_1 \cdot B_{22} P2=S1⋅B22
P 3 = S 2 ⋅ B 11 P_3 = S_2 \cdot B_{11} P3=S2⋅B11
P 4 = A 22 ⋅ S 3 P_4 = A_{22} \cdot S_3 P4=A22⋅S3
P 5 = S 4 ⋅ S 5 P_5 = S_4 \cdot S_5 P5=S4⋅S5
P 6 = S 6 ⋅ S 7 P_6 = S_6 \cdot S_7 P6=S6⋅S7
P 7 = S 8 ⋅ S 9 P_7 = S_8 \cdot S_9 P7=S8⋅S9
步骤4:
C 11 = P 5 + P 4 − P 2 + P 6 C_{11} = P_5 + P_4 - P_2 + P_6 C11=P5+P4−P2+P6
C 12 = P 2 + P 2 C_{12} = P_2 + P_2 C12=P2+P2
C 21 = P 3 + P 4 C_{21} = P_3 + P_4 C21=P3+P4
C 22 = P 5 + P 1 − P 3 − P 7 C_{22} = P_5 + P_1 - P_3 - P_7 C22=P5+P1−P3−P7
写成代码即为:
// 算法3:Strassen算法
pub fn square_matrix_multiply_strassen<
T: Default
+ Clone
+ AddAssign
+ SubAssign
+ Mul<Output = T>
+ Add<Output = T>
+ Sub<Output = T>
+ Debug,
>(
a: &Vec<Vec<T>>,
b: &Vec<Vec<T>>,
) -> Vec<Vec<T>>
where
T: Mul<T> + Add<T> + Sub<T>,
{
let mut n = a.len();
let mut c = vec![vec![T::default(); n]; n];
if n == 1 {
c[0][0] = a[0][0].clone() * b[0][0].clone();
} else {
n /= 2;
let mut a11 = vec![vec![T::default(); n]; n];
let mut a12 = vec![vec![T::default(); n]; n];
let mut a21 = vec![vec![T::default(); n]; n];
let mut a22 = vec![vec![T::default(); n]; n];
let mut b11 = vec![vec![T::default(); n]; n];
let mut b12 = vec![vec![T::default(); n]; n];
let mut b21 = vec![vec![T::default(); n]; n];
let mut b22 = vec![vec![T::default(); n]; n];
partition_four(&a, &mut a11, &mut a12, &mut a21, &mut a22);
partition_four(&b, &mut b11, &mut b12, &mut b21, &mut b22);
let s0 = square_matrix_sub(&b12, &b22);
let s1 = square_matrix_add(&a11, &a12);
let s2 = square_matrix_add(&a21, &a22);
let s3 = square_matrix_sub(&b21, &b11);
let s4 = square_matrix_add(&a11, &a22);
let s5 = square_matrix_add(&b11, &b22);
let s6 = square_matrix_sub(&a12, &a22);
let s7 = square_matrix_add(&b21, &b22);
let s8 = square_matrix_sub(&a11, &a21);
let s9 = square_matrix_add(&b11, &b12);
let p1 = square_matrix_multiply(&a11, &s0);
let p2 = square_matrix_multiply(&s1, &b22);
let p3 = square_matrix_multiply(&s2, &b11);
let p4 = square_matrix_multiply(&a22, &s3);
let p5 = square_matrix_multiply(&s4, &s5);
let p6 = square_matrix_multiply(&s6, &s7);
let p7 = square_matrix_multiply(&s8, &s9);
let mut c11 = square_matrix_sub(&square_matrix_add(&p5, &p4), &square_matrix_sub(&p2, &p6));
let mut c12 = square_matrix_add(&p1, &p2);
let mut c21 = square_matrix_add(&p3, &p4);
let mut c22 = square_matrix_sub(&square_matrix_add(&p5, &p1), &square_matrix_add(&p3, &p7));
merge_four(&mut c, &mut c11, &mut c12, &mut c21, &mut c22);
}
c
}
时间复杂度 O ( n l g 7 ) O(n^{lg7}) O(nlg7)
Strassen 算法的渐进复杂性低于直接的定义法
编写测试函数测试正确性:
#[cfg(test)]
mod tests{
#[test]
fn test_recursive() {
// 4 * 4 matrix
let n = 4;
// generate matrix a and b
let mut a = vec![vec![0i64; n]; n];
random_square_matrix_range(&mut a, n, (1, 10));
let mut b = vec![vec![0i64; n]; n];
random_square_matrix_range(&mut b, n, (1, 10));
// calcu c by algo2
let c = square_matrix_multiply_recursive(&a, &b);
assert_eq!(c, square_matrix_multiply(&a, &b));
// calcu c by algo3
let c = square_matrix_multiply_strassen(&a, &b);
assert_eq!(c, square_matrix_multiply(&a, &b));
println!("{:?}", c);
}
}
我们可以看到Strassen算法确实生成了正确的矩阵乘积。