Skip to content

Commit ca2ae66

Browse files
Merge pull request #2273 from AI-Hypercomputer:gagik-supported-models
PiperOrigin-RevId: 808409228
2 parents 87131c1 + fd20367 commit ca2ae66

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

docs/reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
:maxdepth: 1
2121
2222
reference/terminology.md
23+
reference/supported_models_and_architectures.md
2324
reference/alternatives.md
2425
reference/benchmark_and_performance.md
2526
reference/architecture_overview.md
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Supported Models & Architectures
2+
3+
> **Purpose**: This page provides detailed, reference-style information about model families supported in MaxText. This page is a technical dictionary for quick lookup, reproducibility, and customization.
4+
5+
## Overview
6+
7+
MaxText is an open-source, high-performance LLM framework written in Python/JAX. It targets Google Cloud TPUs and NVIDIA GPUs for training. MaxText prioritizes scalability (from a single host to recent runs with tens of thousands of chips), high Model FLOPs Utilization (MFU), and simplicity by leveraging JAX with the XLA compiler and optimized JAX Pallas kernels.
8+
9+
**Key capabilities and features**:
10+
11+
* **Supported Precisions**: FP32, BF16, INT8, and FP8.
12+
* **Ahead-of-Time Compilation (AOT)**: For faster model development/prototyping and earlier OOM detection.
13+
* **Quantization**: Via **Qwix** (recommended) and AQT. See Quantization [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/quantization.md).
14+
* **Diagnostics**: Structured error context via **`cloud_tpu_diagnostics`** (filters stack traces to user code), simple logging via `max_logging`, profiling in **XProf**, and visualization in **TensorBoard**.
15+
* **Multi-Token Prediction (MTP)**: Enables token efficient training with mutli-token prediction.
16+
* **Elastic Training**: Fault-tolorent and dynamic scale-up/scale-down on Cloud TPUs with Pathways.
17+
* **Flexible Remat Policy**: Provides fine-grained control over memory-compute trade-offs. Users can select pre-defined policies (like 'full' or 'minimal') or set the policy to **'custom'**.
18+
19+
20+
## Supported Model Families
21+
22+
> _**Note on GPU Coverage**: Support and tested configurations for NVIDIA GPUs can vary by model family. Please see the specific model guides for details._
23+
24+
**Primary Platforms**: All model families listed below target **TPU** and **NVIDIA GPUs**.
25+
26+
27+
### Llama
28+
29+
* **Variants**: Llama 2; **Llama 3 / 3.1 / 3.3**; Llama 4 (**Scout**, **Maverick**; text & multimodal)
30+
* **Notes**: RoPE, RMSNorm, SwiGLU; GQA; routed experts (Llama 4); **QK-Norm** (Llama 4); multimodal projector & vision encoder.
31+
32+
### Mistral / Mixtral
33+
34+
* **Variants**: Mistral (dense); Mixtral 8×7B, 8×22B (MoE)
35+
* **Notes**: Sliding-Window Attention (SWA), GQA; MoE top-k with load-balancing loss.
36+
37+
### Gemma
38+
39+
* **Variants**: Gemma 1 (2B/7B), Gemma 2 (2B/9B/27B), **Gemma 3 (4B/12B/27B)** (text & multimodal)
40+
* **Notes**: RMSNorm; RoPE; GELU/SwiGLU; **QK-Norm** (Gemma 3); Local–Global interleaved attention; long-context scaling.
41+
42+
### DeepSeek
43+
44+
* **Variants**: V2 (16B, 236B), V3 (671B), R1
45+
* **Notes**: MLA; shared/finer-grained experts; MTP; YaRN-style scaling.
46+
47+
### Qwen3
48+
49+
* **Variants**: Dense (0.6B–32B); MoE (30B-A3B, 235B-A22B, 480B Coder)
50+
* **Notes**: **QK-Norm**, GQA, SwiGLU, RMSNorm, RoPE.
51+
52+
## Parallelism Building Blocks
53+
54+
MaxText supports a wide range of parallelism strategies for scaling training and inference across TPUs and GPUs:
55+
56+
* **FSDP (Fully Sharded Data Parallel)**: Reduces memory footprint by sharding parameters and optimizer states across devices.
57+
* **TP (Tensor Parallelism)**: Splits tensor computations (e.g., matrix multiplications, attention heads) across devices for intra-layer speedups.
58+
* **EP (Expert Parallelism)**: Distributes MoE experts across devices, supporting dropless routing and load balancing to ensure efficient utilization.
59+
* **DP (Data Parallelism)**: Replicates the model across devices while splitting the input data batches.
60+
* **PP (Pipeline Parallelism)**: Splits layers across device stages to support extremely large models by managing inter-stage communication.
61+
* **CP (Context Parallelism)**: Splits sequence tokens across devices, complementing tensor parallelism for long-context workloads.
62+
* **Hybrid Parallelism**: Allows for flexible combinations of FSDP, TP, EP, DP, PP, and CP to maximize hardware utilization based on model size and topology.
63+
64+
## Performance Characteristics
65+
66+
The following summarizes observed runtime efficiency and scaling behaviors of MaxText across different hardware and model types, based on published benchmarks and large-scale runs.
67+
68+
* **High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/performance_metrics.md) for the definition and how we calculate it.
69+
* **Quantization**: MaxText supports quantization via both the AQT and Qwix libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
70+
* **MoE**: The Mixture-of-Experts implementation features dropless routing with Megablox and `jax.lax.ragged_dot` kernels for enhanced performance.
71+
* **Multi-Token Prediction (MTP)**: This feature improves training efficiency on DeepSeek-style models by adding an auxiliary loss based on predicting multiple future tokens.
72+
* **Long-Context Optimizations**: Implements various efficient attention mechanisms, including: Grouped-Query Attention (GQA), Sliding-Window Attention (SWA), Local–Global interleaved attention, Multi-Head Latent Attention (MLA). They reduce the KV-cache size, making it possible to handle long contexts efficiently.
73+
74+
75+
## References
76+
77+
* [MaxText Repo](https://github.com/AI-Hypercomputer/maxtext)
78+
79+
* **Model Implementation Guides & Source Code:**
80+
* **Llama**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/tutorials/run_llama2.md) | [Llama2 and Llama3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/llama2.py) | [Llama4 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/llama4.py)
81+
* **Gemma**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/gemma/Run_Gemma.md) | [Gemma Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gemma.py) | [Gemma2 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gemma2.py) | [Gemma3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/gemma3.py)
82+
* **Mixtral**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/mixtral/Run_Mixtral.md) | [Mixtral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/mixtral.py) | [Mistral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/mistral.py)
83+
* **DeepSeek**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/deepseek/Run_DeepSeek.md) | [DeepSeek Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/deepseek.py)
84+
* **Qwen3**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/qwen/moe/run_qwen_moe.md) | [Qwen3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/qwen3.py)
85+
86+
* **Technical Explanations:**
87+
* [Parallelism & Sharding](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/sharding.md)
88+
* [Quantization Documentation](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/explanations/quantization.md)
89+
* [AOT Compilation Instructions](https://github.com/AI-Hypercomputer/maxtext/blob/main/README.md#ahead-of-time-compilation-aot)

0 commit comments

Comments
 (0)