POJ 3213 PM3 矩阵乘法

PM3

USTC has recently developed the Parallel Matrix Multiplication Machine – PM3, which is used for very large matrix multiplication.

Given two matrices A and B, where A is an N × P matrix and B is a P × M matrix, PM3 can compute matrix C = AB in O(P(N + P + M)) time. However the developers of PM3 soon discovered a small problem: there is a small chance that PM3 makes a mistake, and whenever a mistake occurs, the resultant matrix C will contain exactly one incorrect element.

The developers come up with a natural remedy. After PM3 gives the matrix C, they check and correct it. They think it is a simple task, because there will be at most one incorrect element.

So you are to write a program to check and correct the result computed by PM3.

Input

The first line of the input three integers N, P and M (0 < N, P, M ≤ 1,000), which indicate the dimensions of A and B. Then follow N lines with P integers each, giving the elements of A in row-major order. After that the elements of B and C are given in the same manner.

Elements of A and B are bounded by 1,000 in absolute values which those of C are bounded by 2,000,000,000.

Output

If C contains no incorrect element, print “Yes”. Otherwise print “No” followed by two more lines, with two integers r and c on the first one, and another integer v on the second one, which indicates the element of C at row r, column c should be corrected to v.

Sample Input

2 3 2
1 2 -1
3 -1 0
-1 0
0 2
1 3
-2 -1
-3 -2
Sample Output

No
1 2
1

题意:给矩阵A、B、C,问A*B是否完全等于C,等于输出C,不等于输出坐标和那个坐标给的值,而且只会有一个地方有错误。

简单暴力直接乘不用long long就可以过,但是有一种简化的方法。
给A(2,3),B(3,2)可以得到一个C(2,2)的矩阵。
写出C11和C12。 ///字母后的数字是下标
C11 = A11*B11 + A12*B21+ A13*B31;
C12 = A11*B12 + A12*B22+ A13*B32;
两式相加可得
C11+C12 = B11(A11+A21) + B12(A12+A22) + B13(A13+A23);

可以看得出来,C的每一列的和等于B的每一列的每个元素与对应的A的每一列的和的乘积。
这样只需要根据C每一列的和就可以判断是哪一列出了问题,然后在出错的这一列进行矩阵乘法找出错误的元素并输出就行了。

CODE

#include"stdio.h"
#include"algorithm"
#include"iostream"
#include"string.h"
#define maxn 1000+10
using namespace std;

int n,p,m;
int a[maxn][maxn];
int b[maxn][maxn];
int c[maxn][maxn];
int A[maxn];  ///a的每列和
int C[maxn];  ///c的每列和

bool check(int r,int l)   ///对应的那一列进行矩阵乘法
{
    int t = 0;
    for(int i = 1;i <= p;i++)
    {
        t += a[r][i]*b[i][l];
    }
    if(t != c[r][l])
    {
        printf("No\n%d %d\n%d\n",r,l,t);
        return true;
    }
    return false;
}
int main(void)
{
    while(scanf("%d%d%d",&n,&p,&m) !=EOF)
    {
        memset(A,0,sizeof A);
        memset(B,0,sizeof B);
        memset(C,0,sizeof C);
        for(int i = 1;i <= n;i++)
            for(int j = 1;j <= p;j++)
            {
                scanf("%d",&a[i][j]);
                A[j] += a[i][j];   ///A的每列和
            }
        for(int i = 1;i <= p;i++)
        {
            for(int j = 1;j <= m;j++)
            {
                scanf("%d",&b[i][j]);
            }
        }
        for(int i = 1;i <= n;i++)
        {
            for(int j = 1;j <= m;j++)
            {
                scanf("%d",&c[i][j]);
                C[j] += c[i][j];   ///C的每列和
            }
        }
        int flag = 1;   ///判断是否有错误
        for(int i = 1;i <= m;i++)  ///枚举B的每一列
        {
            int sum = 0;
            for(int j = 1;j <= p;j++)
                sum += b[j][i]*A[j];
            if(sum != C[i]) ///第i列出错
            {
                flag = 0;
                for(int j = 1;j <= n;j++)
                {
                    if(check(j,i))
                        break;
                }
            }
        }
        if(flag)
            printf("Yes\n");
    }
    return 0;
}

你可能感兴趣的:(poj,矩阵乘法)