CUDA内联汇编和PTX ISA入门指南

Example

首先给出一个具体例子,下文根据该例讲述基本用法

        asm("{\n\t"
                ".reg .s32 b;\n\t"
                ".reg .pred p;\n\t"
                "add.cc.u32 %1, %1, %2;\n\t"
                "addc.s32 b, 0, 0;\n\t"
                "sub.cc.u32 %0, %0, %2;\n\t"
                "subc.cc.u32 %1, %1, 0;\n\t"
                "subc.s32 b, b, 0;\n\t"
                "setp.eq.s32 p, b, 1;\n\t"
                "@p add.cc.u32 %0, %0, 0xffffffff;\n\t"
                "@p addc.u32 %1, %1, 0;\n\t"
                "}"
                : "+r"(x[0]), "+r"(x[1])
                : "r"(x[2]));

内联汇编(inline assembly)

官方参考:https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html

CUDA内联汇编语法和C/C++相同(除不允许设置clobbered registers)

       asm ( "assembler template string" 
           : "constraint"(output operands)                  /* optional */
           : "constraint"(input operands)                   /* optional */
           );

引用

"assembler template string"指PTX汇编指令代码,其中可以用%n引用操作数,n为从0开始的整数。编号顺序和在出现在输出操作数、输入操作数中的先后顺序一致(例如Example中,操作数出现顺序为x[0],x[1],x[2]分别对应汇编代码中的引用%0,%1,%2

constriant

规定使用相应的PTX寄存器类型

"h" = .u16 reg	// u表示无符号 16表示位数
"r" = .u32 reg
"l" = .u64 reg
"f" = .f32 reg  // f表示浮点寄存器
"d" = .f64 reg

补充(modifier)

由于官方参考关于操作数(operand)的修饰符描述较少,这里稍作简略描述(参考OpenXL C/C++ inline asemmbly):

  • =:表示该操作数只写(write-only),之前的值会被弃用,被这一新的输出数据取而代之
  • +:表示该操作数既可读又可写

输出操作数必须带上修饰符=+,输入操作数由于官方文档并未说明可直接认为修饰符可选(或没有)

PTX ISA

概述

PTX:Parallel Thread Execution
ISA:Instruction Set Architecture(指令集)

PTX ISA比CUDA C++更底层,但编程模型完全相同,仅在一些称谓上有所区别,例如

CTA(Cooperative Thread Array):等同于CUDA线程模型中的Block(PS:Grid这一概念依然保持不变)

变量声明

Example第2、3行均为变量声明语句,单独来看

.reg .s32 b;
.reg .pred p;

由三部分组成:State Space + Type + Identifier

State Space

规定存储位置,带有.前缀,区别标识符

名称 描述
.reg 寄存器,访存快
.sreg 特殊寄存器。只读;预定义;平台有关
.const 共享内存(Shared memory),只读
.global 全局内存(Global memory),全部线程共享
.local 本地内存(Local memory),每个线程私有
.param Kernel parameter
.shared 可按地址访问的共享内存,一个CTA内共享

具体特性参见官方文档State Spaces

Type

基本类型 对应符号
有符号整数 .s8, .s16, .s32, .s64
无符号整数 .u8, .u16, .u32, .u64
浮点数 .f16, .f16x2, .f32, .f64
位串(无类型) .b8, .b16, .b32, .b64
谓词 .pred

Identifier

标识符基本和所有语言规定相同,细微区别参见官方文档Identifier


回到例子

.reg .s32 b;	// 有符号32位整数变量b,使用寄存器存放
.reg .pred p;	// 谓词变量p(用于存放谓词逻辑结果,布尔值),使用寄存器存放

关于谓词的使用方法下文会有详细描述

指令集

指令格式

@p opcode;
@p opcode a;
@p opcode d, a;
@p opcode d, a, b;
@p opcode d, a, b, c;
  • 注意指令的操作数个数,以及操作数的含义(源、目的等等)
  • 最左边的@p是可选的guard predicate,即根据对应谓词结果选择是否执行该条指令

指令类型信息

多类型指令必须带上类型及大小描述符,例如Example中第4行add.cc.u32 %1, %1, %2;.u32表示进行无符号32位整数的加法。

扩展精度的整数运算

主要用于处理进位,例如Example中的4、5行:

add.cc.u32 %1, %1, %2;
addc.s32 b, 0, 0;

add.cc表示该条指令会改写条件码(Condition Code,简称CC)寄存器中的进位标志位(Carry Flag,简称CF);addc表示执行带进位的加法,也就是说除了源操作数外还会加上CC.CF

其它扩展运算指令参考官方文档Extended-Precision Integer Arithmetic Instructions

谓词逻辑执行分支

谓词寄存器本质上是虚拟的寄存器,用于处理PTX中的分支(类比其他ISA的条件跳转指令beq等)

  1. 声明谓词寄存器

    如Example中第3行.reg .pred p,声明谓词变量p

  2. step指令给谓词变量绑定具体谓词逻辑

    如Example中第9行setp.eq.s32 p, b, 1setp指set predicate register

    基本语法格式:setp.CmpOp.type p, a, b

    • type:规定源操作数a,b的类型
    • CmpOp:比较运算符
    • p必须是.pred类型变量

    关于step指令的详细用法参见官方文档Comparison and Selection Instructions: setp

  3. 设置条件分支

    如Example中10、11行

    @p add.cc.u32 %0, %0, 0xffffffff;
    @p addc.u32 %1, %1, 0;
    

    @p表示当p=True时,执行该条指令;而@!p则表示p=False时执行。

翻译为C-like language

// 函数功能:X mod p
// p = 0xFFFFFFFF00000001
// x = {x[0], x[1], x[2]}
// x[0]: LSW(least significant word) word=32-bit
void _uint96_modP(uint32 *x) {
	x[1] += x[2];
    b = /*上条指令发生溢出*/ ? 1 : 0;
    x[0] -= x[2];
    x[1] -= /*上条指令发生溢出*/ ? 1 : 0;
    if (b == 1) {
        x[0] += 0xFFFFFFFF; // x[0] += UINT_MAX
        x[1] += /*上条指令发生溢出*/ ? 1 : 0;
    }
}

(以下内容为分析函数功能,考虑到文章完整性才添加,和主旨无关)
X = x 0 + x 1 ⋅ 2 32 + x 2 ⋅ 2 64 m o d    ( P = 2 64 − 2 32 + 1 ) = x 0 + x 1 ⋅ 2 32 + x 2 ⋅ 2 64 − x 2 P m o d    P = ( x 0 − x 2 ) + ( x 1 + x 2 ) ⋅ 2 32 m o d    P \begin{aligned} X&=x_0+x_1\cdot 2^{32}+x_2\cdot 2^{64} \mod {(P=2^{64}-2^{32}+1)}\\ &=x_0+x_1\cdot 2^{32}+x_2\cdot 2^{64}-x_2P \mod P\\ &=(x_0-x_2)+(x_1+x_2)\cdot 2^{32} \mod P \end{aligned} X=x0+x1232+x2264mod(P=264232+1)=x0+x1232+x2264x2PmodP=(x0x2)+(x1+x2)232modP
代码中x[0] = x[0] - x[2]x[0]位长为32,可能发生溢出即实际所求为模 2 32 2^{32} 232的结果 [ x 0 − x 2 ] 2 32 = x 0 − x 2 + 2 32 [x_0-x_2]_{2^{32}}=x_0-x_2+2^{32} [x0x2]232=x0x2+232,额外加的 2 32 2^{32} 232相当于从高32位 x 1 x_1 x1借位,体现在上面代码第9行;

同样,x[1] = x[1] + x[2],也可能溢出 [ x 1 + x 2 ] 2 32 = x 1 + x 2 − 2 32 [x_1+x_2]_{2^{32}}=x_1+x_2-2^{32} [x1+x2]232=x1+x2232

于是 X X X中应该继续处理多减掉的 2 32 2^{32} 232
( [ x 1 + x 2 ] 2 32 + 2 32 ) ⋅ 2 32 m o d    P = [ x 1 + x 2 ] 2 32 ⋅ 2 32 + ⋅ 2 64 − P m o d    P = [ x 1 + x 2 ] 2 32 + ( 2 32 − 1 ) m o d    P \begin{aligned} &([x_1+x_2]_{2^{32}}+2^{32})\cdot2^{32}\mod P\\ =&[x_1+x_2]_{2^{32}}\cdot2^{32}+\cdot2^{64}-P\mod P\\ =&[x_1+x_2]_{2^{32}}+(2^{32} - 1) \mod P \end{aligned} ==([x1+x2]232+232)232modP[x1+x2]232232+264PmodP[x1+x2]232+(2321)modP
2 32 − 1 = 0 xffff ffff 2^{32} - 1=0\text {xffff ffff} 2321=0xffff ffff应该加到低32位,对应代码第11、12行(包括处理来自低32位的进位)


Example2

 uint64 _mul_modP(uint64 x, uint64 y) {
	volatile register uint32 mul[4]; // NEVER REMOVE VOLATILE HERE!!!
    // 128-bit = 64-bit * 64-bit
    asm("mul.lo.u32 %0, %4, %6;\n\t"
        "mul.hi.u32 %1, %4, %6;\n\t"
        "mul.lo.u32 %2, %5, %7;\n\t"
        "mul.hi.u32 %3, %5, %7;\n\t"
        "mad.lo.cc.u32 %1, %4, %7, %1;\n\t"
        "madc.hi.cc.u32 %2, %4, %7, %2;\n\t"
        "addc.u32 %3, %3, 0;\n\t"
        "mad.lo.cc.u32 %1, %5, %6, %1;\n\t"
        "madc.hi.cc.u32 %2, %5, %6, %2;\n\t"
        "addc.u32 %3, %3, 0;\n\t"
        : "+r"(mul[0]), "+r"(mul[1]), "+r"(mul[2]), "+r"(mul[3])
        : "r"(((uint32 *)&x)[0]), "r"(((uint32 *)&x)[1]),
        "r"(((uint32 *)&y)[0]), "r"(((uint32 *)&y)[1]));
 	/* ... */
 }

// output: %0 = mul[0], %1 = mul[1], %2 = mul[2], %3 = mul[3]
// input: %4 = x[0..31], %5 = x[32...63], %6 = y[0...31], %7 = y[32...63]

mul指令

语法如下:

mul.mode.type  d, a, b;

.mode = { .hi, .lo, .wide };
.type = { .u16, .u32, .u64,
          .s16, .s32, .s64 };
  • type规定操作数类型和大小
  • mode有三种:当使用.wide时,要求目的操作数d的长度为源操作数a,b的两倍,PTX规定此时type只能使用位长16和32的类型;当使用.hi.lo时对type不作要求,此时d的位长和a,b相同,分别取高位和低位。

mad指令

"multiplication and addition"的意思,自然有四个操作数,具体语法如下:

mad.mode.type  d, a, b, c;

.mode = { .hi, .lo, .wide };
.type = { .u16, .u32, .u64,
          .s16, .s32, .s64 };
  • mode和type规定决定乘法的中间值,完全同上;c的位宽和d相同
  • 由于涉及加法,同样有和CC寄存器相关的扩展指令:mad.cc、madc等

翻译为C-like language

// output: %0 = mul[0], %1 = mul[1], %2 = mul[2], %3 = mul[3]
// input: %4 = x[0..31], %5 = x[32...63], %6 = y[0...31], %7 = y[32...63]
__inline__ __device__
uint64 _mul_modP(uint64 x, uint64 y) {
	volatile register uint32 mul[4]; // NEVER REMOVE VOLATILE HERE!!!
	mul[0] = (x[0:31] * y[0:31])[0:31];
	mul[1] = (x[0:31] * y[0:31])[32:63];
	mul[2] = (x[32:63] * y[32:63])[0:31];
	mul[3] = (x[32:63] * y[32:63])[32:63];
 	mul[1] += (x[0:31] * y[32:63])[0:31];
 	mul[2] += (x[0:31] * y[32:63])[32:63] + /*上条指令发生溢出*/ ? 1 : 0;
	mul[3] += /*上条指令发生溢出*/ ? 1 : 0;
 	mul[1] += (x[32:63] * y[0:31])[0:31];
	mul[2] += (x[32:63] * y[0:31])[32:63] + /*上条指令发生溢出*/ ? 1 : 0;
	mul[3] += /*上条指令发生溢出*/ ? 1 : 0;
    _uint128_modP(mul);
    if (*(uint64 *)mul > valP)
            *(uint64 *)mul -= valP;
        return *(uint64 *)mul;
}

X = x 0 + x 1 ⋅ 2 32 , Y = y 0 + y 1 ⋅ 2 32 X=x_0 + x_1\cdot2^{32},Y=y_0+y_1\cdot2^{32} X=x0+x1232,Y=y0+y1232,注意到 x i , y i , i = 0 , 1 x_i,y_i, i =0,1 xi,yi,i=0,1都是一字长(32位)
X Y = ( x 0 + x 1 ⋅ 2 32 ) ( y 0 + y 1 ⋅ 2 32 ) = x 0 y 0 + ( x 0 y 1 + x 1 y 0 ) ⋅ 2 32 + x 1 y 1 ⋅ 2 64 \begin{aligned} XY &=(x_0 + x_1\cdot2^{32})(y_0+y_1\cdot2^{32})\\ &=x_0y_0+(x_0y_1+x_1y_0)\cdot2^{32}+x_1y_1\cdot2^{64} \end{aligned} XY=(x0+x1232)(y0+y1232)=x0y0+(x0y1+x1y0)232+x1y1264
乘法结果为4字长,可以复用_uint128_modP,而 x i y j x_iy_j xiyj为2字长,继续分解
X Y = m u l 0 + ( m u l 1 + x 0 y 1 [ 0 : 31 ] + x 1 y 0 [ 0 : 31 ] ) ⋅ 2 32 + ( m u l 2 + x 0 y 1 [ 32 : 63 ] + x 1 y 0 [ 32 : 63 ] ) ⋅ 2 64 + m u l 3 ⋅ 2 96 XY=mul_0+(mul_1+x_0y_1[0:31]+x_1y_0[0:31])\cdot2^{32}+(mul_2+x_0y_1[32:63]+x_1y_0[32:63])\cdot2^{64}+mul_3\cdot2^{96} XY=mul0+(mul1+x0y1[0:31]+x1y0[0:31])232+(mul2+x0y1[32:63]+x1y0[32:63])264+mul3296
其中 m u l 0 ∼ 3 mul_{0\sim3} mul03的计算对应代码6~9行; m u l 1 + x 0 y 1 [ 0 : 31 ] mul_1+x_0y_1[0:31] mul1+x0y1[0:31] m u l 2 + x 0 y 1 [ 32 : 63 ] mul_2+x_0y_1[32:63] mul2+x0y1[32:63](包括加法进位处理)对应代码10~12行, m u l 1 + x 1 y 0 [ 0 : 31 ] mul_1+x_1y_0[0:31] mul1+x1y0[0:31] m u l 2 + x 1 y 0 [ 32 : 63 ] mul_2+x_1y_0[32:63] mul2+x1y0[32:63](包括加法进位处理)对应代码13~15行

你可能感兴趣的:(c++)