|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +import triton |
| 5 | +import triton.language as tl |
| 6 | + |
| 7 | + |
| 8 | +@triton.autotune( |
| 9 | + configs=[ |
| 10 | + triton.Config( |
| 11 | + {"M_INCREMENT": M_INCREMENT}, |
| 12 | + num_warps=w, |
| 13 | + ) |
| 14 | + for M_INCREMENT in [1, 2, 4, 8, 16] |
| 15 | + for w in [2, 4, 8] |
| 16 | + ], |
| 17 | + key=["M", "N"], |
| 18 | +) |
| 19 | +@triton.jit |
| 20 | +def _rms_norm_bwd_fused( |
| 21 | + DX, # pointer to the input gradient |
| 22 | + DY, # pointer to the output gradient |
| 23 | + DW, # pointer to the partial sum of weights gradient |
| 24 | + X, # pointer to the input |
| 25 | + W, # pointer to the weights |
| 26 | + RMS, # pointer to the rms |
| 27 | + stride, # how much to increase the pointer when moving by 1 row |
| 28 | + N, # number of columns in X |
| 29 | + M, # number of rows in X |
| 30 | + BLOCK_SIZE_N: tl.constexpr, |
| 31 | + BLOCK_SIZE_M: tl.constexpr, |
| 32 | + M_INCREMENT: tl.constexpr, |
| 33 | + N_POW_2: tl.constexpr, |
| 34 | +): |
| 35 | + # Map the program id to the elements of X, DX, and DY it should compute. |
| 36 | + pid = tl.program_id(0) |
| 37 | + start_row = pid * BLOCK_SIZE_M |
| 38 | + grad_w = tl.full([BLOCK_SIZE_N], 0, tl.float32) |
| 39 | + cols = tl.arange(0, BLOCK_SIZE_N) |
| 40 | + if N_POW_2: |
| 41 | + col_mask = None |
| 42 | + else: |
| 43 | + col_mask = cols < N |
| 44 | + |
| 45 | + w = tl.load(W + cols, mask=col_mask).to(tl.float32)[None, :] |
| 46 | + |
| 47 | + for cur_row in tl.range(0, BLOCK_SIZE_M, M_INCREMENT): |
| 48 | + rows = start_row + cur_row + tl.arange(0, M_INCREMENT) |
| 49 | + row_indices = rows * stride |
| 50 | + row_mask = rows < M |
| 51 | + |
| 52 | + rms = tl.load(RMS + rows, mask=row_mask).to(tl.float32)[:, None] |
| 53 | + |
| 54 | + if N_POW_2: |
| 55 | + index_mask = row_mask[:, None] |
| 56 | + else: |
| 57 | + index_mask = row_mask[:, None] & col_mask[None, :] |
| 58 | + |
| 59 | + indices = row_indices[:, None] + cols[None, :] |
| 60 | + |
| 61 | + # Load data to SRAM |
| 62 | + x = tl.load(X + indices, mask=index_mask, other=0).to(tl.float32) |
| 63 | + dy = tl.load(DY + indices, mask=index_mask, other=0).to(tl.float32) |
| 64 | + |
| 65 | + # Compute dx |
| 66 | + m = dy * w |
| 67 | + row_dot = tl.sum(m * x, axis=1)[:, None] |
| 68 | + scale = -(1.0 / N) * rms * rms * rms |
| 69 | + dx = rms * m |
| 70 | + dx += scale * row_dot * x |
| 71 | + |
| 72 | + # Write dx |
| 73 | + tl.store(DX + indices, dx, mask=index_mask) |
| 74 | + |
| 75 | + grad_w += tl.sum((dy * x) * rms, axis=0) |
| 76 | + |
| 77 | + tl.store(DW + pid * N + cols, grad_w, mask=col_mask) |
| 78 | + |
| 79 | + |
| 80 | +class RMSNorm(torch.autograd.Function): |
| 81 | + @staticmethod |
| 82 | + def forward(ctx, x, normalized_shape, weight, eps): |
| 83 | + # allocate output |
| 84 | + y = torch.empty_like(x) |
| 85 | + # reshape input data into 2D tensor |
| 86 | + x_arg = x.reshape(-1, x.shape[-1]).to(weight.dtype) |
| 87 | + |
| 88 | + def rmsnorm_ref(inp, w, eps=1e-6): |
| 89 | + rms = 1.0 / torch.sqrt(torch.mean(inp.square(), dim=-1, keepdim=True) + eps) |
| 90 | + return (inp * rms * w).to(inp.dtype), rms |
| 91 | + |
| 92 | + y, rms = rmsnorm_ref(x_arg, weight, eps) |
| 93 | + ctx.save_for_backward(x, weight, rms) |
| 94 | + ctx.eps = eps |
| 95 | + return y |
| 96 | + |
| 97 | + @staticmethod |
| 98 | + def backward(ctx, dy): |
| 99 | + x, w, rms = ctx.saved_tensors |
| 100 | + x_arg = x.reshape(-1, x.shape[-1]) |
| 101 | + # heuristics for amount of parallel reduction stream for DW/DB |
| 102 | + N = w.shape[0] |
| 103 | + # allocate output |
| 104 | + dw = torch.empty((N,), dtype=w.dtype, device=w.device) |
| 105 | + dx = torch.empty_like(dy) |
| 106 | + |
| 107 | + M, N = x_arg.shape |
| 108 | + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count |
| 109 | + BLOCK_SIZE_M = min(2048, triton.next_power_of_2(M // (8 * NUM_SMS))) |
| 110 | + PARTIAL_SIZE = math.ceil(M / BLOCK_SIZE_M) |
| 111 | + |
| 112 | + # Columnwise stride for reducing partial sums at end, contiguous loads |
| 113 | + _dw = torch.empty((PARTIAL_SIZE, N), dtype=w.dtype, device=w.device) |
| 114 | + |
| 115 | + MAX_FUSED_SIZE = 65536 // x.element_size() |
| 116 | + BLOCK_SIZE = triton.next_power_of_2(N) |
| 117 | + assert ( |
| 118 | + BLOCK_SIZE <= MAX_FUSED_SIZE |
| 119 | + ), "This layer norm doesn't support feature dim >= 64KB." |
| 120 | + |
| 121 | + _rms_norm_bwd_fused[(PARTIAL_SIZE,)]( # |
| 122 | + dx, |
| 123 | + dy, |
| 124 | + _dw, |
| 125 | + x_arg, |
| 126 | + w, |
| 127 | + rms, |
| 128 | + x_arg.stride(0), |
| 129 | + N, |
| 130 | + M, |
| 131 | + BLOCK_SIZE_M=BLOCK_SIZE_M, |
| 132 | + BLOCK_SIZE_N=BLOCK_SIZE, |
| 133 | + N_POW_2=(N % BLOCK_SIZE == 0), |
| 134 | + ) |
| 135 | + |
| 136 | + dw = torch.sum(_dw, dim=0) |
| 137 | + |
| 138 | + return dx, None, dw, None, None |
| 139 | + |
| 140 | + |
| 141 | +rms_norm = RMSNorm.apply |
0 commit comments