diff --git a/mlflow_export_import/common/mlflow_utils.py b/mlflow_export_import/common/mlflow_utils.py index 44f82e90..95485ae2 100644 --- a/mlflow_export_import/common/mlflow_utils.py +++ b/mlflow_export_import/common/mlflow_utils.py @@ -43,7 +43,7 @@ def set_experiment(mlflow_client, dbx_client, exp_name, tags=None): :return: Experiment ID """ from mlflow_export_import.common import utils - if utils.importing_into_databricks(): + if utils.get_import_target_implementation() == utils.MLFlowImplementation.DATABRICKS: create_workspace_dir(dbx_client, os.path.dirname(exp_name)) try: if not tags: tags = {} diff --git a/mlflow_export_import/common/utils.py b/mlflow_export_import/common/utils.py index aa40395c..5c525120 100644 --- a/mlflow_export_import/common/utils.py +++ b/mlflow_export_import/common/utils.py @@ -1,7 +1,12 @@ import pandas as pd from tabulate import tabulate import mlflow +from enum import Enum, auto +class MLFlowImplementation(Enum): + DATABRICKS = auto() + AZURE_ML = auto() + OSS = auto() # Databricks tags that cannot or should not be set _DATABRICKS_SKIP_TAGS = set([ @@ -11,15 +16,26 @@ "mlflow.experiment.sourceType", "mlflow.experiment.sourceId" ]) +_AZURE_ML_SKIP_TAGS = set([ + "mlflow.user", + "mlflow.source.git.commit" + ]) + def create_mlflow_tags_for_databricks_import(tags): - if importing_into_databricks(): - tags = { k:v for k,v in tags.items() if not k in _DATABRICKS_SKIP_TAGS } - return tags + environment = get_import_target_implementation() + if environment == MLFlowImplementation.DATABRICKS: + return { k:v for k,v in tags.items() if not k in _DATABRICKS_SKIP_TAGS } + if environment == MLFlowImplementation.AZURE_ML: + return { k:v for k,v in tags.items() if not k in _AZURE_ML_SKIP_TAGS } + if environment == MLFlowImplementation.OSS: + return tags + raise Exception("Unsupported environment") def set_dst_user_id(tags, user_id, use_src_user_id): - if importing_into_databricks(): + if get_import_target_implementation() in (MLFlowImplementation.DATABRICKS, + MLFlowImplementation.AZURE_ML): return from mlflow.entities import RunTag from mlflow.utils.mlflow_tags import MLFLOW_USER @@ -59,8 +75,12 @@ def nested_tags(dst_client, run_ids_mapping): dst_client.set_tag(dst_run_id, "mlflow.parentRunId", dst_parent_run_id) -def importing_into_databricks(): - return mlflow.tracking.get_tracking_uri().startswith("databricks") +def get_import_target_implementation() -> MLFlowImplementation: + if mlflow.tracking.get_tracking_uri().startswith("databricks"): + return MLFlowImplementation.DATABRICKS + if mlflow.tracking.get_tracking_uri().startswith("azureml"): + return MLFlowImplementation.AZURE_ML + return MLFlowImplementation.OSS def show_table(title, lst, columns): diff --git a/mlflow_export_import/run/export_run.py b/mlflow_export_import/run/export_run.py index 5a8acf3f..537a21c6 100644 --- a/mlflow_export_import/run/export_run.py +++ b/mlflow_export_import/run/export_run.py @@ -15,6 +15,8 @@ from mlflow_export_import.common.timestamp_utils import fmt_ts_millis from mlflow_export_import.common.http_client import DatabricksHttpClient from mlflow_export_import.notebook.download_notebook import download_notebook +from mlflow_export_import.common import MlflowExportImportException + from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_NOTEBOOK_PATH MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID = "mlflow.databricks.notebookRevisionID" # NOTE: not in mlflow/utils/mlflow_tags.py @@ -32,8 +34,13 @@ def __init__(self, mlflow_client, notebook_formats=None): if notebook_formats is None: notebook_formats = [] self.mlflow_client = mlflow_client - self.dbx_client = DatabricksHttpClient() - print("Databricks REST client:", self.dbx_client) + try: + self.dbx_client = DatabricksHttpClient() + except MlflowExportImportException as e: + print("WARNING: Databricks REST Client could not be initialized " + "Notebook export functionality will be unavailable") + else: + print("Databricks REST client:", self.dbx_client) self.notebook_formats = notebook_formats @@ -92,6 +99,8 @@ def export_run(self, run_id, output_dir): def _export_notebook(self, output_dir, notebook, run, fs): + if self.dbx_client is None: + raise MlflowExportImportException("Cannot export notebooks without an HTTP connection") notebook_dir = os.path.join(output_dir, "artifacts", "notebooks") fs.mkdirs(notebook_dir) revision_id = run.data.tags.get(MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID, None) diff --git a/mlflow_export_import/run/import_run.py b/mlflow_export_import/run/import_run.py index bb9f12ab..8eaf0f59 100644 --- a/mlflow_export_import/run/import_run.py +++ b/mlflow_export_import/run/import_run.py @@ -31,8 +31,9 @@ class RunImporter(): def __init__(self, mlflow_client, import_source_tags=False, - mlmodel_fix=True, + mlmodel_fix=True, use_src_user_id=False, \ + in_azure_ml=False, \ dst_notebook_dir_add_run_id=False): """ :param mlflow_client: MLflow client. @@ -50,11 +51,13 @@ def __init__(self, self.mlmodel_fix = mlmodel_fix self.use_src_user_id = use_src_user_id self.in_databricks = "DATABRICKS_RUNTIME_VERSION" in os.environ + self.in_azure_ml = in_azure_ml + self.dst_notebook_dir_add_run_id = dst_notebook_dir_add_run_id self.dbx_client = DatabricksHttpClient() self.import_source_tags = import_source_tags print(f"in_databricks: {self.in_databricks}") - print(f"importing_into_databricks: {utils.importing_into_databricks()}") + print(f"importing_into_environment: {utils.get_import_target_implementation().name}") def import_run(self, exp_name, input_dir, dst_notebook_dir=None): @@ -93,7 +96,7 @@ def _import_run(self, dst_exp_name, input_dir, dst_notebook_dir): import traceback traceback.print_exc() raise MlflowExportImportException(e, f"Importing run {run_id} of experiment '{exp.name}' failed") - if utils.importing_into_databricks() and dst_notebook_dir: + if utils.get_import_target_implementation() == utils.MLFlowImplementation.DATABRICKS and dst_notebook_dir: ndir = os.path.join(dst_notebook_dir, run_id) if self.dst_notebook_dir_add_run_id else dst_notebook_dir self._upload_databricks_notebook(input_dir, src_run_dct, ndir) @@ -132,7 +135,7 @@ def _import_run_data(self, run_dct, run_id, src_user_id): run_id, MAX_PARAMS_TAGS_PER_BATCH, self.import_source_tags, - self.in_databricks, + self.in_databricks or self.in_azure_ml, src_user_id, self.use_src_user_id ) diff --git a/tests/compare_utils.py b/tests/compare_utils.py index 720a31c4..b2a6d2c9 100644 --- a/tests/compare_utils.py +++ b/tests/compare_utils.py @@ -93,7 +93,7 @@ def compare_versions(mlflow_client_src, mlflow_client_dst, vr_src, vr_dst, outpu assert vr_src.status_message == vr_dst.status_message if mlflow_client_src != mlflow_client_src: assert vr_src.name == vr_dst.name - if not utils.importing_into_databricks(): + if utils.get_import_target_implementation() != utils.MLFlowImplementation.DATABRICKS: assert vr_src.user_id == vr_dst.user_id tags_dst = { k:v for k,v in vr_dst.tags.items() if not k.startswith(ExportTags.PREFIX_ROOT) }