CUTLASS 是 CUDA C++ 模板抽象的集合,用于在 CUDA 内的所有级别和规模上实现高性能矩阵-矩阵乘法 (GEMM) 和相关计算。它采用了类似于 cuBLAS 和 cuDNN 中实现的分层分解和数据移动策略。
CUTLASS 最新版本为3.3,相比1.3.3变动较大。然而重温一下1.3.3仍然是有意义的。因为它更易于理解:
HMMA.884.F16.F16
一种指令。Demystifying Tensor Cores to Optimize Half-Precision Matrix Multiply 中提到 T4 GPU 在引入 Tensor Core 之后,原来重计算瓶颈的 GEMM 也变成了 IO 瓶颈。虽然 V100的带宽是 T4的三倍,然而带宽不足问题同样存在。因此,CUTLASS 对于数据路径进行了如下优化:
LDG.128
、STS.128
、LDS.128
、STD.128
;LDG.128
、LDS.128
和HMMA.884.F16.F16
三种指令并行,隐藏数据移动。下面以一个矩阵乘测例为例,介绍 Volta884_h884gemm 的实现。
OutputTile
即 threadblock tile,该测例下设置为32x64x128。WarpGemmShape
为32x64x64,这个是固定值。
run_gemm 初始化 Volta884GemmTraits::Params 和 GemmTestbed,调用 Gemm::launch 运行后比对结果。
TEST(Volta884_h884gemm_64x64x32_nt, 520x264x136) {
typedef cutlass::gemm::Volta884GemmTraits<
cutlass::MatrixLayout::kColumnMajor,
cutlass::MatrixLayout::kRowMajor,
cutlass::Shape<32, 64, 128>,
cutlass::Shape<32, 64, 64>,
half,
half,
half,
2
> GemmTraits;
run_gemm<GemmTraits>(520, 264, 136);
}
CUTLASS 中 Volta884实现的层次结构如下图所示
Kernel 函数申请动态 Shared Memory,并传递给 GemmMainloop,然后调用 GemmMainloop::multiply_add 进行计算。
/// GEMM kernel without launch bounds specified
template <typename Gemm_>
__global__ /* __launch_bounds__(Gemm_::kThreads) */
void gemm_kernel_nolb(typename Gemm_::Params params) {
// Dynamic shared memory base pointer
extern __shared__ int GemmSharedStorageBase[];
// Declare pointer to dynamic shared memory.
typename Gemm_::SharedStorage *shared_storage =
reinterpret_cast<typename Gemm_::SharedStorage *>(GemmSharedStorageBase);
// Construct the GEMM object.
Gemm_ gemm(params, *shared_storage);
// Run GEMM.
gemm.multiply_add();
}
Shared Memory 和寄存器需要两个缓冲区,通过 SM 上的调度实现三条流水线并行。Global Memory 到 Shared Memory 的加载有同步,而从 Shared Memory 移动到寄存器时不需要同步。由于 Ampere 之前的架构不支持 Global Memory 到 Shared Memory 的直接拷贝,因此整个搬运过程比较复杂。如下图所示,程序中多处调用 Copy::transform 函数生成transformed_fragment
。原因应该是为了实现类型转换,但 Volta 只支持 half,也就没有实际作用。
template <typename Traits_>
struct GemmMainloop {
//
// Type definitions
//
/// The traits.
typedef Traits_ Traits;
/// The GEMM mainloop
typedef typename Traits::KernelClass KernelClass;
/// The shared storage.
typedef typename Traits::SharedStorage SharedStorage;
/// The scalar for A.
typedef typename Traits::ScalarA ScalarA;
/// The scalar for B.
typedef typename Traits::ScalarB ScalarB;
/// The scalar in the epilogue.
typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
/// The scalar for C.
typedef typename Traits::Epilogue::ScalarC ScalarC;
/// The scalar for D.
typedef typename Traits::Epilogue::ScalarD ScalarD;
/// The index.
typedef typename Traits::Index Index;
/// Define the mainloop iteration size
typedef typename Traits::MultiplyAdd MultiplyAdd;
/// The number of threads.
static int const kThreads = Traits::GemmConfig::kThreads;
AccumulatorsPerWarp
为 GemmConfig::AccumulatorsPerWarp 即 Volta884MultiplyAdd::WarpGemmShape,为32x64x64。
Volta884MultiplyAdd::InstructionShape 为4x32x32。因此,kWarpGemmSteps
为8。
// Number of warp-level multiply-accumulate steps executed by each warp.
static Index const kWarpGemmSteps =
Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
/*
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
*/
/// Use the params object defined in traits
typedef typename Traits::Params Params;
//
// Data members
//
/// The params.
Params const& params;
/// SharedStorage object
SharedStorage& shared_storage;
//
// Methods
//
/// Ctor.
CUTLASS_DEVICE GemmMainloop(Params const& params_, SharedStorage& shared_storage_)
: params(params_), shared_storage(shared_storage_) {}
Volta884GemmTraits::GlobalLoadStream 即 GlobalLoadStreamPair 类型。
GlobalLoadStreamPair::residue 函数调用两次 MMAGlobalLoadStream::residue,计算在线程块 tile 最后一次加载所需的预测掩码。
GlobalLoadStreamPair::copy 函数调用两次 MMAGlobalLoadStream::copy 从 Global Memory 拷贝矩阵元素到寄存器。后者调用 TileLoadIterator::load_post_increment 函数。
/// Fetches global stream pair
template <bool Residue>
CUTLASS_DEVICE void fetch_global(typename Traits::GlobalLoadStream& global_to_shared_stream,
Index outer_k) {
// If residue portion and not calculating residue in prolog, update residue predicates now.
if (Residue) {
global_to_shared_stream.residue(outer_k);
}
global_to_shared_stream.copy();
}
如果kWarpGemmSteps
小于等于4,则为kGlobalStreamFirst
,先从 Global Memory 加载下一次迭代的数据。
/// Computes a warp-level GEMM on data held in shared memory
template <bool Residue, bool LastIteration>
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
typename Traits::SharedStream& shared_load_stream,
typename MultiplyAdd::Accumulators& accumulators,
Index outer_k) {
// Whether to load global stream before loading shared stream
const bool kGlobalStreamFirst = (kWarpGemmSteps <= 4);
// Load data for the next iteration of the main loop (unless it's the last iteration).
if (kGlobalStreamFirst && !LastIteration) {
fetch_global<Residue>(global_to_shared_stream, outer_k);
}
首先从 Shared Memory 加载下一次迭代的输入。拥有双缓冲区。
MMASharedLoadStream::copy 调用 Volta884WarpMultiplicandLoadIterator::load 函数加载数据到寄存器中。
问题是前一步如果没有调用 GemmMainloop::fetch_global,从 Shared Memory 拷贝不会有问题吗?
CUTLASS_PRAGMA_UNROLL
for (int step = 0; step < kWarpGemmSteps; ++step) {
// Trigger the copy from shared memory for the next A/B values.
shared_load_stream.copy((step + 1) % kWarpGemmSteps);
如果不是kGlobalStreamFirst
, 在循环的第一步时调用GemmMainloop::fetch_global 函数加载输入。
// Load data for the next iteration of the main loop (unless it's the last iteration).
if (!kGlobalStreamFirst && (step == 0) && !LastIteration) {
fetch_global<Residue>(global_to_shared_stream, outer_k);
}
如果是倒数第2步,需要确保数据已经加载到了 Shared Memory。
Volta884GemmTraits::shared_load_fence 根据外部传入的StageCount
来确定是否同步线程。
GlobalLoadStreamPair::commit 函数会分别调用两个矩阵的 GlobalLoadStream::commit 拷贝到 Shared Memory。
Volta884GemmTraits::shared_store_fence 同步线程。
MMASharedLoadStream::inc_stage 递增stage_index
。
if (step == kWarpGemmSteps - 2) {
// Make sure the data from shared memory has been entirely consumed.
Traits::shared_load_fence(true);
global_to_shared_stream.commit();
// Make sure the data is in shared memory.
Traits::shared_store_fence(true);
// Move to the next stage for the load (if it makes sense).
shared_load_stream.inc_stage();
}
MMASharedLoadStream::commit 调用 Copy 进行拷贝。Volta884WarpMultiplicandLoadIterator::Fragment 即 Fragment 。
Volta884MultiplyAdd::multiply_add 完成 Warp Tile 的计算。
// Make sure the values are available for the current iteration to do the multiply-add.
shared_load_stream.commit(step);
// Do the math on the fragments of the current iteration.
MultiplyAdd multiply_add;
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
shared_load_stream.fragment_b(step),
accumulators,
accumulators);
}
}
make_Coord_from_shape 根据形状创建一个 Coord 对象。
IdentityBlockSwizzle::get_threadblock_offset 获得当前线程块在输出二维图上的偏移。
Volta884GemmTraits::ClearAccumulators 即 ClearAccumulators。
IdentityBlockSwizzle::get_threadblock_bounds 返回 threadblock 的三维边界。
/// Do the GEMM.
CUTLASS_DEVICE void multiply_add() {
// Swizzle the IDs of the block (to enable better cache behavior).
typename Traits::BlockSwizzle block_swizzle;
Coord<3> threadblock_offset =
block_swizzle.get_threadblock_offset(make_Coord_from_shape<typename Traits::OutputTile>());
// We may want to use shared memory to clear the registers.
typedef typename Traits::ClearAccumulators ClearAccumulators;
// Get the bounds for each thread, it maybe different than problem_size
Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,
params.partitionK_range);
params.global_to_shared_stream
即 GlobalLoadStreamPair::Params。
shared_storage.main_loop.global_to_shared_stream
为 GlobalLoadStreamPair::SharedStorage。
shared_storage.main_loop.threadblock_tile
为 GlobalLoadStreamPair::ThreadblockTileStorage,即 ZipTileAllocation。ZipTileAllocation::reference 返回指向数据的 ZipTensorRef 对象。
global_to_shared_stream
为 Volta884GemmTraits::GlobalLoadStream 即 GlobalLoadStreamPair。
GlobalLoadStreamPair::add_batch_offset 调用 GlobalLoadStreamPair::add_batch_offset GlobalLoadStream::add_batch_offset 函数设置迭代器的 batch 偏移。
// The streams to read A/B from global memory to shared memory.
typename Traits::GlobalLoadStream global_to_shared_stream(
params.global_to_shared_stream,
shared_storage.main_loop.global_to_shared_stream,
shared_storage.main_loop.threadblock_tile.reference(),
bounds,
threadblock_offset);
// update A and B pointer offset based on batch_id and batch_stride_offset
global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());
// Create the accumulator clear.
ClearAccumulators clear;
GlobalLoadStreamPair::move_to_residue 如果是在序幕中执行余数则调用 MMAGlobalLoadStream::move_to_residue 移动指针,否则直接调用 GlobalLoadStreamPair::residue 函数。
GlobalLoadStreamPair::copy 调用 MMAGlobalLoadStream::copy 函数,后者调用 TileLoadIterator::load_post_increment 加载 A 和 B 矩阵的片段到 Fragment 寄存器。
GlobalLoadStreamPair::commit 调用 MMAGlobalLoadStream::commit 函数,后者调用 Copy.transform 进行拷贝,然后调用
Volta884ThreadblockMultiplicandStoreIterator::store_post_increment 保存到 Shared Memory。
Volta884GemmTraits::shared_store_fence 同步 threadblock 内的线程。
GlobalLoadStreamPair::rollback 调用 MMAGlobalLoadStream::rollback 函数,后者调用 TileLoadIterator::initialize_predicates 初始化预测向量,然后移动偏移。
// Deal with residue in prolog.
// global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);
// Fetch the fragments for A and B from global memory.
global_to_shared_stream.copy();
// Copy the elements to shared memory (after transformation if needed).
global_to_shared_stream.commit();
// Make sure the data is in shared memory.
Traits::shared_store_fence(false);
// Rollback to the beginning of the first tile (if residue exists).
// global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
shared_load_stream
为 Volta884GemmTraits::SharedStream 类型,即 SharedStreamPair。
SharedStreamPair::copy 调用 MMASharedLoadStream::copy,后者调用 Volta884WarpMultiplicandLoadIterator::load 从 Shared Memory 加载。
accumulators
为 Volta884MultiplyAdd::Accumulators 类型,即 Fragment。
ClearAccumulators::clear 调用 Fragment::clear 将存储清零。
outer_k
是什么?
// The stream of data from shared memory to fragments.
typename Traits::SharedStream shared_load_stream(
params.shared_stream,
shared_storage.main_loop.threadblock_tile.reference());
// Trigger the copy from shared memory for the 1st stream.
shared_load_stream.copy(0);
// Allocate the accumulators.
typename MultiplyAdd::Accumulators accumulators;
// Clear the accumulators.
clear.clear(accumulators);
// Initial index
// Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
// problem_size[0] might be bigger than bounds[0]
Index outer_k = bounds[0] - Traits::OutputTile::kD;
如果在序幕中计算了剩余,则仅最后一次处理余数。
GemmMainloop::consume_tile 计算k = Traits::OutputTile::kD
的分块。
// Check if we are computing residue in prolog or not.
if (Traits::GemmConfig::kResidueInProlog) {
// Execute all mainloop iterations but the last one.
CUTLASS_GEMM_LOOP
for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
CUTLASS_GEMM_LOOP_HEADER
consume_tile<false, false>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
consume_tile<false, true>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
否则,每次迭代都考虑余数。
} else {
// When kResidueSeparate = true, execute all mainloop iterations but the last two without any
// consideration for K-residue or predicate updates. This improves the steady state of some
// kernels.
if (Traits::GemmConfig::kResidueSeparate) {
CUTLASS_GEMM_LOOP
for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
CUTLASS_GEMM_LOOP_HEADER
consume_tile<false, false>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
}
// Execute remaining tiles with K-residue predicate updates enabled.
CUTLASS_GEMM_LOOP
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
CUTLASS_GEMM_LOOP_HEADER
consume_tile<true, false>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
}
创建 MMAEpilogue 对象,然后调用 MMAEpilogue::epilogue 函数。
typedef typename Traits::Epilogue Epilogue;
Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
}
};