Skip to content

Commit cac78ea

Browse files
abrarsheikhdstrodtman
authored andcommitted
Add rank and world size in replica context (#55827)
1. `rank` and `world_size` are now available on `ReplicaContext`. 2. Replica initialization now requires providing a rank 3. Any change to the replicas rank will be communicated from controller via `.reconfigure()` method. 4. Assigned rank to replica can be fetched from `get_metadata()` function, this will be useful during controller recovery to reconstruct the state. This PR fills in a dummy rank value, in the future PR we will fetch the replica from DeploymentRankManager and pass in the correct value. Part 2 of #54938 Next diff in pipeline #55829 --------- Signed-off-by: abrar <[email protected]> Signed-off-by: Douglas Strodtman <[email protected]>
1 parent de3d5d1 commit cac78ea

File tree

6 files changed

+56
-11
lines changed

6 files changed

+56
-11
lines changed

python/ray/serve/_private/deployment_state.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ def __init__(
250250
self._initialization_latency_s: Optional[float] = None
251251
self._port: Optional[int] = None
252252
self._docs_path: Optional[str] = None
253+
# Rank assigned to the replica.
254+
self._rank: Optional[int] = None
253255
# Populated in `on_scheduled` or `recover`.
254256
self._actor_handle: ActorHandle = None
255257
self._placement_group: PlacementGroup = None
@@ -282,6 +284,10 @@ def replica_id(self) -> str:
282284
def deployment_name(self) -> str:
283285
return self._deployment_id.name
284286

287+
@property
288+
def rank(self) -> Optional[int]:
289+
return self._rank
290+
285291
@property
286292
def app_name(self) -> str:
287293
return self._deployment_id.app_name
@@ -454,6 +460,8 @@ def start(self, deployment_info: DeploymentInfo) -> ReplicaSchedulingRequest:
454460
if self._deployment_is_cross_language
455461
else deployment_info.replica_config.serialized_init_args
456462
)
463+
# TODO(abrar): Fill in the correct rank
464+
rank = 0
457465
init_args = (
458466
self.replica_id,
459467
cloudpickle.dumps(deployment_info.replica_config.deployment_def)
@@ -467,6 +475,7 @@ def start(self, deployment_info: DeploymentInfo) -> ReplicaSchedulingRequest:
467475
self._version,
468476
deployment_info.ingress,
469477
deployment_info.route_prefix,
478+
rank,
470479
)
471480
# TODO(simon): unify the constructor arguments across language
472481
elif (
@@ -598,8 +607,12 @@ def reconfigure(self, version: DeploymentVersion) -> bool:
598607
deployment_config.user_config = self._format_user_config(
599608
deployment_config.user_config
600609
)
610+
# TODO(abrar): FIll in the correct rank
611+
rank = 0
601612
self._ready_obj_ref = self._actor_handle.reconfigure.remote(
602-
deployment_config, version.route_prefix
613+
deployment_config,
614+
rank,
615+
version.route_prefix,
603616
)
604617

605618
self._version = version
@@ -729,6 +742,7 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]:
729742
self._initialization_latency_s,
730743
self._port,
731744
self._docs_path,
745+
self._rank,
732746
) = ray.get(self._ready_obj_ref)
733747
except RayTaskError as e:
734748
logger.exception(

python/ray/serve/_private/replica.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
Optional[float],
116116
Optional[int],
117117
Optional[str],
118+
int,
118119
]
119120

120121

@@ -356,6 +357,7 @@ def __init__(
356357
version: DeploymentVersion,
357358
ingress: bool,
358359
route_prefix: str,
360+
rank: int,
359361
):
360362
self._version = version
361363
self._replica_id = replica_id
@@ -402,7 +404,7 @@ def __init__(
402404

403405
# Set metadata for logs and metrics.
404406
# servable_object will be populated in `initialize_and_get_metadata`.
405-
self._set_internal_replica_context(servable_object=None)
407+
self._set_internal_replica_context(servable_object=None, rank=rank)
406408

407409
self._metrics_manager = create_replica_metrics_manager(
408410
replica_id=replica_id,
@@ -422,19 +424,27 @@ def get_num_ongoing_requests(self) -> int:
422424
return self._metrics_manager.get_num_ongoing_requests()
423425

424426
def get_metadata(self) -> ReplicaMetadata:
427+
current_rank = ray.serve.context._get_internal_replica_context().rank
425428
return (
426429
self._version.deployment_config,
427430
self._version,
428431
self._initialization_latency,
429432
self._port,
430433
self._docs_path,
434+
current_rank,
431435
)
432436

433-
def _set_internal_replica_context(self, *, servable_object: Callable = None):
437+
def _set_internal_replica_context(
438+
self, *, servable_object: Callable = None, rank: int = None
439+
):
440+
# Calculate world_size from deployment config instead of storing it
441+
world_size = self._deployment_config.num_replicas
434442
ray.serve.context._set_internal_replica_context(
435443
replica_id=self._replica_id,
436444
servable_object=servable_object,
437445
_deployment_config=self._deployment_config,
446+
rank=rank,
447+
world_size=world_size,
438448
)
439449

440450
def _configure_logger_and_profilers(
@@ -752,7 +762,10 @@ async def initialize(self, deployment_config: DeploymentConfig):
752762
raise RuntimeError(traceback.format_exc()) from None
753763

754764
async def reconfigure(
755-
self, deployment_config: DeploymentConfig, route_prefix: Optional[str] = None
765+
self,
766+
deployment_config: DeploymentConfig,
767+
rank: int,
768+
route_prefix: Optional[str] = None,
756769
):
757770
try:
758771
user_config_changed = (
@@ -782,9 +795,10 @@ async def reconfigure(
782795
)
783796

784797
# We need to update internal replica context to reflect the new
785-
# deployment_config.
798+
# deployment_config and rank.
786799
self._set_internal_replica_context(
787-
servable_object=self._user_callable_wrapper.user_callable
800+
servable_object=self._user_callable_wrapper.user_callable,
801+
rank=rank,
788802
)
789803

790804
self._route_prefix = self._version.route_prefix
@@ -894,8 +908,11 @@ async def record_routing_stats(self) -> Dict[str, Any]:
894908

895909
class Replica(ReplicaBase):
896910
async def _on_initialized(self):
911+
# Get current rank from replica context during initialization
912+
current_rank = ray.serve.context._get_internal_replica_context().rank
897913
self._set_internal_replica_context(
898-
servable_object=self._user_callable_wrapper.user_callable
914+
servable_object=self._user_callable_wrapper.user_callable,
915+
rank=current_rank,
899916
)
900917

901918
# Save the initialization latency if the replica is initializing
@@ -969,6 +986,7 @@ async def __init__(
969986
version: DeploymentVersion,
970987
ingress: bool,
971988
route_prefix: str,
989+
rank: int,
972990
):
973991
deployment_config = DeploymentConfig.from_proto_bytes(
974992
deployment_config_proto_bytes
@@ -985,6 +1003,7 @@ async def __init__(
9851003
version=version,
9861004
ingress=ingress,
9871005
route_prefix=route_prefix,
1006+
rank=rank,
9881007
)
9891008

9901009
def push_proxy_handle(self, handle: ActorHandle):
@@ -1047,9 +1066,9 @@ async def record_routing_stats(self) -> Dict[str, Any]:
10471066
return await self._replica_impl.record_routing_stats()
10481067

10491068
async def reconfigure(
1050-
self, deployment_config, route_prefix: Optional[str] = None
1069+
self, deployment_config, rank: int, route_prefix: Optional[str] = None
10511070
) -> ReplicaMetadata:
1052-
await self._replica_impl.reconfigure(deployment_config, route_prefix)
1071+
await self._replica_impl.reconfigure(deployment_config, rank, route_prefix)
10531072
return self._replica_impl.get_metadata()
10541073

10551074
def _preprocess_request_args(

python/ray/serve/context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,15 @@ class ReplicaContext:
4141
- deployment: name of the deployment the replica is a part of.
4242
- replica_tag: unique ID for the replica.
4343
- servable_object: instance of the user class/function this replica is running.
44+
- rank: the rank of the replica.
45+
- world_size: the number of replicas in the deployment.
4446
"""
4547

4648
replica_id: ReplicaID
4749
servable_object: Callable
4850
_deployment_config: DeploymentConfig
51+
rank: int
52+
world_size: int
4953

5054
@property
5155
def app_name(self) -> str:
@@ -108,12 +112,16 @@ def _set_internal_replica_context(
108112
replica_id: ReplicaID,
109113
servable_object: Callable,
110114
_deployment_config: DeploymentConfig,
115+
rank: int,
116+
world_size: int,
111117
):
112118
global _INTERNAL_REPLICA_CONTEXT
113119
_INTERNAL_REPLICA_CONTEXT = ReplicaContext(
114120
replica_id=replica_id,
115121
servable_object=servable_object,
116122
_deployment_config=_deployment_config,
123+
rank=rank,
124+
world_size=world_size,
117125
)
118126

119127

python/ray/serve/tests/test_controller_recovery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __call__(self, *args):
6565
replica_version_hash = None
6666
for replica in deployment_dict[id]:
6767
ref = replica.actor_handle.initialize_and_get_metadata.remote()
68-
_, version, _, _, _ = ray.get(ref)
68+
_, version, _, _, _, _ = ray.get(ref)
6969
if replica_version_hash is None:
7070
replica_version_hash = hash(version)
7171
assert replica_version_hash == hash(version), (
@@ -118,7 +118,7 @@ def __call__(self, *args):
118118
for replica_name in recovered_replica_names:
119119
actor_handle = ray.get_actor(replica_name, namespace=SERVE_NAMESPACE)
120120
ref = actor_handle.initialize_and_get_metadata.remote()
121-
_, version, _, _, _ = ray.get(ref)
121+
_, version, _, _, _, _ = ray.get(ref)
122122
assert replica_version_hash == hash(
123123
version
124124
), "Replica version hash should be the same after recover from actor names"

python/ray/serve/tests/test_multiplex.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def start_serve_with_context():
3434
),
3535
servable_object=None,
3636
_deployment_config=DeploymentConfig(),
37+
rank=0,
38+
world_size=1,
3739
)
3840
try:
3941
yield

python/ray/serve/tests/unit/test_batching.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
replica_id=ReplicaID(unique_id="test", deployment_id=DeploymentID(name="test")),
2121
servable_object=None,
2222
_deployment_config=default_deployment_config,
23+
rank=0,
24+
world_size=1,
2325
)
2426

2527

0 commit comments

Comments
 (0)