Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlflow_export_import/common/mlflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def set_experiment(mlflow_client, dbx_client, exp_name, tags=None):
:return: Experiment ID
"""
from mlflow_export_import.common import utils
if utils.importing_into_databricks():
if utils.get_import_target_implementation() == utils.MLFlowImplementation.DATABRICKS:
create_workspace_dir(dbx_client, os.path.dirname(exp_name))
try:
if not tags: tags = {}
Expand Down
32 changes: 26 additions & 6 deletions mlflow_export_import/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pandas as pd
from tabulate import tabulate
import mlflow
from enum import Enum, auto

class MLFlowImplementation(Enum):
DATABRICKS = auto()
AZURE_ML = auto()
OSS = auto()

# Databricks tags that cannot or should not be set
_DATABRICKS_SKIP_TAGS = set([
Expand All @@ -11,15 +16,26 @@
"mlflow.experiment.sourceType", "mlflow.experiment.sourceId"
])

_AZURE_ML_SKIP_TAGS = set([
"mlflow.user",
"mlflow.source.git.commit"
])


def create_mlflow_tags_for_databricks_import(tags):
if importing_into_databricks():
tags = { k:v for k,v in tags.items() if not k in _DATABRICKS_SKIP_TAGS }
return tags
environment = get_import_target_implementation()
if environment == MLFlowImplementation.DATABRICKS:
return { k:v for k,v in tags.items() if not k in _DATABRICKS_SKIP_TAGS }
if environment == MLFlowImplementation.AZURE_ML:
return { k:v for k,v in tags.items() if not k in _AZURE_ML_SKIP_TAGS }
if environment == MLFlowImplementation.OSS:
return tags
raise Exception("Unsupported environment")


def set_dst_user_id(tags, user_id, use_src_user_id):
if importing_into_databricks():
if get_import_target_implementation() in (MLFlowImplementation.DATABRICKS,
MLFlowImplementation.AZURE_ML):
return
from mlflow.entities import RunTag
from mlflow.utils.mlflow_tags import MLFLOW_USER
Expand Down Expand Up @@ -59,8 +75,12 @@ def nested_tags(dst_client, run_ids_mapping):
dst_client.set_tag(dst_run_id, "mlflow.parentRunId", dst_parent_run_id)


def importing_into_databricks():
return mlflow.tracking.get_tracking_uri().startswith("databricks")
def get_import_target_implementation() -> MLFlowImplementation:
if mlflow.tracking.get_tracking_uri().startswith("databricks"):
return MLFlowImplementation.DATABRICKS
if mlflow.tracking.get_tracking_uri().startswith("azureml"):
return MLFlowImplementation.AZURE_ML
return MLFlowImplementation.OSS


def show_table(title, lst, columns):
Expand Down
13 changes: 11 additions & 2 deletions mlflow_export_import/run/export_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from mlflow_export_import.common.timestamp_utils import fmt_ts_millis
from mlflow_export_import.common.http_client import DatabricksHttpClient
from mlflow_export_import.notebook.download_notebook import download_notebook
from mlflow_export_import.common import MlflowExportImportException


from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_NOTEBOOK_PATH
MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID = "mlflow.databricks.notebookRevisionID" # NOTE: not in mlflow/utils/mlflow_tags.py
Expand All @@ -32,8 +34,13 @@ def __init__(self, mlflow_client, notebook_formats=None):
if notebook_formats is None:
notebook_formats = []
self.mlflow_client = mlflow_client
self.dbx_client = DatabricksHttpClient()
print("Databricks REST client:", self.dbx_client)
try:
self.dbx_client = DatabricksHttpClient()
except MlflowExportImportException as e:
print("WARNING: Databricks REST Client could not be initialized "
"Notebook export functionality will be unavailable")
else:
print("Databricks REST client:", self.dbx_client)
self.notebook_formats = notebook_formats


Expand Down Expand Up @@ -92,6 +99,8 @@ def export_run(self, run_id, output_dir):


def _export_notebook(self, output_dir, notebook, run, fs):
if self.dbx_client is None:
raise MlflowExportImportException("Cannot export notebooks without an HTTP connection")
notebook_dir = os.path.join(output_dir, "artifacts", "notebooks")
fs.mkdirs(notebook_dir)
revision_id = run.data.tags.get(MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID, None)
Expand Down
11 changes: 7 additions & 4 deletions mlflow_export_import/run/import_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ class RunImporter():
def __init__(self,
mlflow_client,
import_source_tags=False,
mlmodel_fix=True,
mlmodel_fix=True,
use_src_user_id=False, \
in_azure_ml=False, \
dst_notebook_dir_add_run_id=False):
"""
:param mlflow_client: MLflow client.
Expand All @@ -50,11 +51,13 @@ def __init__(self,
self.mlmodel_fix = mlmodel_fix
self.use_src_user_id = use_src_user_id
self.in_databricks = "DATABRICKS_RUNTIME_VERSION" in os.environ
self.in_azure_ml = in_azure_ml

self.dst_notebook_dir_add_run_id = dst_notebook_dir_add_run_id
self.dbx_client = DatabricksHttpClient()
self.import_source_tags = import_source_tags
print(f"in_databricks: {self.in_databricks}")
print(f"importing_into_databricks: {utils.importing_into_databricks()}")
print(f"importing_into_environment: {utils.get_import_target_implementation().name}")


def import_run(self, exp_name, input_dir, dst_notebook_dir=None):
Expand Down Expand Up @@ -93,7 +96,7 @@ def _import_run(self, dst_exp_name, input_dir, dst_notebook_dir):
import traceback
traceback.print_exc()
raise MlflowExportImportException(e, f"Importing run {run_id} of experiment '{exp.name}' failed")
if utils.importing_into_databricks() and dst_notebook_dir:
if utils.get_import_target_implementation() == utils.MLFlowImplementation.DATABRICKS and dst_notebook_dir:
ndir = os.path.join(dst_notebook_dir, run_id) if self.dst_notebook_dir_add_run_id else dst_notebook_dir
self._upload_databricks_notebook(input_dir, src_run_dct, ndir)

Expand Down Expand Up @@ -132,7 +135,7 @@ def _import_run_data(self, run_dct, run_id, src_user_id):
run_id,
MAX_PARAMS_TAGS_PER_BATCH,
self.import_source_tags,
self.in_databricks,
self.in_databricks or self.in_azure_ml,
src_user_id,
self.use_src_user_id
)
Expand Down
2 changes: 1 addition & 1 deletion tests/compare_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def compare_versions(mlflow_client_src, mlflow_client_dst, vr_src, vr_dst, outpu
assert vr_src.status_message == vr_dst.status_message
if mlflow_client_src != mlflow_client_src:
assert vr_src.name == vr_dst.name
if not utils.importing_into_databricks():
if utils.get_import_target_implementation() != utils.MLFlowImplementation.DATABRICKS:
assert vr_src.user_id == vr_dst.user_id

tags_dst = { k:v for k,v in vr_dst.tags.items() if not k.startswith(ExportTags.PREFIX_ROOT) }
Expand Down