diff --git a/components/src/dynamo/router/__main__.py b/components/src/dynamo/router/__main__.py index a11b0cb492..d33932194a 100644 --- a/components/src/dynamo/router/__main__.py +++ b/components/src/dynamo/router/__main__.py @@ -124,21 +124,17 @@ async def generate(self, request): yield llm_engine_output async def best_worker_id(self, token_ids, router_config_override=None): - """ - Get the best worker ID for a given set of tokens without actually routing. + """Get the best worker for given tokens without routing. - This method returns the worker ID that would be selected based on KV cache - overlap, but does NOT actually route the request or update router states. - It's useful for debugging, monitoring, or implementing custom routing logic. + Returns (worker_id, dp_rank, overlap_blocks) based on KV cache overlap. + Does NOT update router states - useful for preview/debugging. """ if self.kv_push_router is None: - logger.error("KvPushRouter not initialized - cannot get best worker") raise RuntimeError("Router not initialized") - result = await self.kv_push_router.best_worker_id( + result = await self.kv_push_router.best_worker( token_ids, router_config_override ) - yield result diff --git a/components/src/dynamo/sglang/main.py b/components/src/dynamo/sglang/main.py index a549ca7997..29594ad883 100644 --- a/components/src/dynamo/sglang/main.py +++ b/components/src/dynamo/sglang/main.py @@ -110,8 +110,7 @@ async def init(runtime: DistributedRuntime, config: Config): health_check_payload = SglangHealthCheckPayload(engine).to_dict() try: - # Start endpoint immediately and register model concurrently - # Requests queue until ready_event is set (TODO: Part of new PR) + # Start endpoint and register model concurrently await asyncio.gather( generate_endpoint.serve_endpoint( handler.generate, @@ -161,6 +160,21 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() + # Register Prefill to expose dp_size to Router + try: + await register_llm_with_readiness_gate( + engine, + generate_endpoint, + server_args, + dynamo_args, + input_type=ModelInput.Tokens, + output_type=ModelType.Chat | ModelType.Completions, + readiness_gate=None, + ) + except Exception as e: + logging.error(f"Failed to register prefill worker: {e}") + raise + tasks = [ generate_endpoint.serve_endpoint( handler.generate, diff --git a/components/src/dynamo/sglang/register.py b/components/src/dynamo/sglang/register.py index 9819a92733..a7c4a22c51 100644 --- a/components/src/dynamo/sglang/register.py +++ b/components/src/dynamo/sglang/register.py @@ -94,6 +94,12 @@ async def _get_runtime_config( if max_prefill_tokens: runtime_config.max_num_batched_tokens = max_prefill_tokens + # Expose data_parallel_size for DP-aware routing + dp_size = getattr(server_args, "dp_size", 1) + runtime_config.data_parallel_size = dp_size + if dp_size > 1: + logging.info(f"Registering with data_parallel_size={dp_size}") + try: # Try to check if the engine has a scheduler attribute with the computed values if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None: diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index 0ceea13ed3..8967854966 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -115,7 +115,17 @@ async def generate( sampling_params = self._build_sampling_params(request) input_param = self._get_input_param(request) + # Extract data_parallel_rank (explicit None check to preserve dp_rank=0) + data_parallel_rank = ( + request.get("data_parallel_rank") + if "data_parallel_rank" in request + and request["data_parallel_rank"] is not None + else request.get("dp_rank") + ) + if self.serving_mode == DisaggregationMode.DECODE: + prefill_dp_rank = None + # request the bootstrap info from the target prefill worker if ( self.prefill_router_client is not None @@ -124,17 +134,37 @@ async def generate( token_ids = request["token_ids"] stream = await self.prefill_router_client.generate(token_ids) result = await anext(stream) - ( - worker_id, - overlap, - ) = result.data() # Returns tuple (worker_id, overlap_amount) - logging.info(f"Best prefill worker ID: {worker_id}, overlap: {overlap}") + # Unpack router response: (worker_id, dp_rank, overlap_blocks) + result_data = result.data() + if len(result_data) == 3: + worker_id, prefill_dp_rank, overlap = result_data + else: + # Fallback for older router versions (2-tuple response) + worker_id, overlap = result_data + prefill_dp_rank = None + if not hasattr(self, "_dp_routing_unavailable_warned"): + logging.warning( + "Router returned 2-tuple, DP routing unavailable (update router)" + ) + self._dp_routing_unavailable_warned = True + + # Build prefill request + prefill_request_dict = DisaggPreprocessedRequest( + request=request, + sampling_params=sampling_params, + ).model_dump() + + # Inject dp_rank into inner request after serialization (Pydantic drops unknown fields) + if prefill_dp_rank is not None and isinstance( + prefill_request_dict.get("request"), dict + ): + # SGLang engine reads data_parallel_rank from inner request + prefill_request_dict["request"][ + "data_parallel_rank" + ] = prefill_dp_rank prefill_stream = await self.prefill_client.direct( - DisaggPreprocessedRequest( - request=request, - sampling_params=sampling_params, - ).model_dump(), + prefill_request_dict, worker_id, ) else: @@ -154,14 +184,28 @@ async def generate( if not bootstrap_info: raise RuntimeError("No bootstrap info received from prefill worker") - decode = await self.engine.async_generate( + # Prefill and Decode must use same dp_rank for bootstrap connection + generate_kwargs = { **input_param, - sampling_params=sampling_params, - stream=True, - bootstrap_host=bootstrap_info["bootstrap_host"], - bootstrap_port=bootstrap_info["bootstrap_port"], - bootstrap_room=bootstrap_info["bootstrap_room"], - ) + "sampling_params": sampling_params, + "stream": True, + "bootstrap_host": bootstrap_info["bootstrap_host"], + "bootstrap_port": bootstrap_info["bootstrap_port"], + "bootstrap_room": bootstrap_info["bootstrap_room"], + } + + # Use router-selected dp_rank (fallback to request-level if not provided) + if prefill_dp_rank is not None: + effective_dp_rank = prefill_dp_rank + elif data_parallel_rank is not None: + effective_dp_rank = data_parallel_rank + else: + effective_dp_rank = None + + if effective_dp_rank is not None: + generate_kwargs["data_parallel_rank"] = effective_dp_rank + + decode = await self.engine.async_generate(**generate_kwargs) if self.skip_tokenizer_init: async for out in self._process_token_stream(decode, context): @@ -170,11 +214,17 @@ async def generate( async for out in self._process_text_stream(decode, context): yield out else: - agg = await self.engine.async_generate( + # Aggregated mode + generate_kwargs = { **input_param, - sampling_params=sampling_params, - stream=True, - ) + "sampling_params": sampling_params, + "stream": True, + } + + if data_parallel_rank is not None: + generate_kwargs["data_parallel_rank"] = data_parallel_rank + + agg = await self.engine.async_generate(**generate_kwargs) if self.skip_tokenizer_init: async for out in self._process_token_stream(agg, context): yield out diff --git a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py index dc55ab9762..60f70c1325 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py @@ -74,16 +74,33 @@ async def generate( yield bootstrap_info - input_param = self._get_input_param(request["request"]) + inner_request = request["request"] + sampling_params_dict = request["sampling_params"] + + # Extract data_parallel_rank from inner request (explicit None check to preserve dp_rank=0) + data_parallel_rank = ( + inner_request.get("data_parallel_rank") + if "data_parallel_rank" in inner_request + and inner_request["data_parallel_rank"] is not None + else None + ) + + input_param = self._get_input_param(inner_request) - results = await self.engine.async_generate( + # Build engine kwargs + generate_kwargs = { **input_param, - sampling_params=request["sampling_params"], - stream=True, - bootstrap_host=self.bootstrap_host, - bootstrap_port=self.bootstrap_port, - bootstrap_room=bootstrap_room, - ) + "sampling_params": sampling_params_dict, + "stream": True, + "bootstrap_host": self.bootstrap_host, + "bootstrap_port": self.bootstrap_port, + "bootstrap_room": bootstrap_room, + } + + if data_parallel_rank is not None: + generate_kwargs["data_parallel_rank"] = data_parallel_rank + + results = await self.engine.async_generate(**generate_kwargs) task = asyncio.create_task(self._consume_results(results, context)) self._consume_tasks.add(task) diff --git a/tests/router/test_router_e2e_with_mockers.py b/tests/router/test_router_e2e_with_mockers.py index 63506c6707..7aca8f620c 100644 --- a/tests/router/test_router_e2e_with_mockers.py +++ b/tests/router/test_router_e2e_with_mockers.py @@ -1545,3 +1545,173 @@ async def test_sync(): # Clean up mockers if "mockers" in locals(): mockers.__exit__(None, None, None) + + +@pytest.mark.pre_merge +@pytest.mark.model(MODEL_NAME) +def test_disagg_dp_routing_e2e(request, runtime_services, predownload_tokenizers): + """ + E2E test for DP-aware routing in disaggregated prefill-decode mode. + + This test validates the complete DP routing flow: + 1. Router correctly selects worker_id and dp_rank for prefill requests + 2. Decode worker receives and uses the router-selected dp_rank + 3. Bootstrap connection works correctly with matching dp_ranks between prefill and decode + 4. End-to-end requests complete successfully through the full pipeline + + Flow: + - Start multiple prefill workers with dp_size=4 (each worker has 4 DP ranks) + - Create KV router and send requests with varying token sequences + - For each request: + * Query router for best (worker_id, dp_rank, overlap) + * Verify dp_rank is valid (0 <= dp_rank < DP_SIZE) + * Send request through router using selected dp_rank + * Verify request completes successfully (proves bootstrap connection works) + - Verify all DP ranks get utilized across multiple requests + """ + + logger.info("Starting disaggregated DP routing E2E test") + + DP_SIZE = 4 # Each worker has 4 DP ranks + NUM_PREFILL_WORKERS = 2 # 2 workers × 4 DP ranks = 8 total DP ranks + + # Create mocker args with DP support + mocker_args = { + "speedup_ratio": SPEEDUP_RATIO, + "block_size": BLOCK_SIZE, + "dp_size": DP_SIZE, + "num_gpu_blocks": 1000, + } + + try: + # Start prefill workers with DP support + logger.info( + f"Starting {NUM_PREFILL_WORKERS} prefill workers with dp_size={DP_SIZE} each " + f"({NUM_PREFILL_WORKERS * DP_SIZE} total DP ranks)" + ) + prefill_mockers = MockerProcess( + request, mocker_args=mocker_args, num_mockers=NUM_PREFILL_WORKERS + ) + logger.info(f"Prefill workers using endpoint: {prefill_mockers.endpoint}") + prefill_mockers.__enter__() + + async def test_dp_routing(): + # Get runtime and create components + runtime = get_runtime() + namespace = runtime.namespace(prefill_mockers.namespace) + component = namespace.component("mocker") + endpoint = component.endpoint("generate") + + # Create router with configuration + kv_router_config = KvRouterConfig( + overlap_score_weight=2.0, + router_temperature=0.0, # Deterministic routing for testing + ) + kv_push_router = KvPushRouter( + endpoint=endpoint, + block_size=BLOCK_SIZE, + kv_router_config=kv_router_config, + ) + + logger.info("Created KvPushRouter for DP routing test") + + # Wait for prefill workers to be ready + instance_ids = await wait_for_mockers_ready( + endpoint, kv_push_router, expected_num_workers=NUM_PREFILL_WORKERS + ) + logger.info(f"Prefill workers ready: {instance_ids}") + + # Track which DP ranks are used across requests + dp_ranks_used = set() + num_test_requests = 20 + + # Send multiple requests to test DP routing + for i in range(num_test_requests): + # Generate different token sequences to exercise routing logic + num_tokens = random.randint(30, 100) + test_tokens = [random.randint(1, 10000) for _ in range(num_tokens)] + + # Query router for best worker and dp_rank (without actually routing yet) + if hasattr(kv_push_router, "best_worker"): + worker_id, dp_rank, overlap = await kv_push_router.best_worker( + test_tokens + ) + + # Verify dp_rank is valid + assert dp_rank is not None, ( + f"Router should return dp_rank for request {i+1}, " + f"but got None" + ) + assert isinstance(dp_rank, int), ( + f"dp_rank should be integer for request {i+1}, " + f"but got {type(dp_rank)}" + ) + assert ( + 0 <= dp_rank < DP_SIZE + ), f"Request {i+1}: dp_rank {dp_rank} out of valid range [0, {DP_SIZE})" + + dp_ranks_used.add(dp_rank) + + logger.info( + f"Request {i+1}/{num_test_requests}: Router selected " + f"worker={worker_id}, dp_rank={dp_rank}, overlap={overlap} blocks " + f"for {num_tokens} input tokens" + ) + + # Send actual request through the router with selected dp_rank + # This tests the full pipeline: Router -> Prefill -> Decode (via bootstrap) + await send_request_via_python_kv_router( + kv_python_router=kv_push_router, + token_ids=test_tokens, + initial_wait=1.0, + max_retries=8, + stop_conditions={ + "ignore_eos": True, + "max_tokens": 5, # Short generation for fast test + }, + worker_id=worker_id, + dp_rank=dp_rank, + ) + + logger.info( + f"Request {i+1}/{num_test_requests} completed successfully " + f"with dp_rank={dp_rank}" + ) + + else: + logger.warning( + "Router doesn't support best_worker API - skipping DP routing test" + ) + return + + # Verify DP rank coverage across all requests + num_ranks_used = len(dp_ranks_used) + logger.info( + f"DP rank coverage: {num_ranks_used}/{DP_SIZE} ranks used across " + f"{num_test_requests} requests: {sorted(dp_ranks_used)}" + ) + + # We expect reasonable coverage (at least 2 different ranks for 20 requests) + assert num_ranks_used >= 2, ( + f"Poor DP rank coverage: only {num_ranks_used}/{DP_SIZE} ranks used " + f"across {num_test_requests} requests. Expected at least 2 different ranks." + ) + + logger.info( + f"Successfully validated DP-aware routing E2E:\n" + f" - {num_test_requests} requests completed successfully\n" + f" - {num_ranks_used}/{DP_SIZE} DP ranks utilized\n" + f" - Router correctly selected (worker_id, dp_rank, overlap) tuples\n" + f" - Prefill-Decode bootstrap connections worked with matching dp_ranks\n" + f" - All requests completed through full pipeline" + ) + + # Run the async test + asyncio.run(test_dp_routing()) + + logger.info("Disaggregated DP routing E2E test completed successfully") + + finally: + # Clean up prefill mockers + if "prefill_mockers" in locals(): + prefill_mockers.__exit__(None, None, None)