-
Notifications
You must be signed in to change notification settings - Fork 694
feat(SGLang): Add DP-aware routing and dp_rank propagation across Prefill/Decode #4221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
YAMY1234
wants to merge
6
commits into
ai-dynamo:main
Choose a base branch
from
YAMY1234:dp_rank
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
2d4b3b5
dp rank routing code
YAMY1234 51bd5d7
format fix
YAMY1234 881be2f
optimize use of dp_rank/data_parallel_rank
YAMY1234 424836a
code optimize and add e2e test
YAMY1234 0624354
remove redundant logging
YAMY1234 bf1d471
small fix
YAMY1234 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -115,6 +115,14 @@ async def generate( | |
| sampling_params = self._build_sampling_params(request) | ||
| input_param = self._get_input_param(request) | ||
|
|
||
| # Extract data_parallel_rank (explicit None check to preserve dp_rank=0) | ||
| data_parallel_rank = ( | ||
| request.get("data_parallel_rank") | ||
| if "data_parallel_rank" in request | ||
| and request["data_parallel_rank"] is not None | ||
| else request.get("dp_rank") | ||
| ) | ||
|
|
||
| if self.serving_mode == DisaggregationMode.DECODE: | ||
| # request the bootstrap info from the target prefill worker | ||
| if ( | ||
|
|
@@ -124,17 +132,38 @@ async def generate( | |
| token_ids = request["token_ids"] | ||
| stream = await self.prefill_router_client.generate(token_ids) | ||
| result = await anext(stream) | ||
| ( | ||
| worker_id, | ||
| overlap, | ||
| ) = result.data() # Returns tuple (worker_id, overlap_amount) | ||
| logging.info(f"Best prefill worker ID: {worker_id}, overlap: {overlap}") | ||
| # Unpack router response: (worker_id, dp_rank, overlap_blocks) | ||
| result_data = result.data() | ||
| if len(result_data) == 3: | ||
| worker_id, prefill_dp_rank, overlap = result_data | ||
| else: | ||
| # Fallback for older router versions (2-tuple response) | ||
| worker_id, overlap = result_data | ||
| prefill_dp_rank = None | ||
| if not hasattr(self, "_dp_routing_unavailable_warned"): | ||
| logging.warning( | ||
| "Router returned 2-tuple, DP routing unavailable (update router)" | ||
| ) | ||
| self._dp_routing_unavailable_warned = True | ||
|
|
||
| # Build prefill request | ||
| prefill_request_dict = DisaggPreprocessedRequest( | ||
| request=request, | ||
| sampling_params=sampling_params, | ||
| ).model_dump() | ||
|
|
||
| # Inject dp_rank into inner request after serialization (Pydantic drops unknown fields) | ||
| if prefill_dp_rank is not None and isinstance( | ||
| prefill_request_dict.get("request"), dict | ||
| ): | ||
| # SGLang engine reads data_parallel_rank from inner request | ||
| prefill_request_dict["request"][ | ||
| "data_parallel_rank" | ||
| ] = prefill_dp_rank | ||
| logging.info(f"Routing to prefill dp_rank={prefill_dp_rank}") | ||
|
|
||
| prefill_stream = await self.prefill_client.direct( | ||
| DisaggPreprocessedRequest( | ||
| request=request, | ||
| sampling_params=sampling_params, | ||
| ).model_dump(), | ||
| prefill_request_dict, | ||
| worker_id, | ||
| ) | ||
| else: | ||
|
|
@@ -154,14 +183,29 @@ async def generate( | |
| if not bootstrap_info: | ||
| raise RuntimeError("No bootstrap info received from prefill worker") | ||
|
|
||
| decode = await self.engine.async_generate( | ||
| # Prefill and Decode must use same dp_rank for bootstrap connection | ||
| generate_kwargs = { | ||
| **input_param, | ||
| sampling_params=sampling_params, | ||
| stream=True, | ||
| bootstrap_host=bootstrap_info["bootstrap_host"], | ||
| bootstrap_port=bootstrap_info["bootstrap_port"], | ||
| bootstrap_room=bootstrap_info["bootstrap_room"], | ||
| ) | ||
| "sampling_params": sampling_params, | ||
| "stream": True, | ||
| "bootstrap_host": bootstrap_info["bootstrap_host"], | ||
| "bootstrap_port": bootstrap_info["bootstrap_port"], | ||
| "bootstrap_room": bootstrap_info["bootstrap_room"], | ||
| } | ||
|
|
||
| # Use router-selected dp_rank (fallback to request-level if not provided) | ||
| if "prefill_dp_rank" in locals() and prefill_dp_rank is not None: | ||
|
||
| effective_dp_rank = prefill_dp_rank | ||
| elif data_parallel_rank is not None: | ||
| effective_dp_rank = data_parallel_rank | ||
| else: | ||
| effective_dp_rank = None | ||
|
|
||
| if effective_dp_rank is not None: | ||
| generate_kwargs["data_parallel_rank"] = effective_dp_rank | ||
| logging.debug(f"Using dp_rank={effective_dp_rank} for decode") | ||
|
|
||
| decode = await self.engine.async_generate(**generate_kwargs) | ||
|
|
||
| if self.skip_tokenizer_init: | ||
| async for out in self._process_token_stream(decode, context): | ||
|
|
@@ -170,11 +214,18 @@ async def generate( | |
| async for out in self._process_text_stream(decode, context): | ||
| yield out | ||
| else: | ||
| agg = await self.engine.async_generate( | ||
| # Aggregated mode | ||
| generate_kwargs = { | ||
| **input_param, | ||
| sampling_params=sampling_params, | ||
| stream=True, | ||
| ) | ||
| "sampling_params": sampling_params, | ||
| "stream": True, | ||
| } | ||
|
|
||
| if data_parallel_rank is not None: | ||
| generate_kwargs["data_parallel_rank"] = data_parallel_rank | ||
| logging.debug(f"Using dp_rank={data_parallel_rank} for aggregated mode") | ||
|
|
||
| agg = await self.engine.async_generate(**generate_kwargs) | ||
| if self.skip_tokenizer_init: | ||
| async for out in self._process_token_stream(agg, context): | ||
| yield out | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are adding a good bit of logic to the critical path of a request. Do you mind running some benchmarks (maybe a 4 GPU benchmark) and ramping up to high concurrency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah sure! I used
prefix_ratio_benchmark.pyand conducted some benchmark experiments~ Already put the results in the PR Description