OpenBLAS矩阵乘法源码结构分析

用于记录阅读分析OpenBLAS源代码的各种知识点,防止遗忘。这里主要记录OpenBLAS的代码结构,因为确实比较复杂,直接看源代码很可能比较蒙比,如果知道其结构,看起来就比较轻松了。至于OpenBLAS矩阵乘法的算法,这篇不涉及,我会在另一篇文章中简单(瞎jb)分析。

OpenBLAS代码总体上可以分成三个层次:

1.接口层
在OpenBLAS接口层中,运算又分为三个类型,分别是level1到3,其中level3对应矩阵和矩阵的运算,level2和level1依次维度越来越低。不过这些level1~3都是BLAS内部计算时使用的接口(源代码基本在driver/level下),对外界用户的接口是不涉及这个概念的(对外接口基本都在interface文件夹下):
OpenBLAS矩阵乘法源码结构分析_第1张图片
每一个源代码文件对应一种操作,如这里的gemm指的是普通矩阵乘法(General Matrix Multiplication)而gemv指的是普通矩阵向量乘法(General Matrix Vector)。打开gemm.c可以大致观察一下其源代码(大幅度阉割版):

//后边函数体中要使用的函数表,是计算矩阵乘法的核心函数,有这么多不同的函数指针,是区分了大量特殊情况,如GEMM_NN是两个普通矩阵相乘,而GEMM_TN说明第一个矩阵是转置过的,而带THREAD标签的是多线程实现的核心函数,其执行效率是单核执行的若干倍。
static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = {
#ifndef GEMM3M
  GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
  GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
  GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR,
  GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC,
#if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
  GEMM_THREAD_NN, GEMM_THREAD_TN, GEMM_THREAD_RN, GEMM_THREAD_CN,
  GEMM_THREAD_NT, GEMM_THREAD_TT, GEMM_THREAD_RT, GEMM_THREAD_CT,
  GEMM_THREAD_NR, GEMM_THREAD_TR, GEMM_THREAD_RR, GEMM_THREAD_CR,
  GEMM_THREAD_NC, GEMM_THREAD_TC, GEMM_THREAD_RC, GEMM_THREAD_CC,
#endif
#else
  GEMM3M_NN, GEMM3M_TN, GEMM3M_RN, GEMM3M_CN,
  GEMM3M_NT, GEMM3M_TT, GEMM3M_RT, GEMM3M_CT,
  GEMM3M_NR, GEMM3M_TR, GEMM3M_RR, GEMM3M_CR,
  GEMM3M_NC, GEMM3M_TC, GEMM3M_RC, GEMM3M_CC,
#if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
  GEMM3M_THREAD_NN, GEMM3M_THREAD_TN, GEMM3M_THREAD_RN, GEMM3M_THREAD_CN,
  GEMM3M_THREAD_NT, GEMM3M_THREAD_TT, GEMM3M_THREAD_RT, GEMM3M_THREAD_CT,
  GEMM3M_THREAD_NR, GEMM3M_THREAD_TR, GEMM3M_THREAD_RR, GEMM3M_THREAD_CR,
  GEMM3M_THREAD_NC, GEMM3M_THREAD_TC, GEMM3M_THREAD_RC, GEMM3M_THREAD_CC,
#endif
#endif
};

//gemm的函数体
void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB,
       blasint m, blasint n, blasint k,
#ifndef COMPLEX
       FLOAT alpha,
#else
       FLOAT *alpha,
#endif
       FLOAT *a, blasint lda,
       FLOAT *b, blasint ldb,
#ifndef COMPLEX
       FLOAT beta,
#else
       FLOAT *beta,
#endif
       FLOAT *c, blasint ldc) {
       //大量代码实现,主要是对输入矩阵格式进行操作,并且选择正确分支,最后调用上边函数表中的一个函数完成运算
                                。
                                。
                                。
                                。
        //调用函数表中的一个核心计算函数
        (gemm[(transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0);
       }

首先比较蒙比的是,其函数名字为CNAME而不是gemm,而且其函数体中有大量ifdef的分支,这是因为OpenBLAS采用了类似模版编程的办法,因为大量操作涉及相同的接口结构,所以OpenBLAS使用同一个源代码文件gemm.c来生成各种不同的矩阵相乘操作的真正代码文件。在使用cmake生成工程之后(生成之前是没有的),我们可以在interface找到一个文件sgemm.c,代码很简洁:

#define ASMNAME _sgemm
#define ASMFNAME _sgemm_
#define NAME sgemm_
#define CNAME sgemm
#define CHAR_NAME "sgemm_"
#define CHAR_CNAME "sgemm"
#include "PATH_TO_YOUR_OPENBLAS/OpenBLAS-develop/interface/gemm.c"

可以看到,这个sgemm的接口define了CNAME,且include的了上边的gemm.c,这样gemm.c的源代码就会被粘贴进来,而CNAME会被替换成这个接口真正的名字,形成一个有效的编译单元。同理,interface的zgemm,cgemm等等都会include上边的gemm.c,生成一个完整的函数实现。这种写法就省去了大量冗余代码,因为无论zgemm、sgemm、cgemm等等他们都是矩阵乘法,大致实现相同,只是一些细节不同,所以不需要每个函数都从头写。至于zgemm,sgemm这些接口有啥区别,这里有比较详细的说明,大致为C代表复数计算,Z代表双精度复数,而S代表单精度常数(链接中说是半精度,实际上OpenBLAS中是单精度),则zgemm是双精度复数矩阵乘法,他们通过这里的CNAME来区分名字,同时通过很多其他的预编译宏来区分一些细微的实现细节,因为没有详细研究过复数矩阵乘法的源代码,就不多赘述那些分支的区别了。

2.核心函数层
之后可以深入到gemm中函数表里的核心计算函数了。我们挑选最简单的GEMM_NN,两个普通矩阵相乘来看。GEMM_NN这个预编译宏最后指向哪个函数,是和很多其他预编译宏相关的(对预编译宏的谜の热爱),比如define了COMPLEX,表示复数矩阵乘法时,GEMM_NN就指向qgemm_nn函数,而如果是实数矩阵乘法,就会指向sgemm_nn函数。核心函数层依然是那种模版编程的思路,直接在VS里边是索引不到sgemm_nn的实现的,因为sgemm_nn的实现代码如下:

#define NN
#define ASMNAME _sgemm_nn
#define ASMFNAME _sgemm_nn_
#define NAME sgemm_nn_
#define CNAME sgemm_nn
#define CHAR_NAME "sgemm_nn_"
#define CHAR_CNAME "sgemm_nn"
#include "PATH_TO_YOUR_OPENBLAS/OpenBLAS-develop/driver/level3/gemm.c"

和上边sgemm.c的代码思路相同,这个源代码中仅仅定义函数的名字,真正的源代码模版在gemm.c中,不过注意了,这里的gemm.c和上边提到的gemm.c不是同一份代码,这个gemm.c在level3下,其内容是真正的矩阵乘法实现,上边的gemm.c干的事情主要是一些准备和分支控制工作(他是暴漏给用户的接口函数,这个是内部的干活函数)。level3下的gemm.c:

#include 
#include "common.h"

#undef  TIMING

#ifdef PARAMTEST
#undef GEMM_P
#undef GEMM_Q
#undef GEMM_R

#define GEMM_P  (args -> gemm_p)
#define GEMM_Q  (args -> gemm_q)
#define GEMM_R  (args -> gemm_r)
#endif

#if 0
#undef GEMM_P
#undef GEMM_Q

#define GEMM_P 504
#define GEMM_Q 128
#endif

#ifdef THREADED_LEVEL3
#include "level3_thread.c"
#else
#include "level3.c"
#endif

又是故技重施,这里使用预编译宏控制了一些变量,而有效代码都在level3.c中。level3代码(大量精简版):

//这里大量的define比较吓人,但是除去分支控制,主要内容是定义了三种操作(函数),BETA_OPERATION,ICOPY_OPERATION,OCOPY_OPERATION,KERNEL_OPERATION。BETA_OPERATION用于给整个矩阵乘以一个系数,因为gemm计算的是c=alpha*a*b+beta*c,这个BETA_OPERATION就是用于乘以那个beta的。两个COPY函数用于矩阵的拷贝,这个和OpenBLAS的实现算法有关,在算法分析中在详细说明。而KERNEL_OPERATION就是核心函数中的核心函数了,他是真正做乘法和乘累加的地方,整个矩阵乘法就是他算完的了。
#ifndef BETA_OPERATION
#if !defined(XDOUBLE) || !defined(QUAD_PRECISION)
#ifndef COMPLEX
#define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \
    GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \
          BETA[0], NULL, 0, NULL, 0, \
          (FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC)
#else
                                。
                                。
                                。
                                。


#ifndef ICOPY_OPERATION
#if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \
    defined(RN) || defined(RT) || defined(RC) || defined(RR)
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
#else
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
#endif
#endif

#ifndef OCOPY_OPERATION
#if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \
    defined(NR) || defined(TR) || defined(CR) || defined(RR)
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
#else
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
#endif
#endif
                                。
                                。
                                。
                                。



#ifndef KERNEL_OPERATION
#if !defined(XDOUBLE) || !defined(QUAD_PRECISION)
#ifndef COMPLEX
#define KERNEL_OPERATION(M, N, K, ALPHA, SA, SB, C, LDC, X, Y) \
    KERNEL_FUNC(M, N, K, ALPHA[0], SA, SB, (FLOAT *)(C) + ((X) + (Y) * LDC) * COMPSIZE, LDC)
#else
#define KERNEL_OPERATION(M, N, K, ALPHA, SA, SB, C, LDC, X, Y) \
    KERNEL_FUNC(M, N, K, ALPHA[0], ALPHA[1], SA, SB, (FLOAT *)(C) + ((X) + (Y) * LDC) * COMPSIZE, LDC)
#endif
#else
#define KERNEL_OPERATION(M, N, K, ALPHA, SA, SB, C, LDC, X, Y) \
    KERNEL_FUNC(M, N, K, ALPHA, SA, SB, (FLOAT *)(C) + ((X) + (Y) * LDC) * COMPSIZE, LDC)
#endif
#endif
                                。
                                。
                                。
                                。



//sgemm_nn的函数体,应该说是所有gemm的模版函数体,负责将矩阵乘法完成
int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
          XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){
                                。
                                。
                                。
                                。
//对矩阵c调用beta操作(c=alpha*a*b+beta*c)
BETA_OPERATION(m_from, m_to, n_from, n_to, beta, c, ldc);
                                。
                                。
                                。
                                。
//大量复杂的for循环控制,因为这一篇不详细分析算法细节,这里就简化了,源代码中是多个不同层级for嵌套,用于分片
for(。。。。。。)
{
    //进行一次矩阵拷贝,大致意义是从两个操作数矩阵之一中切出一小块,而这一小块会被放入L1Cache中反复使用,提高缓存命中率以达到加速效果
    ICOPY_OPERATION(min_l, min_i, a, lda, ls, m_from, sa);
    //同上
    OCOPY_OPERATION(min_l, min_jj, b, ldb, ls, jjs, sb + min_l * (jjs - js) * COMPSIZE * l1stride);
    //对切出的小块进行矩阵乘法
    KERNEL_OPERATION(min_i, min_jj, min_l, alpha, sa, sb + min_l * (jjs - js)  * COMPSIZE * l1stride, c, ldc, m_from, jjs);
}

可见sgemm_nn已经完成了两个不转置矩阵的乘法了,但是sgemm_nn中的BETA_OPERATION,ICOPY_OPERATION,OCOPY_OPERATION,KERNEL_OPERATION这三个最热点的关键函数又是怎么实现的?你会发现VS f12是找不到这三个函数的实现的,因为他们又是使用模版的思路编写的。。。如这里的BETA_OPERATION最后指向sgemm_beta,我们可以找到sgemm_beta.c:

#define ASMNAME _sgemm_beta
#define ASMFNAME _sgemm_beta_
#define NAME sgemm_beta_
#define CNAME sgemm_beta
#define CHAR_NAME "sgemm_beta_"
#define CHAR_CNAME "sgemm_beta"
#include "PATH_TO_YOUR_OPENBLAS/OpenBLAS-develop/kernel/x86/../generic/gemm_beta.c"

可见实现代码在kernel文件夹下的gemm_beta.c中。
不过这里有一点值得一提,就是使用VS编译OpenBLAS时,这三个最核心最热点的OPERATION都是像上边sgemm_beta.c一样,最后是使用c代码实现的,其执行效率非常低,而使用官方推荐mingw或者在其他平台,如arm上使用cmake编译时,这些OPERATION会使用更高效的汇编.S代码实现(这个不是编译器的锅,似乎是makefile方面的问题)。我们可以进入kernel文件夹,发现大量kernel代码:
OpenBLAS矩阵乘法源码结构分析_第2张图片
可以看到对于不同平台,OpenBLAS都准备的大量的kernel代码,这些代码基本都是平台特定的高效汇编代码,我们进入arm64文件夹中,查看KERNEL.ARMV8文件:

                。
                。
                。
                。
SGEMMKERNEL    =  sgemm_kernel_4x4.S
SGEMMONCOPY    =  ../generic/gemm_ncopy_4.c
SGEMMOTCOPY    =  ../generic/gemm_tcopy_4.c
SGEMMONCOPYOBJ =  sgemm_oncopy.o
SGEMMOTCOPYOBJ =  sgemm_otcopy.o

DGEMMKERNEL    =  ../generic/gemmkernel_2x2.c
DGEMMONCOPY    = ../generic/gemm_ncopy_2.c
DGEMMOTCOPY    = ../generic/gemm_tcopy_2.c
DGEMMONCOPYOBJ = dgemm_oncopy.o
DGEMMOTCOPYOBJ = dgemm_otcopy.o

CGEMMKERNEL    = ../generic/zgemmkernel_2x2.c
CGEMMONCOPY    = ../generic/zgemm_ncopy_2.c
CGEMMOTCOPY    = ../generic/zgemm_tcopy_2.c
CGEMMONCOPYOBJ =  cgemm_oncopy.o
CGEMMOTCOPYOBJ =  cgemm_otcopy.o

ZGEMMKERNEL    = ../generic/zgemmkernel_2x2.c
ZGEMMONCOPY    = ../generic/zgemm_ncopy_2.c
ZGEMMOTCOPY    = ../generic/zgemm_tcopy_2.c
ZGEMMONCOPYOBJ =  zgemm_oncopy.o
ZGEMMOTCOPYOBJ =  zgemm_otcopy.o
                。
                。
                。
                。

可以看到对于特定平台的特定CPU构架,makefile最后会控制这些最热点函数的实现文件,如果这里sgemm_kernel(KERNEL_OPERATION)指向了.S文件,最后编译时就会编译对应的.S文件,然后链接进入库中,而不像在windows环境下使用vs默认的全部使用.c文件实现(所以不使用官方推荐的mingw在windows平台编译的话,编译出来OpenBLAS性能会慢若干倍)。当然,对于一些性能要求不高,或者编写实在麻烦的操作,比如ICOPY_OPERATION,OCOPY_OPERATION,arm64位平台还是指向了generic文件夹,使用了通用的.c实现。
至于平台特定的.S代码如何链接进入最后的库中,我们可以简单查看下sgemm_kernel_4x4.S:

                。
                。
                。
                。
                。

    PROLOGUE

    .align 5
    add sp, sp, #-(11 * 16)
    stp d8, d9, [sp, #(0 * 16)]
    stp d10, d11, [sp, #(1 * 16)]
    stp d12, d13, [sp, #(2 * 16)]
    stp d14, d15, [sp, #(3 * 16)]
    stp d16, d17, [sp, #(4 * 16)]
    stp x18, x19, [sp, #(5 * 16)]
    stp x20, x21, [sp, #(6 * 16)]
    stp x22, x23, [sp, #(7 * 16)]
    stp x24, x25, [sp, #(8 * 16)]
    stp x26, x27, [sp, #(9 * 16)]
    str x28, [sp, #(10 * 16)]

    fmov    alpha0, s0
    fmov    alpha1, s0
    fmov    alpha2, s0
    fmov    alpha3, s0

    lsl LDC, LDC, #2            // ldc = ldc * 4

    mov pB, origPB

    mov counterJ, origN
    asr     counterJ, counterJ, #2      // J = J / 4
    cmp     counterJ, #0
    ble sgemm_kernel_L2_BEGIN
                。
                。
                。
                。
                。

这里的PROLOGUE是一个预编译宏(OpenBLAS到底多喜欢预编译宏啊),最后编译.S的时候会被替换成声明global sgemm_kernel和label名sgemm_kernel,而OpenBLAS的头文件cblas中声明了同样名字的sgemm_kernel函数,于是因为函数名相同,这个汇编代码就被链接进最后的二进制库中了。

这样整个OpenBLAS的代码结构就差不多讲完了,因为OpenBLAS开发者对于预编译宏的谜の执着和对不同平台的特定优化,整个库源代码的结构还是比较复杂的,这就比较阻碍我们进一步去理解他的算法(找个函数半天找不到实现实在是(╯‵□′)╯︵┴─┴ )。OpenBLAS对于矩阵乘法实现的算法也是非常精彩,会在另一篇文章中简单(瞎jb)分析。

4

你可能感兴趣的:(c++,openblas)