diff --git a/model-engine/model_engine_server/api/v2/chat_completion.py b/model-engine/model_engine_server/api/v2/chat_completion.py index 614f159d0..f6b4facce 100644 --- a/model-engine/model_engine_server/api/v2/chat_completion.py +++ b/model-engine/model_engine_server/api/v2/chat_completion.py @@ -263,7 +263,7 @@ async def chat_completion( ) else: logger.info( - f"POST /v2/chat/completion ({('stream' if request.stream else 'sync')}) with {request} to endpoint {model_endpoint_name} for {auth}" + f"POST /v2/chat/completion ({('stream' if request.stream else 'sync')}) with request {request} to endpoint {model_endpoint_name} for {auth}" ) if request.stream: diff --git a/model-engine/model_engine_server/api/v2/completion.py b/model-engine/model_engine_server/api/v2/completion.py index ed529fe3b..eb101e020 100644 --- a/model-engine/model_engine_server/api/v2/completion.py +++ b/model-engine/model_engine_server/api/v2/completion.py @@ -262,7 +262,7 @@ async def completion( ) else: logger.info( - f"POST /v2/completion ({('stream' if request.stream else 'sync')}) with {request} to endpoint {model_endpoint_name} for {auth}" + f"POST /v2/completion ({('stream' if request.stream else 'sync')}) with request {request} to endpoint {model_endpoint_name} for {auth}" ) if request.stream: diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 1fb8dbed9..4f3602aad 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -170,8 +170,8 @@ } -NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes -DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes +NUM_DOWNSTREAM_REQUEST_RETRIES = 80 * 12 # has to be high enough so that the retries take the 5 minutes +DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 60 * 60 # 5 minutes DEFAULT_BATCH_COMPLETIONS_NODES_PER_WORKER = 1 @@ -377,6 +377,87 @@ def check_docker_image_exists_for_image_tag( tag=framework_image_tag, ) + async def create_sglang_multinode_bundle( + self, + user: User, + model_name: str, + framework_image_tag: str, + endpoint_unique_name: str, + num_shards: int, + nodes_per_worker: int, + quantize: Optional[Quantization], + checkpoint_path: Optional[str], + chat_template_override: Optional[str], + additional_args: Optional[SGLangEndpointAdditionalArgs] = None, + ): + leader_command = [ + "python3", + "/root/sglang-startup-script.py", + "--model", + "deepseek-ai/DeepSeek-R1-0528", + "--nnodes", + "2", + "--node-rank", + "0", + "--worker-port", + "5005", + "--leader-port", + "5002", + ] + + worker_command = [ + "python3", + "/root/sglang-startup-script.py", + "--model", + "deepseek-ai/DeepSeek-R1-0528", + "--nnodes", + "2", + "--node-rank", + "1", + "--worker-port", + "5005", + "--leader-port", + "5002", + ] + + # NOTE: the most important env var SGLANG_HOST_IP is already established in the sglang startup script + + common_sglang_envs = { # these are for debugging + "NCCL_SOCKET_IFNAME": "eth0", + "GLOO_SOCKET_IFNAME": "eth0", + } + + # This is same as VLLM multinode bundle + create_model_bundle_v2_request = CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository=hmi_config.sglang_repository, + tag=framework_image_tag, + command=leader_command, + streaming_command=leader_command, + protocol="http", + readiness_initial_delay_seconds=10, + healthcheck_route="/health", + predict_route="/predict", + streaming_predict_route="/stream", + extra_routes=[OPENAI_CHAT_COMPLETION_PATH, OPENAI_COMPLETION_PATH], + env=common_sglang_envs, + worker_command=worker_command, + worker_env=common_sglang_envs, + ), + metadata={}, + ) + + return ( + await self.create_model_bundle_use_case.execute( + user, + create_model_bundle_v2_request, + do_auth_check=False, + ) + ).model_bundle_id + async def execute( self, user: User, @@ -400,7 +481,10 @@ async def execute( self.check_docker_image_exists_for_image_tag( framework_image_tag, INFERENCE_FRAMEWORK_REPOSITORY[framework] ) - if multinode and framework != LLMInferenceFramework.VLLM: + if multinode and framework not in [ + LLMInferenceFramework.VLLM, + LLMInferenceFramework.SGLANG, + ]: raise ObjectHasInvalidValueException( f"Multinode is not supported for framework {framework}." ) @@ -481,16 +565,30 @@ async def execute( if additional_args else None ) - bundle_id = await self.create_sglang_bundle( - user, - model_name, - framework_image_tag, - endpoint_name, - num_shards, - checkpoint_path, - chat_template_override, - additional_args=additional_sglang_args, - ) + if multinode: + bundle_id = await self.create_sglang_multinode_bundle( + user, + model_name, + framework_image_tag, + endpoint_name, + num_shards, + nodes_per_worker, + quantize, + checkpoint_path, + chat_template_override, + additional_args=additional_sglang_args, + ) + else: + bundle_id = await self.create_sglang_bundle( + user, + model_name, + framework_image_tag, + endpoint_name, + num_shards, + checkpoint_path, + chat_template_override, + additional_args=additional_sglang_args, + ) case _: assert_never(framework) raise ObjectHasInvalidValueException( @@ -1321,10 +1419,10 @@ async def execute( request.inference_framework ) - if ( - request.nodes_per_worker > 1 - and not request.inference_framework == LLMInferenceFramework.VLLM - ): + if request.nodes_per_worker > 1 and not request.inference_framework in [ + LLMInferenceFramework.VLLM, + LLMInferenceFramework.SGLANG, + ]: raise ObjectHasInvalidValueException( "Multinode endpoints are only supported for VLLM models." ) diff --git a/model-engine/model_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py index 096d48bfd..5518e713c 100644 --- a/model-engine/model_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -174,7 +174,9 @@ async def forward(self, json_payload: Any) -> Any: logger.info(f"Accepted request, forwarding {json_payload_repr=}") try: - async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient: + async with aiohttp.ClientSession( + json_serialize=_serialize_json, timeout=aiohttp.ClientTimeout(total=60 * 60) + ) as aioclient: response_raw = await aioclient.post( self.predict_endpoint, json=json_payload, @@ -430,7 +432,9 @@ async def forward(self, json_payload: Any) -> AsyncGenerator[Any, None]: # prag try: response: aiohttp.ClientResponse - async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient: + async with aiohttp.ClientSession( + json_serialize=_serialize_json, timeout=aiohttp.ClientTimeout(total=60 * 60) + ) as aioclient: response = await aioclient.post( self.predict_endpoint, json=json_payload, diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py index 89fcb3fb1..39f38a5c5 100644 --- a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -22,6 +22,8 @@ logger = make_logger(logger_name()) +LOG_SENSITIVE_DATA = False + def get_config(): overrides = os.getenv("CONFIG_OVERRIDES") @@ -90,7 +92,10 @@ async def predict( ) return response except Exception: - logger.error(f"Failed to decode payload from: {request}") + if LOG_SENSITIVE_DATA: + logger.error(f"Failed to decode payload from: {request}") + else: + logger.error(f"Failed to decode payload") raise @@ -103,10 +108,16 @@ async def stream( try: payload = request.model_dump() except Exception: - logger.error(f"Failed to decode payload from: {request}") + if LOG_SENSITIVE_DATA: + logger.error(f"Failed to decode payload from: {request}") + else: + logger.error(f"Failed to decode payload") raise else: - logger.debug(f"Received request: {payload}") + if LOG_SENSITIVE_DATA: + logger.debug(f"Received request: {request}") + else: + logger.debug(f"Received request") responses = forwarder.forward(payload) # We fetch the first response to check if upstream request was successful diff --git a/model-engine/model_engine_server/inference/sglang/Dockerfile.sglang b/model-engine/model_engine_server/inference/sglang/Dockerfile.sglang index 61e4ae440..cb799c309 100644 --- a/model-engine/model_engine_server/inference/sglang/Dockerfile.sglang +++ b/model-engine/model_engine_server/inference/sglang/Dockerfile.sglang @@ -1,4 +1,5 @@ -FROM 692474966980.dkr.ecr.us-west-2.amazonaws.com/sglang:v0.4.1.post7-cu124 +# FROM lmsysorg/sglang:v0.4.6.post5-cu124 -- this one didn't work +FROM lmsysorg/sglang:v0.4.5.post3-cu121 # These aren't all needed but good to have for debugging purposes RUN apt-get -yq update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ @@ -35,7 +36,7 @@ RUN apt-get -yq update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ tk-dev \ libffi-dev \ liblzma-dev \ - python-openssl \ + python3-openssl \ moreutils \ libcurl4-openssl-dev \ libssl-dev \ diff --git a/model-engine/model_engine_server/inference/sglang/sglang-startup-script.py b/model-engine/model_engine_server/inference/sglang/sglang-startup-script.py index 157e9c30c..a688be5c3 100755 --- a/model-engine/model_engine_server/inference/sglang/sglang-startup-script.py +++ b/model-engine/model_engine_server/inference/sglang/sglang-startup-script.py @@ -12,7 +12,7 @@ def wait_for_dns(dns_name: str, max_retries: int = 20, sleep_seconds: int = 3): sleeping sleep_seconds between attempts. Raises RuntimeError if resolution fails repeatedly. """ - for attempt in range(1, max_retries + 1): + for attempt in range(1, max_retries + 2): try: # Use AF_UNSPEC to allow both IPv4 and IPv6 socket.getaddrinfo(dns_name, None, socket.AF_UNSPEC) @@ -107,7 +107,7 @@ def main( "--tp", str(tp), "--host", - "::", + "0.0.0.0", "--port", str(worker_port), "--dist-init-addr",