方阵乘法,矩阵乘法,Strassen 算法——算法作业 2.3,EOJ 1050

方阵乘法,矩阵乘法,Strassen 算法——算法作业 2.3,EOJ 1050

方阵相乘
Time Limit:1000MS Memory Limit:30000KB

Description

实现两个n*n 方阵相乘的Strassen 算法,这里假设 n 为 2 的方幂。

Input

第一行为一个正整数N,表示有几组测试数据。
每组测试数据的第一行为一个正整数n(1<=n<=100),n为2的方幂,表示方阵n*n
接下去的n行表示第一个方阵,每行有n个整数,用空格分开。
再接下去的n行表示第二个方阵,每行有n个整数,用空格分开。

Output

对于每组测试出据,输出n行,每行有n个整数,用空格分开,不能有多余的空格。

Sample Input

1
2
1 2
3 4
5 6
7 8

Sample Output

19 22
43 50



朴素的矩阵乘法
 1 #include  < iostream >
 2 #include  < cstdio >
 3  
 4 using   namespace  std;
 5  
 6 const   int  L  =   103 ;
 7  
 8 int  a[ L ][ L ], b[ L ][ L ], c[ L ][ L ];
 9  
10 int  main()  {
11        int td, n, i, j, k, tmp;
12        scanf( "%d"&td );
13        while ( td-- ) {
14                scanf( "%d"&n );
15                for ( i = 0; i < n; ++i )
16                        for ( j = 0; j < n; ++j )
17                                scanf( "%d"&a[ i ][ j ] );
18                for ( i = 0; i < n; ++i )
19                        for ( j = 0; j < n; ++j )
20                                scanf( "%d"&b[ i ][ j ] );
21                for ( i = 0; i < n; ++i )
22                        for ( j = 0; j < n; ++j ) {
23                                tmp = 0;
24                                for ( k = 0; k < n; ++k )
25                                        tmp += a[ i ][ k ] * b[ k ][ j ];
26                                c[ i ][ j ] = tmp;
27                        }

28                for ( i = 0; i < n; ++i ) {
29                        printf( "%d", c[ i ][ 0 ] );
30                        for ( j = 1; j < n; ++j )
31                                printf( " %d", c[ i ][ j ] );
32                        printf( "\n" );
33                }

34        }

35        return 0;
36}

37


Strassen 算法

  1 #include  < iostream >
  2 #include  < cstdio >
  3  
  4 using   namespace  std;
  5  
  6 #define   L     102
  7 #define   LIM   400
  8  
  9 typedef  int  Mat[ L ][ L ];
 10  
 11 Mat buf[ LIM ];
 12 int  top;
 13  
 14 void  input(  int  a[][L],  int  n )  {
 15        int i, j;
 16        for ( i = 1; i <= n; ++i ) {
 17                for ( j = 1; j <= n; ++j ) {
 18                        scanf( "%d"&a[ i ][ j ] );
 19                }

 20        }

 21}

 22  
 23 void  output(  int  c[][L],  int  n )  {
 24        int i, j;
 25        for ( i = 1; i <= n; ++i ) {
 26                for ( j = 1; j < n; ++j ) {
 27                        printf( "%d ", c[ i ][ j ] );
 28                }

 29                printf( "%d\n", c[ i ][ j ] );
 30        }

 31}

 32  
 33 void   get int  a[][L],  int  a11[][L],  int  a12[][L],  int  a21[][L],  int  a22[][L],  int  n )  {
 34        int i, j;
 35        for ( i = 1; i <= n; ++i ) {
 36                for ( j = 1; j <= n; ++j ) {
 37                        a11[ i ][ j ] = a[ i     ][ j     ];
 38                        a12[ i ][ j ] = a[ i     ][ j + n ];
 39                        a21[ i ][ j ] = a[ i + n ][ j     ];
 40                        a22[ i ][ j ] = a[ i + n ][ j + n ];
 41                }

 42        }

 43}

 44  
 45 void  put(  int  a[][L],  int  a11[][L],  int  a12[][L],  int  a21[][L],  int  a22[][L],  int  n )  {
 46        int i, j;
 47        for ( i = 1; i <= n; ++i ) {
 48                for ( j = 1; j <= n; ++j ) {
 49                        a[ i     ][ j     ] = a11[ i ][ j ];
 50                        a[ i     ][ j + n ] = a12[ i ][ j ];
 51                        a[ i + n ][ j     ] = a21[ i ][ j ];
 52                        a[ i + n ][ j + n ] = a22[ i ][ j ];
 53                }

 54        }

 55}

 56  
 57 void  add(  int  c[][L],  int  a[][L],  int  b[][L],  int  n )  {
 58        int i, j;
 59        for ( i = 1; i <= n; ++i ) {
 60                for ( j = 1; j <= n; ++j ) {
 61                        c[ i ][ j ] = a[ i ][ j ] + b[ i ][ j ];
 62                }

 63        }

 64}

 65  
 66 void  sub(  int  c[][L],  int  a[][L],  int  b[][L],  int  n )  {
 67        int i, j;
 68        for ( i = 1; i <= n; ++i ) {
 69                for ( j = 1; j <= n; ++j ) {
 70                        c[ i ][ j ] = a[ i ][ j ] - b[ i ][ j ];
 71                }

 72        }

 73}

 74  
 75 void  mul(  int  c[][L],  int  a[][L],  int  b[][L],  int  n )  {
 76#define  ADD(m)  Mat &m = buf[ top++ ]
 77#define  ADDS(a)  ADD(a##11); ADD(a##12); ADD(a##21); ADD(a##22)
 78#define  ENTER  ADDS(a); ADDS(b); ADDS(c); ADD(d1); ADD(d2); ADD(d3); ADD(d4); ADD(d5); ADD(d6); ADD(d7); ADD(t1); ADD(t2)
 79#define  LEAVE  top -= 21
 80 
 81        ENTER;
 82 
 83        if ( top >= LIM ) {
 84                // for debug
 85                fprintf( stderr, "buf overflow!!" );
 86                LEAVE;
 87                return;
 88        }

 89 
 90 
 91        if ( n < 1 ) {
 92                LEAVE;
 93                return;
 94        }

 95        if ( n == 1 ) {
 96                c[ 1 ][ 1 ] = a[ 1 ][ 1 ] * b[ 1 ][ 1 ];
 97                LEAVE;
 98                return;
 99        }

100        n >>= 1;
101        get( a, a11, a12, a21, a22, n );
102        get( b, b11, b12, b21, b22, n );
103 
104        add( t1, a11, a22, n );
105        add( t2, b11, b22, n );
106        mul( d1, t1, t2, n );
107 
108        add( t1, a21, a22, n );
109        mul( d2, t1, b11, n );
110 
111        sub( t2, b12, b22, n );
112        mul( d3, a11, t2, n );
113 
114        sub( t2, b21, b11, n );
115        mul( d4, a22, t2, n );
116 
117        add( t1, a11, a12, n );
118        mul( d5, t1, b22, n );
119 
120        sub( t1, a21, a11, n );
121        add( t2, b11, b12, n );
122        mul( d6, t1, t2, n );
123 
124        sub( t1, a12, a22, n );
125        add( t2, b21, b22, n );
126        mul( d7, t1, t2, n );
127 
128        add( t1, d1, d4, n );
129        sub( t2, d5, d7, n );
130        sub( c11, t1, t2, n );
131 
132        add( c12, d3, d5, n );
133 
134        add( c21, d2, d4, n );
135 
136        add( t1, d1, d3, n );
137        sub( t2, d2, d6, n );
138        sub( c22, t1, t2, n );
139 
140        put( c, c11, c12, c21, c22, n );
141 
142        LEAVE;
143}

144  
145 int  main()  {
146        int td, n, a[ L ][ L ], b[ L ][ L ], c[ L ][ L ];
147        scanf( "%d"&td );
148        while ( td-- > 0 ) {
149                top = 0;
150                scanf( "%d"&n );
151                input( a, n );
152                input( b, n );
153                mul( c, a, b, n );
154                output( c, n );
155        }

156        return 0;
157}

158


我的实现有点丑。。。

你可能感兴趣的:(方阵乘法,矩阵乘法,Strassen 算法——算法作业 2.3,EOJ 1050)