Skip to content

Conversation

@NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Nov 12, 2025

Description

This PR updates the training loop to align with MaxDiffusion patterns, improving the stability of TFLOP/s/device reporting and eliminating metric outliers.

Motivation

Previously, reported step times were heavily influenced by two factors, leading to noisy benchmarking results:

  1. Profiler Overhead: The synchronization required for clean profiling (blocking the device) caused massive outliers in step time during activation/deactivation steps.
  2. Synchronous I/O: Data loading on the main thread introduced latency spikes that were falsely attributed to model performance.

Changes

  1. Asynchronous Data Loading: Implemented ThreadPoolExecutor to prefetch the next batch in a background thread, allowing data loading to overlap with device computation.
  2. Isolated Step Timing: Modified the timing logic to strictly track p_train_step execution by adding an explicit block_until_ready() call to the loss metric.

Tests

Example log difference

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@NuojCheng NuojCheng added the draft Draft PR label Nov 12, 2025
@NuojCheng NuojCheng force-pushed the chengnuojin-thread-profiler branch 3 times, most recently from 150872d to a22c9fd Compare November 12, 2025 17:25
@NuojCheng NuojCheng force-pushed the chengnuojin-thread-profiler branch from a22c9fd to 47163f7 Compare November 12, 2025 17:56
@NuojCheng NuojCheng changed the title Add subprocess on profiler and dataloader to remove outlier steps Async Data Loading & Accurate Step Timing to Exclude Profiler Overhead Nov 12, 2025
@github-actions
Copy link

🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request effectively addresses performance metric inaccuracies by introducing asynchronous data loading and refining step timing. The use of ThreadPoolExecutor for prefetching batches is a solid approach to overlap I/O with computation. The core logic changes are sound and well-motivated.

🔍 General Feedback

  • The change to isolate p_train_step for timing purposes by using jax.block_until_ready() is a significant improvement for accurate benchmarking.
  • The code is well-structured, and the addition of asynchronous loading is handled cleanly.
  • One minor suggestion was made to add a comment for clarity.

Overall, this is a valuable enhancement for performance analysis.

jax.block_until_ready(state) # Ensure compilation has finished.
gcs_utils.upload_dump(
config.dump_hlo_local_dir,
config.dump_hlo_gcs_dir,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 To enhance clarity for future maintainers, it would be beneficial to add a comment explaining why jax.block_until_ready(state) is called here. This call is crucial for ensuring that all asynchronous device operations for the current training step are complete before the step time is recorded. This guarantees an accurate measurement of the true execution time.

Suggested change
config.dump_hlo_gcs_dir,
state, metrics = p_train_step(state, example_batch, nextrng)
# Block to get accurate step time
jax.block_until_ready(state)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants