Skip to content

Commit e5c6e76

Browse files
committed
feat(RHOAIENG-26482): add gcs fault tolerance
1 parent 3228739 commit e5c6e76

File tree

4 files changed

+91
-2
lines changed

4 files changed

+91
-2
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .rayjob import RayJob, RayJobClusterConfig
22
from .status import RayJobDeploymentStatus, CodeflareRayJobStatus, RayJobInfo
3+
from .config import RayJobClusterConfig

src/codeflare_sdk/ray/rayjobs/config.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
"""
16-
The config sub-module contains the definition of the RayJobClusterConfigV2 dataclass,
16+
The config sub-module contains the definition of the RayJobClusterConfig dataclass,
1717
which is used to specify resource requirements and other details when creating a
1818
Cluster object.
1919
"""
@@ -139,6 +139,16 @@ class RayJobClusterConfig:
139139
A list of V1Volume objects to add to the Cluster
140140
volume_mounts:
141141
A list of V1VolumeMount objects to add to the Cluster
142+
enable_gcs_ft:
143+
A boolean indicating whether to enable GCS fault tolerance.
144+
enable_usage_stats:
145+
A boolean indicating whether to capture and send Ray usage stats externally.
146+
redis_address:
147+
The address of the Redis server to use for GCS fault tolerance, required when enable_gcs_ft is True.
148+
redis_password_secret:
149+
Kubernetes secret reference containing Redis password. ex: {"name": "secret-name", "key": "password-key"}
150+
external_storage_namespace:
151+
The storage namespace to use for GCS fault tolerance. By default, KubeRay sets it to the UID of RayCluster.
142152
"""
143153

144154
head_cpu_requests: Union[int, str] = 2
@@ -165,8 +175,39 @@ class RayJobClusterConfig:
165175
annotations: Dict[str, str] = field(default_factory=dict)
166176
volumes: list[V1Volume] = field(default_factory=list)
167177
volume_mounts: list[V1VolumeMount] = field(default_factory=list)
178+
enable_gcs_ft: bool = False
179+
enable_usage_stats: bool = False
180+
redis_address: Optional[str] = None
181+
redis_password_secret: Optional[Dict[str, str]] = None
182+
external_storage_namespace: Optional[str] = None
168183

169184
def __post_init__(self):
185+
if self.enable_usage_stats:
186+
self.envs["RAY_USAGE_STATS_ENABLED"] = "1"
187+
else:
188+
self.envs["RAY_USAGE_STATS_ENABLED"] = "0"
189+
190+
if self.enable_gcs_ft:
191+
if not self.redis_address:
192+
raise ValueError(
193+
"redis_address must be provided when enable_gcs_ft is True"
194+
)
195+
196+
if self.redis_password_secret and not isinstance(
197+
self.redis_password_secret, dict
198+
):
199+
raise ValueError(
200+
"redis_password_secret must be a dictionary with 'name' and 'key' fields"
201+
)
202+
203+
if self.redis_password_secret and (
204+
"name" not in self.redis_password_secret
205+
or "key" not in self.redis_password_secret
206+
):
207+
raise ValueError(
208+
"redis_password_secret must contain both 'name' and 'key' fields"
209+
)
210+
170211
self._validate_types()
171212
self._memory_to_string()
172213
self._validate_gpu_config(self.head_accelerators)
@@ -251,6 +292,11 @@ def build_ray_cluster_spec(self, cluster_name: str) -> Dict[str, Any]:
251292
"workerGroupSpecs": [self._build_worker_group_spec(cluster_name)],
252293
}
253294

295+
# Add GCS fault tolerance if enabled
296+
if self.enable_gcs_ft:
297+
gcs_ft_options = self._build_gcs_ft_options()
298+
ray_cluster_spec["gcsFaultToleranceOptions"] = gcs_ft_options
299+
254300
return ray_cluster_spec
255301

256302
def _build_head_group_spec(self) -> Dict[str, Any]:
@@ -453,3 +499,25 @@ def _generate_volumes(self) -> list:
453499
def _build_env_vars(self) -> list:
454500
"""Build environment variables list."""
455501
return [V1EnvVar(name=key, value=value) for key, value in self.envs.items()]
502+
503+
def _build_gcs_ft_options(self) -> Dict[str, Any]:
504+
"""Build GCS fault tolerance options."""
505+
gcs_ft_options = {"redisAddress": self.redis_address}
506+
507+
if (
508+
hasattr(self, "external_storage_namespace")
509+
and self.external_storage_namespace
510+
):
511+
gcs_ft_options["externalStorageNamespace"] = self.external_storage_namespace
512+
513+
if hasattr(self, "redis_password_secret") and self.redis_password_secret:
514+
gcs_ft_options["redisPassword"] = {
515+
"valueFrom": {
516+
"secretKeyRef": {
517+
"name": self.redis_password_secret["name"],
518+
"key": self.redis_password_secret["key"],
519+
}
520+
}
521+
}
522+
523+
return gcs_ft_options

src/codeflare_sdk/ray/rayjobs/rayjob.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def __init__(
140140
self.cluster_name = cluster_name
141141
logger.info(f"Using existing cluster: {self.cluster_name}")
142142

143-
# Initialize the KubeRay job API client
144143
self._api = RayjobApi()
145144

146145
logger.info(f"Initialized RayJob: {self.name} in namespace: {self.namespace}")

src/codeflare_sdk/ray/rayjobs/test_rayjob.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,3 +971,24 @@ def test_rayjob_user_override_shutdown_behavior(mocker):
971971
)
972972

973973
assert rayjob_override_priority.shutdown_after_job_finishes is True
974+
975+
976+
def test_build_ray_cluster_spec_with_gcs_ft(mocker):
977+
"""Test build_ray_cluster_spec with GCS fault tolerance enabled."""
978+
from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig
979+
980+
# Create a test cluster config with GCS FT enabled
981+
cluster_config = RayJobClusterConfig(
982+
enable_gcs_ft=True,
983+
redis_address="redis://redis-service:6379",
984+
external_storage_namespace="storage-ns",
985+
)
986+
987+
# Build the spec using the method on the cluster config
988+
spec = cluster_config.build_ray_cluster_spec("test-cluster")
989+
990+
# Verify GCS fault tolerance options
991+
assert "gcsFaultToleranceOptions" in spec
992+
gcs_ft = spec["gcsFaultToleranceOptions"]
993+
assert gcs_ft["redisAddress"] == "redis://redis-service:6379"
994+
assert gcs_ft["externalStorageNamespace"] == "storage-ns"

0 commit comments

Comments
 (0)