Skip to content

Commit a551941

Browse files
Make token group alignment size configurable (#1503)
- For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
1 parent d655e16 commit a551941

File tree

3 files changed

+81
-71
lines changed

3 files changed

+81
-71
lines changed

test.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import functional as F
4+
from dataclasses import dataclass
5+
6+
7+
@dataclass
8+
class Config:
9+
num_experts=2
10+
intermediate_size=1024
11+
dim=2048
12+
13+
class MoE(nn.Module):
14+
def __init__(self, config):
15+
super().__init__()
16+
self.config = config
17+
self.w1 = nn.Parameter( # num exp, expert_dim, hidden_dim
18+
torch.empty(config.num_experts, config.intermediate_size, config.dim)
19+
) # E, I, D
20+
self.w2 = nn.Parameter(
21+
torch.empty(config.num_experts, config.dim, config.intermediate_size)
22+
) # E, D, I
23+
self.w3 = nn.Parameter(
24+
torch.empty(config.num_experts, config.intermediate_size, config.dim)
25+
) # E, I, D
26+
self.w2 = torch.nn.Parameter(self.w2.transpose(-2, -1).contiguous().transpose(-2, -1))
27+
28+
nn.init.normal_(self.w1, std=0.02)
29+
nn.init.normal_(self.w2, std=0.02)
30+
nn.init.normal_(self.w3, std=0.02)
31+
self.w13 = torch.cat((self.w1, self.w3), dim=1)
32+
self.w13 = torch.nn.Parameter(self.w13.transpose(-2,-1).contiguous().transpose(-2,-1))
33+
34+
def forward(self, ordered_inputs, optim=True):
35+
M = ordered_inputs.size(0)
36+
group_size = M // self.config.num_experts
37+
offs = torch.arange(group_size, M + 1, group_size, device="cuda", dtype=torch.int32)
38+
39+
# Since we do do grouped gemms with same size:
40+
# - x @ w1 => (M, K) @ (E, K, D)
41+
# - x @ w3 => (M, K) @ (E, K, D)
42+
# We can concatenate w1 and w3 along the K dim so we have:
43+
# (M, K) @ (E, 2*K, D) -> (M, E, D)
44+
# and K dim of inputs is broadcasted to 2*K for the gemm.
45+
breakpoint()
46+
x13 = torch._grouped_mm(ordered_inputs, self.w13.transpose(-2, -1), offs)
47+
x1, x3 = x13.split(self.config.intermediate_size, dim=1)
48+
y1 = F.silu(x1) * x3
49+
ordered_outs1 = torch._grouped_mm(y1, self.w2.transpose(-2,-1), offs)
50+
51+
x1 = F.silu(torch._grouped_mm(ordered_inputs, self.w1.transpose(-2,-1), offs))
52+
x3 = torch._grouped_mm(ordered_inputs, self.w3.transpose(-2,-1), offs)
53+
y1 = x1 * x3
54+
ordered_outs2 = torch._grouped_mm(y1, self.w2.transpose(-2,-1), offs)
55+
56+
assert torch.equal(ordered_outs1, ordered_outs2)
57+
return ordered_outs1
58+
59+
config = Config()
60+
m = MoE(config).cuda().bfloat16()
61+
ordered_inputs = torch.randn(256, config.intermediate_size, device="cuda", dtype=torch.bfloat16)
62+
m(ordered_inputs, optim=True)

torchtitan/experiments/llama4/infra/expert_parallel.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -272,24 +272,22 @@ def expert_parallel(func: Callable) -> Callable:
272272
"""
273273

274274
def wrapper(
275-
w1: torch.Tensor,
275+
w13: torch.Tensor,
276276
w2: torch.Tensor,
277-
w3: torch.Tensor,
278277
x: torch.Tensor,
279278
num_tokens_per_expert: torch.Tensor | None = None,
280279
) -> torch.Tensor:
281280
global TOKEN_GROUP_ALIGN_SIZE_M
282-
if isinstance(w1, DTensor):
283-
w1 = w1.to_local()
281+
if isinstance(w13, DTensor):
282+
w13 = w13.to_local()
284283
w2 = w2.to_local()
285-
w3 = w3.to_local()
286284

287285
if num_tokens_per_expert is not None:
288286
from torchtitan.experiments.kernels.moe.indices import (
289287
generate_permute_indices,
290288
)
291289

292-
experts_per_ep_rank = w1.shape[0]
290+
experts_per_ep_rank = w13.shape[0]
293291
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank
294292

295293
with torch.no_grad():
@@ -309,7 +307,7 @@ def wrapper(
309307
input_shape = x.shape
310308
x = x[permuted_indices, :]
311309

312-
out = func(w1, w2, w3, x, num_tokens_per_expert)
310+
out = func(w13, w2, x, num_tokens_per_expert)
313311

314312
if num_tokens_per_expert is not None:
315313
out_unpermuted = out.new_empty(input_shape)

torchtitan/experiments/llama4/model/moe.py

Lines changed: 14 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -23,76 +23,27 @@ def __init__(
2323
):
2424
super().__init__()
2525
self.num_experts = num_experts
26-
self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
27-
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
28-
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
26+
# Combine w1 and w3 into a single tensor to perform so we can combine
27+
# `x @ w1` and `x @ w3` into a single grouped mm.
28+
self.w13 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim * 2))
29+
self.w2 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
2930
self.use_grouped_mm = use_grouped_mm
3031

3132
def forward(
3233
self,
3334
x: torch.Tensor,
3435
num_tokens_per_expert: torch.Tensor | None = None,
3536
) -> torch.Tensor:
36-
if self.use_grouped_mm:
37-
return GroupedExperts._run_experts_grouped_mm(
38-
self.w1, self.w2, self.w3, x, num_tokens_per_expert
39-
)
40-
else:
41-
return GroupedExperts._run_experts_for_loop(
42-
self.w1, self.w2, self.w3, x, num_tokens_per_expert
43-
)
44-
45-
# TODO: keeping this for-loop implementation for comparison
46-
# and readability, may remove later
47-
@expert_parallel
48-
@staticmethod
49-
def _run_experts_for_loop(
50-
w1: torch.Tensor,
51-
w2: torch.Tensor,
52-
w3: torch.Tensor,
53-
x: torch.Tensor,
54-
num_tokens_per_expert: torch.Tensor | None = None,
55-
) -> torch.Tensor:
56-
if num_tokens_per_expert is not None:
57-
# NOTE: this would incur a synchronization between device and host
58-
num_tokens_per_expert = num_tokens_per_expert.tolist()
59-
60-
# side-effect code due to the usage of generate_permute_indices
61-
num_padding = x.shape[0] - sum(num_tokens_per_expert)
62-
63-
# a tuple of tensors indexed by experts
64-
# each with shape (tokens_per_expert(varying), dim)
65-
x = torch.split(
66-
x[: sum(num_tokens_per_expert)],
67-
split_size_or_sections=num_tokens_per_expert,
68-
dim=0,
69-
)
70-
out_experts_splits = []
71-
for expert_idx, x_expert in enumerate(x):
72-
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
73-
h = h * torch.matmul(x_expert, w3[expert_idx])
74-
h = torch.matmul(h, w2[expert_idx])
75-
# h shape (tokens_per_expert(varying), dim)
76-
out_experts_splits.append(h)
77-
out = torch.cat(out_experts_splits, dim=0)
78-
79-
# side-effect code due to the usage of generate_permute_indices
80-
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
81-
else:
82-
# x shape (num_experts, tokens_per_expert, dim)
83-
h = F.silu(torch.bmm(x, w1))
84-
h = h * torch.bmm(x, w3)
85-
# out shape (num_experts, tokens_per_expert, dim)
86-
out = torch.bmm(h, w2)
37+
return GroupedExperts._run_experts_grouped_mm(
38+
self.w13, self.w2, x, num_tokens_per_expert
39+
)
8740

88-
return out
8941

9042
@expert_parallel
9143
@staticmethod
9244
def _run_experts_grouped_mm(
93-
w1: torch.Tensor,
45+
w13: torch.Tensor,
9446
w2: torch.Tensor,
95-
w3: torch.Tensor,
9647
x: torch.Tensor,
9748
num_tokens_per_expert: torch.Tensor | None = None,
9849
) -> torch.Tensor:
@@ -105,16 +56,14 @@ def _run_experts_grouped_mm(
10556
# fall back to regular bmm between 3D tensors
10657
assert x.dim() == 3
10758

108-
h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
109-
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
110-
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)
111-
59+
x1, x3 = torch._grouped_mm(x, w13.transpose(-2, -1), offs=offsets).chunk(2, dim=-1)
60+
y = F.silu(x1) * x3
61+
out = torch._grouped_mm(y, w2.transpose(-2, -1), offs=offsets).type_as(x)
11262
return out
11363

11464
def init_weights(self, init_std: float):
115-
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
65+
nn.init.trunc_normal_(self.w13, mean=0.0, std=0.02)
11666
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
117-
nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)
11867

11968

12069
class TokenChoiceTopKRouter(nn.Module):
@@ -299,7 +248,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
299248

300249
# shared expert
301250
if self.shared_expert is not None:
302-
out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape(
251+
out = self.shared_expert(x.reshape(1, bs * slen, dim))
252+
out = out.reshape(
303253
bs * slen, dim
304254
)
305255
else:

0 commit comments

Comments
 (0)