-
Notifications
You must be signed in to change notification settings - Fork 443
Change LLMRayActor
to continually process individual prompts
#859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
* Fixed timing. * Cleaned up code * Cleaned up code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally I think this is pretty good to merge, mostly just minor nits left, assuming the slowdown stuff is worked out!
vllm_top_p: float = 1.0 | ||
"""vLLM top p for nucleus sampling""" | ||
inference_batch_size: Optional[int] = None | ||
"""Number of inference requests to batch together for vLLM processing""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we could be more specific: (something like) "number of unique prompts sent to a single vllm engine at once"?
|
||
finish_reasons += [finish_reasons[i] for i in sampled_indices] | ||
|
||
print( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we keep this logging?
open_instruct/grpo_fast.py
Outdated
raise ValueError(f"Unknown tool: {tool}") | ||
|
||
actor_manager = vllm_utils3.ActorManager.remote() | ||
actor_manager = ray.remote(vllm_utils3.ActorManager).remote() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change necessary since ActorManager is already decorated with ray.remote?
if has_beaker_job: | ||
logger.info(f"is_beaker_job: BEAKER_JOB_ID value: {os.environ.get('BEAKER_JOB_ID')}") | ||
return has_beaker_job | ||
return "BEAKER_JOB_ID" in os.environ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
Code Review for PR #859: Change LLMRayActor to continually process individual promptsOverviewThis PR implements a significant architectural change from batch-based to individual prompt processing, introducing an inflight_updates flag and achieving a reported 40% performance improvement (349.22 → 500.42 tokens/second). ✅ StrengthsPerformance Improvements:
Code Quality:
Architecture:
|
New & improved version of #807.
Now, we process individual prompts through the queues, and do inflight weight updates (configurable with the
inflight_weight_updates
flag, set to False by default).Runs with
inflight_updates=False
(which should recreate existing behaviour):With
inflight_updates=True
:This is 40% more efficient (tokens per second go from 349.22 -> 500.42 in the benchmark).
Benchmark results for main at HEAD:
Benchmark results in this PR: