Skip to content

Commit d472541

Browse files
committed
Added to dict function to account for get cluster fails
Signed-off-by: Pat O'Connor <[email protected]>
1 parent 5b4999d commit d472541

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

src/codeflare_sdk/ray/cluster/cluster.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from time import sleep
2222
from typing import List, Optional, Tuple, Dict
2323
import copy
24-
import dataclasses
2524

2625
from ray.job_submission import JobSubmissionClient, JobStatus
2726
import time
@@ -1060,9 +1059,7 @@ def get_cluster(
10601059
)
10611060
# 1. Prepare RayClusterSpec from ClusterConfiguration
10621061
# Create a temporary config with appwrapper=False to ensure build_ray_cluster returns RayCluster YAML
1063-
temp_cluster_config_dict = dataclasses.asdict(cluster_config)
1064-
# Filter out None values similar to exclude_none=True in Pydantic
1065-
temp_cluster_config_dict = {k: v for k, v in temp_cluster_config_dict.items() if v is not None}
1062+
temp_cluster_config_dict = cluster_config.dict(exclude_none=True)
10661063
temp_cluster_config_dict["appwrapper"] = False
10671064
# Set overwrite_default_resource_mapping=True to avoid conflicts when extended_resource_mapping
10681065
# already contains the combined default and custom mappings from the original object

src/codeflare_sdk/ray/cluster/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,19 @@ def check_type(value, expected_type):
289289
return isinstance(value, expected_type)
290290

291291
return check_type(value, expected_type)
292+
293+
def dict(self, exclude_none: bool = False):
294+
"""
295+
Convert the ClusterConfiguration to a dictionary.
296+
297+
Args:
298+
exclude_none (bool): If True, exclude fields with None values.
299+
300+
Returns:
301+
dict: Dictionary representation of the ClusterConfiguration.
302+
"""
303+
import dataclasses
304+
result = dataclasses.asdict(self)
305+
if exclude_none:
306+
result = {k: v for k, v in result.items() if v is not None}
307+
return result

0 commit comments

Comments
 (0)