Skip to content

Commit d972e35

Browse files
Merge pull request #2335 from AI-Hypercomputer:tunix_fix
PiperOrigin-RevId: 806064971
2 parents 1781eb2 + 1196b26 commit d972e35

File tree

3 files changed

+3
-4
lines changed

3 files changed

+3
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ dependencies = [
236236
"tqdm>=4.67.1",
237237
"transformers>=4.56.1",
238238
"treescope>=0.1.10",
239-
"tunix @ https://github.com/google/tunix/archive/4c5561be36d8a2f1f0858c2685554ca4e1a65fd2.zip",
239+
"tunix @ https://github.com/google/tunix/archive/d770659621eb16ef6588268e26fa687fa068df20.zip",
240240
"typing-extensions>=4.14.1",
241241
"typing-inspection>=0.4.1",
242242
"tzdata>=2025.2",

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ tensorflow-text
3636
tensorflow
3737
tiktoken
3838
transformers
39-
tunix @ https://github.com/google/tunix/archive/4c5561be36d8a2f1f0858c2685554ca4e1a65fd2.zip
39+
tunix @ https://github.com/google/tunix/archive/d770659621eb16ef6588268e26fa687fa068df20.zip
4040
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip
4141
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
4242
qwix @ https://github.com/google/qwix/archive/f2fd7b9114ff8d09e5b0131a453351578502da8a.zip

src/MaxText/examples/grpo_llama3_demo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,6 @@ def get_ref_maxtext_model(config):
474474
if DEBUG:
475475
print("Model initialized successfully")
476476
print(f"Model mesh shape: {mesh_policy.shape}")
477-
print(f"Model config: {model_config_policy}")
478477

479478
# Sanity check that weights are loaded correctly
480479
_maxtext_state_flatten = nnx.state(llama3_1_8b_policy).flat_state()
@@ -955,7 +954,7 @@ def evaluate(
955954
# verify if vllm sampler works
956955
output = rl_cluster.rollout.generate(
957956
["The capital of France is"],
958-
rollout_config=RolloutConfig(n=1, max_tokens_to_generate=64, temperature=0.1),
957+
rollout_config=RolloutConfig(max_tokens_to_generate=64, temperature=0.1),
959958
)
960959

961960
print(f"Output: {output}")

0 commit comments

Comments
 (0)