Skip to content

Commit 9f266a6

Browse files
committed
updates
1 parent ce71603 commit 9f266a6

File tree

9 files changed

+154
-107
lines changed

9 files changed

+154
-107
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
ARG BASEIMAGE
16+
FROM ${BASEIMAGE}
17+
ARG MODE
18+
ENV MODE=$MODE
19+
20+
RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}"
21+
RUN pip uninstall -y jax jaxlib libtpu
22+
23+
RUN pip install aiohttp==3.12.15
24+
25+
# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
26+
RUN pip install keyring keyrings.google-artifactregistry-auth
27+
28+
RUN pip install numba==0.61.2
29+
30+
COPY tunix /tunix
31+
RUN pip install -e /tunix --no-cache-dir
32+
33+
34+
COPY vllm /vllm
35+
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \
36+
--extra-index-url https://pypi.org/simple/ \
37+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
38+
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
39+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
40+
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
41+
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
42+
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
43+
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
44+
45+
46+
COPY tpu-inference /tpu-inference
47+
RUN pip install -e /tpu-inference --no-cache-dir --pre \
48+
--extra-index-url https://pypi.org/simple/ \
49+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
50+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
51+
52+
53+
RUN if [ "$MODE" = "post-training-experimental" ]; then \
54+
echo "MODE=post-training-experimental: Re-installing JAX/libtpu"; \
55+
pip uninstall -y jax jaxlib libtpu && \
56+
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
57+
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
58+
fi

dependencies/scripts/docker_build_dependency_image.sh

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# you may not use this file except in compliance with the License.
77
# You may obtain a copy of the License at
88
#
9-
# https://www.apache.org/licenses/LICENSE-2.0
9+
# https://www.apache.org/licenses/LICENSE-2.0
1010
#
1111
# Unless required by applicable law or agreed to in writing, software
1212
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -20,14 +20,15 @@
2020
# bash docker_build_dependency_image.sh MODE=nightly
2121
# bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13
2222
# Nightly build with JAX_VERSION for GPUs. Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax:
23-
# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly
23+
# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly
2424
# MODE=custom_wheels is the same as nightly except that it reinstalls any
2525
# additional wheels that are present in the maxtext directory.
2626
# The main use case is to install custom jax or jaxlib wheels but it also
2727
# works with any custom wheels.
2828
# bash docker_build_dependency_image.sh MODE=custom_wheels
2929

3030
# bash docker_build_dependency_image.sh MODE=post-training
31+
# bash docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local
3132

3233
if [ "${BASH_SOURCE-}" ]; then
3334
this_file="${BASH_SOURCE[0]}"
@@ -97,6 +98,12 @@ if [[ -z ${DEVICE} ]]; then
9798
echo "Default DEVICE=${DEVICE}"
9899
fi
99100

101+
# New flag for post-training source
102+
if [[ -z ${POST_TRAINING_SOURCE} ]]; then
103+
export POST_TRAINING_SOURCE=remote # Default to the original Dockerfile
104+
echo "Default POST_TRAINING_SOURCE=${POST_TRAINING_SOURCE}"
105+
fi
106+
100107
# Function to build with MODE=jax_ai_image
101108
build_ai_image() {
102109
if [[ -z ${BASEIMAGE+x} ]]; then
@@ -171,24 +178,34 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then
171178
exit 1
172179
fi
173180

174-
# # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__.
175-
# # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext).
176-
# rsync -a --exclude='__pycache__' ../tpu_commons .
177-
# # To install vllm from a local path, we copy it into the build context, excluding __pycache__.
178-
# # This assumes vllm is a sibling directory to the current one (maxtext).
179-
# rsync -a --exclude='__pycache__' ../vllm .
180-
181-
# rsync -a --exclude='__pycache__' ../tunix .
182-
183-
# # The cleanup is set to run even if the build fails to remove the copied directory.
184-
# trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM
185-
186-
docker build \
187-
--network host \
188-
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
189-
--build-arg MODE=${MODE} \
190-
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile' \
191-
-t ${LOCAL_IMAGE_NAME} .
181+
DOCKERFILE_NAME=""
182+
if [[ ${POST_TRAINING_SOURCE} == "local" ]] ; then
183+
184+
# To install tpu-inference from a local path, we copy it into the build context, excluding __pycache__.
185+
# This assumes vllm, tunix, tpu-inference is a sibling directory to the current one (maxtext).
186+
rsync -a --exclude='__pycache__' ../tpu-inference .
187+
# To install vllm from a local path, we copy it into the build context, excluding __pycache__.
188+
# This assumes vllm is a sibling directory to the current one (maxtext).
189+
rsync -a --exclude='__pycache__' ../vllm .
190+
191+
rsync -a --exclude='__pycache__' ../tunix .
192+
193+
# The cleanup is set to run even if the build fails to remove the copied directory.
194+
trap "rm -rf ./tpu-inference ./vllm ./tunix" EXIT INT TERM
195+
196+
DOCKERFILE_NAME='maxtext_post_training_local_dependencies.Dockerfile'
197+
echo "Using local post-training dependencies Dockerfile: $DOCKERFILE_NAME"
198+
else
199+
DOCKERFILE_NAME='maxtext_post_training_dependencies.Dockerfile'
200+
echo "Using remote post-training dependencies Dockerfile: $DOCKERFILE_NAME"
201+
fi
202+
203+
docker build \
204+
--network host \
205+
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
206+
--build-arg MODE=${MODE} \
207+
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/'"$DOCKERFILE_NAME" \
208+
-t ${LOCAL_IMAGE_NAME} .
192209
fi
193210

194211
if [[ ${CUSTOM_JAX} -eq 1 ]] ; then

docs/tutorials/grpo.md

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,19 @@ In this tutorial we use a single host TPUVM such as `v6e-8/v5p-8`. Let's get sta
2727

2828
## Setup your virtual environment
2929

30-
### Create a Python3.12 venv if not already pre-existing and install MaxText dependencies
31-
```sh
32-
bash tools/setup/setup.sh
33-
```
30+
### Create virtual environment and Install MaxText dependencies
31+
Follow instructions in [Install MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md)
3432

35-
### Activate your virtual environment (Skip if you have already done this for running `bash tools/setup/setup.sh` )
36-
```
37-
# Replace with your virtual environment name if not using this default name
38-
venv_name="maxtext_venv"
39-
source ~/$venv_name/bin/activate
40-
```
33+
## vLLM and tpu-inference installations
4134

42-
## vLLM and tpu-commons installations
43-
44-
Next, run the following bash script to get all the necessary installations inside the virtual environment.
35+
Next, run the following bash script to get all the necessary installations inside the virtual environment (for e.g., `maxtext_venv`).
4536
This will take few minutes. Follow along the installation logs and look out for any issues!
4637

4738
```
4839
bash ~/maxtext/src/MaxText/examples/install_tunix_vllm_requirement.sh
4940
```
5041

51-
1. It installs `pip install keyring keyrings.google-artifactregistry-auth` which enables pip to authenticate with Google Artifact Registry automatically.
52-
2. Next, it installs `vLLM` for Jax and TPUs from the artifact registry `https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/`
53-
3. Then, it installs `tpu-commons` from the same artifact registry.
54-
55-
`tpu_commons` is the TPU backend for vLLM. You will need both libraries to run vLLM on tpus.
56-
We use the scheduler code from vLLM, and the model runner code from `tpu_commons`
42+
Primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support.
5743

5844

5945
## Run GRPO
@@ -62,15 +48,15 @@ Finally, run the command
6248

6349
```
6450
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
65-
--model_name=llama3.1-8b \
66-
--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
67-
--load_parameters_path=gs://path/to/checkpoint/0/items \
68-
--run_name=$WORKLOAD \
69-
--base_output_directory=$OUTPUT_PATH \
70-
--hf_access_token=$HF_TOKEN
51+
model_name=llama3.1-8b \
52+
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
53+
load_parameters_path=gs://path/to/checkpoint/0/items \
54+
run_name=$WORKLOAD \
55+
base_output_directory=$OUTPUT_PATH \
56+
hf_access_token=$HF_TOKEN
7157
```
7258

73-
The overview of the demo script is as follows:
59+
The overview of the what this run will do is as follows:
7460

7561
1. We load a policy model and a reference model. Both are copies of `Llama3.1-8b-Instruct`.
7662
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.

docs/tutorials/grpo_with_pathways.md

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,23 @@ And we use vLLM as the library for efficient model inference and generation.
2525

2626
Furthermore, we use Pathways for [orchestration](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro). Using Pathways, you can also run GRPO in a disaggregated mode where the trainer and the samplers are running on separate mesh. Try out the following recipe `v5p-64`. You can submit jobs to a Pathways enabled GKE cluster.
2727

28-
## Build and Upload MaxText Docker Image with Tunix, vLLM, tpu-commons dependencies
29-
Run the following bash script to create a docker image with all the dependencies of MaxText, Tunix, vLLM and tpu-commons installed.
30-
31-
In addition to MaxText dependencies,
32-
33-
1. It installs `pip install keyring keyrings.google-artifactregistry-auth` which enables pip to authenticate with Google Artifact Registry automatically.
34-
2. Next, it installs `vLLM` for Jax and TPUs from the artifact registry `https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/`
35-
3. Then, it installs `tpu-commons` from the same artifact registry.
28+
## Build and Upload MaxText Docker Image with Tunix, vLLM, tpu-inference dependencies
3629

30+
### Installing stable releases of tunix and vllm-tpu
31+
Run the following bash script to create a docker image with all the dependencies of MaxText, Tunix, vLLM and tpu-commons installed.
3732

38-
`tpu_commons` is the TPU backend for vLLM. You will need both libraries to run vLLM on tpus.
39-
We use the scheduler code from vLLM, and the model runner code from `tpu_commons`
40-
33+
In addition to MaxText dependencies, primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support.
34+
4135
```
42-
bash docker_build_dependency_image.sh MODE=post-training
36+
bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training
4337
```
4438

45-
You can also use `bash docker_build_dependency_image.sh MODE=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API
39+
You can also use `bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API
4640

41+
### Install from locally git cloned repo's
4742

43+
You can also locally git clone [tunix](https://github.com/google/tunix), [tpu-inference](https://github.com/vllm-project/tpu-inference), [vllm](https://github.com/vllm-project/vllm.git) and then use the following command to build a docker image using them:
44+
`bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local`
4845

4946
### Upload the dependency docker image along with MaxText code
5047
```
@@ -61,12 +58,12 @@ xpk workload create-pathways --workload $WORKLOAD \
6158
--project=$PROJECT_ID --priority=high \
6259
--command "HF_TOKEN=$HF_TOKEN TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' # Llama3.1-70B-Instruct
6360
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
64-
--model_name=llama3.1-70b \
65-
--tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
66-
--load_parameters_path=gs://path/to/checkpoint/0/items \
67-
--run_name=$WORKLOAD \
68-
--base_output_directory=$OUTPUT_PATH \
69-
--hf_access_token=$HF_TOKEN"
61+
model_name=llama3.1-70b \
62+
tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
63+
load_parameters_path=gs://path/to/checkpoint/0/items \
64+
run_name=$WORKLOAD \
65+
base_output_directory=$OUTPUT_PATH \
66+
hf_access_token=$HF_TOKEN"
7067
```
7168

7269
The overview of the demo script ~/maxtext/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py` is as follows:

src/MaxText/configs/rl.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ value_proj: 'offload'
7171
checkpoint_storage_use_ocdbt: False # For Pathways
7272
checkpoint_storage_use_zarr3: False # For Pathways
7373
use_pathways: True
74+
log_period: 20
7475

7576
# ====== Debugging ======
7677
debug:

src/MaxText/examples/install_tunix_vllm_requirement.sh

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,28 +25,11 @@ pip uninstall -y jax jaxlib libtpu
2525

2626
pip install aiohttp==3.12.15
2727

28-
# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
29-
pip install keyring keyrings.google-artifactregistry-auth
30-
31-
# Install vLLM for Jax and TPUs from the artifact registry
32-
VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
33-
--index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
34-
--extra-index-url https://pypi.org/simple/ \
35-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
36-
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
37-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
38-
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
39-
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
40-
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
41-
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
42-
vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu
43-
44-
# Install tpu-commons from the artifact registry
45-
pip install --no-cache-dir --pre \
46-
--index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
47-
--extra-index-url https://pypi.org/simple/ \
48-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
49-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
50-
tpu-commons==0.1.2
28+
# Install vLLM for Jax and TPUs
29+
pip install vllm-tpu
5130

5231
pip install numba==0.61.2
32+
33+
pip install qwix==0.1.1
34+
35+
pip install flax==0.11.1

src/MaxText/rl/evaluate_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def generate_responses(
6868
responses = rl_cluster.rollout.generate(
6969
prompts,
7070
rollout_config=RolloutConfig(
71-
max_tokens_to_generate=tmvp_config.max_target_length,
71+
max_tokens_to_generate=tmvp_config.max_target_length - tmvp_config.max_prefill_predict_length,
7272
temperature=eval_strategy["eval_temperature"],
7373
top_k=eval_strategy["eval_top_k"],
7474
top_p=eval_strategy["eval_top_p"],

0 commit comments

Comments
 (0)