|
| 1 | +# Supported Models & Architectures |
| 2 | + |
| 3 | +> **Purpose**: This page provides detailed, reference-style information about model families supported in MaxText. This page is a technical dictionary for quick lookup, reproducibility, and customization. |
| 4 | +
|
| 5 | +## Overview |
| 6 | + |
| 7 | +MaxText is an open-source, high-performance LLM framework written in Python/JAX. It targets Google Cloud TPUs and NVIDIA GPUs for training. MaxText prioritizes scalability (from a single host to recent runs with tens of thousands of chips), high Model FLOPs Utilization (MFU), and simplicity by leveraging JAX with the XLA compiler and optimized JAX Pallas kernels. |
| 8 | + |
| 9 | +**Key capabilities and features**: |
| 10 | + |
| 11 | +* **Supported Precisions**: FP32, BF16, INT8, and FP8. |
| 12 | +* **Ahead-of-Time Compilation (AOT)**: For faster model development/prototyping and earlier OOM detection. |
| 13 | +* **Quantization**: Via **Qwix** (recommended) and AQT. See Quantization [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/quantization.md). |
| 14 | +* **Diagnostics**: Structured error context via **`cloud_tpu_diagnostics`** (filters stack traces to user code), simple logging via `max_logging`, profiling in **XProf**, and visualization in **TensorBoard**. |
| 15 | +* **Multi-Token Prediction (MTP)**: Enables token efficient training with mutli-token prediction. |
| 16 | +* **Elastic Training**: Fault-tolorent and dynamic scale-up/scale-down on Cloud TPUs with Pathways. |
| 17 | +* **Flexible Remat Policy**: Provides fine-grained control over memory-compute trade-offs. Users can select pre-defined policies (like 'full' or 'minimal') or set the policy to **'custom'**. |
| 18 | + |
| 19 | + |
| 20 | +## Supported Model Families |
| 21 | + |
| 22 | +> _**Note on GPU Coverage**: Support and tested configurations for NVIDIA GPUs can vary by model family. Please see the specific model guides for details._ |
| 23 | +
|
| 24 | +**Primary Platforms**: All model families listed below target **TPU** and **NVIDIA GPUs**. |
| 25 | + |
| 26 | + |
| 27 | +### Llama |
| 28 | + |
| 29 | +* **Variants**: Llama 2; **Llama 3 / 3.1 / 3.3**; Llama 4 (**Scout**, **Maverick**; text & multimodal) |
| 30 | +* **Notes**: RoPE, RMSNorm, SwiGLU; GQA; routed experts (Llama 4); **QK-Norm** (Llama 4); multimodal projector & vision encoder. |
| 31 | + |
| 32 | +### Mistral / Mixtral |
| 33 | + |
| 34 | +* **Variants**: Mistral (dense); Mixtral 8×7B, 8×22B (MoE) |
| 35 | +* **Notes**: Sliding-Window Attention (SWA), GQA; MoE top-k with load-balancing loss. |
| 36 | + |
| 37 | +### Gemma |
| 38 | + |
| 39 | +* **Variants**: Gemma 1 (2B/7B), Gemma 2 (2B/9B/27B), **Gemma 3 (4B/12B/27B)** (text & multimodal) |
| 40 | +* **Notes**: RMSNorm; RoPE; GELU/SwiGLU; **QK-Norm** (Gemma 3); Local–Global interleaved attention; long-context scaling. |
| 41 | + |
| 42 | +### DeepSeek |
| 43 | + |
| 44 | +* **Variants**: V2 (16B, 236B), V3 (671B), R1 |
| 45 | +* **Notes**: MLA; shared/finer-grained experts; MTP; YaRN-style scaling. |
| 46 | + |
| 47 | +### Qwen3 |
| 48 | + |
| 49 | +* **Variants**: Dense (0.6B–32B); MoE (30B-A3B, 235B-A22B, 480B Coder) |
| 50 | +* **Notes**: **QK-Norm**, GQA, SwiGLU, RMSNorm, RoPE. |
| 51 | + |
| 52 | +## Parallelism Building Blocks |
| 53 | + |
| 54 | +MaxText supports a wide range of parallelism strategies for scaling training and inference across TPUs and GPUs: |
| 55 | + |
| 56 | +* **FSDP (Fully Sharded Data Parallel)**: Reduces memory footprint by sharding parameters and optimizer states across devices. |
| 57 | +* **TP (Tensor Parallelism)**: Splits tensor computations (e.g., matrix multiplications, attention heads) across devices for intra-layer speedups. |
| 58 | +* **EP (Expert Parallelism)**: Distributes MoE experts across devices, supporting dropless routing and load balancing to ensure efficient utilization. |
| 59 | +* **DP (Data Parallelism)**: Replicates the model across devices while splitting the input data batches. |
| 60 | +* **PP (Pipeline Parallelism)**: Splits layers across device stages to support extremely large models by managing inter-stage communication. |
| 61 | +* **CP (Context Parallelism)**: Splits sequence tokens across devices, complementing tensor parallelism for long-context workloads. |
| 62 | +* **Hybrid Parallelism**: Allows for flexible combinations of FSDP, TP, EP, DP, PP, and CP to maximize hardware utilization based on model size and topology. |
| 63 | + |
| 64 | +## Performance Characteristics |
| 65 | + |
| 66 | +The following summarizes observed runtime efficiency and scaling behaviors of MaxText across different hardware and model types, based on published benchmarks and large-scale runs. |
| 67 | + |
| 68 | +* **High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/performance_metrics.md) for the definition and how we calculate it. |
| 69 | +* **Quantization**: MaxText supports quantization via both the AQT and Qwix libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ). |
| 70 | + * **MoE**: The Mixture-of-Experts implementation features dropless routing with Megablox and `jax.lax.ragged_dot` kernels for enhanced performance. |
| 71 | +* **Multi-Token Prediction (MTP)**: This feature improves training efficiency on DeepSeek-style models by adding an auxiliary loss based on predicting multiple future tokens. |
| 72 | +* **Long-Context Optimizations**: Implements various efficient attention mechanisms, including: Grouped-Query Attention (GQA), Sliding-Window Attention (SWA), Local–Global interleaved attention, Multi-Head Latent Attention (MLA). They reduce the KV-cache size, making it possible to handle long contexts efficiently. |
| 73 | + |
| 74 | + |
| 75 | +## References |
| 76 | + |
| 77 | +* [MaxText Repo](https://github.com/AI-Hypercomputer/maxtext) |
| 78 | + |
| 79 | +* **Model Implementation Guides & Source Code:** |
| 80 | + * **Llama**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/tutorials/run_llama2.md) | [Llama2 and Llama3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/llama2.py) | [Llama4 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/llama4.py) |
| 81 | + * **Gemma**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/gemma/Run_Gemma.md) | [Gemma Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gemma.py) | [Gemma2 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gemma2.py) | [Gemma3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gemma3.py) |
| 82 | + * **Mixtral**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/mixtral/Run_Mixtral.md) | [Mixtral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/mixtral.py) | [Mistral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/mistral.py) |
| 83 | + * **DeepSeek**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/deepseek/Run_DeepSeek.md) | [DeepSeek Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/deepseek.py) |
| 84 | + * **Qwen3**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/qwen/moe/run_qwen_moe.md) | [Qwen3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/qwen3.py) |
| 85 | + |
| 86 | +* **Technical Explanations:** |
| 87 | + * [Parallelism & Sharding](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/sharding.md) |
| 88 | + * [Quantization Documentation](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/quantization.md) |
| 89 | + * [AOT Compilation Instructions](https://github.com/AI-Hypercomputer/maxtext/blob/main/README.md#ahead-of-time-compilation-aot) |
0 commit comments