Skip to content

Commit 81eca03

Browse files
Merge pull request #2632 from AI-Hypercomputer:anisha-fix-trainrl
PiperOrigin-RevId: 831929294
2 parents 67f43b8 + be4576a commit 81eca03

File tree

11 files changed

+186
-175
lines changed

11 files changed

+186
-175
lines changed

dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,38 +22,17 @@ RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with M
2222

2323

2424
# Uninstall existing jax to avoid conflicts
25-
RUN pip uninstall -y jax jaxlib libtpu
26-
27-
RUN pip install aiohttp==3.12.15
28-
29-
# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
30-
RUN pip install keyring keyrings.google-artifactregistry-auth
31-
32-
RUN pip install numba==0.61.2
33-
34-
# Install vLLM for Jax and TPUs from the artifact registry
35-
RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
36-
--index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
37-
--extra-index-url https://pypi.org/simple/ \
38-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
39-
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
40-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
41-
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
42-
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
43-
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
44-
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
45-
vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu
46-
47-
# Install tpu-commons from the artifact registry
48-
RUN pip install --no-cache-dir --pre \
49-
--index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
50-
--extra-index-url https://pypi.org/simple/ \
51-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
52-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
53-
tpu-commons==0.1.2
25+
RUN uv pip uninstall -y jax jaxlib libtpu
26+
27+
RUN uv pip install aiohttp==3.12.15
28+
29+
RUN uv pip install numba==0.61.2
30+
31+
# Install vLLM for Jax and TPUs
32+
RUN uv pip install vllm-tpu
5433

5534
RUN if [ "$MODE" = "post-training-experimental" ]; then \
56-
pip uninstall -y jax jaxlib libtpu && \
57-
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
58-
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
35+
uv pip uninstall -y jax jaxlib libtpu && \
36+
uv pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
37+
uv pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
5938
fi
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
ARG BASEIMAGE
16+
FROM ${BASEIMAGE}
17+
ARG MODE
18+
ENV MODE=$MODE
19+
20+
RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}"
21+
RUN pip uninstall -y jax jaxlib libtpu
22+
23+
RUN pip install aiohttp==3.12.15
24+
25+
# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
26+
RUN pip install keyring keyrings.google-artifactregistry-auth
27+
28+
RUN pip install numba==0.61.2
29+
30+
COPY tunix /tunix
31+
RUN pip install -e /tunix --no-cache-dir
32+
33+
34+
COPY vllm /vllm
35+
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \
36+
--extra-index-url https://pypi.org/simple/ \
37+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
38+
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
39+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
40+
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
41+
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
42+
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
43+
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
44+
45+
46+
COPY tpu-inference /tpu-inference
47+
RUN pip install -e /tpu-inference --no-cache-dir --pre \
48+
--extra-index-url https://pypi.org/simple/ \
49+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
50+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
51+
52+
53+
RUN if [ "$MODE" = "post-training-experimental" ]; then \
54+
echo "MODE=post-training-experimental: Re-installing JAX/libtpu"; \
55+
pip uninstall -y jax jaxlib libtpu && \
56+
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
57+
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
58+
fi

dependencies/scripts/docker_build_dependency_image.sh

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# you may not use this file except in compliance with the License.
77
# You may obtain a copy of the License at
88
#
9-
# https://www.apache.org/licenses/LICENSE-2.0
9+
# https://www.apache.org/licenses/LICENSE-2.0
1010
#
1111
# Unless required by applicable law or agreed to in writing, software
1212
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -20,14 +20,15 @@
2020
# bash docker_build_dependency_image.sh MODE=nightly
2121
# bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13
2222
# Nightly build with JAX_VERSION for GPUs. Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax:
23-
# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly
23+
# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly
2424
# MODE=custom_wheels is the same as nightly except that it reinstalls any
2525
# additional wheels that are present in the maxtext directory.
2626
# The main use case is to install custom jax or jaxlib wheels but it also
2727
# works with any custom wheels.
2828
# bash docker_build_dependency_image.sh MODE=custom_wheels
2929

3030
# bash docker_build_dependency_image.sh MODE=post-training
31+
# bash docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local
3132

3233
if [ "${BASH_SOURCE-}" ]; then
3334
this_file="${BASH_SOURCE[0]}"
@@ -97,6 +98,12 @@ if [[ -z ${DEVICE} ]]; then
9798
echo "Default DEVICE=${DEVICE}"
9899
fi
99100

101+
# New flag for post-training source
102+
if [[ -z ${POST_TRAINING_SOURCE} ]]; then
103+
export POST_TRAINING_SOURCE=remote # Default to the original Dockerfile
104+
echo "Default POST_TRAINING_SOURCE=${POST_TRAINING_SOURCE}"
105+
fi
106+
100107
# Function to build with MODE=jax_ai_image
101108
build_ai_image() {
102109
if [[ -z ${BASEIMAGE+x} ]]; then
@@ -171,24 +178,34 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then
171178
exit 1
172179
fi
173180

174-
# # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__.
175-
# # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext).
176-
# rsync -a --exclude='__pycache__' ../tpu_commons .
177-
# # To install vllm from a local path, we copy it into the build context, excluding __pycache__.
178-
# # This assumes vllm is a sibling directory to the current one (maxtext).
179-
# rsync -a --exclude='__pycache__' ../vllm .
180-
181-
# rsync -a --exclude='__pycache__' ../tunix .
182-
183-
# # The cleanup is set to run even if the build fails to remove the copied directory.
184-
# trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM
185-
186-
docker build \
187-
--network host \
188-
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
189-
--build-arg MODE=${MODE} \
190-
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile' \
191-
-t ${LOCAL_IMAGE_NAME} .
181+
DOCKERFILE_NAME=""
182+
if [[ ${POST_TRAINING_SOURCE} == "local" ]] ; then
183+
184+
# To install tpu-inference from a local path, we copy it into the build context, excluding __pycache__.
185+
# This assumes vllm, tunix, tpu-inference is a sibling directory to the current one (maxtext).
186+
rsync -a --exclude='__pycache__' ../tpu-inference .
187+
# To install vllm from a local path, we copy it into the build context, excluding __pycache__.
188+
# This assumes vllm is a sibling directory to the current one (maxtext).
189+
rsync -a --exclude='__pycache__' ../vllm .
190+
191+
rsync -a --exclude='__pycache__' ../tunix .
192+
193+
# The cleanup is set to run even if the build fails to remove the copied directory.
194+
trap "rm -rf ./tpu-inference ./vllm ./tunix" EXIT INT TERM
195+
196+
DOCKERFILE_NAME='maxtext_post_training_local_dependencies.Dockerfile'
197+
echo "Using local post-training dependencies Dockerfile: $DOCKERFILE_NAME"
198+
else
199+
DOCKERFILE_NAME='maxtext_post_training_dependencies.Dockerfile'
200+
echo "Using remote post-training dependencies Dockerfile: $DOCKERFILE_NAME"
201+
fi
202+
203+
docker build \
204+
--network host \
205+
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
206+
--build-arg MODE=${MODE} \
207+
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/'"$DOCKERFILE_NAME" \
208+
-t ${LOCAL_IMAGE_NAME} .
192209
fi
193210

194211
if [[ ${CUSTOM_JAX} -eq 1 ]] ; then

docs/tutorials/grpo.md

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,35 +25,20 @@ And we use vLLM as the library for efficient model inference and generation.
2525

2626
In this tutorial we use a single host TPUVM such as `v6e-8/v5p-8`. Let's get started!
2727

28-
## Setup your virtual environment
28+
## Create virtual environment and Install MaxText dependencies
29+
Follow instructions in [Install MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), but
30+
recommend creating the virtual environment outside the `maxtext` directory.
2931

30-
### Create a Python3.12 venv if not already pre-existing and install MaxText dependencies
31-
```sh
32-
bash tools/setup/setup.sh
33-
```
34-
35-
### Activate your virtual environment (Skip if you have already done this for running `bash tools/setup/setup.sh` )
36-
```
37-
# Replace with your virtual environment name if not using this default name
38-
venv_name="maxtext_venv"
39-
source ~/$venv_name/bin/activate
40-
```
41-
42-
## vLLM and tpu-commons installations
32+
## vLLM and tpu-inference installations
4333

44-
Next, run the following bash script to get all the necessary installations inside the virtual environment.
34+
Next, run the following bash script to get all the necessary installations inside the virtual environment (for e.g., `maxtext_venv`).
4535
This will take few minutes. Follow along the installation logs and look out for any issues!
4636

4737
```
4838
bash ~/maxtext/src/MaxText/examples/install_tunix_vllm_requirement.sh
4939
```
5040

51-
1. It installs `pip install keyring keyrings.google-artifactregistry-auth` which enables pip to authenticate with Google Artifact Registry automatically.
52-
2. Next, it installs `vLLM` for Jax and TPUs from the artifact registry `https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/`
53-
3. Then, it installs `tpu-commons` from the same artifact registry.
54-
55-
`tpu_commons` is the TPU backend for vLLM. You will need both libraries to run vLLM on tpus.
56-
We use the scheduler code from vLLM, and the model runner code from `tpu_commons`
41+
Primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support.
5742

5843

5944
## Run GRPO
@@ -62,15 +47,15 @@ Finally, run the command
6247

6348
```
6449
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
65-
--model_name=llama3.1-8b \
66-
--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
67-
--load_parameters_path=gs://path/to/checkpoint/0/items \
68-
--run_name=$WORKLOAD \
69-
--base_output_directory=$OUTPUT_PATH \
70-
--hf_access_token=$HF_TOKEN
50+
model_name=llama3.1-8b \
51+
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
52+
load_parameters_path=gs://path/to/checkpoint/0/items \
53+
run_name=$WORKLOAD \
54+
base_output_directory=$OUTPUT_PATH \
55+
hf_access_token=$HF_TOKEN
7156
```
7257

73-
The overview of the demo script is as follows:
58+
The overview of the what this run will do is as follows:
7459

7560
1. We load a policy model and a reference model. Both are copies of `Llama3.1-8b-Instruct`.
7661
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.

docs/tutorials/grpo_with_pathways.md

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,28 @@ We use Tunix as the library for GRPO.
2424
And we use vLLM as the library for efficient model inference and generation.
2525

2626
Furthermore, we use Pathways for [orchestration](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro). Using Pathways, you can also run GRPO in a disaggregated mode where the trainer and the samplers are running on separate mesh. Try out the following recipe `v5p-64`. You can submit jobs to a Pathways enabled GKE cluster.
27-
28-
## Build and Upload MaxText Docker Image with Tunix, vLLM, tpu-commons dependencies
29-
Run the following bash script to create a docker image with all the dependencies of MaxText, Tunix, vLLM and tpu-commons installed.
30-
31-
In addition to MaxText dependencies,
3227

33-
1. It installs `pip install keyring keyrings.google-artifactregistry-auth` which enables pip to authenticate with Google Artifact Registry automatically.
34-
2. Next, it installs `vLLM` for Jax and TPUs from the artifact registry `https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/`
35-
3. Then, it installs `tpu-commons` from the same artifact registry.
28+
## Create virtual environment and Install MaxText dependencies
29+
Follow instructions in [Install MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), but
30+
recommend creating the virtual environment outside the `maxtext` directory.
3631

32+
## Build and Upload MaxText Docker Image with Tunix, vLLM, tpu-inference dependencies
3733

38-
`tpu_commons` is the TPU backend for vLLM. You will need both libraries to run vLLM on tpus.
39-
We use the scheduler code from vLLM, and the model runner code from `tpu_commons`
34+
### Installing stable releases of tunix and vllm-tpu
35+
Run the following bash script to create a docker image with all the dependencies of MaxText, Tunix, vLLM and tpu-inference installed.
4036

37+
In addition to MaxText dependencies, primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support.
38+
4139
```
42-
bash docker_build_dependency_image.sh MODE=post-training
40+
bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training
4341
```
4442

45-
You can also use `bash docker_build_dependency_image.sh MODE=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API
43+
You can also use `bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API
4644

45+
### Install from locally git cloned repo's
4746

47+
You can also locally git clone [tunix](https://github.com/google/tunix), [tpu-inference](https://github.com/vllm-project/tpu-inference), [vllm](https://github.com/vllm-project/vllm.git) and then use the following command to build a docker image using them:
48+
`bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local`
4849

4950
### Upload the dependency docker image along with MaxText code
5051
```
@@ -61,12 +62,12 @@ xpk workload create-pathways --workload $WORKLOAD \
6162
--project=$PROJECT_ID --priority=high \
6263
--command "HF_TOKEN=$HF_TOKEN TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' # Llama3.1-70B-Instruct
6364
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
64-
--model_name=llama3.1-70b \
65-
--tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
66-
--load_parameters_path=gs://path/to/checkpoint/0/items \
67-
--run_name=$WORKLOAD \
68-
--base_output_directory=$OUTPUT_PATH \
69-
--hf_access_token=$HF_TOKEN"
65+
model_name=llama3.1-70b \
66+
tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
67+
load_parameters_path=gs://path/to/checkpoint/0/items \
68+
run_name=$WORKLOAD \
69+
base_output_directory=$OUTPUT_PATH \
70+
hf_access_token=$HF_TOKEN"
7071
```
7172

7273
The overview of the demo script ~/maxtext/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py` is as follows:

src/MaxText/configs/rl.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ value_proj: 'offload'
7171
checkpoint_storage_use_ocdbt: False # For Pathways
7272
checkpoint_storage_use_zarr3: False # For Pathways
7373
use_pathways: True
74+
log_period: 20
7475

7576
# ====== Debugging ======
7677
debug:

src/MaxText/examples/install_tunix_vllm_requirement.sh

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,15 @@
1919
set -e
2020
set -x
2121

22-
python -m ensurepip --default-pip
23-
24-
pip uninstall -y jax jaxlib libtpu
25-
26-
pip install aiohttp==3.12.15
27-
28-
# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
29-
pip install keyring keyrings.google-artifactregistry-auth
30-
31-
# Install vLLM for Jax and TPUs from the artifact registry
32-
VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
33-
--index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
34-
--extra-index-url https://pypi.org/simple/ \
35-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
36-
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
37-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
38-
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
39-
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
40-
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
41-
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
42-
vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu
43-
44-
# Install tpu-commons from the artifact registry
45-
pip install --no-cache-dir --pre \
46-
--index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
47-
--extra-index-url https://pypi.org/simple/ \
48-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
49-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
50-
tpu-commons==0.1.2
51-
52-
pip install numba==0.61.2
22+
uv pip uninstall -y jax jaxlib libtpu
23+
24+
uv pip install aiohttp==3.12.15
25+
26+
# Install vLLM for Jax and TPUs
27+
uv pip install vllm-tpu
28+
29+
uv pip install numba==0.61.2
30+
31+
uv pip install qwix==0.1.1
32+
33+
uv pip install flax==0.11.1

0 commit comments

Comments
 (0)