CCF认证202305-2矩阵运算(C语言版)

试题编号: 202305-2
试题名称: 矩阵运算
时间限制: 5.0s
内存限制: 512.0MB
问题描述:

题目背景

 \mathrm{Softmax}(\frac{\mathbf{Q} \times \mathbf{K}^{T}}{\sqrt{d}}) \times \mathbf{V}是 Transformer 中注意力模块的核心算式,其中 \mathbf{Q}\mathbf{K} 和 \mathbf{V} 均是  行、列的矩阵, \mathbf{K}^{T}表示矩阵 \mathbf{K} 的转置,\times 表示矩阵乘法。

问题描述

为了方便计算,顿顿同学将 \mathrm{Softmax} 简化为了点乘一个大小为 n 的一维向量 \mathbf{W}
        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        \left(\mathbf{W} \cdot (\mathbf{Q} \times \mathbf{K}^{T})\right) \times \mathbf{V}
点乘即对应位相乘,记 \mathbf{W}^{(i)} 为向量 \mathbf{W} 的第 i 个元素,即将 (\mathbf{Q} \times \mathbf{K}^{T}) 第 i 行中的每个元素都与\mathbf{W}^{(i)}  相乘。

现给出矩阵\mathbf{Q} 、\mathbf{K} 和 \mathbf{V} 和向量\mathbf{W} ,试计算顿顿按简化的算式计算的结果。

输入格式

从标准输入读入数据。

输入的第一行包含空格分隔的两个正整数 n 和 d,表示矩阵的大小。

接下来依次输入矩阵\mathbf{Q} 、\mathbf{K} 和\mathbf{V} 。每个矩阵输入n  行,每行包含空格分隔的 d 个整数,其中第  I行的第 j 个数对应矩阵的第 I​​​​​​​行、第 j 列。

最后一行输入 n 个整数,表示向量\mathbf{W} 。

输出格式

输出到标准输出中。

输出共 n 行,每行包含空格分隔的 d 个整数,表示计算的结果。

样例输入

3 2
1 2
3 4
5 6
10 10
-20 -20
30 30
6 5
4 3
2 1
4 0 -5

Data

样例输出

480 240
0 0
-2200 -1100

Data

子任务

 的测试数据满足:n \le 100 且 d \le 10 ;输入矩阵、向量中的元素均为整数,且绝对值均不超过30 。

全部的测试数据满足:n \le 10^4 且 d \le 20;输入矩阵、向量中的元素均为整数,且绝对值均不超过1000 。

提示

请谨慎评估矩阵乘法运算后的数值范围,并使用适当数据类型存储矩阵中的整数。

代码一:(强制类型转换)

#include "stdio.h"
long long a[30][30];
long long f[10010][30];
int main(){
    int n,d;
    int Q[10010][30],K[10010][30],V[10010][30];
    int W[10010];
    scanf("%d %d",&n,&d);
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            scanf("%d",&Q[i][j]);
        }
    }
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            scanf("%d",&K[i][j]);
        }
    }
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            scanf("%d",&V[i][j]);
        }
    }
    for (int i = 0; i < n; ++i) {
        scanf("%d",&W[i]);
    }
    for (int i = 0; i < d; ++i) {
        for (int j = 0; j < d; ++j) {
            for (int k = 0; k < n; ++k) {
                a[i][j]+=(K[k][i]*V[k][j]);
            }
        }
    }
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            for (int k = 0; k < d; ++k) {
                f[i][j]+=(a[k][j]*Q[i][k]);
            }
        }
    }
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            f[i][j]*=(long long)W[i];
        }
    }
    for (int i = 0; i < n; ++i) {
        printf("%lld",f[i][0]);
        for (int j = 1; j < d; ++j) {
            printf(" %lld",f[i][j]);
        }
        printf("\n");
    }
    return 0;
}

代码二:

#include "stdio.h"
long long a[30][30];
long long f[10010][30];
int main(){
    int n,d;
    int Q[10010][30],K[10010][30],V[10010][30];
    long long W[10010];
    scanf("%d %d",&n,&d);
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            scanf("%d",&Q[i][j]);
        }
    }
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            scanf("%d",&K[i][j]);
        }
    }
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            scanf("%d",&V[i][j]);
        }
    }
    for (int i = 0; i < n; ++i) {
        scanf("%lld",&W[i]);
    }
    for (int i = 0; i < d; ++i) {
        for (int j = 0; j < d; ++j) {
            for (int k = 0; k < n; ++k) {
                a[i][j]+=(K[k][i]*V[k][j]);
            }
        }
    }
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            for (int k = 0; k < d; ++k) {
                f[i][j]+=(a[k][j]*Q[i][k]);
            }
        }
    }
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            f[i][j]*=W[i];
        }
    }
    for (int i = 0; i < n; ++i) {
        printf("%lld",f[i][0]);
        for (int j = 1; j < d; ++j) {
            printf(" %lld",f[i][j]);
        }
        printf("\n");
    }
    return 0;
}

欢迎评论区留言

你可能感兴趣的:(算法)