warp矩阵函数:
C ++ warp矩阵操作利用Tensor Cores来加速D = A * B + C形式的矩阵问题。 这需要一个warp中所有线程的协作。
这些warp矩阵函数是计算能力为7.0或更高的设备支持的预览功能。 此处描述的数据结构和API在未来版本中可能会有所变化,并且可能与这些未来版本不兼容。
描述:
以下所有函数和类型都在命名空间nvcuda :: wmma中定义。
template
class fragment;
template<> class fragment
template<> class fragment
template<> class fragment
template<> class fragment
template<> class fragment
template<> class fragment
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm);
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm, layout_t,layout);
void store_matrix_sync(T* mptr, const fragment<...> &a, unsigned ldm, layout_t
layout);
void fill_fragment(fragment<...> &a, const T& v);
void mma_sync(fragment<...> &d, const fragment<...> &a, const fragment<...>
&b, const fragment<...> &c, bool satf = false);
fragment:
一个重载类,包含分布在warp中所有线程上的矩阵部分。 矩阵元素到片段内部存储的映射是未指定的,并且在将来的架构中可能会改变。
只允许某些模板参数的组合。 第一个模板参数指定片段如何参与矩阵操作。 使用的可接受值为:
m,n和k大小描述了参与乘法累加运算的整经矩阵矩阵的形状。 每个瓦片的尺寸取决于其作用。 对于矩阵_a,瓦片的维数为m×k; 对于matrix_b,维数为k×n,累加器瓦片为m×n。
数据类型T对于被乘数必须是__half
,对于累加器可以是__half
或float。 必须为matrix_a和matrix_b片段指定Layout参数。 row_major或col_major分别指示矩阵行或列内的元素在内存中是连续的。 累加器矩阵的Layout参数应保留void的缺省值。 只有当累加器被加载或存储时才会指定行或列布局,如下所述。
load_matrix_sync:
等待直到经线中的所有线程都收敛为止,然后从存储器中加载矩阵片段a。 mptr必须是一个128位对齐的指针,指向内存中矩阵的第一个元素。 ldm描述连续行(对于行主要布局)或列(对于列主要布局)之间的元素的跨度,并且必须是16个字节的倍数(即,8个__half元素或4个浮动元素)。 如果片段是累加器,则布局参数必须指定为mem_row_major或mem_col_major。 对于matrix_a和matrix_b片段,根据片段的布局参数推断布局。 对于变形中的所有线程,mptr,ldm,布局和所有模板参数的值必须相同。 此函数必须由warp中的所有线程调用,否则结果未定义。
store_matrix_sync:
等待直到经线中的所有线程都收敛为止,然后将矩阵片段a存储到存储器中。 mptr必须是一个128位对齐的指针,指向内存中矩阵的第一个元素。 ldm描述连续行(对于行主布局)或列(对于列主要布局)之间的元素的跨度,并且必须是16个字节的倍数。 输出矩阵的布局必须指定为mem_row_major或mem_col_major。 对于变形中的所有线程,mptr,ldm,布局和所有模板参数的值必须相同。 此函数必须由warp中的所有线程调用,否则结果未定义。
fill_fragment:
填充具有常数值v的矩阵片段。由于未指定矩阵元素到每个片段的映射,因此该函数通常由warp中的所有线程调用,并具有v的公共值。
mma_sync:
等待直到warp中的所有线程收敛,然后执行warpsynchronous矩阵乘 - 累积运算D = A B + C。 就地操作C = A B + C也受支持。 每个矩阵片段的satf和模板参数的值对于warp中的所有线程都必须相同。 此外,模板参数m,n和k必须在分段A,B,C和D之间匹配。此函数必须由warp中的所有线程调用,否则结果未定义。
如果饱和(有限值)模式为真,则以下附加数值属性适用于目标累加器:
enum fragment
作为一个例子,下面的代码将累加器矩阵图块缩放一半。
wmma::fragment frag;
float alpha = 0.5f; // Same value for all threads in warp
...
for (int t = 0; t
例子:
以下代码在单个warp中实现了16x16x16矩阵乘法。
#include
using namespace nvcuda;
__global__ void wmma_ker(half *a, half *b, float *c) {
// Declare the fragments
wmma::fragment a_frag;
wmma::fragment b_frag;
wmma::fragment c_frag;
// Initialize the output to zero
wmma::fill_fragment(c_frag, 0.0f);
// Load the inputs
wmma::load_matrix_sync(a_frag, a, 16);
wmma::load_matrix_sync(b_frag, b, 16);
// Perform the matrix multiplication
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
// Store the output
wmma::store_matrix_sync(c, c_frag, 16, wmma::mem_row_major);
}