第 9 章 · GEMM 深入

⏱️ 90 分钟🎯 触摸 Tensor Core📂 code/ch09_gemm/🔥 关键瓶颈章

学习目标

9.1 GEMM 在 LLM 里的地位

统计:LLM 推理 80%+ 的 FLOPs 在 GEMM 上。Transformer 每层主要算子:

所以 GEMM 快 1×,整个推理快 80%。这就是为什么 NVIDIA 把"Tensor Core"做到第 5 代。

9.2 寄存器 Tile:从 Ch6 再进一步

Ch6 里每 thread 算 1 个 cell。瓶颈是每读 As/Bs 一次只用 2 个 FLOP。 如果每 thread 算 8×8 = 64 个 cell:每读 8+8 个 shared 值就做 64 个 FLOP,shared 访问减少 8×。

graph TD
    A["block 处理 C 的 128×128 tile
thread block = 16×16 = 256 threads"] A --> B["每 thread 持有 8×8 reg accumulator"] B --> C["K 维分 BK=16 切片
每片协作加载 128×16 + 16×128 到 shared"] C --> D["内层 K 循环 (BK=16):
读 8 个 As 行 + 8 个 Bs 列 → 64 个 fmadd"] D --> E["写回 C"] style B fill:#f3f1e8,stroke:#2f5d3a style D fill:#f3f1e8,stroke:#a86420
constexpr int BM = 128, BN = 128, BK = 16;
constexpr int TM = 8,   TN = 8;

__global__ void gemm_reg_tile(const float* A, const float* B, float* C, int M, int N, int K) {
    __shared__ float As[BM][BK], Bs[BK][BN];
    int ty = threadIdx.y, tx = threadIdx.x;
    int row0 = blockIdx.y * BM + ty * TM;
    int col0 = blockIdx.x * BN + tx * TN;
    float acc[TM][TN] = {0};

    for (int kt = 0; kt < K; kt += BK) {
        /* 协作 load A 的 128x16 + B 的 16x128 (每 thread load 8+8 个数) */
        __syncthreads();

        #pragma unroll
        for (int k = 0; k < BK; ++k) {
            float a_reg[TM], b_reg[TN];
            #pragma unroll for (int i=0;i<TM;++i) a_reg[i] = As[ty*TM+i][k];
            #pragma unroll for (int j=0;j<TN;++j) b_reg[j] = Bs[k][tx*TN+j];
            #pragma unroll for (int i=0;i<TM;++i)
            #pragma unroll for (int j=0;j<TN;++j) acc[i][j] += a_reg[i] * b_reg[j];
        }
        __syncthreads();
    }
    /* write back acc → C */
}

关键洞察:内层 8×8 = 64 个 fmadd 全部在寄存器之间发生。BK=16 步内只触发 16×(8+8) = 256 次 shared load,每 thread 做 16×64 = 1024 个 FLOP。算术强度 4 FLOP/shared-byte。

9.3 Tensor Core — WMMA API

fp32 CUDA core 路线最多到 ~10 TFLOPS(A100)。要继续上 30+ TFLOPS 必须用 Tensor Core——它是专门做 16×16×16 fp16 矩阵乘加的硬件。

WMMA fragment 三件套

#include <mma.h>
using namespace nvcuda;

wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, 16, 16, 16, float>                c_frag;

wmma::fill_fragment(c_frag, 0.f);
for (int kt = 0; kt < K; kt += 16) {
    wmma::load_matrix_sync(a_frag, A + row*K + kt, K);
    wmma::load_matrix_sync(b_frag, B + kt*N + col, N);
    wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);   // <-- TC magic
}
wmma::store_matrix_sync(C + row*N + col, c_frag, N, wmma::mem_row_major);

每个 warp 共享一个 fragment,硬件保证整个 warp 32 个 lane 协作完成 16×16×16 = 4096 FMA,一次 mma_sync 指令吞 8 个时钟。

限制: 仅 sm_70+;M/N/K 通常要 16 倍数;fp16 输入精度有限(fp32 累加缓解,但仍不如纯 fp32)。 实际应用通常配 per-channel scaling + loss scaling 防止下溢。

9.4 cuBLAS baseline

调 cuBLAS 是判断"自己写得好不好"的标尺:

#include <cublas_v2.h>
cublasHandle_t h; cublasCreate(&h);
float alpha = 1.f, beta = 0.f;

// 注意:cuBLAS 是 column-major!
// row-major C(M,N) = A(M,K) @ B(K,N)
//   ≡  C^T(N,M) = B^T(N,K) @ A^T(K,M)
//   ≡  cublasSgemm(N, N, N, M, K, B, N, A, K, C, N)
cublasSgemm(h, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K,
            &alpha, dB, N, dA, K, &beta, dC, N);

9.5 性能榜单(A100,1024³ MatMul)

实现精度GFLOPS% of peak
Ch6 tiled (32×32)fp32~18009%
Ch9 register tile (128×128, 8×8)fp32~7000-900040-50%
cuBLAS sgemmfp32~1600082%
本章 WMMA fp16fp16~50000-8000020-30% (of 312 TF)
cuBLAS gemmEx fp16 + TCfp16~250000+80%+
CUTLASS fp16 + TCfp16~280000+90%+

本章手写 WMMA 还有大量优化空间(async copy、swizzle、warpgroup mma)——但拿到 ~30% peak 已经是用 100 行代码做到的,足以理解原理。

9.6 下一步技术(不在本章实现,但要知道存在)

9.7 自检

Q1: register tile 越大越好吗?

不是。每 thread 寄存器有 255 上限,TM=TN=16 就用了 256 个累加寄存器,spill 到 local memory 反而慢。

Q2: 为什么 WMMA 用 fp32 accumulator?

fp16 累加范围太窄(mantissa 仅 10 bit),长 K 的求和会下溢/上溢。fp32 累加保持精度,最后才转 fp16 输出。

Q3: GEMM 在 M=1 (batch=1 decode) 时性能为啥差?

1×K @ K×N = GEMV,瓶颈是带宽不是算力,Tensor Core 完全用不上。所以 LLM decode 阶段是 memory-bound。第 13 章会讲量化怎么救它。

Q4: CUTLASS 是什么?

NVIDIA 开源的 C++ GEMM 模板库,把 tile size / warp 切分 / async copy 都抽象成 type traits。生产 GEMM 几乎都用 CUTLASS 或它的 DSL (CuTe)。第 12 章 FlashAttention 也是基于 CUTLASS 写的。

Q5: Triton 是什么?

OpenAI 的 Python-like GPU DSL,编译到 PTX。比 CUDA 易写、性能接近 CUTLASS,vLLM 大量算子用 Triton 写。

9.8 练习

  1. gemm_reg_tile 加 double buffer(开两块 shared,K-loop 中 prefetch 下一片)。
  2. __pipeline_memcpy_async(sm_80+)替换 shared 加载,看 GFLOPS 提升。
  3. 把 WMMA 改成支持 bf16(把 __half 换成 __nv_bfloat16)。
  4. 用 ncu 看 gemm_wmma 的 Tensor Active %,如果只有 30-50%,调整 block size 提到 70%+。

9.9 工业实战:CUTLASS、低精度、autotune、生产 GEMM 决策树

9.9.1 CUTLASS — 不是库,是 GEMM 工厂

cuBLAS 是闭源黑盒,你不能改 kernel 内部行为;CUTLASS 是 NVIDIA 开源的 C++ 模板库,把 GEMM 拆成可拼装的组件:

// CUTLASS 3.x (CuTe DSL)
#include <cutlass/gemm/device/gemm_universal.h>

using Gemm = cutlass::gemm::device::GemmUniversal<
    cutlass::half_t, cutlass::layout::RowMajor,            // A: fp16 row-major
    cutlass::half_t, cutlass::layout::ColumnMajor,         // B: fp16 col-major
    cutlass::half_t, cutlass::layout::RowMajor,            // C: fp16 row-major
    float,                                                  // accumulator: fp32
    cutlass::arch::OpClassTensorOp,                        // Tensor Core
    cutlass::arch::Sm80,                                   // A100
    cutlass::gemm::GemmShape<128, 256, 32>,                // ThreadBlock tile
    cutlass::gemm::GemmShape<64, 64, 32>,                  // Warp tile
    cutlass::gemm::GemmShape<16, 8, 16>,                   // Instruction shape (mma.sync)
    cutlass::epilogue::thread::LinearCombinationRelu<...>, // epilogue: out = relu(alpha*acc + beta*C)
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
    3                                                       // pipeline stages (软流水深度)
>;

Gemm gemm;
gemm({M, N, K}, alpha, A, lda, B, ldb, beta, C, ldc, ...);

看似冗长,但 5 行模板就拿到了 cuBLAS 90% 的性能。大公司生产代码大量基于 CUTLASS(TensorRT-LLM、FlashAttention、xformers)。

CUTLASS profiler — 自动 autotune

不知道你的 size 该选哪个 tile?跑 profiler:

cd cutlass/build
./tools/profiler/cutlass_profiler \
    --kernels=cutlass_tensorop_h16816gemm \
    --m=4096 --n=4096 --k=4096 \
    --A=f16:row --B=f16:col --C=f16 \
    --accumulator-type=f32 \
    --providers=cutlass,cublas

# 输出: top 5 best-performing configurations

9.9.2 mma.sync vs WMMA:CUTLASS 用前者

本章用了 WMMA API(nvcuda::wmma),它是 fragment 级抽象。CUTLASS 用更底层的 mma.sync PTX 指令:

WMMAmma.sync (PTX)
抽象层fragment + load/mma/store32 lane 的 inline ASM
shape 限制固定 16×16×16, 16×8×8 等少数几种所有硬件支持的 shape
fragment layoutopaque (不知道哪 lane 持有哪个值)明确 (能跟 epilogue 配合)
性能~70-80% peak~90-95% peak
易写容易需要细心
// mma.sync 长这样 (Ampere fp16):
asm volatile(
    "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
    "{%0,%1,%2,%3}, "                  // D (acc, 4 fp32)
    "{%4,%5,%6,%7}, "                  // A (4 fp16x2)
    "{%8,%9}, "                        // B (2 fp16x2)
    "{%0,%1,%2,%3};\n"                 // C (= D in/out)
    : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
    : "r"(a0), "r"(a1), "r"(a2), "r"(a3),
      "r"(b0), "r"(b1));

普通工程师不需要直接写 mma.sync——CUTLASS 模板已封装。但读 FlashAttention v2 源码必看懂。

9.9.3 低精度类型一览:fp16 / bf16 / tf32 / fp8 / int8 / int4

类型位数动态范围典型用途支持架构
fp321+8+23±10⁻³⁸ 到 ±10³⁸baseline所有
tf321+8+10同 fp32 但精度低训练默认sm_80+
fp161+5+10±10⁻⁵ 到 ±10⁴ ⚠️ 易溢出推理首选sm_70+
bf161+8+7同 fp32 但精度低训练首选sm_80+ (A100, H100)
fp8 E4M31+4+3±0.0049 到 ±448推理 (前向)sm_89+ (Ada, H100)
fp8 E5M21+5+2±10⁻⁵ 到 ±57344训练 (反向)sm_89+
int8±127推理量化sm_75+
int4±7weight-only 量化sm_80+
fp4 (B100)±62025+ 推理sm_100

fp16 vs bf16:选哪个?

fp8 — H100 的 LLM 杀手锏

Llama 70B 用 fp8 推理(权重 + KV cache 都 fp8),相比 fp16:

调用:

cublasGemmEx(h, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K,
             &alpha,
             dB_e4m3, CUDA_R_8F_E4M3, N,
             dA_e4m3, CUDA_R_8F_E4M3, K,
             &beta,
             dC_half, CUDA_R_16F, N,
             CUBLAS_COMPUTE_32F,
             CUBLAS_GEMM_DEFAULT);

9.9.4 production GEMM 决策树

graph TD
    Start["GEMM with M, N, K, dtype"] --> M1{"M 等于 1 ?"}
    M1 -->|"是"| GEMV["走 GEMV kernel
weight-only int4 量化"] M1 -->|"否"| Size{"min of M N 小于 64 ?"} Size -->|"是"| Thin["thin GEMM
每 SM 算一行或一列"] Size -->|"否"| K1{"K 大于 4 倍 min M N ?"} K1 -->|"是"| StreamK["CUTLASS Stream-K"] K1 -->|"否"| Dtype{"dtype 是 fp16 bf16 fp8 ?"} Dtype -->|"是"| TC["cublasGemmEx + TC
或 CUTLASS"] Dtype -->|"否 fp32"| Sgemm["cublasSgemm"] style GEMV fill:#f3f1e8,stroke:#a86420 style TC fill:#f3f1e8,stroke:#2f5d3a

9.9.5 LLM 各层 GEMM 的实际 shape

Llama 7B (D=4096, n_head=32, d_head=128, d_ff=11008),prefill 阶段 (T=2048, batch=1)

GEMMM × N × KFLOPs类型
QKV projection2048 × 12288 × 40960.2 TF普通
Output projection2048 × 4096 × 40960.07 TF普通
FFN gate2048 × 11008 × 40960.18 TF普通
FFN up同上同上普通
FFN down2048 × 4096 × 110080.18 TF普通
logits2048 × 32000 × 40960.5 TF普通但 N=32K 大

同样模型 decode 阶段 (M=1):所有 GEMM 都变 GEMV,FLOPs 减为 1/2048 但 HBM 流量不变 → 完全 memory-bound。

这就是 LLM 推理的"两个世界":prefill 在 compute-bound,可以用大 batch + TC 充分利用算力;decode 在 memory-bound,只能靠量化 + KV cache 优化 + speculative decoding 救。

9.10 研究前沿(2025-2026):FP4 GEMM、DeepSeek 原生 FP8、CUTLASS 3.5

9.10.1 FP4 GEMM — Blackwell 的杀招

2025 NVIDIA Blackwell 把"用 4-bit 浮点训练 / 推理"从研究变成产品。两种 fp4 微缩放格式:

格式编码block scale来源用途
NVFP4 (E2M1)1+2+1fp8 (E4M3), block=16NVIDIA 私有推理首选,精度高
MXFP4 (E2M1)1+2+1E8M0, block=32OCP 开放标准跨厂商兼容
MXFP6 (E3M2/E2M3)1+3+2 或 1+2+3E8M0, block=32OCP激活用,比 fp4 稳

FP4 GEMM 调用(CUTLASS 3.5+)

using Gemm = cutlass::gemm::collective::CollectiveBuilder<
    cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp,
    cutlass::nv_float4_t, cutlass::layout::RowMajor, 16,        // A: NVFP4, block 16
    cutlass::nv_float4_t, cutlass::layout::ColumnMajor, 16,     // B: NVFP4
    float, /*accumulator*/                                       // accumulator fp32
    cutlass::gemm::collective::StageCount<4>,                    // 4-stage pipeline
    ...
>::CollectiveOp;

性能:B200 上 4096³ fp4 GEMM 跑到 ~6.5 PF(占 7.7 PF peak 的 85%)。相比 H100 fp16(~600 TF)提速 ~11×。LLM 推理 throughput 在 B200 上跟 H100 fp16 比典型 4-6×。

9.10.2 DeepSeek-V3 原生 FP8 训练 — 验证可行性

2024.12 DeepSeek-V3 发布,用 fp8 训练 671B MoE 模型成功,是产业第一个公开复现的案例。技术要点:

开源 fp8 训练框架推荐:

9.10.3 CUTLASS 3.5+ 与 CuTe DSL

2024-2025 CUTLASS 3.5 重大改动:

# CUTLASS Python 调用 (2025+)
import cutlass
plan = cutlass.op.Gemm(element_a=cutlass.DataType.bf16,
                       element_b=cutlass.DataType.bf16,
                       element_c=cutlass.DataType.bf16,
                       element_accumulator=cutlass.DataType.f32,
                       layout=cutlass.LayoutType.RowMajor)
plan.run(A, B, C, alpha=1.0, beta=0.0)        # 自动选最优 tile

9.10.4 W4A4 / W4A8KV4 — 极致量化推理(2024-2025)

方案权重激活KV cache典型加速
fp16 baselinefp16fp16fp16
W8A8 (SmoothQuant)int8int8fp161.5-2×
W4A16 (GPTQ/AWQ)int4fp16fp162-3× decode
W4A8 (QServe MIT, 2024)int4int8int42.5× over W4A16
W4A4 (Atom, QuaRot 2024)int4int4int4~3-4×, 需精心实现
NVFP4 (Blackwell 2025)fp4fp4fp8/fp4~4-6×

关键论文:

9.10.5 训练 GEMM 的新范式:Microscaling + Transformer Engine 2

Hopper 上的 fp8 训练已经稳定(NVIDIA TE 1)。Blackwell 把它升级到动态精度训练

9.10.6 工业 GEMM 库格局(2026)

定位覆盖
cuBLAS / cuBLASLtNVIDIA 标配, 闭源黑盒全精度, 全架构
CUTLASS 3.5+开源模板, fused epilogueHopper+Blackwell 完整
TensorRT-LLM kernels推理 fused (qkv+norm+attention+output)Hopper+Blackwell
ThunderKittens极简研究 DSLHopper+ (Blackwell 适配中)
FlashInfer专注 attention 各变体多平台
TritonPython kernel, vLLM/SGLang 用Ampere-Blackwell
Marlin / MacheteW4A16 极致 kernelAmpere+
FBGEMM_GPU (Meta)fp8 + 量化推理Hopper+

9.11 常见坑