From 2d4b3b539f0e986461e051d9ea138259c792b3a1 Mon Sep 17 00:00:00 2001 From: Yangmin Li Date: Mon, 10 Nov 2025 01:16:45 -0800 Subject: [PATCH 1/6] dp rank routing code --- components/src/dynamo/router/__main__.py | 18 +- components/src/dynamo/sglang/main.py | 14 +- components/src/dynamo/sglang/register.py | 6 + .../request_handlers/llm/decode_handler.py | 92 ++++++--- .../request_handlers/llm/prefill_handler.py | 45 ++++- tests/router/test_dp_rank_routing.py | 177 ++++++++++++++++++ 6 files changed, 311 insertions(+), 41 deletions(-) create mode 100644 tests/router/test_dp_rank_routing.py diff --git a/components/src/dynamo/router/__main__.py b/components/src/dynamo/router/__main__.py index a11b0cb492..e355791711 100644 --- a/components/src/dynamo/router/__main__.py +++ b/components/src/dynamo/router/__main__.py @@ -124,20 +124,20 @@ async def generate(self, request): yield llm_engine_output async def best_worker_id(self, token_ids, router_config_override=None): - """ - Get the best worker ID for a given set of tokens without actually routing. + """Get the best worker for given tokens without routing. - This method returns the worker ID that would be selected based on KV cache - overlap, but does NOT actually route the request or update router states. - It's useful for debugging, monitoring, or implementing custom routing logic. + Returns (worker_id, dp_rank, overlap_blocks) based on KV cache overlap. + Does NOT update router states - useful for preview/debugging. """ if self.kv_push_router is None: - logger.error("KvPushRouter not initialized - cannot get best worker") raise RuntimeError("Router not initialized") - result = await self.kv_push_router.best_worker_id( - token_ids, router_config_override - ) + # Use new API with dp_rank support, fallback to old API for compatibility + if hasattr(self.kv_push_router, "best_worker"): + result = await self.kv_push_router.best_worker(token_ids, router_config_override) + else: + wid, overlap = await self.kv_push_router.best_worker_id(token_ids, router_config_override) + result = (wid, None, overlap) yield result diff --git a/components/src/dynamo/sglang/main.py b/components/src/dynamo/sglang/main.py index a549ca7997..a94b564e2b 100644 --- a/components/src/dynamo/sglang/main.py +++ b/components/src/dynamo/sglang/main.py @@ -110,8 +110,7 @@ async def init(runtime: DistributedRuntime, config: Config): health_check_payload = SglangHealthCheckPayload(engine).to_dict() try: - # Start endpoint immediately and register model concurrently - # Requests queue until ready_event is set (TODO: Part of new PR) + # Start endpoint and register model concurrently await asyncio.gather( generate_endpoint.serve_endpoint( handler.generate, @@ -161,6 +160,17 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() + # Register Prefill to expose dp_size to Router + await register_llm_with_readiness_gate( + engine, + generate_endpoint, + server_args, + dynamo_args, + input_type=ModelInput.Tokens, + output_type=ModelType.Chat | ModelType.Completions, + readiness_gate=None, + ) + tasks = [ generate_endpoint.serve_endpoint( handler.generate, diff --git a/components/src/dynamo/sglang/register.py b/components/src/dynamo/sglang/register.py index 9819a92733..a7c4a22c51 100644 --- a/components/src/dynamo/sglang/register.py +++ b/components/src/dynamo/sglang/register.py @@ -94,6 +94,12 @@ async def _get_runtime_config( if max_prefill_tokens: runtime_config.max_num_batched_tokens = max_prefill_tokens + # Expose data_parallel_size for DP-aware routing + dp_size = getattr(server_args, "dp_size", 1) + runtime_config.data_parallel_size = dp_size + if dp_size > 1: + logging.info(f"Registering with data_parallel_size={dp_size}") + try: # Try to check if the engine has a scheduler attribute with the computed values if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None: diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index 0ceea13ed3..cdaa6b1792 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -114,6 +114,13 @@ async def generate( logging.debug(f"New Request ID: {context.id()}") sampling_params = self._build_sampling_params(request) input_param = self._get_input_param(request) + + # Extract data_parallel_rank (explicit None check to preserve dp_rank=0) + data_parallel_rank = ( + request.get("data_parallel_rank") + if "data_parallel_rank" in request and request["data_parallel_rank"] is not None + else request.get("dp_rank") + ) if self.serving_mode == DisaggregationMode.DECODE: # request the bootstrap info from the target prefill worker @@ -124,17 +131,38 @@ async def generate( token_ids = request["token_ids"] stream = await self.prefill_router_client.generate(token_ids) result = await anext(stream) - ( - worker_id, - overlap, - ) = result.data() # Returns tuple (worker_id, overlap_amount) - logging.info(f"Best prefill worker ID: {worker_id}, overlap: {overlap}") - + # Unpack router response: (worker_id, dp_rank, overlap_blocks) + result_data = result.data() + if len(result_data) == 3: + worker_id, prefill_dp_rank, overlap = result_data + if not hasattr(self, '_dp_routing_active_logged'): + logging.info(f"DP-aware routing active: dp_rank={prefill_dp_rank}") + self._dp_routing_active_logged = True + logging.debug(f"Router selected: worker={worker_id}, dp_rank={prefill_dp_rank}, overlap={overlap}") + else: + # Backward compatibility: (worker_id, overlap) + worker_id, overlap = result_data + prefill_dp_rank = None + logging.debug(f"Router selected: worker={worker_id}, overlap={overlap}") + if not hasattr(self, '_dp_routing_unavailable_warned'): + logging.warning("Router not returning dp_rank - DP-aware routing unavailable") + self._dp_routing_unavailable_warned = True + + # Build prefill request + prefill_request_dict = DisaggPreprocessedRequest( + request=request, + sampling_params=sampling_params, + ).model_dump() + + # Inject dp_rank after serialization (Pydantic drops unknown fields) + if prefill_dp_rank is not None: + prefill_request_dict["dp_rank"] = prefill_dp_rank # For Dynamo routing + if isinstance(prefill_request_dict.get("request"), dict): + prefill_request_dict["request"]["dp_rank"] = prefill_dp_rank + prefill_request_dict["request"]["data_parallel_rank"] = prefill_dp_rank + prefill_stream = await self.prefill_client.direct( - DisaggPreprocessedRequest( - request=request, - sampling_params=sampling_params, - ).model_dump(), + prefill_request_dict, worker_id, ) else: @@ -154,14 +182,29 @@ async def generate( if not bootstrap_info: raise RuntimeError("No bootstrap info received from prefill worker") - decode = await self.engine.async_generate( + # Prefill and Decode must use same dp_rank for bootstrap connection + generate_kwargs = { **input_param, - sampling_params=sampling_params, - stream=True, - bootstrap_host=bootstrap_info["bootstrap_host"], - bootstrap_port=bootstrap_info["bootstrap_port"], - bootstrap_room=bootstrap_info["bootstrap_room"], - ) + "sampling_params": sampling_params, + "stream": True, + "bootstrap_host": bootstrap_info["bootstrap_host"], + "bootstrap_port": bootstrap_info["bootstrap_port"], + "bootstrap_room": bootstrap_info["bootstrap_room"], + } + + # Use router-selected dp_rank (fallback to request-level if not provided) + if 'prefill_dp_rank' in locals() and prefill_dp_rank is not None: + effective_dp_rank = prefill_dp_rank + elif data_parallel_rank is not None: + effective_dp_rank = data_parallel_rank + else: + effective_dp_rank = None + + if effective_dp_rank is not None: + generate_kwargs["data_parallel_rank"] = effective_dp_rank + logging.debug(f"Using dp_rank={effective_dp_rank} for decode") + + decode = await self.engine.async_generate(**generate_kwargs) if self.skip_tokenizer_init: async for out in self._process_token_stream(decode, context): @@ -170,11 +213,18 @@ async def generate( async for out in self._process_text_stream(decode, context): yield out else: - agg = await self.engine.async_generate( + # Aggregated mode + generate_kwargs = { **input_param, - sampling_params=sampling_params, - stream=True, - ) + "sampling_params": sampling_params, + "stream": True, + } + + if data_parallel_rank is not None: + generate_kwargs["data_parallel_rank"] = data_parallel_rank + logging.debug(f"Using dp_rank={data_parallel_rank} for aggregated mode") + + agg = await self.engine.async_generate(**generate_kwargs) if self.skip_tokenizer_init: async for out in self._process_token_stream(agg, context): yield out diff --git a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py index dc55ab9762..408c32ccce 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py @@ -74,16 +74,43 @@ async def generate( yield bootstrap_info - input_param = self._get_input_param(request["request"]) - - results = await self.engine.async_generate( + # Validate disaggregated request format + if "request" not in request or "sampling_params" not in request: + raise ValueError( + f"Expected disaggregated format with 'request' and 'sampling_params', " + f"got keys: {list(request.keys())}" + ) + + inner_request = request["request"] + sampling_params_dict = request["sampling_params"] + + # Extract data_parallel_rank (explicit None check to preserve dp_rank=0) + if "data_parallel_rank" in request and request["data_parallel_rank"] is not None: + data_parallel_rank = request["data_parallel_rank"] + elif "data_parallel_rank" in inner_request and inner_request["data_parallel_rank"] is not None: + data_parallel_rank = inner_request["data_parallel_rank"] + elif "dp_rank" in inner_request and inner_request["dp_rank"] is not None: + data_parallel_rank = inner_request["dp_rank"] + else: + data_parallel_rank = None + + input_param = self._get_input_param(inner_request) + + # Build engine kwargs + generate_kwargs = { **input_param, - sampling_params=request["sampling_params"], - stream=True, - bootstrap_host=self.bootstrap_host, - bootstrap_port=self.bootstrap_port, - bootstrap_room=bootstrap_room, - ) + "sampling_params": sampling_params_dict, + "stream": True, + "bootstrap_host": self.bootstrap_host, + "bootstrap_port": self.bootstrap_port, + "bootstrap_room": bootstrap_room, + } + + if data_parallel_rank is not None: + generate_kwargs["data_parallel_rank"] = data_parallel_rank + logging.debug(f"Using dp_rank={data_parallel_rank} for prefill") + + results = await self.engine.async_generate(**generate_kwargs) task = asyncio.create_task(self._consume_results(results, context)) self._consume_tasks.add(task) diff --git a/tests/router/test_dp_rank_routing.py b/tests/router/test_dp_rank_routing.py new file mode 100644 index 0000000000..b2ae6ddb49 --- /dev/null +++ b/tests/router/test_dp_rank_routing.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for data parallel rank-aware routing.""" + +import logging +import random +import string +from typing import List + +import pytest + +from dynamo._core import DistributedRuntime, KvPushRouter, KvRouterConfig +from tests.utils.constants import ROUTER_MODEL_NAME +from tests.utils.managed_process import ManagedProcess + +pytestmark = pytest.mark.pre_merge + +logger = logging.getLogger(__name__) + +MODEL_NAME = ROUTER_MODEL_NAME +DP_SIZE = 4 # Test with 4 data parallel ranks +BLOCK_SIZE = 16 + + +def generate_random_suffix() -> str: + """Generate random suffix for namespace isolation.""" + return "".join(random.choices(string.ascii_lowercase, k=10)) + + +def get_runtime(): + """Get or create a DistributedRuntime instance.""" + try: + return DistributedRuntime.detached() + except Exception: + import asyncio + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return DistributedRuntime(loop, False) + + +class MockerProcess: + """Manages mocker engine instance with DP support""" + + def __init__( + self, + request, + namespace: str, + dp_size: int = 1, + ): + self.namespace = namespace + self.endpoint = f"dyn://{namespace}.prefill.generate" + self.dp_size = dp_size + self.mocker_processes: List[ManagedProcess] = [] + self.request = request + + # Create mocker process + command = [ + "python", "-m", "dynamo.mocker", + "--model-path", MODEL_NAME, + "--endpoint", self.endpoint, + "--speedup-ratio", "10.0", + "--block-size", str(BLOCK_SIZE), + "--num-gpu-blocks-override", "1000", + "--data-parallel-size", str(dp_size), + ] + + process = ManagedProcess( + command=command, + timeout=60, + display_output=True, + log_dir=request.node.name, + terminate_existing=False, + ) + self.mocker_processes.append(process) + + def __enter__(self): + """Start mocker process""" + for process in self.mocker_processes: + process.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop mocker process""" + for process in self.mocker_processes: + process.__exit__(exc_type, exc_val, exc_tb) + + +@pytest.mark.asyncio +async def test_router_returns_valid_dp_rank(request): + """Verify router returns valid dp_rank in routing decision.""" + runtime = get_runtime() + namespace = f"dp-routing-{generate_random_suffix()}" + + with MockerProcess(request, namespace, dp_size=DP_SIZE): + try: + router_comp = runtime.namespace(namespace).component("router") + await router_comp.create_service() + + endpoint = runtime.namespace(namespace).component("prefill").endpoint("generate") + + import asyncio + await asyncio.sleep(2) + + kv_router_config = KvRouterConfig( + overlap_score_weight=2.0, + router_temperature=0.0, + ) + kv_push_router = KvPushRouter(endpoint, BLOCK_SIZE, kv_router_config) + + if hasattr(kv_push_router, "best_worker"): + worker_id, dp_rank, overlap = await kv_push_router.best_worker([1, 2, 3, 4, 5]) + + assert dp_rank is not None, "Router should return dp_rank" + assert isinstance(dp_rank, int), "dp_rank should be integer" + assert 0 <= dp_rank < DP_SIZE, f"dp_rank {dp_rank} out of range [0, {DP_SIZE})" + + logger.info(f"✅ Router returned valid dp_rank={dp_rank}") + else: + logger.warning("Router API doesn't support best_worker - skipping test") + + except Exception as e: + logger.error(f"Test failed: {e}") + raise + + +@pytest.mark.asyncio +async def test_dp_rank_coverage(request): + """Verify router selects from full range of DP ranks.""" + runtime = get_runtime() + namespace = f"dp-routing-{generate_random_suffix()}" + + with MockerProcess(request, namespace, dp_size=DP_SIZE): + try: + router_comp = runtime.namespace(namespace).component("router") + await router_comp.create_service() + + endpoint = runtime.namespace(namespace).component("prefill").endpoint("generate") + + import asyncio + await asyncio.sleep(2) + + kv_router_config = KvRouterConfig( + overlap_score_weight=2.0, + router_temperature=0.0, + ) + kv_push_router = KvPushRouter(endpoint, BLOCK_SIZE, kv_router_config) + + if hasattr(kv_push_router, "best_worker"): + dp_ranks_used = set() + + # Query with varied sequences to cover all ranks + for i in range(50): + test_tokens = list(range(i * 7, i * 7 + 10)) + worker_id, dp_rank, overlap = await kv_push_router.best_worker(test_tokens) + + assert dp_rank is not None + assert isinstance(dp_rank, int) + assert 0 <= dp_rank < DP_SIZE + dp_ranks_used.add(dp_rank) + + # Expect reasonable coverage across DP ranks + num_ranks = len(dp_ranks_used) + assert num_ranks >= 2, f"Poor coverage: only {num_ranks} ranks used" + logger.info(f"✅ Router coverage: {num_ranks}/{DP_SIZE} ranks used - {sorted(dp_ranks_used)}") + else: + logger.warning("Router API doesn't support best_worker - skipping test") + + except Exception as e: + logger.error(f"Test failed: {e}") + raise + + + From 51bd5d75133c1b8deb7f9b5f78ecd1d10b39bef6 Mon Sep 17 00:00:00 2001 From: Yangmin Li Date: Mon, 10 Nov 2025 09:45:29 -0800 Subject: [PATCH 2/6] format fix --- components/src/dynamo/router/__main__.py | 8 +- .../request_handlers/llm/decode_handler.py | 49 +++++++---- .../request_handlers/llm/prefill_handler.py | 20 +++-- tests/router/test_dp_rank_routing.py | 86 ++++++++++++------- 4 files changed, 103 insertions(+), 60 deletions(-) diff --git a/components/src/dynamo/router/__main__.py b/components/src/dynamo/router/__main__.py index e355791711..c2b102c987 100644 --- a/components/src/dynamo/router/__main__.py +++ b/components/src/dynamo/router/__main__.py @@ -134,9 +134,13 @@ async def best_worker_id(self, token_ids, router_config_override=None): # Use new API with dp_rank support, fallback to old API for compatibility if hasattr(self.kv_push_router, "best_worker"): - result = await self.kv_push_router.best_worker(token_ids, router_config_override) + result = await self.kv_push_router.best_worker( + token_ids, router_config_override + ) else: - wid, overlap = await self.kv_push_router.best_worker_id(token_ids, router_config_override) + wid, overlap = await self.kv_push_router.best_worker_id( + token_ids, router_config_override + ) result = (wid, None, overlap) yield result diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index cdaa6b1792..bdf4e23062 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -114,11 +114,12 @@ async def generate( logging.debug(f"New Request ID: {context.id()}") sampling_params = self._build_sampling_params(request) input_param = self._get_input_param(request) - + # Extract data_parallel_rank (explicit None check to preserve dp_rank=0) data_parallel_rank = ( request.get("data_parallel_rank") - if "data_parallel_rank" in request and request["data_parallel_rank"] is not None + if "data_parallel_rank" in request + and request["data_parallel_rank"] is not None else request.get("dp_rank") ) @@ -135,17 +136,25 @@ async def generate( result_data = result.data() if len(result_data) == 3: worker_id, prefill_dp_rank, overlap = result_data - if not hasattr(self, '_dp_routing_active_logged'): - logging.info(f"DP-aware routing active: dp_rank={prefill_dp_rank}") + if not hasattr(self, "_dp_routing_active_logged"): + logging.info( + f"DP-aware routing active: dp_rank={prefill_dp_rank}" + ) self._dp_routing_active_logged = True - logging.debug(f"Router selected: worker={worker_id}, dp_rank={prefill_dp_rank}, overlap={overlap}") + logging.debug( + f"Router selected: worker={worker_id}, dp_rank={prefill_dp_rank}, overlap={overlap}" + ) else: # Backward compatibility: (worker_id, overlap) worker_id, overlap = result_data prefill_dp_rank = None - logging.debug(f"Router selected: worker={worker_id}, overlap={overlap}") - if not hasattr(self, '_dp_routing_unavailable_warned'): - logging.warning("Router not returning dp_rank - DP-aware routing unavailable") + logging.debug( + f"Router selected: worker={worker_id}, overlap={overlap}" + ) + if not hasattr(self, "_dp_routing_unavailable_warned"): + logging.warning( + "Router not returning dp_rank - DP-aware routing unavailable" + ) self._dp_routing_unavailable_warned = True # Build prefill request @@ -153,14 +162,18 @@ async def generate( request=request, sampling_params=sampling_params, ).model_dump() - + # Inject dp_rank after serialization (Pydantic drops unknown fields) if prefill_dp_rank is not None: - prefill_request_dict["dp_rank"] = prefill_dp_rank # For Dynamo routing + prefill_request_dict[ + "dp_rank" + ] = prefill_dp_rank # For Dynamo routing if isinstance(prefill_request_dict.get("request"), dict): prefill_request_dict["request"]["dp_rank"] = prefill_dp_rank - prefill_request_dict["request"]["data_parallel_rank"] = prefill_dp_rank - + prefill_request_dict["request"][ + "data_parallel_rank" + ] = prefill_dp_rank + prefill_stream = await self.prefill_client.direct( prefill_request_dict, worker_id, @@ -191,19 +204,19 @@ async def generate( "bootstrap_port": bootstrap_info["bootstrap_port"], "bootstrap_room": bootstrap_info["bootstrap_room"], } - + # Use router-selected dp_rank (fallback to request-level if not provided) - if 'prefill_dp_rank' in locals() and prefill_dp_rank is not None: + if "prefill_dp_rank" in locals() and prefill_dp_rank is not None: effective_dp_rank = prefill_dp_rank elif data_parallel_rank is not None: effective_dp_rank = data_parallel_rank else: effective_dp_rank = None - + if effective_dp_rank is not None: generate_kwargs["data_parallel_rank"] = effective_dp_rank logging.debug(f"Using dp_rank={effective_dp_rank} for decode") - + decode = await self.engine.async_generate(**generate_kwargs) if self.skip_tokenizer_init: @@ -219,11 +232,11 @@ async def generate( "sampling_params": sampling_params, "stream": True, } - + if data_parallel_rank is not None: generate_kwargs["data_parallel_rank"] = data_parallel_rank logging.debug(f"Using dp_rank={data_parallel_rank} for aggregated mode") - + agg = await self.engine.async_generate(**generate_kwargs) if self.skip_tokenizer_init: async for out in self._process_token_stream(agg, context): diff --git a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py index 408c32ccce..8224ffcda6 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py @@ -80,20 +80,26 @@ async def generate( f"Expected disaggregated format with 'request' and 'sampling_params', " f"got keys: {list(request.keys())}" ) - + inner_request = request["request"] sampling_params_dict = request["sampling_params"] - + # Extract data_parallel_rank (explicit None check to preserve dp_rank=0) - if "data_parallel_rank" in request and request["data_parallel_rank"] is not None: + if ( + "data_parallel_rank" in request + and request["data_parallel_rank"] is not None + ): data_parallel_rank = request["data_parallel_rank"] - elif "data_parallel_rank" in inner_request and inner_request["data_parallel_rank"] is not None: + elif ( + "data_parallel_rank" in inner_request + and inner_request["data_parallel_rank"] is not None + ): data_parallel_rank = inner_request["data_parallel_rank"] elif "dp_rank" in inner_request and inner_request["dp_rank"] is not None: data_parallel_rank = inner_request["dp_rank"] else: data_parallel_rank = None - + input_param = self._get_input_param(inner_request) # Build engine kwargs @@ -105,11 +111,11 @@ async def generate( "bootstrap_port": self.bootstrap_port, "bootstrap_room": bootstrap_room, } - + if data_parallel_rank is not None: generate_kwargs["data_parallel_rank"] = data_parallel_rank logging.debug(f"Using dp_rank={data_parallel_rank} for prefill") - + results = await self.engine.async_generate(**generate_kwargs) task = asyncio.create_task(self._consume_results(results, context)) diff --git a/tests/router/test_dp_rank_routing.py b/tests/router/test_dp_rank_routing.py index b2ae6ddb49..e47ed63ee9 100644 --- a/tests/router/test_dp_rank_routing.py +++ b/tests/router/test_dp_rank_routing.py @@ -34,6 +34,7 @@ def get_runtime(): return DistributedRuntime.detached() except Exception: import asyncio + try: loop = asyncio.get_running_loop() except RuntimeError: @@ -59,13 +60,21 @@ def __init__( # Create mocker process command = [ - "python", "-m", "dynamo.mocker", - "--model-path", MODEL_NAME, - "--endpoint", self.endpoint, - "--speedup-ratio", "10.0", - "--block-size", str(BLOCK_SIZE), - "--num-gpu-blocks-override", "1000", - "--data-parallel-size", str(dp_size), + "python", + "-m", + "dynamo.mocker", + "--model-path", + MODEL_NAME, + "--endpoint", + self.endpoint, + "--speedup-ratio", + "10.0", + "--block-size", + str(BLOCK_SIZE), + "--num-gpu-blocks-override", + "1000", + "--data-parallel-size", + str(dp_size), ] process = ManagedProcess( @@ -94,34 +103,41 @@ async def test_router_returns_valid_dp_rank(request): """Verify router returns valid dp_rank in routing decision.""" runtime = get_runtime() namespace = f"dp-routing-{generate_random_suffix()}" - + with MockerProcess(request, namespace, dp_size=DP_SIZE): try: router_comp = runtime.namespace(namespace).component("router") await router_comp.create_service() - - endpoint = runtime.namespace(namespace).component("prefill").endpoint("generate") - + + endpoint = ( + runtime.namespace(namespace).component("prefill").endpoint("generate") + ) + import asyncio + await asyncio.sleep(2) - + kv_router_config = KvRouterConfig( overlap_score_weight=2.0, router_temperature=0.0, ) kv_push_router = KvPushRouter(endpoint, BLOCK_SIZE, kv_router_config) - + if hasattr(kv_push_router, "best_worker"): - worker_id, dp_rank, overlap = await kv_push_router.best_worker([1, 2, 3, 4, 5]) - + worker_id, dp_rank, overlap = await kv_push_router.best_worker( + [1, 2, 3, 4, 5] + ) + assert dp_rank is not None, "Router should return dp_rank" assert isinstance(dp_rank, int), "dp_rank should be integer" - assert 0 <= dp_rank < DP_SIZE, f"dp_rank {dp_rank} out of range [0, {DP_SIZE})" - + assert ( + 0 <= dp_rank < DP_SIZE + ), f"dp_rank {dp_rank} out of range [0, {DP_SIZE})" + logger.info(f"✅ Router returned valid dp_rank={dp_rank}") else: logger.warning("Router API doesn't support best_worker - skipping test") - + except Exception as e: logger.error(f"Test failed: {e}") raise @@ -132,46 +148,50 @@ async def test_dp_rank_coverage(request): """Verify router selects from full range of DP ranks.""" runtime = get_runtime() namespace = f"dp-routing-{generate_random_suffix()}" - + with MockerProcess(request, namespace, dp_size=DP_SIZE): try: router_comp = runtime.namespace(namespace).component("router") await router_comp.create_service() - - endpoint = runtime.namespace(namespace).component("prefill").endpoint("generate") - + + endpoint = ( + runtime.namespace(namespace).component("prefill").endpoint("generate") + ) + import asyncio + await asyncio.sleep(2) - + kv_router_config = KvRouterConfig( overlap_score_weight=2.0, router_temperature=0.0, ) kv_push_router = KvPushRouter(endpoint, BLOCK_SIZE, kv_router_config) - + if hasattr(kv_push_router, "best_worker"): dp_ranks_used = set() - + # Query with varied sequences to cover all ranks for i in range(50): test_tokens = list(range(i * 7, i * 7 + 10)) - worker_id, dp_rank, overlap = await kv_push_router.best_worker(test_tokens) - + worker_id, dp_rank, overlap = await kv_push_router.best_worker( + test_tokens + ) + assert dp_rank is not None assert isinstance(dp_rank, int) assert 0 <= dp_rank < DP_SIZE dp_ranks_used.add(dp_rank) - + # Expect reasonable coverage across DP ranks num_ranks = len(dp_ranks_used) assert num_ranks >= 2, f"Poor coverage: only {num_ranks} ranks used" - logger.info(f"✅ Router coverage: {num_ranks}/{DP_SIZE} ranks used - {sorted(dp_ranks_used)}") + logger.info( + f"✅ Router coverage: {num_ranks}/{DP_SIZE} ranks used - {sorted(dp_ranks_used)}" + ) else: logger.warning("Router API doesn't support best_worker - skipping test") - + except Exception as e: logger.error(f"Test failed: {e}") raise - - - From 881be2f1964c27031158f72ab7096daf2b3a9717 Mon Sep 17 00:00:00 2001 From: Yangmin Li Date: Mon, 10 Nov 2025 11:08:34 -0800 Subject: [PATCH 3/6] optimize use of dp_rank/data_parallel_rank --- components/src/dynamo/router/__main__.py | 14 +++------- .../request_handlers/llm/decode_handler.py | 26 ++++++++----------- .../request_handlers/llm/prefill_handler.py | 22 +++++----------- 3 files changed, 21 insertions(+), 41 deletions(-) diff --git a/components/src/dynamo/router/__main__.py b/components/src/dynamo/router/__main__.py index c2b102c987..d33932194a 100644 --- a/components/src/dynamo/router/__main__.py +++ b/components/src/dynamo/router/__main__.py @@ -132,17 +132,9 @@ async def best_worker_id(self, token_ids, router_config_override=None): if self.kv_push_router is None: raise RuntimeError("Router not initialized") - # Use new API with dp_rank support, fallback to old API for compatibility - if hasattr(self.kv_push_router, "best_worker"): - result = await self.kv_push_router.best_worker( - token_ids, router_config_override - ) - else: - wid, overlap = await self.kv_push_router.best_worker_id( - token_ids, router_config_override - ) - result = (wid, None, overlap) - + result = await self.kv_push_router.best_worker( + token_ids, router_config_override + ) yield result diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index bdf4e23062..95347b32a4 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -145,15 +145,12 @@ async def generate( f"Router selected: worker={worker_id}, dp_rank={prefill_dp_rank}, overlap={overlap}" ) else: - # Backward compatibility: (worker_id, overlap) + # Fallback for older router versions (2-tuple response) worker_id, overlap = result_data prefill_dp_rank = None - logging.debug( - f"Router selected: worker={worker_id}, overlap={overlap}" - ) if not hasattr(self, "_dp_routing_unavailable_warned"): logging.warning( - "Router not returning dp_rank - DP-aware routing unavailable" + "Router returned 2-tuple, DP routing unavailable (update router)" ) self._dp_routing_unavailable_warned = True @@ -163,16 +160,15 @@ async def generate( sampling_params=sampling_params, ).model_dump() - # Inject dp_rank after serialization (Pydantic drops unknown fields) - if prefill_dp_rank is not None: - prefill_request_dict[ - "dp_rank" - ] = prefill_dp_rank # For Dynamo routing - if isinstance(prefill_request_dict.get("request"), dict): - prefill_request_dict["request"]["dp_rank"] = prefill_dp_rank - prefill_request_dict["request"][ - "data_parallel_rank" - ] = prefill_dp_rank + # Inject dp_rank into inner request after serialization (Pydantic drops unknown fields) + if prefill_dp_rank is not None and isinstance( + prefill_request_dict.get("request"), dict + ): + # SGLang engine reads data_parallel_rank from inner request + prefill_request_dict["request"][ + "data_parallel_rank" + ] = prefill_dp_rank + logging.info(f"Routing to prefill dp_rank={prefill_dp_rank}") prefill_stream = await self.prefill_client.direct( prefill_request_dict, diff --git a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py index 8224ffcda6..243a8aecdb 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py @@ -84,21 +84,13 @@ async def generate( inner_request = request["request"] sampling_params_dict = request["sampling_params"] - # Extract data_parallel_rank (explicit None check to preserve dp_rank=0) - if ( - "data_parallel_rank" in request - and request["data_parallel_rank"] is not None - ): - data_parallel_rank = request["data_parallel_rank"] - elif ( - "data_parallel_rank" in inner_request + # Extract data_parallel_rank from inner request (explicit None check to preserve dp_rank=0) + data_parallel_rank = ( + inner_request.get("data_parallel_rank") + if "data_parallel_rank" in inner_request and inner_request["data_parallel_rank"] is not None - ): - data_parallel_rank = inner_request["data_parallel_rank"] - elif "dp_rank" in inner_request and inner_request["dp_rank"] is not None: - data_parallel_rank = inner_request["dp_rank"] - else: - data_parallel_rank = None + else None + ) input_param = self._get_input_param(inner_request) @@ -114,7 +106,7 @@ async def generate( if data_parallel_rank is not None: generate_kwargs["data_parallel_rank"] = data_parallel_rank - logging.debug(f"Using dp_rank={data_parallel_rank} for prefill") + logging.info(f"Prefill using dp_rank={data_parallel_rank}") results = await self.engine.async_generate(**generate_kwargs) From 424836af0e43451be7fb5fa99d4afef4b2ea59f2 Mon Sep 17 00:00:00 2001 From: Yangmin Li Date: Mon, 10 Nov 2025 21:30:03 -0800 Subject: [PATCH 4/6] code optimize and add e2e test --- components/src/dynamo/sglang/main.py | 44 +++-- tests/router/test_dp_rank_routing.py | 197 ------------------- tests/router/test_router_e2e_with_mockers.py | 170 ++++++++++++++++ 3 files changed, 193 insertions(+), 218 deletions(-) delete mode 100644 tests/router/test_dp_rank_routing.py diff --git a/components/src/dynamo/sglang/main.py b/components/src/dynamo/sglang/main.py index a94b564e2b..84e5ad8ef7 100644 --- a/components/src/dynamo/sglang/main.py +++ b/components/src/dynamo/sglang/main.py @@ -156,32 +156,34 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): engine, config, component, generate_endpoint ) + # Readiness gate: requests wait until model is registered + ready_event = asyncio.Event() + handler = PrefillWorkerHandler(component, engine, config, publisher) health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() - # Register Prefill to expose dp_size to Router - await register_llm_with_readiness_gate( - engine, - generate_endpoint, - server_args, - dynamo_args, - input_type=ModelInput.Tokens, - output_type=ModelType.Chat | ModelType.Completions, - readiness_gate=None, - ) - - tasks = [ - generate_endpoint.serve_endpoint( - handler.generate, - graceful_shutdown=True, - metrics_labels=metrics_labels, - health_check_payload=health_check_payload, - ) - ] - try: - await asyncio.gather(*tasks) + # Start endpoint and register model concurrently + # Requests queue until ready_event is set + # Register Prefill to expose dp_size to Router + await asyncio.gather( + generate_endpoint.serve_endpoint( + handler.generate, + graceful_shutdown=True, + metrics_labels=metrics_labels, + health_check_payload=health_check_payload, + ), + register_llm_with_readiness_gate( + engine, + generate_endpoint, + server_args, + dynamo_args, + input_type=ModelInput.Tokens, + output_type=ModelType.Chat | ModelType.Completions, + readiness_gate=ready_event, + ), + ) except Exception as e: logging.error(f"Failed to serve endpoints: {e}") raise diff --git a/tests/router/test_dp_rank_routing.py b/tests/router/test_dp_rank_routing.py deleted file mode 100644 index e47ed63ee9..0000000000 --- a/tests/router/test_dp_rank_routing.py +++ /dev/null @@ -1,197 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Tests for data parallel rank-aware routing.""" - -import logging -import random -import string -from typing import List - -import pytest - -from dynamo._core import DistributedRuntime, KvPushRouter, KvRouterConfig -from tests.utils.constants import ROUTER_MODEL_NAME -from tests.utils.managed_process import ManagedProcess - -pytestmark = pytest.mark.pre_merge - -logger = logging.getLogger(__name__) - -MODEL_NAME = ROUTER_MODEL_NAME -DP_SIZE = 4 # Test with 4 data parallel ranks -BLOCK_SIZE = 16 - - -def generate_random_suffix() -> str: - """Generate random suffix for namespace isolation.""" - return "".join(random.choices(string.ascii_lowercase, k=10)) - - -def get_runtime(): - """Get or create a DistributedRuntime instance.""" - try: - return DistributedRuntime.detached() - except Exception: - import asyncio - - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return DistributedRuntime(loop, False) - - -class MockerProcess: - """Manages mocker engine instance with DP support""" - - def __init__( - self, - request, - namespace: str, - dp_size: int = 1, - ): - self.namespace = namespace - self.endpoint = f"dyn://{namespace}.prefill.generate" - self.dp_size = dp_size - self.mocker_processes: List[ManagedProcess] = [] - self.request = request - - # Create mocker process - command = [ - "python", - "-m", - "dynamo.mocker", - "--model-path", - MODEL_NAME, - "--endpoint", - self.endpoint, - "--speedup-ratio", - "10.0", - "--block-size", - str(BLOCK_SIZE), - "--num-gpu-blocks-override", - "1000", - "--data-parallel-size", - str(dp_size), - ] - - process = ManagedProcess( - command=command, - timeout=60, - display_output=True, - log_dir=request.node.name, - terminate_existing=False, - ) - self.mocker_processes.append(process) - - def __enter__(self): - """Start mocker process""" - for process in self.mocker_processes: - process.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Stop mocker process""" - for process in self.mocker_processes: - process.__exit__(exc_type, exc_val, exc_tb) - - -@pytest.mark.asyncio -async def test_router_returns_valid_dp_rank(request): - """Verify router returns valid dp_rank in routing decision.""" - runtime = get_runtime() - namespace = f"dp-routing-{generate_random_suffix()}" - - with MockerProcess(request, namespace, dp_size=DP_SIZE): - try: - router_comp = runtime.namespace(namespace).component("router") - await router_comp.create_service() - - endpoint = ( - runtime.namespace(namespace).component("prefill").endpoint("generate") - ) - - import asyncio - - await asyncio.sleep(2) - - kv_router_config = KvRouterConfig( - overlap_score_weight=2.0, - router_temperature=0.0, - ) - kv_push_router = KvPushRouter(endpoint, BLOCK_SIZE, kv_router_config) - - if hasattr(kv_push_router, "best_worker"): - worker_id, dp_rank, overlap = await kv_push_router.best_worker( - [1, 2, 3, 4, 5] - ) - - assert dp_rank is not None, "Router should return dp_rank" - assert isinstance(dp_rank, int), "dp_rank should be integer" - assert ( - 0 <= dp_rank < DP_SIZE - ), f"dp_rank {dp_rank} out of range [0, {DP_SIZE})" - - logger.info(f"✅ Router returned valid dp_rank={dp_rank}") - else: - logger.warning("Router API doesn't support best_worker - skipping test") - - except Exception as e: - logger.error(f"Test failed: {e}") - raise - - -@pytest.mark.asyncio -async def test_dp_rank_coverage(request): - """Verify router selects from full range of DP ranks.""" - runtime = get_runtime() - namespace = f"dp-routing-{generate_random_suffix()}" - - with MockerProcess(request, namespace, dp_size=DP_SIZE): - try: - router_comp = runtime.namespace(namespace).component("router") - await router_comp.create_service() - - endpoint = ( - runtime.namespace(namespace).component("prefill").endpoint("generate") - ) - - import asyncio - - await asyncio.sleep(2) - - kv_router_config = KvRouterConfig( - overlap_score_weight=2.0, - router_temperature=0.0, - ) - kv_push_router = KvPushRouter(endpoint, BLOCK_SIZE, kv_router_config) - - if hasattr(kv_push_router, "best_worker"): - dp_ranks_used = set() - - # Query with varied sequences to cover all ranks - for i in range(50): - test_tokens = list(range(i * 7, i * 7 + 10)) - worker_id, dp_rank, overlap = await kv_push_router.best_worker( - test_tokens - ) - - assert dp_rank is not None - assert isinstance(dp_rank, int) - assert 0 <= dp_rank < DP_SIZE - dp_ranks_used.add(dp_rank) - - # Expect reasonable coverage across DP ranks - num_ranks = len(dp_ranks_used) - assert num_ranks >= 2, f"Poor coverage: only {num_ranks} ranks used" - logger.info( - f"✅ Router coverage: {num_ranks}/{DP_SIZE} ranks used - {sorted(dp_ranks_used)}" - ) - else: - logger.warning("Router API doesn't support best_worker - skipping test") - - except Exception as e: - logger.error(f"Test failed: {e}") - raise diff --git a/tests/router/test_router_e2e_with_mockers.py b/tests/router/test_router_e2e_with_mockers.py index 63506c6707..7aca8f620c 100644 --- a/tests/router/test_router_e2e_with_mockers.py +++ b/tests/router/test_router_e2e_with_mockers.py @@ -1545,3 +1545,173 @@ async def test_sync(): # Clean up mockers if "mockers" in locals(): mockers.__exit__(None, None, None) + + +@pytest.mark.pre_merge +@pytest.mark.model(MODEL_NAME) +def test_disagg_dp_routing_e2e(request, runtime_services, predownload_tokenizers): + """ + E2E test for DP-aware routing in disaggregated prefill-decode mode. + + This test validates the complete DP routing flow: + 1. Router correctly selects worker_id and dp_rank for prefill requests + 2. Decode worker receives and uses the router-selected dp_rank + 3. Bootstrap connection works correctly with matching dp_ranks between prefill and decode + 4. End-to-end requests complete successfully through the full pipeline + + Flow: + - Start multiple prefill workers with dp_size=4 (each worker has 4 DP ranks) + - Create KV router and send requests with varying token sequences + - For each request: + * Query router for best (worker_id, dp_rank, overlap) + * Verify dp_rank is valid (0 <= dp_rank < DP_SIZE) + * Send request through router using selected dp_rank + * Verify request completes successfully (proves bootstrap connection works) + - Verify all DP ranks get utilized across multiple requests + """ + + logger.info("Starting disaggregated DP routing E2E test") + + DP_SIZE = 4 # Each worker has 4 DP ranks + NUM_PREFILL_WORKERS = 2 # 2 workers × 4 DP ranks = 8 total DP ranks + + # Create mocker args with DP support + mocker_args = { + "speedup_ratio": SPEEDUP_RATIO, + "block_size": BLOCK_SIZE, + "dp_size": DP_SIZE, + "num_gpu_blocks": 1000, + } + + try: + # Start prefill workers with DP support + logger.info( + f"Starting {NUM_PREFILL_WORKERS} prefill workers with dp_size={DP_SIZE} each " + f"({NUM_PREFILL_WORKERS * DP_SIZE} total DP ranks)" + ) + prefill_mockers = MockerProcess( + request, mocker_args=mocker_args, num_mockers=NUM_PREFILL_WORKERS + ) + logger.info(f"Prefill workers using endpoint: {prefill_mockers.endpoint}") + prefill_mockers.__enter__() + + async def test_dp_routing(): + # Get runtime and create components + runtime = get_runtime() + namespace = runtime.namespace(prefill_mockers.namespace) + component = namespace.component("mocker") + endpoint = component.endpoint("generate") + + # Create router with configuration + kv_router_config = KvRouterConfig( + overlap_score_weight=2.0, + router_temperature=0.0, # Deterministic routing for testing + ) + kv_push_router = KvPushRouter( + endpoint=endpoint, + block_size=BLOCK_SIZE, + kv_router_config=kv_router_config, + ) + + logger.info("Created KvPushRouter for DP routing test") + + # Wait for prefill workers to be ready + instance_ids = await wait_for_mockers_ready( + endpoint, kv_push_router, expected_num_workers=NUM_PREFILL_WORKERS + ) + logger.info(f"Prefill workers ready: {instance_ids}") + + # Track which DP ranks are used across requests + dp_ranks_used = set() + num_test_requests = 20 + + # Send multiple requests to test DP routing + for i in range(num_test_requests): + # Generate different token sequences to exercise routing logic + num_tokens = random.randint(30, 100) + test_tokens = [random.randint(1, 10000) for _ in range(num_tokens)] + + # Query router for best worker and dp_rank (without actually routing yet) + if hasattr(kv_push_router, "best_worker"): + worker_id, dp_rank, overlap = await kv_push_router.best_worker( + test_tokens + ) + + # Verify dp_rank is valid + assert dp_rank is not None, ( + f"Router should return dp_rank for request {i+1}, " + f"but got None" + ) + assert isinstance(dp_rank, int), ( + f"dp_rank should be integer for request {i+1}, " + f"but got {type(dp_rank)}" + ) + assert ( + 0 <= dp_rank < DP_SIZE + ), f"Request {i+1}: dp_rank {dp_rank} out of valid range [0, {DP_SIZE})" + + dp_ranks_used.add(dp_rank) + + logger.info( + f"Request {i+1}/{num_test_requests}: Router selected " + f"worker={worker_id}, dp_rank={dp_rank}, overlap={overlap} blocks " + f"for {num_tokens} input tokens" + ) + + # Send actual request through the router with selected dp_rank + # This tests the full pipeline: Router -> Prefill -> Decode (via bootstrap) + await send_request_via_python_kv_router( + kv_python_router=kv_push_router, + token_ids=test_tokens, + initial_wait=1.0, + max_retries=8, + stop_conditions={ + "ignore_eos": True, + "max_tokens": 5, # Short generation for fast test + }, + worker_id=worker_id, + dp_rank=dp_rank, + ) + + logger.info( + f"Request {i+1}/{num_test_requests} completed successfully " + f"with dp_rank={dp_rank}" + ) + + else: + logger.warning( + "Router doesn't support best_worker API - skipping DP routing test" + ) + return + + # Verify DP rank coverage across all requests + num_ranks_used = len(dp_ranks_used) + logger.info( + f"DP rank coverage: {num_ranks_used}/{DP_SIZE} ranks used across " + f"{num_test_requests} requests: {sorted(dp_ranks_used)}" + ) + + # We expect reasonable coverage (at least 2 different ranks for 20 requests) + assert num_ranks_used >= 2, ( + f"Poor DP rank coverage: only {num_ranks_used}/{DP_SIZE} ranks used " + f"across {num_test_requests} requests. Expected at least 2 different ranks." + ) + + logger.info( + f"Successfully validated DP-aware routing E2E:\n" + f" - {num_test_requests} requests completed successfully\n" + f" - {num_ranks_used}/{DP_SIZE} DP ranks utilized\n" + f" - Router correctly selected (worker_id, dp_rank, overlap) tuples\n" + f" - Prefill-Decode bootstrap connections worked with matching dp_ranks\n" + f" - All requests completed through full pipeline" + ) + + # Run the async test + asyncio.run(test_dp_routing()) + + logger.info("Disaggregated DP routing E2E test completed successfully") + + finally: + # Clean up prefill mockers + if "prefill_mockers" in locals(): + prefill_mockers.__exit__(None, None, None) From 06243546459cf69b15df0cf7ab60a0cceebd46bc Mon Sep 17 00:00:00 2001 From: Yangmin Li Date: Mon, 10 Nov 2025 22:10:54 -0800 Subject: [PATCH 5/6] remove redundant logging --- components/src/dynamo/sglang/main.py | 46 ++++++++++--------- .../request_handlers/llm/decode_handler.py | 8 ---- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/components/src/dynamo/sglang/main.py b/components/src/dynamo/sglang/main.py index 84e5ad8ef7..29594ad883 100644 --- a/components/src/dynamo/sglang/main.py +++ b/components/src/dynamo/sglang/main.py @@ -156,34 +156,36 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): engine, config, component, generate_endpoint ) - # Readiness gate: requests wait until model is registered - ready_event = asyncio.Event() - handler = PrefillWorkerHandler(component, engine, config, publisher) health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() + # Register Prefill to expose dp_size to Router try: - # Start endpoint and register model concurrently - # Requests queue until ready_event is set - # Register Prefill to expose dp_size to Router - await asyncio.gather( - generate_endpoint.serve_endpoint( - handler.generate, - graceful_shutdown=True, - metrics_labels=metrics_labels, - health_check_payload=health_check_payload, - ), - register_llm_with_readiness_gate( - engine, - generate_endpoint, - server_args, - dynamo_args, - input_type=ModelInput.Tokens, - output_type=ModelType.Chat | ModelType.Completions, - readiness_gate=ready_event, - ), + await register_llm_with_readiness_gate( + engine, + generate_endpoint, + server_args, + dynamo_args, + input_type=ModelInput.Tokens, + output_type=ModelType.Chat | ModelType.Completions, + readiness_gate=None, ) + except Exception as e: + logging.error(f"Failed to register prefill worker: {e}") + raise + + tasks = [ + generate_endpoint.serve_endpoint( + handler.generate, + graceful_shutdown=True, + metrics_labels=metrics_labels, + health_check_payload=health_check_payload, + ) + ] + + try: + await asyncio.gather(*tasks) except Exception as e: logging.error(f"Failed to serve endpoints: {e}") raise diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index 95347b32a4..284b5e48fb 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -136,14 +136,6 @@ async def generate( result_data = result.data() if len(result_data) == 3: worker_id, prefill_dp_rank, overlap = result_data - if not hasattr(self, "_dp_routing_active_logged"): - logging.info( - f"DP-aware routing active: dp_rank={prefill_dp_rank}" - ) - self._dp_routing_active_logged = True - logging.debug( - f"Router selected: worker={worker_id}, dp_rank={prefill_dp_rank}, overlap={overlap}" - ) else: # Fallback for older router versions (2-tuple response) worker_id, overlap = result_data From bf1d4716206578a00d4135b20fbeaa763dc1add2 Mon Sep 17 00:00:00 2001 From: Yangmin Li Date: Tue, 11 Nov 2025 12:16:50 -0800 Subject: [PATCH 6/6] small fix --- .../dynamo/sglang/request_handlers/llm/decode_handler.py | 7 +++---- .../dynamo/sglang/request_handlers/llm/prefill_handler.py | 8 -------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index 284b5e48fb..8967854966 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -124,6 +124,8 @@ async def generate( ) if self.serving_mode == DisaggregationMode.DECODE: + prefill_dp_rank = None + # request the bootstrap info from the target prefill worker if ( self.prefill_router_client is not None @@ -160,7 +162,6 @@ async def generate( prefill_request_dict["request"][ "data_parallel_rank" ] = prefill_dp_rank - logging.info(f"Routing to prefill dp_rank={prefill_dp_rank}") prefill_stream = await self.prefill_client.direct( prefill_request_dict, @@ -194,7 +195,7 @@ async def generate( } # Use router-selected dp_rank (fallback to request-level if not provided) - if "prefill_dp_rank" in locals() and prefill_dp_rank is not None: + if prefill_dp_rank is not None: effective_dp_rank = prefill_dp_rank elif data_parallel_rank is not None: effective_dp_rank = data_parallel_rank @@ -203,7 +204,6 @@ async def generate( if effective_dp_rank is not None: generate_kwargs["data_parallel_rank"] = effective_dp_rank - logging.debug(f"Using dp_rank={effective_dp_rank} for decode") decode = await self.engine.async_generate(**generate_kwargs) @@ -223,7 +223,6 @@ async def generate( if data_parallel_rank is not None: generate_kwargs["data_parallel_rank"] = data_parallel_rank - logging.debug(f"Using dp_rank={data_parallel_rank} for aggregated mode") agg = await self.engine.async_generate(**generate_kwargs) if self.skip_tokenizer_init: diff --git a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py index 243a8aecdb..60f70c1325 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py @@ -74,13 +74,6 @@ async def generate( yield bootstrap_info - # Validate disaggregated request format - if "request" not in request or "sampling_params" not in request: - raise ValueError( - f"Expected disaggregated format with 'request' and 'sampling_params', " - f"got keys: {list(request.keys())}" - ) - inner_request = request["request"] sampling_params_dict = request["sampling_params"] @@ -106,7 +99,6 @@ async def generate( if data_parallel_rank is not None: generate_kwargs["data_parallel_rank"] = data_parallel_rank - logging.info(f"Prefill using dp_rank={data_parallel_rank}") results = await self.engine.async_generate(**generate_kwargs)