Skip to content

Commit d41d7b9

Browse files
committed
wip MoE refactor
Summary: now that the pytorch grouped_mm kernels don't require padding, refactoring the moe implementation to use that rather than what was there before. DONE -implement moe with grouped_mm [x] -add handling for generic module swap to AOQuantizable (MoEMapping) [x] -refactor MoEQuantConfig to swap generic modules [x] TODO -add dispatch from grouped_mm to linear decomposition of quantized kernel -compare linear decomposition vs new linear decomposition vs grouped_mm for eager, compile, autotuned compile linear decomposition -compare linear decomposition vs new linear decomposition for quantized kernels -add scaled_group_gemm and fbgemm kernel (probably in a new PR) ISSUE: the autotuned grouped_mm kernels don't give the correct output, but then work in eager and compile with reduce-overhead. why? see new_run.log output, first 2 runs are fine, line 144 is nonsense Test Plan: sh run.sh Reviewers: Subscribers: Tasks: Tags:
1 parent 0e00df3 commit d41d7b9

File tree

11 files changed

+806
-349
lines changed

11 files changed

+806
-349
lines changed

torchao/_models/mixtral-moe/generate.py

Lines changed: 55 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import torch._inductor.config
1515

1616
from torchao.utils import get_model_size_in_bytes
17+
from torchao.prototype.moe_quant import MoEFeedForwardAOQuantizable
18+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
19+
from model import MoEFeedForward
1720

1821
torch.manual_seed(0)
1922

@@ -187,7 +190,6 @@ def _load_model(checkpoint_path, device, precision):
187190

188191
B_INST, E_INST = "[INST]", "[/INST]"
189192

190-
191193
def main(
192194
prompt: str = "Hello, my name is",
193195
interactive: bool = False,
@@ -199,6 +201,7 @@ def main(
199201
checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"),
200202
compile: bool = True,
201203
compile_prefill: bool = False,
204+
compile_mode: str = "reduce-overhead",
202205
moe_quant: Optional[str] = None,
203206
profile: Optional[Path] = None,
204207
memory_profile: Optional[Path] = None,
@@ -212,6 +215,13 @@ def main(
212215
precision = torch.bfloat16
213216
is_chat = "chat" in str(checkpoint_path)
214217

218+
if batch_size > 1 and moe_quant is None:
219+
print(
220+
"Warning: Detected no moe_quant but batchsize>1. The default MoE implementation uses a lot of memory when batched,"+
221+
" if it OOMs you can instead run without quantization by specifying --moe_quant noquant which uses the AO quantizable"+
222+
"module without quantization to run the quantizable module without quantization"
223+
)
224+
215225
if device == "cuda" and memory_profile is not None:
216226
torch.cuda.memory._record_memory_history(
217227
True, trace_alloc_max_entries=500000, trace_alloc_record_context=True
@@ -236,10 +246,12 @@ def main(
236246
]
237247
)
238248

239-
from torchao.prototype.moe_quant.utils import (
249+
from torchao.prototype.moe_quant import (
240250
MoEQuantConfig,
251+
MoEMapping,
241252
UseFakeExtraDimTensor,
242-
cond_ffn_filter,
253+
MoEFeedForwardAOQuantizable,
254+
243255
)
244256
from torchao.quantization.quant_api import (
245257
Float8DynamicActivationFloat8WeightConfig,
@@ -255,71 +267,64 @@ def main(
255267

256268
if moe_quant:
257269
torch._dynamo.config.capture_dynamic_output_shape_ops = True
258-
config = None
270+
config = MoEQuantConfig(mapping=MoEMapping(target_module_type=MoEFeedForward))
259271
if "int8wo-base" in moe_quant:
260-
config = MoEQuantConfig(Int8WeightOnlyConfig())
272+
config.base_config = Int8WeightOnlyConfig()
261273

262274
elif "int8wo" in moe_quant:
263-
config = MoEQuantConfig(
264-
Int8WeightOnlyConfig(),
265-
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
266-
)
275+
config.base_config = Int8WeightOnlyConfig()
276+
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
267277

268278
elif "int8dq-base" in moe_quant:
269-
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
279+
config.base_config = Int8DynamicActivationInt8WeightConfig()
270280

271281
elif "int8dq" in moe_quant:
272-
config = MoEQuantConfig(
273-
Int8DynamicActivationInt8WeightConfig(),
274-
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
275-
)
282+
config.base_config = Int8DynamicActivationInt8WeightConfig()
283+
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
276284

277285
elif "int4wo-base" in moe_quant:
278-
config = MoEQuantConfig(Int4WeightOnlyConfig())
286+
config.base_config = Int4WeightOnlyConfig()
279287

280288
elif "int4wo" in moe_quant:
281-
config = MoEQuantConfig(
282-
Int4WeightOnlyConfig(),
283-
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
284-
)
289+
config.base_config = Int4WeightOnlyConfig()
290+
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
285291

286292
elif "fp8wo-base" in moe_quant:
287-
config = MoEQuantConfig(Float8WeightOnlyConfig())
293+
config.base_config = Float8WeightOnlyConfig()
288294

289295
elif "fp8wo" in moe_quant:
290-
config = MoEQuantConfig(
291-
Float8WeightOnlyConfig(),
292-
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
293-
)
296+
config.base_config = Float8WeightOnlyConfig()
297+
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
294298

295299
elif "fp8dq-base" in moe_quant:
296-
config = MoEQuantConfig(
297-
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
298-
)
300+
config.base_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
299301

300302
elif "fp8dq" in moe_quant:
301-
config = MoEQuantConfig(
302-
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
303-
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
304-
)
303+
config.base_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
304+
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
305305

306306
elif "intxdq" in moe_quant:
307-
config = MoEQuantConfig(
308-
Int8DynamicActivationIntxWeightConfig(
307+
config.base_config = Int8DynamicActivationIntxWeightConfig(
309308
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
310309
),
311-
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
312-
)
310+
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
311+
elif "noquant" in moe_quant:
312+
pass
313313
else:
314314
assert config is not None, (
315315
f"expected moe_quant to match one of the options but got {moe_quant}"
316316
)
317317

318-
if config is not None:
319-
quantize_(model, config, filter_fn=cond_ffn_filter, device=device)
320-
print(
321-
f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds"
322-
)
318+
def filter_fn(mod, fqn):
319+
return isinstance(mod, MoEFeedForward)
320+
321+
model.layers = model.layers
322+
323+
quantize_(model, config, filter_fn=filter_fn, device=device)
324+
325+
print(
326+
f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds"
327+
)
323328

324329
model.to(device=device)
325330
device_sync(device=device)
@@ -335,12 +340,12 @@ def main(
335340

336341
global decode_one_token, prefill
337342

338-
if batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant):
343+
if True and (batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant)):
339344
decode_one_token = torch.compile(
340-
decode_one_token, mode="reduce-overhead", fullgraph=True
345+
decode_one_token, mode=compile_mode, fullgraph=True
341346
)
342347
else:
343-
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead")
348+
decode_one_token = torch.compile(decode_one_token, mode=compile_mode)
344349

345350
if args.compile_prefill:
346351
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
@@ -474,6 +479,12 @@ def callback(x):
474479
action="store_true",
475480
help="Whether to compile the prefill (improves prefill perf, but higher compile times)",
476481
)
482+
parser.add_argument(
483+
"--compile_mode",
484+
type=str,
485+
default="reduce-overhead",
486+
help="which torch.compile mode to use: reduce-overhead or max-autotune, does nothing if --compile is not set.",
487+
)
477488
# parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8')
478489
parser.add_argument(
479490
"--moe_quant",
@@ -499,6 +510,7 @@ def callback(x):
499510
args.checkpoint_path,
500511
args.compile,
501512
args.compile_prefill,
513+
args.compile_mode,
502514
args.moe_quant,
503515
args.profile,
504516
args.memory_profile,

torchao/_models/mixtral-moe/model.py

Lines changed: 98 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class TransformerBlock(nn.Module):
156156
def __init__(self, config: ModelArgs) -> None:
157157
super().__init__()
158158
self.attention = Attention(config)
159-
self.block_sparse_moe = MOEFeedForwardAOQuantizable(config)
159+
self.block_sparse_moe = MoEFeedForward(config)
160160
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
161161
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
162162

@@ -225,41 +225,39 @@ def forward(
225225
y = self.wo(y)
226226
return y
227227

228+
class MoEFeedForward(nn.Module):
229+
def __init__(self, config) -> None:
230+
super().__init__()
231+
self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
232+
self.cond_ffn = ConditionalFeedForward(config)
233+
self.dim = config.dim
234+
self.num_activated_experts = config.num_activated_experts
235+
def forward(self, x: Tensor) -> Tensor:
236+
x = x.view(-1, self.dim)
237+
# T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
238+
# x: [T, D]
239+
scores = self.gate(x) # [T, E]
240+
expert_weights = F.softmax(scores, dim=-1)
241+
expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
242+
expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
243+
expert_outs = self.cond_ffn(x, expert_indices)
244+
return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)
228245

229-
# class ConditionalFeedForward(nn.Module):
230-
# def __init__(self, config):
231-
# super().__init__()
232-
# self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
233-
# self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
234-
# self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
235-
236-
# def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
237-
# w1_weights = self.w1[expert_indices] # [T, A, D, D]
238-
# w3_weights = self.w3[expert_indices] # [T, A, D, D]
239-
# w2_weights = self.w2[expert_indices] # [T, A, D, D]
240-
# x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
241-
# x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
242-
# expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
243-
# return expert_outs
244-
245-
246-
# class MOEFeedForward(nn.Module):
247-
# def __init__(self, config) -> None:
248-
# super().__init__()
249-
# self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
250-
# self.cond_ffn = ConditionalFeedForward(config)
251-
# self.dim = config.dim
252-
# self.num_activated_experts = config.num_activated_experts
253-
# def forward(self, x: Tensor) -> Tensor:
254-
# x = x.view(-1, self.dim)
255-
# # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
256-
# # x: [T, D]
257-
# scores = self.gate(x) # [T, E]
258-
# expert_weights = F.softmax(scores, dim=-1)
259-
# expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
260-
# expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
261-
# expert_outs = self.cond_ffn(x, expert_indices)
262-
# return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)
246+
class ConditionalFeedForward(nn.Module):
247+
def __init__(self, config):
248+
super().__init__()
249+
self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
250+
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
251+
self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
252+
253+
def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
254+
w1_weights = self.w1[expert_indices] # [T, A, D, D]
255+
w3_weights = self.w3[expert_indices] # [T, A, D, D]
256+
w2_weights = self.w2[expert_indices] # [T, A, D, D]
257+
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
258+
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
259+
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
260+
return expert_outs
263261

264262

265263
class RMSNorm(nn.Module):
@@ -301,6 +299,8 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
301299
x_out2 = x_out2.flatten(3)
302300
return x_out2.type_as(x)
303301

302+
#TODO delete
303+
304304

305305
# T tokens
306306
# E experts
@@ -310,7 +310,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
310310
# T'(e) tokens for expert e
311311

312312

313-
class MOEFeedForwardAOQuantizable(nn.Module):
313+
class MoEFeedForwardAOQuantizable(nn.Module):
314314
def __init__(self, config) -> None:
315315
super().__init__()
316316
self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
@@ -337,7 +337,7 @@ class ConditionalFeedForwardAOQuantizable(nn.Module):
337337
def __init__(self, config):
338338
super().__init__()
339339
self.config = config
340-
self.w1 = nn.Parameter(
340+
self.w1 = nn.Parameter( # num exp, expert_dim, hidden_dim
341341
torch.empty(config.num_experts, config.intermediate_size, config.dim)
342342
) # E, I, D
343343
self.w2 = nn.Parameter(
@@ -347,6 +347,14 @@ def __init__(self, config):
347347
torch.empty(config.num_experts, config.intermediate_size, config.dim)
348348
) # E, I, D
349349
self.num_experts = config.num_experts
350+
self.perf_is_optimized = False
351+
352+
def optimize_perf(self):
353+
self.w13 = torch.cat((self.w1, self.w3), dim=1)
354+
self.w13 = torch.nn.Parameter(self.w13.transpose(-2,-1).contiguous().transpose(-2,-1))
355+
self.w2 = torch.nn.Parameter(self.w2.transpose(-2, -1).contiguous().transpose(-2, -1))
356+
del self.w1, self.w3
357+
self.perf_is_optimized = True
350358

351359
def forward(
352360
self,
@@ -355,8 +363,60 @@ def forward(
355363
expert_weights: Tensor, # T, A
356364
num_activated_experts: int,
357365
) -> Tensor:
366+
367+
358368
num_tokens, dim = x.shape
359-
num_token_activations = num_tokens * num_activated_experts
369+
num_token_activations = expert_indices.numel()
370+
371+
372+
ordered_token_activations = expert_indices.view(-1).argsort(stable=True)
373+
ordered_token_indices = (
374+
ordered_token_activations.div(num_activated_experts)
375+
.floor()
376+
.to(torch.int32)
377+
) # [T]
378+
379+
indices_for_histc = expert_indices.view(-1) if expert_indices.is_cuda else expert_indices.float().view(-1) # histc doesn't work on cpu for integers
380+
num_tokens_per_expert = torch.histc(
381+
indices_for_histc,
382+
bins=self.num_experts,
383+
min=0,
384+
max=self.num_experts,
385+
)
386+
offs = num_tokens_per_expert.cumsum(dim=0).to(torch.int32)
387+
ordered_inputs = x[ordered_token_indices]
388+
389+
if self.optimized_perf:
390+
x1, x3 = torch._grouped_mm(ordered_inputs, self.w13.transpose(-2, -1), offs).split(self.config.intermediate_size, dim=1)
391+
y1 = F.silu(x1) * x3
392+
else:
393+
x1 = F.silu(torch._grouped_mm(ordered_inputs, self.w1.transpose(-2,-1), offs))
394+
x3 = torch._grouped_mm(ordered_inputs, self.w3.transpose(-2,-1), offs)
395+
y1 = x1 * x3
396+
ordered_outs = torch._grouped_mm(y1, self.w2.transpose(-2,-1), offs)
397+
# ordered_outs = torch._grouped_mm(y1, self.w2, offs)
398+
399+
ordered_token_activation_weights = expert_weights.view(-1, 1)[
400+
ordered_token_activations
401+
].view(-1, 1) # [T*A, 1]
402+
weighted_ordered_outs = (
403+
ordered_outs * ordered_token_activation_weights
404+
) # [T*A, D]
405+
406+
# sum weighted token-activation outputs together for each token
407+
final_out = torch.zeros_like(x) # [T, D]
408+
final_out = final_out.scatter_add(
409+
dim=0,
410+
index=ordered_token_indices.unsqueeze(-1)
411+
.expand(num_token_activations, dim)
412+
.to(torch.int64),
413+
src=weighted_ordered_outs,
414+
)
415+
416+
return final_out
417+
418+
419+
360420
if x.shape[0] == 1 and not isinstance(
361421
self.w1, FakeExtraDimTensor
362422
): # only 1 token (can be done without graph breaks when compiled)

0 commit comments

Comments
 (0)