DeepGEMM:统一高性能张量核心内核库,多功能升级提升性能

张开发
2026/4/19 13:35:15 15 分钟阅读

分享文章

DeepGEMM:统一高性能张量核心内核库,多功能升级提升性能
DeepGEMM统一的高性能张量核心内核库DeepGEMM 是一个统一的高性能张量核心内核库它将现代大语言模型的关键计算原语整合到一个统一的 CUDA 代码库中这些原语包括通用矩阵乘法GEMMs支持 FP8、FP4、BF16、融合专家混合MoE与重叠通信Mega MoE、闪电索引器的多查询注意力MQA评分、超连接HyperConnectionHC等。所有内核都通过轻量级即时编译Just - In - TimeJIT模块在运行时编译安装过程中无需进行 CUDA 编译。DeepGEMM 借鉴了 CUTLASS 和 CuTe 的一些概念但避免了对它们的模板或代数的过度依赖。该库设计简洁核心内核函数数量有限是学习 NVIDIA GPU 内核优化技术的优质资源。尽管设计轻量但 DeepGEMM 在各种矩阵形状下的性能与经过专家调优的库相当甚至更优。新闻动态2026 年 4 月 16 日新增 Mega MoE、FP8xFP4 GEMM、FP4 索引器、程序依赖启动PDL、更快的 JIT 编译等功能。性能对比后续公布详情见 #304。2025 年 9 月 28 日DeepGEMM 现在支持 DeepSeek v3.2 闪电索引器的评分内核加权 ReLU MQA 对数详情见 #200。2025 年 7 月 20 日DeepGEMM 同时支持 SM90 和 SM100 架构对低 CPU 开销的 JIT CPP 模块进行了全面重构禁用了 NVRTC 和编译后 SASS 优化后续将支持 NVRTC。由于 NVCC 12.9 会自动进行 FFMA 交错不再支持所有编译后优化详情见 #112。2025 年 5 月 14 日DeepGEMM 现在提供用于密集和 MoE 反向传播的权重梯度内核详情见 #95。2025 年 5 月 7 日DeepGEMM 现在支持 NVRTC编译速度最高可提升 10 倍详情见 #94。使用 DG_JIT_USE_NVRTC 1 启用某些情况下可能会有性能损失。2025 年 4 月 18 日DeepGEMM 在 H800 上实现了高达 1550 TFLOPS 的性能详情见 #74、#78、#81、#86 和 340d988。快速开始要求- NVIDIA SM90 或 SM100 架构的 GPU- Python 3.8 或更高版本- 支持 C20 的编译器- CUDA 工具包SM90 建议使用 CUDA 12.3 或更高版本为获得最佳性能强烈推荐 12.9 或更高版本SM100 建议使用 CUDA 12.9 或更高版本- PyTorch 2.1 或更高版本- CUTLASS 4.0 或更高版本可通过 Git 子模块克隆- {fmt} 库可通过 Git 子模块克隆开发bash# 必须克隆子模块git clone --recursive gitgithub.com:deepseek - ai/DeepGEMM.gitcd DeepGEMM# 链接一些必要的头文件并构建 CPP JIT 模块cat develop.sh./develop.sh安装bashcat install.sh./install.sh然后在你的 Python 项目中导入 deep_gemm 即可开始使用接口说明通用信息该库为 NVIDIA GPU 提供优化的 GEMM 内核命名规则为 D C A B。输入形状布局为 NTA 不转置B 转置。SM90 实现仅支持 NT 内存布局行主序、列主序而 SM100 实现支持所有内存布局NT、TN、NN、TT。例如fp8_gemm_nt 会执行 D C A B.T。对于两种架构左侧缩放因子都需要具有 TMA 对齐和转置的布局。SM90 和 SM100 的缩放因子数据格式不同SM90 需要 FP32 格式的缩放因子SM100 需要打包的 UE8M0 格式即将 4 个 UE8M0 打包成一个 torch.int。请注意输入转置或 FP8 转换等操作需要用户单独处理请自行实现或将其融合到之前的内核中。虽然库中提供了一些简单的 PyTorch 实用函数但可能会导致性能下降库的主要重点是优化 GEMM 内核本身。普通密集 GEMMs非分组要执行基本的非分组 FP8 GEMM调用 fp8_gemm_{nt, nn, tn, tt} 函数具体详情请参考函数文档。分组 GEMMs连续布局与 CUTLASS 中的传统分组 GEMMs 不同DeepGEMM 仅对 M 轴进行分组N 和 K 必须保持固定。这种设计适用于 MoE 模型中专家共享相同形状的场景。在训练前向传播或推理预填充阶段每个专家可能处理不同数量的令牌我们将这些令牌连接成一个张量即“连续”布局。请注意每个专家段必须与 GEMM M 块大小对齐get_mk_alignment_for_contiguous_layout()。更多信息请参考 m_grouped_fp8_gemm_{nt, nn}_contiguous 函数文档。我们还提供了用于 MoE 权重反向传播的 K 轴分组 APIM 和 N 必须保持固定详情请参考 k_grouped_fp8_gemm_tn_contiguous。分组 GEMMs掩码布局在推理解码阶段当启用 CUDA 图且 CPU 不知道每个专家接收的令牌数量时我们支持掩码分组 GEMMs。通过提供掩码张量内核仅计算有效部分。使用 m_grouped_fp8_gemm_nt_masked 并参考相关文档。一个示例用法是使用 DeepEP 低延迟内核的输出作为输入。V3.2 索引器的 MQA 内核该内核家族有两个版本非分页用于预填充和分页用于解码。以非分页版本 fp8_mqa_logits 为例它有 6 个输入- q形状为 [seq_len, num_heads, head_dim] 的 E4M3 张量- kv形状为 [seq_len_kv, head_dim] 的 E4M3 张量带有形状为 [seq_len_kv] 的浮点缩放因子- weights形状为 [seq_len, num_heads] 的浮点张量- cu_seq_len_k_start 和 cu_seq_len_k_end形状为 [seq_len] 的整数张量- clean_logits是否将未填充的对数清零为 - inf输出张量形状为 [seq_len, seq_len_kv]表示令牌到令牌的对数。对于 q 中的每个令牌 i它将遍历 [cu_seq_len_k_start[i], cu_seq_len_k_end[i]) 中的所有令牌 j并计算对数 out[i, j] 如下pythonkv_j kv[0][j, :] * kv[1][j].unsqueeze(1) # [head_dim]out_ij q[i, :, :] kv_j # [num_heads]out_ij out_ij.relu() * weights[i, :] # [num_heads]out_ij out_ij.sum() # 标量更多详情和分页版本 fp8_paged_mqa_logits 请参考 tests/test_attention.py。Mega MoEMega MoE 将专家并行EP调度、线性层 1FP8xFP4、SwiGLU、线性层 2FP8xFP4和 EP 合并融合到一个巨型内核中实现了 NVLink 通信和张量核心计算的重叠。它需要使用对称内存进行多进程启动。使用方法python# 分配对称内存缓冲区# 注意需要 PyTorch 2.9buffer deep_gemm.get_symm_buffer_for_mega_moe(group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden)# 将权重FP4 带 UE8M0 缩放因子转换为所需布局transformed_l1, transformed_l2 deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights)# 在每次调用前将输入复制到缓冲区# 你可以将这些操作融合到之前的内核中buffer.x[:num_tokens].copy_(x_fp8)buffer.x_sf[:num_tokens].copy_(x_sf)buffer.topk_idx[:num_tokens].copy_(topk_idx)buffer.topk_weights[:num_tokens].copy_(topk_weights)# 运行融合的 Mega MoE 内核y torch.empty((num_tokens, hidden), dtype torch.bfloat16, device cuda)deep_gemm.fp8_fp4_mega_moe(y, transformed_l1, transformed_l2, buffer)完整的多进程设置和基准测试示例请参考 tests/test_mega_moe.py。实用函数除了上述内核库还提供了一些实用函数- deep_gemm.set_num_sms / get_num_sms设置/获取要使用的最大 SM 数量- deep_gemm.set_tc_util / get_tc_util设置/获取近似的张量核心利用率- deep_gemm.set_pdl / get_pdl启用/禁用程序依赖启动PDL- deep_gemm.set_mk_alignment_for_contiguous_layout / get_mk_alignment_for_contiguous_layout设置/获取连续布局的组级 M/K 对齐- deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout获取理论最小 M/K 对齐- deep_gemm.set_ignore_compile_dims配置 JIT 编译时要忽略的维度- deep_gemm.set_block_size_multiple_of将块大小限制为给定值的倍数- deep_gemm.transform_sf_into_required_layout将缩放因子转换为所需布局- deep_gemm.get_tma_aligned_size获取所需的 TMA 对齐大小- deep_gemm.get_mn_major_tma_aligned_tensor获取 MN 主序 TMA 对齐的张量- deep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor获取 MN 主序 TMA 对齐打包成 UE8M0 的张量- deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensorK 分组 GEMM 打包内核环境变量通用- DG_JIT_DEBUG0 或 1打印 JIT 调试信息默认为 0- DG_PRINT_CONFIGS0 或 1打印每个形状的选定配置默认为 0JIT 缓存- DG_JIT_CACHE_DIR字符串编译内核的缓存目录默认为 $HOME/.deep_gemm编译器选择- DG_JIT_USE_NVRTC0 或 1使用 NVRTC 代替 NVCC编译速度更快某些情况下可能性能较低默认为 0- DG_JIT_NVCC_COMPILER字符串NVCC 编译器路径默认为 torch.utils.cpp_extension.CUDA_HOME- DG_JIT_CPP_STANDARD整数C 标准版本默认为 20编译器输出- DG_JIT_PRINT_COMPILER_COMMAND0 或 1打印编译命令默认为 0- DG_JIT_PTXAS_VERBOSE0 或 1显示详细的 PTXAS 输出默认为 0- DG_JIT_PTXAS_CHECK0 或 1断言编译内核中无本地内存使用默认为 0- DG_JIT_PRINT_LOAD_TIME0 或 1打印内核加载时间默认为 0调试和性能分析- DG_JIT_WITH_LINEINFO0 或 1为性能分析工具嵌入源代码行信息默认为 0- DG_JIT_DUMP_ASM0 或 1转储 PTX 和 SASS默认为 0- DG_JIT_DUMP_PTX0 或 1转储 PTX 输出默认为 0- DG_JIT_DUMP_SASS0 或 1转储 SASS 输出默认为 0- DG_COMM_KERNEL_DEBUG0 或 1在每次 Mega MoE 调用前将对称缓冲区清零以进行调试默认为 0- DG_USE_NVIDIA_TOOLS0 或 1在外部 NVIDIA 工具下运行时跳过内部性能分析默认为 0构建选项- DG_SKIP_CUDA_BUILD0 或 1安装过程中跳过 CUDA 扩展构建默认为 0- DG_FORCE_BUILD0 或 1强制本地构建而不是下载预构建的 wheel 文件默认为 0- DG_JIT_USE_RUNTIME_API0 或 1使用 CUDA 运行时 API 加载内核需要 CUDA 运行时 12.8默认为 0更多示例和详细信息请参考测试代码或查看相应的 Python 文档。致谢DeepGEMM 受到 CUTLASS 项目的启发感谢并尊重开发者们许可证此代码库根据 MIT 许可证发布。引用bibtexmisc{deepgemm2025,title{DeepGEMM: clean and efficient BLAS kernel library on GPU},author{Chenggang Zhao and Zhean Xu and Liang Zhao and Jiashi Li and Chenhao Xu and Anyi Xu and Shengyu Liu and Kexing Zhou and Kuai Yu},year{2025},publisher {GitHub},howpublished {\url{https://github.com/deepseek - ai/DeepGEMM}},}

更多文章