Skip to content

Commit 3626393

Browse files
[mxfp8] add usage documentation and benchmarks
1 parent b39377f commit 3626393

File tree

3 files changed

+190
-0
lines changed

3 files changed

+190
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ To accelerate contributions to and innovations around torchtitan, we host an [`e
5959
- [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine-tuning
6060
5. `torch.compile` support
6161
6. [Float8](https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323) support ([how-to](docs/float8.md))
62+
7. [MXFP8 training for dense and MoE models](docs/mxfp8.md) on Blackwell GPUs.
6263
7. DDP and HSDP
6364
8. [TorchFT](https://github.com/pytorch/torchft) integration
6465
9. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md)

assets/images/mxfp8_with_loss.png

45.9 KB
Loading

docs/mxfp8.md

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
## MXFP8 Training on B200 GPUs
2+
3+
MXFP8 training can provide substantial training speedups for models where the majority of GEMMs are sufficiently large. MXFP8 is a microscaling format from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) that uses block-based scaling to maintain numerical accuracy while leveraging low-precision tensor cores. On NVIDIA B200 GPUs, MXFP8 training achieves up to **28% speedup** over bfloat16 baseline with minimal accuracy degradation.
4+
5+
> **📖 For a comprehensive case study of using TorchTitan MXFP8 to train dense models at scale**, see our blog post: [Accelerating 2K+ Scale Pre-training up to 1.28x with TorchAO MXFP8 and TorchTitan on Crusoe B200 Cluster](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/)
6+
7+
### Table of Contents
8+
9+
- [Requirements](#requirements)
10+
- [How MXFP8 Works](#how-mxfp8-works)
11+
- [MXFP8 for Linear Modules](#mxfp8-for-linear-modules)
12+
- [Usage](#usage)
13+
- [MXFP8 for Grouped GEMMs (MoE)](#mxfp8-for-grouped-gemms-moe)
14+
- [Usage](#usage-1)
15+
- [Example TOML Configuration](#example-toml-configuration)
16+
- [Performance](#performance)
17+
- [Dense Models](#dense-models)
18+
- [MoE models](#moe-models)
19+
- [Composability](#composability)
20+
- [Known Limitations](#known-limitations)
21+
- [Additional Resources](#additional-resources)
22+
23+
### Requirements
24+
25+
- NVIDIA B200 (SM100 or SM100a)
26+
- PyTorch nightly
27+
- TorchAO v0.14.0 or newer ([TorchAO Installation Guide](https://github.com/pytorch/ao#installation))
28+
29+
Note: GB200 is also supported but requires building torchao from source (see installation guide above).
30+
31+
### How MXFP8 Works
32+
33+
MXFP8 differs from standard Float8 training in its scaling approach:
34+
35+
- **Block-based scaling**: Instead of using a single scale factor per tensor (tensorwise) or per row/column (rowwise), MXFP8 uses block-based scaling with a default block size of 1x32 elements. Each block of 32 elements shares a common scale factor. The data dtype is `torch.float8_e4m3fn`, and the scale factor dtype is `torch.float8_e8mfnu`.
36+
- **Native hardware support**: On NVIDIA B200 (Blackwell) GPUs, MXFP8 GEMMs are accelerated using cuBLAS kernels exposed via `torch._scaled_mm`, achieving up to 2x speedup over bfloat16 on common shapes.
37+
- **Dynamic quantization**: Both activations and weights are dynamically quantized to MXFP8 during forward and backward passes, with high-precision accumulation.
38+
39+
### MXFP8 for Linear Modules
40+
41+
#### Usage
42+
43+
To enable MXFP8 training for linear layers, launch your training job with the following command (or alternatively set configs in toml files):
44+
45+
```bash
46+
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
47+
--model.converters="quantize.linear.mx" \
48+
--quantize.linear.mx.recipe_name="mxfp8_cublas" \
49+
--compile.enable
50+
```
51+
52+
**Configuration Options:**
53+
54+
* `--model.converters="quantize.linear.mx"`: Swap `nn.Linear` with `MXLinear` to perform MXFP8 matmul.
55+
* `--quantize.linear.mx.recipe_name="mxfp8_cublas"`: Use the cuBLAS-based MXFP8 recipe for best performance on B200 GPUs. Alternative: `"mxfp8_cublas_rceil"` uses round-ceiling mode for scale calculation.
56+
* `--quantize.linear.mx.mxfp8_dim1_cast_kernel_choice="triton"`: Choose the kernel for dimension-1 quantization. Options: `"triton"` (default), `"cuda"`, or `"torch"`.
57+
* `--quantize.linear.mx.filter_fqns="..."` (optional): Comma-separated list of fully qualified names of modules not to convert to MXFP8 training.
58+
* Example: `--quantize.linear.mx.filter_fqns="attention.wq,attention.wk,attention.wv,output"`
59+
* This allows you to selectively apply MXFP8 only to layers that will benefit from it.
60+
* `--compile.enable` (required for competitive performance): Use `torch.compile` to fuse the MXFP8 scaling/casting kernels.
61+
62+
**Hardware Requirements:**
63+
64+
MXFP8 training requires NVIDIA B200 (SM100) or newer GPUs. The implementation uses native cuBLAS MXFP8 kernels available on these architectures.
65+
66+
### MXFP8 for Grouped GEMMs (MoE)
67+
68+
For Mixture-of-Experts (MoE) models, MXFP8 can accelerate the expert computation through dynamically quantized grouped GEMMs. This is particularly beneficial for MoE models where multiple experts are processed in parallel.
69+
70+
#### Usage
71+
72+
To enable MXFP8 for MoE expert layers:
73+
74+
```bash
75+
CONFIG_FILE="./torchtitan/models/llama4/train_configs/llama4_17bx16e.toml" ./run_train.sh \
76+
--model.converters="quantize.grouped_mm.mx" \
77+
--quantize.grouped_mm.mx.fqns="experts" \
78+
--quantize.grouped_mm.mx.recipe_name="mxfp8" \
79+
--compile.enable \
80+
--model.print_after_conversion
81+
```
82+
83+
**Combined usage**: You can use MXFP8 for both linear modules and grouped GEMMs simultaneously by specifying both converters:
84+
```bash
85+
--model.converters="quantize.linear.mx,quantize.grouped_mm.mx"
86+
```
87+
88+
**Configuration Options:**
89+
90+
* `--model.converters="quantize.grouped_mm.mx"`: Enable MXFP8 grouped GEMM conversion for MoE layers.
91+
* `--quantize.grouped_mm.mx.fqns="experts"`: Comma-separated list of fully qualified names of MoE modules to apply MXFP8 dynamic quantization on grouped GEMM operations. Any module that matches the FQN will be converted, if it has (1) experts represented as 3d nn.Parameter instances (which is the case for TorchTitan MoEs), and (2) a `torch._grouped_mm` op performs the actual routed expert computation using those 3d expert weights.
92+
* You can specify multiple FQNs to target different MoE layers in your model.
93+
* `--quantize.grouped_mm.mx.recipe_name="mxfp8"`: Quantization recipe for grouped GEMMs (currently only `"mxfp8"` is supported).
94+
* `--compile.enable`: Use `torch.compile` for best performance.
95+
96+
**Important Notes:**
97+
98+
* **Token group alignment**: For MoE training with MXFP8, token group sizes must be multiples of 32 (the MXFP8 block size). This is automatically configured [here](https://github.com/pytorch/torchtitan/blob/b39377f9fe33865fefb9bf64a33f6d74a598be87/torchtitan/components/quantization/mx.py#L131) when you enable MXFP8 grouped GEMMs in TorchTitan.
99+
100+
* **torch.compile recommendation**: All benchmarks in this document were run with `torch.compile` enabled. We recommend using `torch.compile` for best performance.
101+
102+
### Example TOML Configuration
103+
104+
Here's an example configuration for MXFP8 training in a TOML file:
105+
106+
```toml
107+
[model]
108+
converters = ["quantize.linear.mx", "quantize.grouped_mm.mx"]
109+
110+
[quantize.linear.mx]
111+
recipe_name = "mxfp8_cublas"
112+
mxfp8_dim1_cast_kernel_choice = "cuda"
113+
filter_fqns = ["output", "router.gate"]
114+
115+
[quantize.grouped_mm.mx]
116+
recipe_name = "mxfp8"
117+
fqns = ["experts"]
118+
119+
[compile]
120+
enable = true
121+
components = ["model"]
122+
```
123+
124+
### Performance
125+
126+
#### Dense Models
127+
128+
Single-node training on 8x power limited B200 GPUs, batch size 1, sequence length 8192, steps 100, torch.compile, FSDP2, per-op SAC:
129+
130+
| Scaling Method | Peak Memory (GB) | Median tokens/s | Speedup over BF16 |
131+
|------------------------|------------------|-----------------|-------------------|
132+
| None (bfloat16) | 33.71 | 8307.5 | - |
133+
| mxfp8_cublas | 33.88 | 9969.0 | +20.0% |
134+
| mxfp8_cublas_rceil | 33.88 | 9642.0 | +16.1% |
135+
| float8 tensorwise | 33.38 | 10417.0 | +25.4% |
136+
137+
- pytorch version: `2.9.0.dev20250815+cu128`
138+
- torchao version: `0.13.0+gite4e681be`
139+
- torchtitan commit: `6fc499f6f5b32151a799188be2208cfb09faed30`
140+
141+
*Source: [TorchAO MX Formats Benchmarks](https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats#training-e2e-benchmarks-on-nvidia-b200)*
142+
143+
#### MoE models
144+
145+
512 GPU training on 64 node GB200 cluster:
146+
147+
| Scaling Method | Median tokens/s | Speedup over BF16 |
148+
|------------------------|-----------------|-------------------|
149+
| None (bfloat16) | 6169 | - |
150+
| mxfp8 | 7401 | +20.3% |
151+
152+
Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout show that MXFP8 MoE training has equivalent convergence to bfloat16 training baseline. In fact, after 3,000 steps it finishes with slightly *lower* loss than bfloat16! This is consistent with our scaling experiments with [MXFP8 training for dense models](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/).
153+
154+
![MXFP8 vs BF16 Training Loss Curves](static/mxfp8_with_loss.png)
155+
*Training loss curves over 3,000 steps showing MXFP8 achieves equivalent convergence to bfloat16 baseline.*
156+
157+
Training and model configurations for this run:
158+
- Model: Llama4 Scout
159+
- Dataset: C4
160+
- Sequence length: 8192
161+
- Local batch size: 10
162+
- Learning rate: 1e-4
163+
- LR scheduler warmup steps: 2000
164+
- Parallelisms (64 nodes of 4 devices each = 256 chips):
165+
- FSDP=256 (on attention layers, shared experts, dense layer FFNs) and 256/4=64 (on routed experts)
166+
- EP=16 (on routed experts)
167+
- Activation checkpointing mode: `none` (ideally this should use selective per op AC but there was a bug at the time preventing us from using it).
168+
- `torch.compile` enabled
169+
- `mxfp8` applied to routed experts computation (grouped GEMMs)
170+
- `mxfp8` applied to all linear layers except: `output`, `router.gate`, `attention.wk`, `attention.wv` (Wk and Wv too small to benefit from mxfp8)
171+
172+
### Composability
173+
For distributed training, MXFP8 is compatible with:
174+
- `torch.compile`
175+
- FSDP2/TP/EP/PP
176+
- Full activation checkpointing
177+
178+
All distributed communication for MXFP8 training is currently done in high precision.
179+
180+
### Known Limitations
181+
- Currently in prototype stage - no BC guarantees.
182+
- Requires torch nightly - important bug fixes have landed since 2.9.1
183+
- For GB200s, requires building torchao from source
184+
185+
### Additional Resources
186+
187+
- [Accelerating 2K+ Scale Pre-training up to 1.28x with TorchAO MXFP8 and TorchTitan on Crusoe B200 Cluster](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/) - Blog post on accelerating dense model training with MXFP8
188+
- [TorchAO MX Formats Documentation](https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats)
189+
- [TorchAO MoE Training Documentation](https://github.com/pytorch/ao/tree/main/torchao/prototype/moe_training)

0 commit comments

Comments
 (0)