From 70980432bb45cedf4c40c66c6b6408d45483173e Mon Sep 17 00:00:00 2001 From: Danny LI Date: Tue, 4 Nov 2025 03:20:44 +0000 Subject: [PATCH] Add Diagon installation during cluster creation and modify the workload.py Add wait_for_deployment_ready() Added unit test update goldens.yaml update goldens.yaml update goldens.yaml Fixed parser/cluster.py update goldens.yaml fixed linter fixed linter pyink Test unit test --- src/xpk/commands/cluster.py | 8 + src/xpk/commands/cluster_test.py | 150 ++++++++++++ src/xpk/commands/managed_ml_diagnostics.py | 254 +++++++++++++++++++++ src/xpk/parser/cluster.py | 13 ++ src/xpk/parser/cluster_test.py | 15 ++ 5 files changed, 440 insertions(+) create mode 100644 src/xpk/commands/managed_ml_diagnostics.py diff --git a/src/xpk/commands/cluster.py b/src/xpk/commands/cluster.py index e1895f8b4..635a7e04e 100644 --- a/src/xpk/commands/cluster.py +++ b/src/xpk/commands/cluster.py @@ -81,6 +81,7 @@ from ..utils.templates import get_templates_absolute_path import shutil import os +from . import managed_ml_diagnostics CLUSTER_PREHEAT_JINJA_FILE = 'cluster_preheat.yaml.j2' @@ -407,6 +408,13 @@ def cluster_create(args) -> None: # pylint: disable=line-too-long f' https://console.cloud.google.com/kubernetes/clusters/details/{get_cluster_location(args.project, args.cluster, args.zone)}/{args.cluster}/details?project={args.project}' ) + + if args.managed_ml_diagnostics: + return_code = managed_ml_diagnostics.install_mldiagnostics_prerequisites() + if return_code != 0: + xpk_print('Installation of MLDiagnostics failed.') + xpk_exit(return_code) + xpk_exit(0) diff --git a/src/xpk/commands/cluster_test.py b/src/xpk/commands/cluster_test.py index 153ace154..3a3b454f3 100644 --- a/src/xpk/commands/cluster_test.py +++ b/src/xpk/commands/cluster_test.py @@ -21,6 +21,7 @@ import pytest from xpk.commands.cluster import _install_kueue, _validate_cluster_create_args, run_gke_cluster_create_command +from xpk.commands.managed_ml_diagnostics import install_mldiagnostics_prerequisites from xpk.core.system_characteristics import SystemCharacteristics, UserFacingNameToSystemCharacteristics from xpk.core.testing.commands_tester import CommandsTester from xpk.utils.feature_flags import FeatureFlags @@ -56,6 +57,9 @@ def mocks(mocker) -> _Mocks: run_command_with_updates_path=( 'xpk.commands.cluster.run_command_with_updates' ), + run_command_for_value_path=( + 'xpk.commands.cluster.run_command_for_value' + ), ), ) @@ -87,6 +91,7 @@ def construct_args(**kwargs: Any) -> Namespace: memory_limit='100Gi', cpu_limit=100, cluster_cpu_machine_type='', + managed_mldiagnostics=False, ) args_dict.update(kwargs) return Namespace(**args_dict) @@ -247,3 +252,148 @@ def test_run_gke_cluster_create_command_with_gke_version_has_no_autoupgrade_flag mocks.commands_tester.assert_command_run( 'clusters create', ' --no-enable-autoupgrade' ) + + +def test_install_mldiagnostics_prerequisites_commands_executed( + mocks: _Mocks, + mocker, +): + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'kubectl', + 'rollout', + 'status', + 'deployment/kueue-controller-manager', + ) + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'kubectl', + 'rollout', + 'status', + 'deployment/cert-manager-webhook', + ) + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'kubectl', + 'apply', + '-f', + 'https://github.com/cert-manager/cert-manager/releases/', + ) + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'gcloud', + 'artifacts', + 'generic', + 'download', + ) + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'kubectl', + 'create', + 'namespace', + 'gke-mldiagnostics', + ) + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'kubectl', + 'apply', + '-f', + '-n', + 'gke-mldiagnostics', + ) + + mocks.commands_tester.set_result_for_command( + (0, ''), + 'kubectl', + 'label', + 'namespace', + 'default', + 'managed-mldiagnostics-gke=true', + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'rollout', + 'status', + 'deployment/kueue-controller-manager', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'apply', + '-f', + 'https://github.com/cert-manager/cert-manager/', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', 'rollout', 'status', 'deployment/cert-manager-webhook', times=1 + ) + + mocks.commands_tester.assert_command_run( + 'gcloud', + 'artifacts', + 'generic', + 'download', + '--package=mldiagnostics-injection-webhook', + '--version=v0.5.0', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', 'create', 'namespace', 'gke-mldiagnostics', times=1 + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'apply', + '-f', + 'mldiagnostics-injection-webhook-v0.5.0.yaml', + '-n', + 'gke-mldiagnostics', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'label', + 'namespace', + 'default', + 'managed-mldiagnostics-gke=true', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'gcloud', + 'artifacts', + 'generic', + 'download', + '--package=mldiagnostics-connection-operator', + '--version=v0.5.0', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', + 'apply', + '-f', + 'mldiagnostics-connection-operator-v0.5.0.yaml', + '-n', + 'gke-mldiagnostics', + times=1, + ) + + mocks.commands_tester.assert_command_run( + 'gcloud', 'artifacts', 'generic', 'download', times=2 + ) + + mocks.commands_tester.assert_command_run( + 'kubectl', 'apply', '-f', '-n', 'gke-mldiagnostics', times=2 + ) diff --git a/src/xpk/commands/managed_ml_diagnostics.py b/src/xpk/commands/managed_ml_diagnostics.py new file mode 100644 index 000000000..cdc556206 --- /dev/null +++ b/src/xpk/commands/managed_ml_diagnostics.py @@ -0,0 +1,254 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import time +from packaging.version import Version +from ..core.commands import run_command_for_value, run_command_with_updates +from ..utils.console import xpk_exit, xpk_print +import os + + +def _install_cert_manager(version: Version = Version('v1.13.0')) -> int: + """ + Apply the cert-manager manifest. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = ( + 'kubectl apply -f' + ' https://github.com/cert-manager/cert-manager/releases/download/' + f'{version}/cert-manager.yaml' + ) + + return_code = run_command_with_updates( + command, f'Applying cert-manager {version} manifest...' + ) + + if return_code != 0: + xpk_exit(return_code) + + return return_code + + +def _download_mldiagnostics_yaml(package_name: str, version: Version) -> int: + """ + Downloads the mldiagnostics injection webhook YAML from Artifact Registry. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = ( + 'gcloud artifacts generic download' + ' --repository=mldiagnostics-webhook-and-operator-yaml --location=us' + f' --package={package_name} --version={version} --destination=/tmp/' + ' --project=ai-on-gke' + ) + + return_code, return_output = run_command_for_value( + command, + f'Download {package_name} {version}...', + ) + + if return_code != 0: + if 'already exists' in return_output: + xpk_print( + f'Artifact file for {package_name} {version} already exists locally.' + ' Skipping download.' + ) + return 0 + + return return_code + + +def _create_mldiagnostics_namespace() -> int: + """ + Creates the 'gke-mldiagnostics' namespace. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = 'kubectl create namespace gke-mldiagnostics' + + return_code, return_output = run_command_for_value( + command, 'Create gke-mldiagnostics namespace...' + ) + + if return_code != 0: + if 'already exists' in return_output: + xpk_print('Namespace already exists. Skipping creation.') + return 0 + + return return_code + + +def _install_mldiagnostics_yaml(artifact_filename: str) -> int: + """ + Applies the mldiagnostics injection webhook YAML manifest. + + Returns: + 0 if successful and 1 otherwise. + """ + full_artifact_path = os.path.join('/tmp', artifact_filename) + + command = f'kubectl apply -f {full_artifact_path} -n gke-mldiagnostics' + + return_code = run_command_with_updates( + command, + f'Install {full_artifact_path}...', + ) + + if return_code != 0: + xpk_print(f'kubectl apply returned with ERROR {return_code}.\n') + xpk_exit(return_code) + + xpk_print(f'{artifact_filename} applied successfully.') + + return 0 + + +def _label_default_namespace_mldiagnostics() -> int: + """ + Labels the 'default' namespace with 'managed-mldiagnostics-gke=true'. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = 'kubectl label namespace default managed-mldiagnostics-gke=true' + + return_code = run_command_with_updates( + command, + 'Label default namespace with managed-mldiagnostics-gke=true', + ) + + if return_code != 0: + xpk_exit(return_code) + + return return_code + + +def install_mldiagnostics_prerequisites() -> int: + """ + Mldiagnostics installation requirements. + + Returns: + 0 if successful and 1 otherwise. + """ + + kueue_deployment_name = 'kueue-controller-manager' + kueue_namespace_name = 'kueue-system' + cert_webhook_deployment_name = 'cert-manager-webhook' + cert_webhook_namespace_name = 'cert-manager' + + if not _wait_for_deployment_ready( + deployment_name=kueue_deployment_name, namespace=kueue_namespace_name + ): + xpk_print( + f'Application {kueue_deployment_name} failed to become ready within the' + ' timeout.' + ) + return 1 + + return_code = _install_cert_manager() + if return_code != 0: + return return_code + + cert_webhook_ready = _wait_for_deployment_ready( + deployment_name=cert_webhook_deployment_name, + namespace=cert_webhook_namespace_name, + ) + if not cert_webhook_ready: + xpk_print('The cert-manager-webhook installation failed.') + return 1 + + webhook_package = 'mldiagnostics-injection-webhook' + webhook_version = 'v0.5.0' + webhook_filename = f'{webhook_package}-{webhook_version}.yaml' + + return_code = _download_mldiagnostics_yaml( + package_name=webhook_package, version=Version(webhook_version) + ) + if return_code != 0: + return return_code + + return_code = _create_mldiagnostics_namespace() + if return_code != 0: + return return_code + + return_code = _install_mldiagnostics_yaml(artifact_filename=webhook_filename) + if return_code != 0: + return return_code + + return_code = _label_default_namespace_mldiagnostics() + if return_code != 0: + return return_code + + operator_package = 'mldiagnostics-connection-operator' + operator_version = 'v0.5.0' + operator_filename = f'{operator_package}-{operator_version}.yaml' + + return_code = _download_mldiagnostics_yaml( + package_name=operator_package, version=Version(operator_version) + ) + if return_code != 0: + return return_code + + return_code = _install_mldiagnostics_yaml(artifact_filename=operator_filename) + if return_code != 0: + return return_code + + xpk_print( + 'All mldiagnostics installation and setup steps have been' + ' successfully completed!' + ) + return 0 + + +def _wait_for_deployment_ready( + deployment_name: str, namespace: str, timeout_seconds: int = 300 +) -> bool: + """ + Polls the Kubernetes Deployment status using kubectl rollout status + until it successfully rolls out (all replicas are ready) or times out. + + Args: + deployment_name: The name of the Kubernetes Deployment (e.g., 'kueue-controller-manager'). + namespace: The namespace where the Deployment is located (e.g., 'kueue-system'). + timeout_seconds: Timeout duration in seconds (default is 300s / 5 minutes). + + Returns: + bool: True if the Deployment successfully rolled out, False otherwise (timeout or error). + """ + + command = ( + f'kubectl rollout status deployment/{deployment_name} -n {namespace}' + f' --timeout={timeout_seconds}s' + ) + + return_code = run_command_with_updates( + command, f'Checking status of deployment {deployment_name}...' + ) + + if return_code != 0: + return False + + # When the status changes to 'running,' it might need about 10 seconds to fully stabilize. + time.sleep(30) + return True diff --git a/src/xpk/parser/cluster.py b/src/xpk/parser/cluster.py index e0d5af7bd..57ba8dca0 100644 --- a/src/xpk/parser/cluster.py +++ b/src/xpk/parser/cluster.py @@ -150,6 +150,13 @@ def set_cluster_create_parser(cluster_create_parser: ArgumentParser): ' enable cluster to accept Pathways workloads.' ), ) + + cluster_create_optional_arguments.add_argument( + '--managed-ml-diagnostics', + action='store_true', + help='Enables the installation of required ML Diagnostics components.', + ) + if FeatureFlags.SUB_SLICING_ENABLED: cluster_create_optional_arguments.add_argument( '--sub-slicing', @@ -222,6 +229,12 @@ def set_cluster_create_pathways_parser( ), ) + cluster_create_pathways_required_arguments.add_argument( + '--managed-ml-diagnostics', + action='store_true', + help='Enables the installation of required ML Diagnostics components.', + ) + ### Optional arguments specific to "cluster create-pathways" cluster_create_pathways_optional_arguments = ( cluster_create_pathways_parser.add_argument_group( diff --git a/src/xpk/parser/cluster_test.py b/src/xpk/parser/cluster_test.py index 2b2706b4f..9ce122c40 100644 --- a/src/xpk/parser/cluster_test.py +++ b/src/xpk/parser/cluster_test.py @@ -64,3 +64,18 @@ def test_cluster_create_sub_slicing_can_be_set(): ) assert args.sub_slicing is True + + +def test_cluster_create_managed_mldiagnostics(): + parser = argparse.ArgumentParser() + + set_cluster_create_parser(parser) + args = parser.parse_args([ + "--cluster", + "test-cluster", + "--tpu-type", + "v5p-8", + "--managed-ml-diagnostics", + ]) + + assert args.managed_ml_diagnostics is True