|
| 1 | +## MXFP8 Training on B200 GPUs |
| 2 | + |
| 3 | +MXFP8 training can provide substantial training speedups for models where the majority of GEMMs are sufficiently large. MXFP8 is a microscaling format from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) that uses block-based scaling to maintain numerical accuracy while leveraging low-precision tensor cores. On NVIDIA B200 GPUs, MXFP8 training achieves up to **28% speedup** over bfloat16 baseline with minimal accuracy degradation. |
| 4 | + |
| 5 | +> **📖 For a comprehensive case study of using TorchTitan MXFP8 to train dense models at scale**, see our blog post: [Accelerating 2K+ Scale Pre-training up to 1.28x with TorchAO MXFP8 and TorchTitan on Crusoe B200 Cluster](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/) |
| 6 | +
|
| 7 | +### Table of Contents |
| 8 | + |
| 9 | +- [Requirements](#requirements) |
| 10 | +- [How MXFP8 Works](#how-mxfp8-works) |
| 11 | +- [MXFP8 for Linear Modules](#mxfp8-for-linear-modules) |
| 12 | + - [Usage](#usage) |
| 13 | +- [MXFP8 for Grouped GEMMs (MoE)](#mxfp8-for-grouped-gemms-moe) |
| 14 | + - [Usage](#usage-1) |
| 15 | +- [Example TOML Configuration](#example-toml-configuration) |
| 16 | +- [Performance](#performance) |
| 17 | + - [Dense Models](#dense-models) |
| 18 | + - [MoE models](#moe-models) |
| 19 | +- [Composability](#composability) |
| 20 | +- [Known Limitations](#known-limitations) |
| 21 | +- [Additional Resources](#additional-resources) |
| 22 | + |
| 23 | +### Requirements |
| 24 | + |
| 25 | +- NVIDIA B200 (SM100 or SM100a) |
| 26 | +- PyTorch nightly |
| 27 | +- TorchAO v0.14.0 or newer ([TorchAO Installation Guide](https://github.com/pytorch/ao#installation)) |
| 28 | + |
| 29 | +Note: GB200 is also supported but requires building torchao from source (see installation guide above). |
| 30 | + |
| 31 | +### How MXFP8 Works |
| 32 | + |
| 33 | +MXFP8 differs from standard Float8 training in its scaling approach: |
| 34 | + |
| 35 | +- **Block-based scaling**: Instead of using a single scale factor per tensor (tensorwise) or per row/column (rowwise), MXFP8 uses block-based scaling with a default block size of 1x32 elements. Each block of 32 elements shares a common scale factor. The data dtype is `torch.float8_e4m3fn`, and the scale factor dtype is `torch.float8_e8mfnu`. |
| 36 | +- **Native hardware support**: On NVIDIA B200 (Blackwell) GPUs, MXFP8 GEMMs are accelerated using cuBLAS kernels exposed via `torch._scaled_mm`, achieving up to 2x speedup over bfloat16 on common shapes. |
| 37 | +- **Dynamic quantization**: Both activations and weights are dynamically quantized to MXFP8 during forward and backward passes, with high-precision accumulation. |
| 38 | + |
| 39 | +### MXFP8 for Linear Modules |
| 40 | + |
| 41 | +#### Usage |
| 42 | + |
| 43 | +To enable MXFP8 training for linear layers, launch your training job with the following command (or alternatively set configs in toml files): |
| 44 | + |
| 45 | +```bash |
| 46 | +CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \ |
| 47 | + --model.converters="quantize.linear.mx" \ |
| 48 | + --quantize.linear.mx.recipe_name="mxfp8_cublas" \ |
| 49 | + --compile.enable |
| 50 | +``` |
| 51 | + |
| 52 | +**Configuration Options:** |
| 53 | + |
| 54 | +* `--model.converters="quantize.linear.mx"`: Swap `nn.Linear` with `MXLinear` to perform MXFP8 matmul. |
| 55 | +* `--quantize.linear.mx.recipe_name="mxfp8_cublas"`: Use the cuBLAS-based MXFP8 recipe for best performance on B200 GPUs. Alternative: `"mxfp8_cublas_rceil"` uses round-ceiling mode for scale calculation. |
| 56 | +* `--quantize.linear.mx.mxfp8_dim1_cast_kernel_choice="triton"`: Choose the kernel for dimension-1 quantization. Options: `"triton"` (default), `"cuda"`, or `"torch"`. |
| 57 | +* `--quantize.linear.mx.filter_fqns="..."` (optional): Comma-separated list of fully qualified names of modules not to convert to MXFP8 training. |
| 58 | + * Example: `--quantize.linear.mx.filter_fqns="attention.wq,attention.wk,attention.wv,output"` |
| 59 | + * This allows you to selectively apply MXFP8 only to layers that will benefit from it. |
| 60 | +* `--compile.enable` (required for competitive performance): Use `torch.compile` to fuse the MXFP8 scaling/casting kernels. |
| 61 | + |
| 62 | +**Hardware Requirements:** |
| 63 | + |
| 64 | +MXFP8 training requires NVIDIA B200 (SM100) or newer GPUs. The implementation uses native cuBLAS MXFP8 kernels available on these architectures. |
| 65 | + |
| 66 | +### MXFP8 for Grouped GEMMs (MoE) |
| 67 | + |
| 68 | +For Mixture-of-Experts (MoE) models, MXFP8 can accelerate the expert computation through dynamically quantized grouped GEMMs. This is particularly beneficial for MoE models where multiple experts are processed in parallel. |
| 69 | + |
| 70 | +#### Usage |
| 71 | + |
| 72 | +To enable MXFP8 for MoE expert layers: |
| 73 | + |
| 74 | +```bash |
| 75 | +CONFIG_FILE="./torchtitan/models/llama4/train_configs/llama4_17bx16e.toml" ./run_train.sh \ |
| 76 | + --model.converters="quantize.grouped_mm.mx" \ |
| 77 | + --quantize.grouped_mm.mx.fqns="experts" \ |
| 78 | + --quantize.grouped_mm.mx.recipe_name="mxfp8" \ |
| 79 | + --compile.enable \ |
| 80 | + --model.print_after_conversion |
| 81 | +``` |
| 82 | + |
| 83 | +**Combined usage**: You can use MXFP8 for both linear modules and grouped GEMMs simultaneously by specifying both converters: |
| 84 | + ```bash |
| 85 | + --model.converters="quantize.linear.mx,quantize.grouped_mm.mx" |
| 86 | + ``` |
| 87 | + |
| 88 | +**Configuration Options:** |
| 89 | + |
| 90 | +* `--model.converters="quantize.grouped_mm.mx"`: Enable MXFP8 grouped GEMM conversion for MoE layers. |
| 91 | +* `--quantize.grouped_mm.mx.fqns="experts"`: Comma-separated list of fully qualified names of MoE modules to apply MXFP8 dynamic quantization on grouped GEMM operations. Any module that matches the FQN will be converted, if it has (1) experts represented as 3d nn.Parameter instances (which is the case for TorchTitan MoEs), and (2) a `torch._grouped_mm` op performs the actual routed expert computation using those 3d expert weights. |
| 92 | + * You can specify multiple FQNs to target different MoE layers in your model. |
| 93 | +* `--quantize.grouped_mm.mx.recipe_name="mxfp8"`: Quantization recipe for grouped GEMMs (currently only `"mxfp8"` is supported). |
| 94 | +* `--compile.enable`: Use `torch.compile` for best performance. |
| 95 | + |
| 96 | +**Important Notes:** |
| 97 | + |
| 98 | +* **Token group alignment**: For MoE training with MXFP8, token group sizes must be multiples of 32 (the MXFP8 block size). This is automatically configured [here](https://github.com/pytorch/torchtitan/blob/b39377f9fe33865fefb9bf64a33f6d74a598be87/torchtitan/components/quantization/mx.py#L131) when you enable MXFP8 grouped GEMMs in TorchTitan. |
| 99 | + |
| 100 | +* **torch.compile recommendation**: All benchmarks in this document were run with `torch.compile` enabled. We recommend using `torch.compile` for best performance. |
| 101 | + |
| 102 | +### Example TOML Configuration |
| 103 | + |
| 104 | +Here's an example configuration for MXFP8 training in a TOML file: |
| 105 | + |
| 106 | +```toml |
| 107 | +[model] |
| 108 | +converters = ["quantize.linear.mx", "quantize.grouped_mm.mx"] |
| 109 | + |
| 110 | +[quantize.linear.mx] |
| 111 | +recipe_name = "mxfp8_cublas" |
| 112 | +mxfp8_dim1_cast_kernel_choice = "cuda" |
| 113 | +filter_fqns = ["output", "router.gate"] |
| 114 | + |
| 115 | +[quantize.grouped_mm.mx] |
| 116 | +recipe_name = "mxfp8" |
| 117 | +fqns = ["experts"] |
| 118 | + |
| 119 | +[compile] |
| 120 | +enable = true |
| 121 | +components = ["model"] |
| 122 | +``` |
| 123 | + |
| 124 | +### Performance |
| 125 | + |
| 126 | +#### Dense Models |
| 127 | + |
| 128 | +Single-node training on 8x power limited B200 GPUs, batch size 1, sequence length 8192, steps 100, torch.compile, FSDP2, per-op SAC: |
| 129 | + |
| 130 | +| Scaling Method | Peak Memory (GB) | Median tokens/s | Speedup over BF16 | |
| 131 | +|------------------------|------------------|-----------------|-------------------| |
| 132 | +| None (bfloat16) | 33.71 | 8307.5 | - | |
| 133 | +| mxfp8_cublas | 33.88 | 9969.0 | +20.0% | |
| 134 | +| mxfp8_cublas_rceil | 33.88 | 9642.0 | +16.1% | |
| 135 | +| float8 tensorwise | 33.38 | 10417.0 | +25.4% | |
| 136 | + |
| 137 | +- pytorch version: `2.9.0.dev20250815+cu128` |
| 138 | +- torchao version: `0.13.0+gite4e681be` |
| 139 | +- torchtitan commit: `6fc499f6f5b32151a799188be2208cfb09faed30` |
| 140 | + |
| 141 | +*Source: [TorchAO MX Formats Benchmarks](https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats#training-e2e-benchmarks-on-nvidia-b200)* |
| 142 | + |
| 143 | +#### MoE models |
| 144 | + |
| 145 | +512 GPU training on 64 node GB200 cluster: |
| 146 | + |
| 147 | +| Scaling Method | Median tokens/s | Speedup over BF16 | |
| 148 | +|------------------------|-----------------|-------------------| |
| 149 | +| None (bfloat16) | 6169 | - | |
| 150 | +| mxfp8 | 7401 | +20.3% | |
| 151 | + |
| 152 | +Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout show that MXFP8 MoE training has equivalent convergence to bfloat16 training baseline. In fact, after 3,000 steps it finishes with slightly *lower* loss than bfloat16! This is consistent with our scaling experiments with [MXFP8 training for dense models](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/). |
| 153 | + |
| 154 | + |
| 155 | +*Training loss curves over 3,000 steps showing MXFP8 achieves equivalent convergence to bfloat16 baseline.* |
| 156 | + |
| 157 | +Training and model configurations for this run: |
| 158 | +- Model: Llama4 Scout |
| 159 | +- Dataset: C4 |
| 160 | +- Sequence length: 8192 |
| 161 | +- Local batch size: 10 |
| 162 | +- Learning rate: 1e-4 |
| 163 | +- LR scheduler warmup steps: 2000 |
| 164 | +- Parallelisms (64 nodes of 4 devices each = 256 chips): |
| 165 | + - FSDP=256 (on attention layers, shared experts, dense layer FFNs) and 256/4=64 (on routed experts) |
| 166 | + - EP=16 (on routed experts) |
| 167 | +- Activation checkpointing mode: `none` (ideally this should use selective per op AC but there was a bug at the time preventing us from using it). |
| 168 | +- `torch.compile` enabled |
| 169 | +- `mxfp8` applied to routed experts computation (grouped GEMMs) |
| 170 | +- `mxfp8` applied to all linear layers except: `output`, `router.gate`, `attention.wk`, `attention.wv` (Wk and Wv too small to benefit from mxfp8) |
| 171 | + |
| 172 | +### Composability |
| 173 | +For distributed training, MXFP8 is compatible with: |
| 174 | +- `torch.compile` |
| 175 | +- FSDP2/TP/EP/PP |
| 176 | +- Full activation checkpointing |
| 177 | + |
| 178 | +All distributed communication for MXFP8 training is currently done in high precision. |
| 179 | + |
| 180 | +### Known Limitations |
| 181 | +- Currently in prototype stage - no BC guarantees. |
| 182 | +- Requires torch nightly - important bug fixes have landed since 2.9.1 |
| 183 | +- For GB200s, requires building torchao from source |
| 184 | + |
| 185 | +### Additional Resources |
| 186 | + |
| 187 | +- [Accelerating 2K+ Scale Pre-training up to 1.28x with TorchAO MXFP8 and TorchTitan on Crusoe B200 Cluster](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/) - Blog post on accelerating dense model training with MXFP8 |
| 188 | +- [TorchAO MX Formats Documentation](https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats) |
| 189 | +- [TorchAO MoE Training Documentation](https://github.com/pytorch/ao/tree/main/torchao/prototype/moe_training) |
0 commit comments