Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,9 +1247,10 @@ def __init__(
).remote(world_size, 0, 0, None, None)

self.models.append(master_policy)
master_addr, master_port = ray_get_with_progress(
results, _ = ray_get_with_progress(
[master_policy.get_master_addr_port.remote()], desc="Getting master address"
)[0]
)
(master_addr, master_port) = results[0]

def get_bundle_index(rank, num_gpus_per_node):
"""given a rank and a list of num_gpus_per_node, return the index of the bundle that the rank belongs to"""
Expand Down Expand Up @@ -2042,7 +2043,8 @@ def create_model_and_optimizer(
verbose=args.verbose,
)

resume_training_step = ray_get_with_progress(inits, desc="Initializing models")[0] + 1
results, _ = ray_get_with_progress(inits, desc="Initializing models")
resume_training_step = results[0] + 1
episode = (resume_training_step - 1) * args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
logger.info("======== ✅ all models and vLLM engines initialized =========")

Expand Down Expand Up @@ -2173,8 +2175,8 @@ def weight_sync_thread(
# First get the futures
weight_broadcast_futures: List[ray.ObjectRef] = [m.broadcast_to_vllm.remote() for m in policy_group.models]

# Wait for all weight updates to complete
ray_get_with_progress(
# Wait for all weight updates to complete and collect individual timings
_, actor_sync_times = ray_get_with_progress(
weight_broadcast_futures,
desc="[Weight Sync Thread] Waiting for weight updates to complete",
enable=args.verbose,
Expand All @@ -2184,8 +2186,17 @@ def weight_sync_thread(
ray.get(actor_manager.set_should_stop.remote(False))
logger.debug("[Weight Sync Thread] Set should_stop to False after weight sync")

# Calculate distribution statistics
sync_time_stats = {
"time/weight_sync": timer.duration,
"time/weight_sync_mean": np.mean(actor_sync_times),
"time/weight_sync_min": np.min(actor_sync_times),
"time/weight_sync_max": np.max(actor_sync_times),
"time/weight_sync_median": np.median(actor_sync_times),
}

Comment on lines +2189 to +2197
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could consider keeping the actual list and logging a histogram

try:
weight_sync_metrics_Q.put_nowait({"time/weight_sync": timer.duration})
weight_sync_metrics_Q.put_nowait(sync_time_stats)
except Full:
logger.warning("[Weight Sync Thread] weight sync metrics queue full, skipping metric")

Expand All @@ -2197,7 +2208,7 @@ def generate_thread(args, vllm_engines, resume_training_step, stop_event, genera
logger.info("[Generate Thread] 🚀 Starting generation thread")
while not stop_event.is_set():
with Timer("🔥 Generation time") as timer:
processed_results = ray_get_with_progress(
processed_results, _ = ray_get_with_progress(
[engine.process_from_queue.remote(timeout=20) for engine in vllm_engines],
desc="[Generate Thread] Waiting for vLLM engines to process",
enable=args.verbose,
Expand Down Expand Up @@ -2235,7 +2246,7 @@ def one_training_step(
"""Train the model for one step."""
update_ref_policy_future = []
with Timer("[Main Thread] 🗡️ Training") as train_timer:
metrics_list: List[dict[str, float]] = ray_get_with_progress(
metrics_list, _ = ray_get_with_progress(
[
policy_group.models[i].train.remote(
**collated_data[i], pad_token_id=tokenizer.pad_token_id, num_mini_batches=args.num_mini_batches
Expand Down
18 changes: 13 additions & 5 deletions open_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def repeat_each(seq, k):

def ray_get_with_progress(
ray_refs: List[ray.ObjectRef], desc: str = "Processing", enable: bool = True, timeout: Optional[float] = None
) -> List[Any]:
"""Execute ray.get() with a progress bar using futures.
):
"""Execute ray.get() with a progress bar using futures and collect timings.

Args:
ray_refs: List of ray object references
Expand All @@ -90,23 +90,31 @@ def ray_get_with_progress(
timeout: Optional timeout in seconds for all operations to complete

Returns:
List of results in the same order as ray_refs
(results, completion_times)
- results: List of results in the same order as ray_refs
- completion_times: time from function start until each ref completed (seconds), aligned to ray_refs

Raises:
TimeoutError: If timeout is specified and operations don't complete in time
"""
t0 = time.perf_counter()

ray_futures = [ref.future() for ref in ray_refs]
fut_to_idx = {f: i for i, f in enumerate(ray_futures)}

results = [None] * len(ray_refs)
completion_times = [None] * len(ray_refs)

futures_iter = futures.as_completed(ray_futures, timeout=timeout)
if enable:
futures_iter = tqdm(futures_iter, total=len(ray_futures), desc=desc, bar_format="{l_bar}{bar}{r_bar}\n")

for future in futures_iter:
idx = ray_futures.index(future)
idx = fut_to_idx[future]
results[idx] = future.result()
completion_times[idx] = time.perf_counter() - t0

return results
return results, completion_times


"""
Expand Down