Skip to content

Commit a2ac1d7

Browse files
authored
Optimized Triton RMSNorm Backwards
Differential Revision: D82789616 Pull Request resolved: #456
1 parent dee73ff commit a2ac1d7

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import math
2+
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
7+
8+
@triton.autotune(
9+
configs=[
10+
triton.Config(
11+
{"M_INCREMENT": M_INCREMENT},
12+
num_warps=w,
13+
)
14+
for M_INCREMENT in [1, 2, 4, 8, 16]
15+
for w in [2, 4, 8]
16+
],
17+
key=["M", "N"],
18+
)
19+
@triton.jit
20+
def _rms_norm_bwd_fused(
21+
DX, # pointer to the input gradient
22+
DY, # pointer to the output gradient
23+
DW, # pointer to the partial sum of weights gradient
24+
X, # pointer to the input
25+
W, # pointer to the weights
26+
RMS, # pointer to the rms
27+
stride, # how much to increase the pointer when moving by 1 row
28+
N, # number of columns in X
29+
M, # number of rows in X
30+
BLOCK_SIZE_N: tl.constexpr,
31+
BLOCK_SIZE_M: tl.constexpr,
32+
M_INCREMENT: tl.constexpr,
33+
N_POW_2: tl.constexpr,
34+
):
35+
# Map the program id to the elements of X, DX, and DY it should compute.
36+
pid = tl.program_id(0)
37+
start_row = pid * BLOCK_SIZE_M
38+
grad_w = tl.full([BLOCK_SIZE_N], 0, tl.float32)
39+
cols = tl.arange(0, BLOCK_SIZE_N)
40+
if N_POW_2:
41+
col_mask = None
42+
else:
43+
col_mask = cols < N
44+
45+
w = tl.load(W + cols, mask=col_mask).to(tl.float32)[None, :]
46+
47+
for cur_row in tl.range(0, BLOCK_SIZE_M, M_INCREMENT):
48+
rows = start_row + cur_row + tl.arange(0, M_INCREMENT)
49+
row_indices = rows * stride
50+
row_mask = rows < M
51+
52+
rms = tl.load(RMS + rows, mask=row_mask).to(tl.float32)[:, None]
53+
54+
if N_POW_2:
55+
index_mask = row_mask[:, None]
56+
else:
57+
index_mask = row_mask[:, None] & col_mask[None, :]
58+
59+
indices = row_indices[:, None] + cols[None, :]
60+
61+
# Load data to SRAM
62+
x = tl.load(X + indices, mask=index_mask, other=0).to(tl.float32)
63+
dy = tl.load(DY + indices, mask=index_mask, other=0).to(tl.float32)
64+
65+
# Compute dx
66+
m = dy * w
67+
row_dot = tl.sum(m * x, axis=1)[:, None]
68+
scale = -(1.0 / N) * rms * rms * rms
69+
dx = rms * m
70+
dx += scale * row_dot * x
71+
72+
# Write dx
73+
tl.store(DX + indices, dx, mask=index_mask)
74+
75+
grad_w += tl.sum((dy * x) * rms, axis=0)
76+
77+
tl.store(DW + pid * N + cols, grad_w, mask=col_mask)
78+
79+
80+
class RMSNorm(torch.autograd.Function):
81+
@staticmethod
82+
def forward(ctx, x, normalized_shape, weight, eps):
83+
# allocate output
84+
y = torch.empty_like(x)
85+
# reshape input data into 2D tensor
86+
x_arg = x.reshape(-1, x.shape[-1]).to(weight.dtype)
87+
88+
def rmsnorm_ref(inp, w, eps=1e-6):
89+
rms = 1.0 / torch.sqrt(torch.mean(inp.square(), dim=-1, keepdim=True) + eps)
90+
return (inp * rms * w).to(inp.dtype), rms
91+
92+
y, rms = rmsnorm_ref(x_arg, weight, eps)
93+
ctx.save_for_backward(x, weight, rms)
94+
ctx.eps = eps
95+
return y
96+
97+
@staticmethod
98+
def backward(ctx, dy):
99+
x, w, rms = ctx.saved_tensors
100+
x_arg = x.reshape(-1, x.shape[-1])
101+
# heuristics for amount of parallel reduction stream for DW/DB
102+
N = w.shape[0]
103+
# allocate output
104+
dw = torch.empty((N,), dtype=w.dtype, device=w.device)
105+
dx = torch.empty_like(dy)
106+
107+
M, N = x_arg.shape
108+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
109+
BLOCK_SIZE_M = min(2048, triton.next_power_of_2(M // (8 * NUM_SMS)))
110+
PARTIAL_SIZE = math.ceil(M / BLOCK_SIZE_M)
111+
112+
# Columnwise stride for reducing partial sums at end, contiguous loads
113+
_dw = torch.empty((PARTIAL_SIZE, N), dtype=w.dtype, device=w.device)
114+
115+
MAX_FUSED_SIZE = 65536 // x.element_size()
116+
BLOCK_SIZE = triton.next_power_of_2(N)
117+
assert (
118+
BLOCK_SIZE <= MAX_FUSED_SIZE
119+
), "This layer norm doesn't support feature dim >= 64KB."
120+
121+
_rms_norm_bwd_fused[(PARTIAL_SIZE,)]( #
122+
dx,
123+
dy,
124+
_dw,
125+
x_arg,
126+
w,
127+
rms,
128+
x_arg.stride(0),
129+
N,
130+
M,
131+
BLOCK_SIZE_M=BLOCK_SIZE_M,
132+
BLOCK_SIZE_N=BLOCK_SIZE,
133+
N_POW_2=(N % BLOCK_SIZE == 0),
134+
)
135+
136+
dw = torch.sum(_dw, dim=0)
137+
138+
return dx, None, dw, None, None
139+
140+
141+
rms_norm = RMSNorm.apply

tritonbench/operators/rms_norm/operator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
register_x_val,
1313
)
1414

15+
from . import fused_triton
16+
1517
try:
1618
from liger_kernel.transformers.rms_norm import LigerRMSNorm
1719
except ModuleNotFoundError:
@@ -140,6 +142,10 @@ def torch_compile_rms(self, H, input, weight) -> Callable:
140142
compiled = torch.compile(module, mode="max-autotune-no-cudagraphs")
141143
return lambda: compiled(input)
142144

145+
@register_benchmark()
146+
def triton_fused_rmsnorm(self, H, input, weight) -> Callable:
147+
return lambda: fused_triton.rms_norm(input, H, weight, self.eps)
148+
143149
@register_benchmark(enabled=is_hip() and HAS_AITER)
144150
def aiter(self, H, input, weight) -> Callable:
145151
module = AITerRMSNorm(hidden_size=H, eps=self.eps).to(self.device)

0 commit comments

Comments
 (0)