Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def __init__(
self._initialization_latency_s: Optional[float] = None
self._port: Optional[int] = None
self._docs_path: Optional[str] = None
# Rank assigned to the replica.
self._rank: Optional[int] = None
# Populated in `on_scheduled` or `recover`.
self._actor_handle: ActorHandle = None
self._placement_group: PlacementGroup = None
Expand Down Expand Up @@ -282,6 +284,10 @@ def replica_id(self) -> str:
def deployment_name(self) -> str:
return self._deployment_id.name

@property
def rank(self) -> Optional[int]:
return self._rank

@property
def app_name(self) -> str:
return self._deployment_id.app_name
Expand Down Expand Up @@ -454,6 +460,8 @@ def start(self, deployment_info: DeploymentInfo) -> ReplicaSchedulingRequest:
if self._deployment_is_cross_language
else deployment_info.replica_config.serialized_init_args
)
# TODO(abrar): Fill in the correct rank
rank = 0
init_args = (
self.replica_id,
cloudpickle.dumps(deployment_info.replica_config.deployment_def)
Expand All @@ -467,6 +475,7 @@ def start(self, deployment_info: DeploymentInfo) -> ReplicaSchedulingRequest:
self._version,
deployment_info.ingress,
deployment_info.route_prefix,
rank,
)
# TODO(simon): unify the constructor arguments across language
elif (
Expand Down Expand Up @@ -598,8 +607,12 @@ def reconfigure(self, version: DeploymentVersion) -> bool:
deployment_config.user_config = self._format_user_config(
deployment_config.user_config
)
# TODO(abrar): FIll in the correct rank
rank = 0
self._ready_obj_ref = self._actor_handle.reconfigure.remote(
deployment_config, version.route_prefix
deployment_config,
rank,
version.route_prefix,
)

self._version = version
Expand Down Expand Up @@ -729,6 +742,7 @@ def check_ready(self) -> Tuple[ReplicaStartupStatus, Optional[str]]:
self._initialization_latency_s,
self._port,
self._docs_path,
self._rank,
) = ray.get(self._ready_obj_ref)
except RayTaskError as e:
logger.exception(
Expand Down
35 changes: 27 additions & 8 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
Optional[float],
Optional[int],
Optional[str],
int,
]


Expand Down Expand Up @@ -356,6 +357,7 @@ def __init__(
version: DeploymentVersion,
ingress: bool,
route_prefix: str,
rank: int,
):
self._version = version
self._replica_id = replica_id
Expand Down Expand Up @@ -402,7 +404,7 @@ def __init__(

# Set metadata for logs and metrics.
# servable_object will be populated in `initialize_and_get_metadata`.
self._set_internal_replica_context(servable_object=None)
self._set_internal_replica_context(servable_object=None, rank=rank)

self._metrics_manager = create_replica_metrics_manager(
replica_id=replica_id,
Expand All @@ -422,19 +424,27 @@ def get_num_ongoing_requests(self) -> int:
return self._metrics_manager.get_num_ongoing_requests()

def get_metadata(self) -> ReplicaMetadata:
current_rank = ray.serve.context._get_internal_replica_context().rank
return (
self._version.deployment_config,
self._version,
self._initialization_latency,
self._port,
self._docs_path,
current_rank,
)

def _set_internal_replica_context(self, *, servable_object: Callable = None):
def _set_internal_replica_context(
self, *, servable_object: Callable = None, rank: int = None
):
# Calculate world_size from deployment config instead of storing it
world_size = self._deployment_config.num_replicas
ray.serve.context._set_internal_replica_context(
replica_id=self._replica_id,
servable_object=servable_object,
_deployment_config=self._deployment_config,
rank=rank,
world_size=world_size,
)

def _configure_logger_and_profilers(
Expand Down Expand Up @@ -752,7 +762,10 @@ async def initialize(self, deployment_config: DeploymentConfig):
raise RuntimeError(traceback.format_exc()) from None

async def reconfigure(
self, deployment_config: DeploymentConfig, route_prefix: Optional[str] = None
self,
deployment_config: DeploymentConfig,
rank: int,
route_prefix: Optional[str] = None,
):
try:
user_config_changed = (
Expand Down Expand Up @@ -782,9 +795,10 @@ async def reconfigure(
)

# We need to update internal replica context to reflect the new
# deployment_config.
# deployment_config and rank.
self._set_internal_replica_context(
servable_object=self._user_callable_wrapper.user_callable
servable_object=self._user_callable_wrapper.user_callable,
rank=rank,
)

self._route_prefix = self._version.route_prefix
Expand Down Expand Up @@ -894,8 +908,11 @@ async def record_routing_stats(self) -> Dict[str, Any]:

class Replica(ReplicaBase):
async def _on_initialized(self):
# Get current rank from replica context during initialization
current_rank = ray.serve.context._get_internal_replica_context().rank
self._set_internal_replica_context(
servable_object=self._user_callable_wrapper.user_callable
servable_object=self._user_callable_wrapper.user_callable,
rank=current_rank,
)

# Save the initialization latency if the replica is initializing
Expand Down Expand Up @@ -969,6 +986,7 @@ async def __init__(
version: DeploymentVersion,
ingress: bool,
route_prefix: str,
rank: int,
):
deployment_config = DeploymentConfig.from_proto_bytes(
deployment_config_proto_bytes
Expand All @@ -985,6 +1003,7 @@ async def __init__(
version=version,
ingress=ingress,
route_prefix=route_prefix,
rank=rank,
)

def push_proxy_handle(self, handle: ActorHandle):
Expand Down Expand Up @@ -1047,9 +1066,9 @@ async def record_routing_stats(self) -> Dict[str, Any]:
return await self._replica_impl.record_routing_stats()

async def reconfigure(
self, deployment_config, route_prefix: Optional[str] = None
self, deployment_config, rank: int, route_prefix: Optional[str] = None
) -> ReplicaMetadata:
await self._replica_impl.reconfigure(deployment_config, route_prefix)
await self._replica_impl.reconfigure(deployment_config, rank, route_prefix)
return self._replica_impl.get_metadata()

def _preprocess_request_args(
Expand Down
8 changes: 8 additions & 0 deletions python/ray/serve/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ class ReplicaContext:
- deployment: name of the deployment the replica is a part of.
- replica_tag: unique ID for the replica.
- servable_object: instance of the user class/function this replica is running.
- rank: the rank of the replica.
- world_size: the number of replicas in the deployment.
"""

replica_id: ReplicaID
servable_object: Callable
_deployment_config: DeploymentConfig
rank: int
world_size: int

@property
def app_name(self) -> str:
Expand Down Expand Up @@ -108,12 +112,16 @@ def _set_internal_replica_context(
replica_id: ReplicaID,
servable_object: Callable,
_deployment_config: DeploymentConfig,
rank: int,
world_size: int,
):
global _INTERNAL_REPLICA_CONTEXT
_INTERNAL_REPLICA_CONTEXT = ReplicaContext(
replica_id=replica_id,
servable_object=servable_object,
_deployment_config=_deployment_config,
rank=rank,
world_size=world_size,
)


Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/tests/test_controller_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __call__(self, *args):
replica_version_hash = None
for replica in deployment_dict[id]:
ref = replica.actor_handle.initialize_and_get_metadata.remote()
_, version, _, _, _ = ray.get(ref)
_, version, _, _, _, _ = ray.get(ref)
if replica_version_hash is None:
replica_version_hash = hash(version)
assert replica_version_hash == hash(version), (
Expand Down Expand Up @@ -118,7 +118,7 @@ def __call__(self, *args):
for replica_name in recovered_replica_names:
actor_handle = ray.get_actor(replica_name, namespace=SERVE_NAMESPACE)
ref = actor_handle.initialize_and_get_metadata.remote()
_, version, _, _, _ = ray.get(ref)
_, version, _, _, _, _ = ray.get(ref)
assert replica_version_hash == hash(
version
), "Replica version hash should be the same after recover from actor names"
Expand Down
2 changes: 2 additions & 0 deletions python/ray/serve/tests/test_multiplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def start_serve_with_context():
),
servable_object=None,
_deployment_config=DeploymentConfig(),
rank=0,
world_size=1,
)
try:
yield
Expand Down
2 changes: 2 additions & 0 deletions python/ray/serve/tests/unit/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
replica_id=ReplicaID(unique_id="test", deployment_id=DeploymentID(name="test")),
servable_object=None,
_deployment_config=default_deployment_config,
rank=0,
world_size=1,
)


Expand Down