From 78057df03ad33e89da28d687a171c37ede573957 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 23 Sep 2025 16:02:13 -0600 Subject: [PATCH 1/3] Added timing for individual weight syncs --- open_instruct/grpo_fast.py | 27 +++++++++++++++++++-------- open_instruct/utils.py | 18 +++++++++++++----- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index f6f60b9e7..4db909a6e 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1247,9 +1247,9 @@ def __init__( ).remote(world_size, 0, 0, None, None) self.models.append(master_policy) - master_addr, master_port = ray_get_with_progress( + (master_addr, master_port), _ = ray_get_with_progress( [master_policy.get_master_addr_port.remote()], desc="Getting master address" - )[0] + ) def get_bundle_index(rank, num_gpus_per_node): """given a rank and a list of num_gpus_per_node, return the index of the bundle that the rank belongs to""" @@ -2039,7 +2039,8 @@ def create_model_and_optimizer( verbose=args.verbose, ) - resume_training_step = ray_get_with_progress(inits, desc="Initializing models")[0] + 1 + results, _ = ray_get_with_progress(inits, desc="Initializing models") + resume_training_step = results[0] + 1 episode = (resume_training_step - 1) * args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout logger.info("======== ✅ all models and vLLM engines initialized =========") @@ -2169,19 +2170,29 @@ def weight_sync_thread( # First get the futures weight_broadcast_futures: List[ray.ObjectRef] = [m.broadcast_to_vllm.remote() for m in policy_group.models] - # Wait for all weight updates to complete - ray_get_with_progress( + # Wait for all weight updates to complete and collect individual timings + _, actor_sync_times = ray_get_with_progress( weight_broadcast_futures, desc="[Weight Sync Thread] Waiting for weight updates to complete", enable=args.verbose, + collect_timings=True, ) # Allow actors to resume ray.get(actor_manager.set_should_stop.remote(False)) logger.debug("[Weight Sync Thread] Set should_stop to False after weight sync") + # Calculate distribution statistics + sync_time_stats = { + "time/weight_sync": timer.duration, + "time/weight_sync_mean": np.mean(actor_sync_times), + "time/weight_sync_min": np.min(actor_sync_times), + "time/weight_sync_max": np.max(actor_sync_times), + "time/weight_sync_median": np.median(actor_sync_times), + } + try: - weight_sync_metrics_Q.put_nowait({"time/weight_sync": timer.duration}) + weight_sync_metrics_Q.put_nowait(sync_time_stats) except Full: logger.warning("[Weight Sync Thread] weight sync metrics queue full, skipping metric") @@ -2193,7 +2204,7 @@ def generate_thread(args, vllm_engines, resume_training_step, stop_event, genera logger.info("[Generate Thread] 🚀 Starting generation thread") while not stop_event.is_set(): with Timer("🔥 Generation time") as timer: - processed_results = ray_get_with_progress( + processed_results, _ = ray_get_with_progress( [engine.process_from_queue.remote(timeout=20) for engine in vllm_engines], desc="[Generate Thread] Waiting for vLLM engines to process", enable=args.verbose, @@ -2229,7 +2240,7 @@ def one_training_step( """Train the model for one step.""" update_ref_policy_future = [] with Timer("[Main Thread] 🗡️ Training") as train_timer: - metrics_list: List[dict[str, float]] = ray_get_with_progress( + metrics_list, _ = ray_get_with_progress( [ policy_group.models[i].train.remote( **collated_data[i], pad_token_id=tokenizer.pad_token_id, num_mini_batches=args.num_mini_batches diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 2c6c7d345..bfd01e651 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -80,8 +80,8 @@ def repeat_each(seq, k): def ray_get_with_progress( ray_refs: List[ray.ObjectRef], desc: str = "Processing", enable: bool = True, timeout: Optional[float] = None -) -> List[Any]: - """Execute ray.get() with a progress bar using futures. +): + """Execute ray.get() with a progress bar using futures and collect timings. Args: ray_refs: List of ray object references @@ -90,23 +90,31 @@ def ray_get_with_progress( timeout: Optional timeout in seconds for all operations to complete Returns: - List of results in the same order as ray_refs + (results, completion_times) + - results: List of results in the same order as ray_refs + - completion_times: time from function start until each ref completed (seconds), aligned to ray_refs Raises: TimeoutError: If timeout is specified and operations don't complete in time """ + t0 = time.perf_counter() + ray_futures = [ref.future() for ref in ray_refs] + fut_to_idx = {f: i for i, f in enumerate(ray_futures)} + results = [None] * len(ray_refs) + completion_times = [None] * len(ray_refs) futures_iter = futures.as_completed(ray_futures, timeout=timeout) if enable: futures_iter = tqdm(futures_iter, total=len(ray_futures), desc=desc, bar_format="{l_bar}{bar}{r_bar}\n") for future in futures_iter: - idx = ray_futures.index(future) + idx = fut_to_idx[future] results[idx] = future.result() + completion_times[idx] = time.perf_counter() - t0 - return results + return results, completion_times """ From 543ddc5a20378a4e0f60dffceb74c77cc7391b67 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 23 Sep 2025 16:07:46 -0600 Subject: [PATCH 2/3] Fixed bug --- open_instruct/grpo_fast.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 4db909a6e..40886594d 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1247,9 +1247,10 @@ def __init__( ).remote(world_size, 0, 0, None, None) self.models.append(master_policy) - (master_addr, master_port), _ = ray_get_with_progress( + results, _ = ray_get_with_progress( [master_policy.get_master_addr_port.remote()], desc="Getting master address" ) + (master_addr, master_port) = results[0] def get_bundle_index(rank, num_gpus_per_node): """given a rank and a list of num_gpus_per_node, return the index of the bundle that the rank belongs to""" From 4426c41e1eb388a15e26dee0b5d415ea952bb1bf Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 23 Sep 2025 16:13:05 -0600 Subject: [PATCH 3/3] Fixed bug --- open_instruct/grpo_fast.py | 1 - 1 file changed, 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 40886594d..1647b3b2d 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -2176,7 +2176,6 @@ def weight_sync_thread( weight_broadcast_futures, desc="[Weight Sync Thread] Waiting for weight updates to complete", enable=args.verbose, - collect_timings=True, ) # Allow actors to resume