手写 Triton Softmax Kernel:程序实例、块大小、mask 与指针算术

简介: 本文以Softmax为切入点,用通俗图解与手写Triton内核,揭开GPU编程黑箱:从块并行模型、片上计算融合,到内存带宽瓶颈与性能悬崖现象,带你真正理解AI算子在GPU上如何高效运行。

GPU 编程看起来总像黑魔法,满眼是 warpsshared memorytensor cores,还有 kernel 里古怪的索引运算。但是这篇文章从一个具体例子入手帮你理解 Triton:从头实现一个 softmax kernel。

以官方 Triton 教程为基础,深入代码背后的原理并配上手绘图解。如果你觉得 GPU 编程教程总是太晦涩,这篇文章正好可以用来入门。

我们的目标不止是写一个 kernel而是理解现代 AI 工作负载在 GPU 上到底怎么跑。

最后会把 kernel 放到 RTX 5090 上跟 PyTorch 的原生 softmax 跑个 benchmark。结果不是简单的"Triton 赢了"——这里有个性能悬崖,教会你 GPU 编程里很重要的一件事。

Softmax:简单的数学,隐藏的内存问题

逐行 softmax 从数学上很简单:每行是一个独立 logit 向量,softmax 把它转成概率。

比如一个

2×3

矩阵,不是对六个值算一个大 softmax,而是算两个独立的 softmax——行 0 一个、行 1 一个。

难点不在数学而是在 GPU 上的执行方式:数据搬几次、中间值存在哪、GPU 是花时间算还是在等内存。

简单的 PyTorch 实现把 softmax 拆成几个独立的张量操作:max、减法、指数、求和、除法。每一步都可能从全局内存读数据再把中间值写回去。

而融合的 Triton kernel 改变了这个模式:一次加载一行,所有 softmax 步骤在数据留在片上时完成,最后一次性写回结果。

这里的片外指 GPU 全局内存/DRAM:大但慢。片上指 GPU 计算单元内部的内存(寄存器或共享内存/SRAM):快得多但小得多。

从概念上说一个 Triton 程序处理一行,但实际运行时是大量 Triton 程序并行跑。

一个简单的 Triton模型

在看 softmax kernel 之前,先搭个简单的模型。

一个

3072

长度的向量

X

,要给每个元素减 1。

CPU 思路是顺序循环:

 foriinrange(3072):  
     X[i] =X[i] -1

在 GPU 上就不是这样了,GPU 要把向量切成块,并行处理。

Triton 里,一个 kernel 描述一个程序实例的行为。启动 kernel 时,启动一个网格,里面很多程序实例并行跑。

 BLOCK_SIZE=1024

每个程序实例处理

1024

个元素。

 3072 / 1024 = 3 → 需要 3 个程序实例。  
 program 0 → elements 0-1023  
 program 1 → elements 1024-2047  
 program 2 → elements 2048-3071

每个程序实例拿到自己的

program_id

,用它定位数据切片,执行相同操作。

Softmax kernel 里也一样,只是每个程序实例处理矩阵的一行,不是向量的一块。

逐行拆解 Triton Softmax Kernel

一个 Triton 程序实例一次处理一行。启动的程序数少于行数时,每个程序以固定步长在矩阵中跳跃,处理多行。

 @triton.jit  
def softmax_kernel(  
    output_ptr, input_ptr,  
    input_row_stride, output_row_stride,  
    n_rows, n_cols,  
    BLOCK_SIZE: tl.constexpr,  
    num_stages: tl.constexpr,  
):  
    row_start = tl.program_id(0)   # 当前程序实例 ID  
    row_step = tl.num_programs(0)  # 轴 0 上的实例总数  

     for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):

tl.program_id(0)拿到当前实例的 id。

如果启了 4 个程序,program 0 从 row 0 开始,program 1 从 row 1 开始以此类推,每个程序按

row_step

跳跃处理后续行。

row_stride

告诉程序在内存里走多远才到下一行的开头。一个常见错误是认为下一行总在

n_cols

个元素之后开始——对紧凑连续张量是对的但不是所有布局都这样。

 # 指向当前行在内存中的起始位置  
 row_start_ptr = input_ptr + row_idx * input_row_stride  
 col_offsets = tl.arange(0, BLOCK_SIZE)  
 input_ptrs = row_start_ptr + col_offsets


区分两个概念:

n_cols

是逻辑列数,

input_row_stride

是两行之间的物理内存距离。

 mask = col_offsets < n_cols  
 row = tl.load(input_ptrs, mask=mask, other=-float("inf"))

mask 告诉 Triton 只加载实际列,假列用

-inf

填充,因为

exp(-inf) = 0

不影响 softmax 分母。

 row_minus_max = row - tl.max(row, axis=0)    
 numerator = tl.exp(row_minus_max)    
 denominator = tl.sum(numerator, axis=0)    
 softmax_output = numerator / denominator

先减最大值保数值稳定,不改变 softmax 结果但防止指数溢出。这些操作都在同一个融合的 Triton 程序里——

row_minus_max

numerator

denominator

不会作为中间张量写回全局内存。

启动 Kernel:Python 包装器

Triton kernel 描述一个程序实例内部干什么,但实际问题需要 Python 代码来回答:块多大?多少 warp?启动几个程序?

 def softmax(x):  
     n_rows, n_cols = x.shape  
     BLOCK_SIZE = triton.next_power_of_2(n_cols)

选择 2 的幂的 BLOCK_SIZE——适合 Triton 的块编程模型和归约操作。一行 3000 列?BLOCK_SIZE 用 4096,多余的用 mask 屏蔽。

 num_warps = 8

Warp 是一组一起执行的 GPU 线程,

num_warps = 8

意味着每个 Triton 程序实例用 8 个 warp。

 num_stages = 4 if SIZE_SMEM > 200000 else 2

num_stages和程序、warp 是不同的,它帮助同一程序内的循环迭代重叠——比如一轮加载、一轮计算、一轮写入同时进行。不过更多阶段用更多片上资源并不一定更好。

 y = torch.empty_like(x)

为输出分配和输入同 shape、dtype、device 的张量。

 kernel = softmax_kernel.warmup(  
    y, x, x.stride(0), y.stride(0),  
    n_rows, n_cols,  
    BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps,  
    grid=(1,),  
)  
kernel._init_handles()  
n_regs = kernel.n_regs  
 size_smem = kernel.metadata.shared

先编译一次 kernel,看看一个程序实例消耗多少寄存器和共享内存。

GPU 流多处理器资源有限。每个 SM 有固定的寄存器和共享内存预算。一个程序用太多,同一 SM 能同时跑的程序就少。

 occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)  
 occupancy = min(occupancy, SIZE_SMEM // size_smem)  
 num_programs = NUM_SM * occupancy  
 num_programs = min(num_programs, n_rows)

占用率受限于最先耗尽的资源。这是持久化风格 kernel:不是每行启一个程序,而是启足够程序占满 GPU,每个程序循环处理多行。

基准测试

RTX 5090 上逐行 softmax benchmark,

_M = 4096_

*,

_N_

变化。*

中小行大小下 PyTorch 更快,意料之中。但

N ≈ 8700

附近两边都撞上性能悬崖。之后 Triton kernel 反超。

这不意味着 Triton 万能更快,因为GPU 性能高度依赖张量形状、块大小、资源使用。y 轴是有效带宽,从输入输出张量大小算出,不是每次内部内存事务。

Triton 实现中,

N

超过

8192

BLOCK_SIZE

跳到

16384

,每个程序实例内部操作更大的块,资源压力上升,性能出现突变。

总结

Triton 可以让你在接近 Python 的层面写 GPU kernel 的方式。这个例子也告诉我们不是 Triton 总比 PyTorch 快,因为PyTorch 已经高度优化了。

本文代码

https://avoid.overfit.cn/post/cc8e8c270bb340a9abfb795f730546a6

by Lounis Hamroun

目录
相关文章
|
15天前
|
人工智能 自然语言处理 文字识别
阿里云百炼Qwen3.7-Max简介:能力、优势、支持订阅计划参考
Qwen3.7-Max是阿里云百炼面向智能体时代推出的新一代旗舰模型,对标GPT-5.5、Claude Opus 4.7等闭源旗舰。该模型支持百万级token上下文窗口,具备顶级推理能力、多模态搜索与视觉理解增强、流式输出低延迟响应等核心优势,覆盖编程、办公、长周期自主执行等复杂场景。同时支持OpenAI接口兼容,便于系统快速迁移。用户可通过Token Plan团队或节省计划等订阅方式灵活调用,适合企业级高要求场景使用。
5801 29
阿里云百炼Qwen3.7-Max简介:能力、优势、支持订阅计划参考
|
10天前
|
存储 定位技术 数据库
CodeGraph 如何让 Claude Code减少 7 成工具调用?
CodeGraph 为 Coding Agent 提供本地代码知识图谱,把函数、类、调用链和框架路由提前整理成“项目地图”,减少盲目搜索和文件读取。它不是新 Agent,而是上下文基础设施,让 Agent 更快找到正确代码路径,平均减少 7 成工具调用。
1168 2
|
7天前
|
人工智能 安全 定位技术
CodeGraph深度解析 让Claude Code工具调用直降七成的核心原理与实操教程
如今以Claude Code为代表的AI编程智能体已经成为开发者日常编码、项目重构、漏洞修复的必备工具。但在长期使用过程中,几乎所有开发者都会遇到同一个明显痛点:AI虽然具备强大的代码生成与分析能力,却常常陷入盲目探索的循环中。
944 1
|
17天前
|
人工智能 自然语言处理 供应链
|
8天前
|
人工智能 弹性计算 安全
阿里云618活动时间、活动入口、优惠活动详细解读
2026年阿里云618创新加速季已全面开启,作为年度力度最大的云产品促销活动,本次大促覆盖轻量应用服务器、ECS云服务器、GPU云服务器、数据库、AI算力、安全服务、CDN等全品类产品,推出5亿元算力补贴、新用户限时秒杀、普惠满减、企业专享、免费试用、云大使返佣等多重福利,个人开发者、中小企业、AI团队均可享受专属低价。本文将系统梳理2026年阿里云618活动的完整时间节点、官方参与入口、各类优惠细则、使用规则、热门产品推荐及实操代码,帮助用户精准参与、高效省钱,以最低成本完成上云部署。
737 4
|
23天前
|
人工智能 开发工具 iOS开发
Claude Code 新手完全上手指南:安装、国产模型配置与常用命令全解
Claude Code 是一款运行在终端环境中的 AI 编程助手,能够直接在命令行中完成代码生成、项目分析、文件修改、命令执行、Git 管理等开发全流程工作。它最大的特点是**任务驱动、终端原生、轻量高效、多模型兼容**,无需图形界面、不依赖 IDE 插件,能够深度融入开发者日常工作流。
3831 15
|
8天前
|
运维
欢迎报名|2026 Agentic AICon—智能体基础设施与AgentOps专场,邀您参会
欢迎报名|2026 Agentic AICon—智能体基础设施与AgentOps专场,邀您参会
1426 0