Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions docs/guides/understand_logs_and_metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ per_device_batch_size=24 max_target_length=2048 steps=10 dataset_type=synthetic
The first section of the log details the configuration of your run. This is crucial for debugging, as it shows you exactly which parameters were used.

MaxText builds its configuration in layers.
- It starts with the **default configuration** from a YAML file. In our example, the file is [src/MaxText/configs/base.yml](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/configs/base.yml).
- It starts with the **default configuration** from a YAML file. In our example, the file is [`src/MaxText/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/configs/base.yml).

- Then, it overwrites any of these values with the arguments you provide in the **command line**.
```
```none
Updating keys from env and command line: ['run_name', 'model_name', 'enable_checkpointing', 'base_output_directory', 'per_device_batch_size', 'dataset_type', 'steps', 'max_target_length']
```
- It updates keys based on the **model-specific configuration** file. When you specify a model, like `deepseek2-16b`, MaxText reads the corresponding parameters from the [deepseek2-16b.yml](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/configs/models/deepseek2-16b.yml) file.

```
```none
Running Model: deepseek2-16b
Updating following parameters in config

Expand All @@ -52,7 +52,7 @@ MaxText builds its configuration in layers.
Note that you cannot modify a key from both model config and command line.

The final, consolidated configuration is printed last.
```
```none
# From base.yml default
Config param opt_type: adamw
...
Expand Down Expand Up @@ -104,8 +104,9 @@ Within this base path, MaxText creates several subdirectories for different type
* Path: `gs://runner-maxtext-logs/demo/checkpoints/`

To generate all optional artifacts in one run, you can set the corresponding flags in the command line, like in the example below.

This command enables tensorboard, profiler, text metrics, config saving, and checkpointing:
```bash
# This command enables tensorboard, profiler, text metrics, config saving, and checkpointing
python3 -m MaxText.train src/MaxText/configs/base.yml \
base_output_directory=gs://runner-maxtext-logs run_name=demo2 \
model_name=deepseek2-16b \
Expand All @@ -121,7 +122,7 @@ enable_checkpointing=True

Next, the log displays the software and hardware environment for your run. This is useful for verifying your setup and understanding how parallelism is being applied.

```
```none
System Information: Jax Version: 0.7.2.dev20250826
System Information: Jaxlib Version: 0.7.2.dev20250826
System Information: Jax Backend: PJRT C API
Expand All @@ -140,9 +141,8 @@ Before executing training, the program analyzes the resource requirements for yo

### 3.1. Memory analysis

We first perform a "dry run" compilation of a training step to [analyze its memory requirement](https://github.com/AI-Hypercomputer/maxtext/blob/f82ce194c490d668b14574a072a0a630c27bbd6e/MaxText/train.py#L630-L632
). This static analysis is performed by the XLA compiler. The log outputs [memory sizes](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/max_utils.py#L700-L718):
```
We first perform a "dry run" compilation of a training step to [analyze its memory requirement](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/train.py#L380-L382). This static analysis is performed by the XLA compiler. The log outputs [memory sizes](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/max_utils.py#L672-L690):
```none
Total memory size: 100.4 GB, Output size: 44.5 GB, Temp size: 55.9 GB, Argument size: 44.5 GB, Host temp size: 0.0 GB.
```
The most important number is `Total memory size: 100.4 GB`. This is the total High Bandwidth Memory (HBM) the TPU device needs to execute the program. Here is a breakdown:
Expand All @@ -153,13 +153,14 @@ The most important number is `Total memory size: 100.4 GB`. This is the total Hi

In addition, it shows temporary memory used on the host CPU. In this case, `Host temp size: 0.0 GB`, indicating that all the significant memory allocation happens on the accelerator device.


### 3.2. Memory snapshot

The previous section is a forecast of memory usage for entire training step, based on static analysis of the compiled code from the XLA compiler. To see the actual memory usage, we now turn to a real-time snapshot from the JAX runtime, captured right after the training state is initialized.

To set the stage for training, we first initialize the training state, which include parameter and optimizer states. At the [beginning](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/train.py#L445), the log shows a real-time snapshot of the [memory statistics](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/max_utils.py#L673-L682) on your TPU devices.
To set the stage for training, we first initialize the training state, which include parameter and optimizer states. At the [beginning](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/train.py#L445), the log shows a real-time snapshot of the [memory statistics](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/max_utils.py#L645-L654) on your TPU devices.

```
```none
number parameters: 15.933 billion

Memstats: After params initialized:
Expand All @@ -173,14 +174,14 @@ This log shows that each of the four TPUs has `95.74 GB` of available High Bandw
### 3.3. Model TFLOP per device

The **model FLOPs** are the floating point operations to perform model computation. For training, the computation includes a single forward and backward pass.
- In MaxText, we estimate model FLOPs by summing operations in matrix multiplications (matmuls); see [calculate_tflops_training_per_device](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/maxtext_utils.py#L454).
- In MaxText, we estimate model FLOPs by summing operations in matrix multiplications (matmuls); see [calculate_tflops_training_per_device](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/maxtext_utils.py#L480).
- The number of model FLOPs is dependent on model architecture, input size (batch size, sequence length), and gradient accumulation steps. It does not include optimization operations.
- We break down the FLOPs into two parts:
- "Learnable weight FLOPs" are matmuls between activations and learnable weights. Specifically, this occurs in embedding, feed forward networks, attention-related projections, and unembedding.
- "Attention FLOPs" are matmuls in attention score computation like $\mathrm{softmax}{\left(\frac{QK^\top}{\sqrt{d}}\right)} V$.

One **TFLOP** (TeraFLOP) is equal to $10^{12}$ FLOPs. The log shows the theoretical estimate of **model TFLOP per device**:
```
```none
Per train step:
Total TFLOPs: 764.67
split as 94.54% learnable weight flops and 5.46% attention flops
Expand All @@ -195,7 +196,7 @@ You can find more information about model FLOPs and MFU in the [Performance Metr
## 4. Training metrics

Finally, we are getting to the training steps! In this section, we introduce performance metrics including TFLOP/s/device, MFU, and Tokens/s/device (throughput). We briefly cover learning metrics including loss and total weights.
```
```none
completed step: 0, seconds: 44.923, TFLOP/s/device: 17.022, Tokens/s/device: 1094.129, total_weights: 196608, loss: 12.038
completed step: 1, seconds: 0.319, TFLOP/s/device: 2400.734, Tokens/s/device: 154316.608, total_weights: 196608, loss: 12.038
completed step: 2, seconds: 5.658, TFLOP/s/device: 135.158, Tokens/s/device: 8687.815, total_weights: 196608, loss: 11.689
Expand All @@ -215,14 +216,14 @@ Before we dive deep here, recall a few numbers from previous sections:
### 4.1. Performance metrics

The performance metrics fluctuate at the beginning, and become stable towards the end. Therefore, we usually read them from the last step. Let's take a closer look at Step 9.
```
```none
completed step: 9, seconds: 5.667, TFLOP/s/device: 134.924, Tokens/s/device: 8672.758, total_weights: 196608, loss: 10.374
```
As shown in `seconds: 5.667`, $\text{measured step time in seconds} \approx 5.667$ (rounded).

**TFLOP per second per device**

- It is [computed](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/metric_logger.py#L211-L213) as
- It is [computed](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/metric_logger.py#L211-L213) as

$$\text{tflop/s/device} = \frac{\text{model tflop per device}}{\text{measured step time in seconds}}$$

Expand All @@ -235,11 +236,11 @@ $$\text{MFU} = \frac{\text{tflop/s/device}}{\text{peak hardware tflop/s}}$$

**Tokens per second per device (throughput)**

- It is [computed](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/metric_logger.py#L215-L217) as
- It is [computed](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/metric_logger.py#L215-L217) as

$$\text{token/s/device} = \frac{\text{number of tokens per device}}{\text{measured step time in seconds}}$$

- The numerator is from [calculate_tokens_training_per_device](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/maxtext_utils.py#L122)
- The numerator is from [calculate_tokens_training_per_device](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/maxtext_utils.py#L148)

$$\text{number of tokens per device} = \text{per device batch size} \times \text{max target length}$$

Expand All @@ -249,11 +250,11 @@ $$\text{number of tokens per device} = \text{per device batch size} \times \text

**Loss**. The loss is the key indicator of learning progress, which should decrease over training steps. In this example, the loss is `12.038` at Step 0 and decreases to `10.374` at Step 9. Ideally, we want the loss to converge to a small value with sufficiently large training steps.

**Total weights**. When discussing the throughput, we have $\text{number of tokens} = \text{per device batch size} \times \text{max target length} \times \text{number of device}$. In this example, $\text{number of tokens} = 24 \times 2048 \times 4 = 196608$. There are two types of tokens: real tokens and pad tokens. The pad tokens are placeholders introduced by data preprocessing: We truncate or pad each sentence to max target length. Only real tokens contribute to the learning signal (i.e., loss). Therefore, we monitor $\text{number of real tokens}$, which is shown as [total weights](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/train.py#L151).
**Total weights**. When discussing the throughput, we have $\text{number of tokens} = \text{per device batch size} \times \text{max target length} \times \text{number of device}$. In this example, $\text{number of tokens} = 24 \times 2048 \times 4 = 196608$. There are two types of tokens: real tokens and pad tokens. The pad tokens are placeholders introduced by data preprocessing: We truncate or pad each sentence to max target length. Only real tokens contribute to the learning signal (i.e., loss). Therefore, we monitor $\text{number of real tokens}$, which is shown as [total weights](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/train.py#L151).
- Here we see `total_weights: 196608` for all steps. This is because we are using `dataset_type=synthetic`, where all sentences are generated with a length of `max_target_length=2048`. As a result, there are no pad tokens and total weights = number of tokens.
- However, in real datasets, sentences can have variable lengths and total weights < number of tokens. For example, we can set `dataset_type=tfds dataset_path=gs://maxtext-dataset dataset_name='c4/en:3.0.1'`, and will see total weights smaller than `196608`:
```
```none
completed step: 8, seconds: 5.670, TFLOP/s/device: 134.856, Tokens/s/device: 8668.393, total_weights: 163259, loss: 9.596
completed step: 9, seconds: 5.669, TFLOP/s/device: 134.884, Tokens/s/device: 8670.184, total_weights: 155934, loss: 9.580
```
- For better convergence, we want to have large total weights. Towards this end, MaxText supports [packing](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/sequence_packing.py#L37) multiple short sequences into one. This is enabled by default with `packing=True` in [base.yml](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/configs/base.yml#L477).
- For better convergence, we want to have large total weights. Towards this end, MaxText supports [packing](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/sequence_packing.py#L37) multiple short sequences into one. This is enabled by default with `packing=True` in [base.yml](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/configs/base.yml#L465).
4 changes: 3 additions & 1 deletion docs/tutorials/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Supervised fine-tuning (SFT) is a process where a pre-trained large language mod
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the model on a Hugging Face dataset using SFT.

We use [Tunix](https://github.com/google/tunix), a JAX-based library designed for post-training tasks, to perform SFT.

Copy link
Collaborator

@bvandermoon bvandermoon Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think there is an extra whitespace added here. It might fail the linter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started!

## Install dependencies
Expand All @@ -41,7 +41,9 @@ install_maxtext_github_deps
```

## Setup environment variables

Set the following environment variables before running SFT.

```sh
# -- Model configuration --
export PRE_TRAINED_MODEL=<model name> # e.g., 'llama3.1-8b'
Expand Down
Loading