Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions components/src/dynamo/router/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,21 +124,17 @@ async def generate(self, request):
yield llm_engine_output

async def best_worker_id(self, token_ids, router_config_override=None):
"""
Get the best worker ID for a given set of tokens without actually routing.
"""Get the best worker for given tokens without routing.

This method returns the worker ID that would be selected based on KV cache
overlap, but does NOT actually route the request or update router states.
It's useful for debugging, monitoring, or implementing custom routing logic.
Returns (worker_id, dp_rank, overlap_blocks) based on KV cache overlap.
Does NOT update router states - useful for preview/debugging.
"""
if self.kv_push_router is None:
logger.error("KvPushRouter not initialized - cannot get best worker")
raise RuntimeError("Router not initialized")

result = await self.kv_push_router.best_worker_id(
result = await self.kv_push_router.best_worker(
token_ids, router_config_override
)

yield result


Expand Down
18 changes: 16 additions & 2 deletions components/src/dynamo/sglang/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ async def init(runtime: DistributedRuntime, config: Config):
health_check_payload = SglangHealthCheckPayload(engine).to_dict()

try:
# Start endpoint immediately and register model concurrently
# Requests queue until ready_event is set (TODO: Part of new PR)
# Start endpoint and register model concurrently
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
Expand Down Expand Up @@ -161,6 +160,21 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):

health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()

# Register Prefill to expose dp_size to Router
try:
await register_llm_with_readiness_gate(
engine,
generate_endpoint,
server_args,
dynamo_args,
input_type=ModelInput.Tokens,
output_type=ModelType.Chat | ModelType.Completions,
readiness_gate=None,
)
except Exception as e:
logging.error(f"Failed to register prefill worker: {e}")
raise

tasks = [
generate_endpoint.serve_endpoint(
handler.generate,
Expand Down
6 changes: 6 additions & 0 deletions components/src/dynamo/sglang/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ async def _get_runtime_config(
if max_prefill_tokens:
runtime_config.max_num_batched_tokens = max_prefill_tokens

# Expose data_parallel_size for DP-aware routing
dp_size = getattr(server_args, "dp_size", 1)
runtime_config.data_parallel_size = dp_size
if dp_size > 1:
logging.info(f"Registering with data_parallel_size={dp_size}")

try:
# Try to check if the engine has a scheduler attribute with the computed values
if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ async def generate(
sampling_params = self._build_sampling_params(request)
input_param = self._get_input_param(request)

# Extract data_parallel_rank (explicit None check to preserve dp_rank=0)
data_parallel_rank = (
request.get("data_parallel_rank")
if "data_parallel_rank" in request
and request["data_parallel_rank"] is not None
else request.get("dp_rank")
)

if self.serving_mode == DisaggregationMode.DECODE:
# request the bootstrap info from the target prefill worker
if (
Expand All @@ -124,17 +132,38 @@ async def generate(
token_ids = request["token_ids"]
stream = await self.prefill_router_client.generate(token_ids)
result = await anext(stream)
(
worker_id,
overlap,
) = result.data() # Returns tuple (worker_id, overlap_amount)
logging.info(f"Best prefill worker ID: {worker_id}, overlap: {overlap}")
# Unpack router response: (worker_id, dp_rank, overlap_blocks)
result_data = result.data()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are adding a good bit of logic to the critical path of a request. Do you mind running some benchmarks (maybe a 4 GPU benchmark) and ramping up to high concurrency?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sure! I used prefix_ratio_benchmark.py and conducted some benchmark experiments~ Already put the results in the PR Description

if len(result_data) == 3:
worker_id, prefill_dp_rank, overlap = result_data
else:
# Fallback for older router versions (2-tuple response)
worker_id, overlap = result_data
prefill_dp_rank = None
if not hasattr(self, "_dp_routing_unavailable_warned"):
logging.warning(
"Router returned 2-tuple, DP routing unavailable (update router)"
)
self._dp_routing_unavailable_warned = True

# Build prefill request
prefill_request_dict = DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump()

# Inject dp_rank into inner request after serialization (Pydantic drops unknown fields)
if prefill_dp_rank is not None and isinstance(
prefill_request_dict.get("request"), dict
):
# SGLang engine reads data_parallel_rank from inner request
prefill_request_dict["request"][
"data_parallel_rank"
] = prefill_dp_rank
logging.info(f"Routing to prefill dp_rank={prefill_dp_rank}")

prefill_stream = await self.prefill_client.direct(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump(),
prefill_request_dict,
worker_id,
)
else:
Expand All @@ -154,14 +183,29 @@ async def generate(
if not bootstrap_info:
raise RuntimeError("No bootstrap info received from prefill worker")

decode = await self.engine.async_generate(
# Prefill and Decode must use same dp_rank for bootstrap connection
generate_kwargs = {
**input_param,
sampling_params=sampling_params,
stream=True,
bootstrap_host=bootstrap_info["bootstrap_host"],
bootstrap_port=bootstrap_info["bootstrap_port"],
bootstrap_room=bootstrap_info["bootstrap_room"],
)
"sampling_params": sampling_params,
"stream": True,
"bootstrap_host": bootstrap_info["bootstrap_host"],
"bootstrap_port": bootstrap_info["bootstrap_port"],
"bootstrap_room": bootstrap_info["bootstrap_room"],
}

# Use router-selected dp_rank (fallback to request-level if not provided)
if "prefill_dp_rank" in locals() and prefill_dp_rank is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is locals()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! Initialized prefill_dp_rank = None at the start of the branch to avoid using locals().

effective_dp_rank = prefill_dp_rank
elif data_parallel_rank is not None:
effective_dp_rank = data_parallel_rank
else:
effective_dp_rank = None

if effective_dp_rank is not None:
generate_kwargs["data_parallel_rank"] = effective_dp_rank
logging.debug(f"Using dp_rank={effective_dp_rank} for decode")

decode = await self.engine.async_generate(**generate_kwargs)

if self.skip_tokenizer_init:
async for out in self._process_token_stream(decode, context):
Expand All @@ -170,11 +214,18 @@ async def generate(
async for out in self._process_text_stream(decode, context):
yield out
else:
agg = await self.engine.async_generate(
# Aggregated mode
generate_kwargs = {
**input_param,
sampling_params=sampling_params,
stream=True,
)
"sampling_params": sampling_params,
"stream": True,
}

if data_parallel_rank is not None:
generate_kwargs["data_parallel_rank"] = data_parallel_rank
logging.debug(f"Using dp_rank={data_parallel_rank} for aggregated mode")

agg = await self.engine.async_generate(**generate_kwargs)
if self.skip_tokenizer_init:
async for out in self._process_token_stream(agg, context):
yield out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,41 @@ async def generate(

yield bootstrap_info

input_param = self._get_input_param(request["request"])
# Validate disaggregated request format
if "request" not in request or "sampling_params" not in request:
raise ValueError(
f"Expected disaggregated format with 'request' and 'sampling_params', "
f"got keys: {list(request.keys())}"
)

inner_request = request["request"]
sampling_params_dict = request["sampling_params"]

# Extract data_parallel_rank from inner request (explicit None check to preserve dp_rank=0)
data_parallel_rank = (
inner_request.get("data_parallel_rank")
if "data_parallel_rank" in inner_request
and inner_request["data_parallel_rank"] is not None
else None
)

input_param = self._get_input_param(inner_request)

results = await self.engine.async_generate(
# Build engine kwargs
generate_kwargs = {
**input_param,
sampling_params=request["sampling_params"],
stream=True,
bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port,
bootstrap_room=bootstrap_room,
)
"sampling_params": sampling_params_dict,
"stream": True,
"bootstrap_host": self.bootstrap_host,
"bootstrap_port": self.bootstrap_port,
"bootstrap_room": bootstrap_room,
}

if data_parallel_rank is not None:
generate_kwargs["data_parallel_rank"] = data_parallel_rank
logging.info(f"Prefill using dp_rank={data_parallel_rank}")

results = await self.engine.async_generate(**generate_kwargs)

task = asyncio.create_task(self._consume_results(results, context))
self._consume_tasks.add(task)
Expand Down
Loading
Loading