diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 8a33bc5eedc9..76398e044775 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -250,6 +250,8 @@ def __init__( self._initialization_latency_s: Optional[float] = None self._port: Optional[int] = None self._docs_path: Optional[str] = None + # Rank assigned to the replica. + self._rank: Optional[int] = None # Populated in `on_scheduled` or `recover`. self._actor_handle: ActorHandle = None self._placement_group: PlacementGroup = None @@ -282,6 +284,10 @@ def replica_id(self) -> str: def deployment_name(self) -> str: return self._deployment_id.name + @property + def rank(self) -> Optional[int]: + return self._rank + @property def app_name(self) -> str: return self._deployment_id.app_name @@ -454,6 +460,8 @@ def start(self, deployment_info: DeploymentInfo) -> ReplicaSchedulingRequest: if self._deployment_is_cross_language else deployment_info.replica_config.serialized_init_args ) + # TODO(abrar): Fill in the correct rank + rank = 0 init_args = ( self.replica_id, cloudpickle.dumps(deployment_info.replica_config.deployment_def) @@ -467,6 +475,7 @@ def start(self, deployment_info: DeploymentInfo) -> ReplicaSchedulingRequest: self._version, deployment_info.ingress, deployment_info.route_prefix, + rank, ) # TODO(simon): unify the constructor arguments across language elif ( @@ -598,8 +607,12 @@ def reconfigure(self, version: DeploymentVersion) -> bool: deployment_config.user_config = self._format_user_config( deployment_config.user_config ) + # TODO(abrar): FIll in the correct rank + rank = 0 self._ready_obj_ref = self._actor_handle.reconfigure.remote( - deployment_config, version.route_prefix + deployment_config, + rank, + version.route_prefix, ) self._version = version @@ -729,6 +742,7 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]: self._initialization_latency_s, self._port, self._docs_path, + self._rank, ) = ray.get(self._ready_obj_ref) except RayTaskError as e: logger.exception( diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 5ce22606f4b9..6ddfe873d97b 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -115,6 +115,7 @@ Optional[float], Optional[int], Optional[str], + int, ] @@ -356,6 +357,7 @@ def __init__( version: DeploymentVersion, ingress: bool, route_prefix: str, + rank: int, ): self._version = version self._replica_id = replica_id @@ -402,7 +404,7 @@ def __init__( # Set metadata for logs and metrics. # servable_object will be populated in `initialize_and_get_metadata`. - self._set_internal_replica_context(servable_object=None) + self._set_internal_replica_context(servable_object=None, rank=rank) self._metrics_manager = create_replica_metrics_manager( replica_id=replica_id, @@ -422,19 +424,27 @@ def get_num_ongoing_requests(self) -> int: return self._metrics_manager.get_num_ongoing_requests() def get_metadata(self) -> ReplicaMetadata: + current_rank = ray.serve.context._get_internal_replica_context().rank return ( self._version.deployment_config, self._version, self._initialization_latency, self._port, self._docs_path, + current_rank, ) - def _set_internal_replica_context(self, *, servable_object: Callable = None): + def _set_internal_replica_context( + self, *, servable_object: Callable = None, rank: int = None + ): + # Calculate world_size from deployment config instead of storing it + world_size = self._deployment_config.num_replicas ray.serve.context._set_internal_replica_context( replica_id=self._replica_id, servable_object=servable_object, _deployment_config=self._deployment_config, + rank=rank, + world_size=world_size, ) def _configure_logger_and_profilers( @@ -752,7 +762,10 @@ async def initialize(self, deployment_config: DeploymentConfig): raise RuntimeError(traceback.format_exc()) from None async def reconfigure( - self, deployment_config: DeploymentConfig, route_prefix: Optional[str] = None + self, + deployment_config: DeploymentConfig, + rank: int, + route_prefix: Optional[str] = None, ): try: user_config_changed = ( @@ -782,9 +795,10 @@ async def reconfigure( ) # We need to update internal replica context to reflect the new - # deployment_config. + # deployment_config and rank. self._set_internal_replica_context( - servable_object=self._user_callable_wrapper.user_callable + servable_object=self._user_callable_wrapper.user_callable, + rank=rank, ) self._route_prefix = self._version.route_prefix @@ -894,8 +908,11 @@ async def record_routing_stats(self) -> Dict[str, Any]: class Replica(ReplicaBase): async def _on_initialized(self): + # Get current rank from replica context during initialization + current_rank = ray.serve.context._get_internal_replica_context().rank self._set_internal_replica_context( - servable_object=self._user_callable_wrapper.user_callable + servable_object=self._user_callable_wrapper.user_callable, + rank=current_rank, ) # Save the initialization latency if the replica is initializing @@ -969,6 +986,7 @@ async def __init__( version: DeploymentVersion, ingress: bool, route_prefix: str, + rank: int, ): deployment_config = DeploymentConfig.from_proto_bytes( deployment_config_proto_bytes @@ -985,6 +1003,7 @@ async def __init__( version=version, ingress=ingress, route_prefix=route_prefix, + rank=rank, ) def push_proxy_handle(self, handle: ActorHandle): @@ -1047,9 +1066,9 @@ async def record_routing_stats(self) -> Dict[str, Any]: return await self._replica_impl.record_routing_stats() async def reconfigure( - self, deployment_config, route_prefix: Optional[str] = None + self, deployment_config, rank: int, route_prefix: Optional[str] = None ) -> ReplicaMetadata: - await self._replica_impl.reconfigure(deployment_config, route_prefix) + await self._replica_impl.reconfigure(deployment_config, rank, route_prefix) return self._replica_impl.get_metadata() def _preprocess_request_args( diff --git a/python/ray/serve/context.py b/python/ray/serve/context.py index ecb412aab37b..99b8e5f0d9e8 100644 --- a/python/ray/serve/context.py +++ b/python/ray/serve/context.py @@ -41,11 +41,15 @@ class ReplicaContext: - deployment: name of the deployment the replica is a part of. - replica_tag: unique ID for the replica. - servable_object: instance of the user class/function this replica is running. + - rank: the rank of the replica. + - world_size: the number of replicas in the deployment. """ replica_id: ReplicaID servable_object: Callable _deployment_config: DeploymentConfig + rank: int + world_size: int @property def app_name(self) -> str: @@ -108,12 +112,16 @@ def _set_internal_replica_context( replica_id: ReplicaID, servable_object: Callable, _deployment_config: DeploymentConfig, + rank: int, + world_size: int, ): global _INTERNAL_REPLICA_CONTEXT _INTERNAL_REPLICA_CONTEXT = ReplicaContext( replica_id=replica_id, servable_object=servable_object, _deployment_config=_deployment_config, + rank=rank, + world_size=world_size, ) diff --git a/python/ray/serve/tests/test_controller_recovery.py b/python/ray/serve/tests/test_controller_recovery.py index 493931872ea9..52f405153a1e 100644 --- a/python/ray/serve/tests/test_controller_recovery.py +++ b/python/ray/serve/tests/test_controller_recovery.py @@ -65,7 +65,7 @@ def __call__(self, *args): replica_version_hash = None for replica in deployment_dict[id]: ref = replica.actor_handle.initialize_and_get_metadata.remote() - _, version, _, _, _ = ray.get(ref) + _, version, _, _, _, _ = ray.get(ref) if replica_version_hash is None: replica_version_hash = hash(version) assert replica_version_hash == hash(version), ( @@ -118,7 +118,7 @@ def __call__(self, *args): for replica_name in recovered_replica_names: actor_handle = ray.get_actor(replica_name, namespace=SERVE_NAMESPACE) ref = actor_handle.initialize_and_get_metadata.remote() - _, version, _, _, _ = ray.get(ref) + _, version, _, _, _, _ = ray.get(ref) assert replica_version_hash == hash( version ), "Replica version hash should be the same after recover from actor names" diff --git a/python/ray/serve/tests/test_multiplex.py b/python/ray/serve/tests/test_multiplex.py index 1ebc29066181..b93857f12c03 100644 --- a/python/ray/serve/tests/test_multiplex.py +++ b/python/ray/serve/tests/test_multiplex.py @@ -34,6 +34,8 @@ def start_serve_with_context(): ), servable_object=None, _deployment_config=DeploymentConfig(), + rank=0, + world_size=1, ) try: yield diff --git a/python/ray/serve/tests/unit/test_batching.py b/python/ray/serve/tests/unit/test_batching.py index fa83a018c85a..9c08ffdf6813 100644 --- a/python/ray/serve/tests/unit/test_batching.py +++ b/python/ray/serve/tests/unit/test_batching.py @@ -20,6 +20,8 @@ replica_id=ReplicaID(unique_id="test", deployment_id=DeploymentID(name="test")), servable_object=None, _deployment_config=default_deployment_config, + rank=0, + world_size=1, )