cuBLAS矩阵乘法

cuBLAS是cuda封装好的一个数学库,头文件为,其中的矩阵乘法函数是我们做深度学习绕不开的函数,

下面是一个常用的函数,后面又有扩展,但是核心函数使用逻辑一样。

#define cublasSgemm cublasSgemm_v2
CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemm_v2
(
    cublasHandle_t handle,
    cublasOperation_t transa, cublasOperation_t transb,
    int m, int n, int k,
    const float *alpha,
    const float *A, int lda,
    const float *B, int ldb,
    const float *beta,
    float *C, int ldc
);
下面对容易犯错的函数参数进行总结,估计看完后还是一脸蒙逼,所以看看下面的
实战操作就可以明白每个参数真实的含义:

C = alpha*A*B + beta*C(看起来没啥问题,使用起来呵呵哒)

cublasHandle_t handle:调用 cuBLAS 库时的句柄
cublasOperation_t transa, 是否对A转置
cublasOperation_t transb, 是否对B转置
int m, int n, int k, mnk表示矩阵计算时候的维度
const float *alpha,
const float *A, //矩阵A
int lda,  //按列读取的长度
const float *B, 
int ldb, 
const float *beta,
float *C, 
int ldc ldc:按列取的个数

           

     一、本来以为的操作

cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, 2 ,4, 3, &alpha, *A, 2, *B, 3, &beta, *C, 2);

cuBLAS矩阵乘法_第1张图片

cublas和mkl里面的cblas的区别在于,这个cublas不能按行存储,只能选择按列存储,而c++/c的数据存储格式都是按行存储的,但是他这个API不管,人家就是按你是列优先的储存格式,所以在实际使用的时候会一直出错,为此需要理清楚这个函数的每个

参数的意思,避免一直出现问题。

首先来看int m,int n, int k这三个参数,这三个参数依次代表的是A1的行,B1的列,A1的列(B1的行)

lda: 这个参数就比较魔幻,A(c/c++是按行存的)进来的数据【1,2,3,4,5,6】按照每lda一列排好,得到A1;

ldb:B同上

ldc:按列从C1中取数,每次取ldc的长度,只能大于等于C1的行数,当大于C1时,c1单列由于数字不够会补零。

可以看到呀,这个函数一旦进入,其实就给数据调换了位置,因为一个协议是按列,一个协议是按行,所以c/c++在调用

这个函数时就有问题,而CUBLAS_OP_T (转置)和CUBLS_OP_N(不转置)利用这两个参数是不是就可以解决这个问题?

二、思考后的操作

cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_T, 2, 4, 3, &alpha, d_A, 3, d_B, 4, &beta, d_C, 2);

cuBLAS矩阵乘法_第2张图片

c1---行优先--->输出c2[[38 83 44 98],[50 113 56 128]]

所以结果还是和预期结果不一致,为此,来了一个骚操作:

三、正确的操作

cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, 4, 2, 3, &alpha, *B, 4, *A, 3 , &beta, *C, 4);

cuBLAS矩阵乘法_第3张图片

结果正确,利用的是 A B = (BT AT)T,真是天秀。

 

附上代码:没有释放内存,随便做实验用的

#include 
#include 
#include 
#include 


using namespace std;


//cuBLAS代码
int main()
{
const float alpha = 1.0f;
const float beta  = 0.0f;
int m = 2, n = 4, k = 3;


float A[6] = {1,2,3,4,5,6};
float B[12] = {1,2,3,4,5,6,7,8,9,10,11,12};
float *C;

float* d_A,*d_B, *d_C;

C = (float*)malloc(sizeof(float)*8);
cudaMalloc((void**)&d_A, sizeof(float)*6);
cudaMalloc((void**)&d_B, sizeof(float)*12);
cudaMalloc((void**)&d_C, sizeof(float)*8);


cudaMemcpy(d_A, A, 6*sizeof(float),  cudaMemcpyHostToDevice);
cudaMemcpy(d_B, B, 12*sizeof(float), cudaMemcpyHostToDevice);

cublasHandle_t handle;
cublasCreate(&handle);
cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, k, &alpha, d_A, 3, d_B, 4, &beta, d_C, 2);
cublasDestroy(handle);

cudaMemcpy(C, d_C, 8*sizeof(float), cudaMemcpyDeviceToHost);

for(int i=0; i<8; i++)
{
   cout<

参考链接:https://www.cnblogs.com/cuancuancuanhao/p/7763256.html(有些结果有问题)

 

你可能感兴趣的:(CUDA编程)