From 9b236461a8c84168dffddf05fca7102da9c2b181 Mon Sep 17 00:00:00 2001 From: birbalin25 Date: Mon, 30 Jun 2025 13:58:00 -0500 Subject: [PATCH 01/25] fix and engancement --- .DS_Store | Bin 0 -> 6148 bytes databricks_notebooks/.DS_Store | Bin 0 -> 6148 bytes databricks_notebooks/bulk/Export_All.py | 120 +++++++++--- .../bulk/Export_All_log_parsing.py | 61 ++++++ .../bulk/Export_Registered_Models.py | 107 +++++++--- .../bulk/Import_Registered_Models.py | 146 ++++++++++++-- .../bulk/master_Export_Registered_Models.py | 173 +++++++++++++++++ .../bulk/master_Export_all.py | 106 ++++++++++ .../bulk/master_Import_Registered_Models.py | 161 ++++++++++++++++ mlflow_export_import/.DS_Store | Bin 0 -> 6148 bytes mlflow_export_import/bulk/bulk_utils.py | 36 +++- mlflow_export_import/bulk/config.py | 2 + mlflow_export_import/bulk/export_all.py | 51 ++++- .../bulk/export_experiments.py | 182 +++++++++++------- mlflow_export_import/bulk/export_models.py | 148 +++++++++----- .../bulk/import_experiments.py | 3 +- mlflow_export_import/bulk/import_models.py | 18 +- mlflow_export_import/bulk/model_utils.py | 51 ++++- mlflow_export_import/bulk/rename_utils.py | 8 +- mlflow_export_import/client/client_utils.py | 30 ++- .../common/checkpoint_thread.py | 119 ++++++++++++ mlflow_export_import/common/logging_utils.py | 8 +- mlflow_export_import/common/mlflow_utils.py | 3 + mlflow_export_import/common/model_utils.py | 28 ++- .../common/uc_permissions_utils.py | 88 ++++++++- .../experiment/export_experiment.py | 12 +- .../experiment/import_experiment.py | 3 +- mlflow_export_import/model/export_model.py | 57 ++++-- mlflow_export_import/model/import_model.py | 21 +- .../model_version/import_model_version.py | 4 +- mlflow_export_import/run/export_run.py | 41 +++- mlflow_export_import/run/import_run.py | 83 ++++++-- mlflow_export_import/run/run_data_importer.py | 9 +- 33 files changed, 1585 insertions(+), 294 deletions(-) create mode 100644 .DS_Store create mode 100644 databricks_notebooks/.DS_Store create mode 100644 databricks_notebooks/bulk/Export_All_log_parsing.py create mode 100644 databricks_notebooks/bulk/master_Export_Registered_Models.py create mode 100644 databricks_notebooks/bulk/master_Export_all.py create mode 100644 databricks_notebooks/bulk/master_Import_Registered_Models.py create mode 100644 mlflow_export_import/.DS_Store create mode 100644 mlflow_export_import/bulk/config.py create mode 100644 mlflow_export_import/common/checkpoint_thread.py diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..c432ecbfba68177465791335c9ba3d979dd3720d GIT binary patch literal 6148 zcmeHKyH3ME5S)b+k!Vt+puAt;4^C0|0zSZy7$k&)Q38U}9ly=&!zhuVr9ivZ?%b_+ z?(8YNJ^-?MYp;L>fH_?epBkp7_thtM78ymcH8$Ae5nDXrFsc4Iq1+`}+~Jje%y00F zp4oe5dtO|#J#4#8-*>vy5<7Hwp!ax3!?F*&aJ(m_;X~}k*+}9OrlkU@Kq`<5qynjc z_H4EJg=5B4AQeajz7^2#L!m3yz`@Zz9Sn8^AkLUJzFDvMO3_;h>&~NFZvDV5NtQ8Zj6?3Ak`0}W(=$Q9w;NWO<=8aCw9|6@R KEfx3+1-<|VVI=ba literal 0 HcmV?d00001 diff --git a/databricks_notebooks/.DS_Store b/databricks_notebooks/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..c546a39615299a97f594372e4906d38bb9a7068f GIT binary patch literal 6148 zcmeH~F^SV?}7 zi2>O0bvpnffDPS=y@#0@^8pvUFyVfCUZ>0BF?o@;>VT*85wrc=7Nmd_kOERb3P^zk zDUiqbX1kzg(xXTLDXpcS=p~5F1H`^?Ok{*+NhK!Ls>QIRGu|q%FB}t-4vU+4PTg$Pp;+9`c#CvcpQup^ zNP%+&ZgaWx`u{|q>Hp73T1f#Za8(Lew|-u)_@t_>lgD|jZS)tq=X}xKI1dVkD96Mo i$6R 0: - _logger.info(f"{failed_runs}/{total_runs} runs failed") - _logger.info(f"Duration for {len(experiments)} experiments export: {duration} seconds") + _logger.info(f"{len(experiments)} experiments exported") + _logger.info(f"{ok_runs}/{total_runs} runs succesfully exported") + if failed_runs > 0: + _logger.info(f"{failed_runs}/{total_runs} runs failed") + _logger.info(f"Duration for {len(experiments)} experiments export: {duration} seconds") - return info_attr + return info_attr + + except Exception as e: #birbal added + _logger.error(f"_export_experiment failed: {e}", exc_info=True) + + finally: #birbal added + checkpoint_thread.stop() + checkpoint_thread.join() + _logger.info("Checkpoint thread flushed and terminated for experiments") def _export_experiment(mlflow_client, exp_id_or_name, output_dir, export_permissions, notebook_formats, export_results, - run_start_time, export_deleted_runs, run_ids): + run_start_time, export_deleted_runs, run_ids, result_queue = None): #birbal added result_queue = None ok_runs = -1; failed_runs = -1 exp_name = exp_id_or_name try: + if not run_ids: + _logger.error(f"no runs to export for experiment {exp_id_or_name}. Throwing exception to capture in checkpoint file") + raise Exception(f"no runs to export for experiment {exp_id_or_name}") + exp = mlflow_utils.get_experiment(mlflow_client, exp_id_or_name) exp_name = exp.name exp_output_dir = os.path.join(output_dir, exp.experiment_id) @@ -166,7 +196,8 @@ def _export_experiment(mlflow_client, exp_id_or_name, output_dir, export_permiss run_start_time = run_start_time, export_deleted_runs = export_deleted_runs, notebook_formats = notebook_formats, - mlflow_client = mlflow_client + mlflow_client = mlflow_client, + result_queue = result_queue #birbal added ) duration = round(time.time() - start_time, 1) result = { @@ -181,14 +212,23 @@ def _export_experiment(mlflow_client, exp_id_or_name, output_dir, export_permiss except RestException as e: mlflow_utils.dump_exception(e) - err_msg = { **{ "message": "Cannot export experiment", "experiment": exp_name }, ** mlflow_utils.mk_msg_RestException(e) } + err_msg = { **{ "message": "Cannot export experiment", "experiment_id": exp_id_or_name }, ** mlflow_utils.mk_msg_RestException(str(e)) } #birbal type casted _logger.error(err_msg) + err_msg["status"] = "failed" #birbal added + result_queue.put(err_msg) #birbal added except MlflowExportImportException as e: - err_msg = { "message": "Cannot export experiment", "experiment": exp_name, "MlflowExportImportException": e.kwargs } + err_msg = { "message": "Cannot export experiment", "experiment_id": exp_id_or_name, "MlflowExportImportException": str(e.kwargs) } #birbal string casted _logger.error(err_msg) + + err_msg["status"] = "failed" #birbal added + result_queue.put(err_msg) #birbal added except Exception as e: - err_msg = { "message": "Cannot export experiment", "experiment": exp_name, "Exception": e } + err_msg = { "message": "Cannot export experiment", "experiment_id": exp_id_or_name, "Exception": str(e) } #birbal string casted _logger.error(err_msg) + + err_msg["status"] = "failed" #birbal added + result_queue.put(err_msg) #birbal added + return Result(exp_name, ok_runs, failed_runs) diff --git a/mlflow_export_import/bulk/export_models.py b/mlflow_export_import/bulk/export_models.py index 83100ea3..90262a1d 100644 --- a/mlflow_export_import/bulk/export_models.py +++ b/mlflow_export_import/bulk/export_models.py @@ -25,6 +25,8 @@ from mlflow_export_import.bulk import export_experiments from mlflow_export_import.bulk.model_utils import get_experiments_runs_of_models from mlflow_export_import.bulk import bulk_utils +from mlflow_export_import.common.checkpoint_thread import CheckpointThread, filter_unprocessed_objects #birbal added +from queue import Queue #birbal added _logger = utils.getLogger(__name__) @@ -40,7 +42,11 @@ def export_models( export_version_model = False, notebook_formats = None, use_threads = False, - mlflow_client = None + mlflow_client = None, + task_index = None, #birbal + num_tasks = None, #birbal + checkpoint_dir_experiment = None, #birbal + checkpoint_dir_model = None #birbal ): """ :param: model_names: Can be either: @@ -55,19 +61,34 @@ def export_models( model_names = f.read().splitlines() mlflow_client = mlflow_client or create_mlflow_client() - exps_and_runs = get_experiments_runs_of_models(mlflow_client, model_names) - exp_ids = exps_and_runs.keys() + exps_and_runs = get_experiments_runs_of_models(mlflow_client, model_names, task_index, num_tasks) ##birbal return dict of key=exp_id and value=list[run_id] + + total_run_ids = sum(len(run_id_list) for run_id_list in exps_and_runs.values()) #birbal added + _logger.info(f"TOTAL MODEL EXPERIMENTS TO EXPORT FOR TASK_INDEX={task_index}: {len(exps_and_runs)} AND TOTAL RUN_IDs TO EXPORT: {total_run_ids}") #birbal added + start_time = time.time() out_dir = os.path.join(output_dir, "experiments") - exps_to_export = exp_ids if export_all_runs else exps_and_runs + + ######Birbal block + exps_and_runs = filter_unprocessed_objects(checkpoint_dir_experiment,"experiments",exps_and_runs) + _logger.info(f"AFTER FILTERING OUT THE PROCESSED EXPERIMENTS FROM CHECKPOINT, REMAINING EXPERIMENTS COUNT TO BE PROCESSED: {len(exps_and_runs)} ") #birbal added + ###### + + # if len(exps_and_runs) == 0: + # _logger.info("NO MODEL EXPERIMENTS TO EXPORT") + # return + + res_exps = export_experiments.export_experiments( mlflow_client = mlflow_client, - experiments = exps_to_export, + experiments = exps_and_runs, #birbal added output_dir = out_dir, export_permissions = export_permissions, export_deleted_runs = export_deleted_runs, notebook_formats = notebook_formats, - use_threads = use_threads + use_threads = use_threads, + task_index = task_index, #birbal added + checkpoint_dir_experiment = checkpoint_dir_experiment #birbal added ) res_models = _export_models( mlflow_client, @@ -79,7 +100,10 @@ def export_models( export_latest_versions = export_latest_versions, export_version_model = export_version_model, export_permissions = export_permissions, - export_deleted_runs = export_deleted_runs + export_deleted_runs = export_deleted_runs, + task_index = task_index, #birbal + num_tasks = num_tasks, #birbal + checkpoint_dir_model = checkpoint_dir_model #birbal ) duration = round(time.time()-start_time, 1) _logger.info(f"Duration for total registered models and versions' runs export: {duration} seconds") @@ -112,60 +136,84 @@ def _export_models( export_latest_versions = False, export_version_model = False, export_permissions = False, - export_deleted_runs = False + export_deleted_runs = False, + task_index = None, + num_tasks = None, + checkpoint_dir_model = None ): max_workers = utils.get_threads(use_threads) start_time = time.time() - model_names = bulk_utils.get_model_names(mlflow_client, model_names) + model_names = bulk_utils.get_model_names(mlflow_client, model_names, task_index, num_tasks) + _logger.info(f"TOTAL MODELS TO EXPORT: {len(model_names)}") #birbal added _logger.info("Models to export:") for model_name in model_names: _logger.info(f" {model_name}") futures = [] - with ThreadPoolExecutor(max_workers=max_workers) as executor: - for model_name in model_names: - dir = os.path.join(output_dir, model_name) - future = executor.submit(export_model, - model_name = model_name, - output_dir = dir, - stages = stages, - export_latest_versions = export_latest_versions, - export_version_model = export_version_model, - export_permissions = export_permissions, - export_deleted_runs = export_deleted_runs, - notebook_formats = notebook_formats, - mlflow_client = mlflow_client, - ) - futures.append(future) - ok_models = [] ; failed_models = [] - for future in futures: - result = future.result() - if result[0]: ok_models.append(result[1]) - else: failed_models.append(result[1]) - duration = round(time.time()-start_time, 1) - info_attr = { - "model_names": model_names, - "stages": stages, - "export_latest_versions": export_latest_versions, - "notebook_formats": notebook_formats, - "use_threads": use_threads, - "output_dir": output_dir, - "num_total_models": len(model_names), - "num_ok_models": len(ok_models), - "num_failed_models": len(failed_models), - "duration": duration, - "failed_models": failed_models - } - mlflow_attr = { - "models": ok_models, - } - io_utils.write_export_file(output_dir, "models.json", __file__, mlflow_attr, info_attr) + ######## birbal new block + model_names = filter_unprocessed_objects(checkpoint_dir_model,"models",model_names) + _logger.info(f"AFTER FILTERING OUT THE PROCESSED MODELS FROM CHECKPOINT, TOTAL REMAINING COUNT: {len(model_names)}") + result_queue = Queue() + checkpoint_thread = CheckpointThread(result_queue, checkpoint_dir_model, interval=300, batch_size=100) + _logger.info(f"checkpoint_thread started for models") + checkpoint_thread.start() + ######## - _logger.info(f"{len(model_names)} models exported") - _logger.info(f"Duration for registered models export: {duration} seconds") + try: #birbal added + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for model_name in model_names: + dir = os.path.join(output_dir, model_name) + future = executor.submit(export_model, + model_name = model_name, + output_dir = dir, + stages = stages, + export_latest_versions = export_latest_versions, + export_version_model = export_version_model, + export_permissions = export_permissions, + export_deleted_runs = export_deleted_runs, + notebook_formats = notebook_formats, + mlflow_client = mlflow_client, + result_queue = result_queue #birbal added + ) + futures.append(future) + ok_models = [] ; failed_models = [] + for future in futures: + result = future.result() + if result[0]: ok_models.append(result[1]) + else: failed_models.append(result[1]) + duration = round(time.time()-start_time, 1) - return info_attr + info_attr = { + "model_names": model_names, + "stages": stages, + "export_latest_versions": export_latest_versions, + "notebook_formats": notebook_formats, + "use_threads": use_threads, + "output_dir": output_dir, + "num_total_models": len(model_names), + "num_ok_models": len(ok_models), + "num_failed_models": len(failed_models), + "duration": duration, + "failed_models": failed_models + } + mlflow_attr = { + "models": ok_models, + } + io_utils.write_export_file(output_dir, "models.json", __file__, mlflow_attr, info_attr) + + _logger.info(f"{len(model_names)} models exported") + _logger.info(f"Duration for registered models export: {duration} seconds") + + return info_attr + + except Exception as e: #birbal added + _logger.error(f"export_model failed: {e}") + + finally: #birbal added + checkpoint_thread.stop() + checkpoint_thread.join() + _logger.info("Checkpoint thread flushed and terminated for models") @click.command() diff --git a/mlflow_export_import/bulk/import_experiments.py b/mlflow_export_import/bulk/import_experiments.py index 100ea9b8..2524a2ef 100644 --- a/mlflow_export_import/bulk/import_experiments.py +++ b/mlflow_export_import/bulk/import_experiments.py @@ -47,12 +47,13 @@ def import_experiments( """ experiment_renames = rename_utils.get_renames(experiment_renames) + mlflow_client = mlflow_client or mlflow.MlflowClient() dct = io_utils.read_file_mlflow(os.path.join(input_dir, "experiments.json")) exps = dct["experiments"] _logger.info("Importing experiments:") for exp in exps: - _logger.info(f" Importing experiment: {exp}") + _logger.info(f"Importing experiment: {exp}") max_workers = utils.get_threads(use_threads) futures = [] diff --git a/mlflow_export_import/bulk/import_models.py b/mlflow_export_import/bulk/import_models.py index 0859ddc9..b25a9ca6 100644 --- a/mlflow_export_import/bulk/import_models.py +++ b/mlflow_export_import/bulk/import_models.py @@ -39,7 +39,9 @@ def import_models( model_renames = None, verbose = False, use_threads = False, - mlflow_client = None + mlflow_client = None, + target_model_catalog = None, #birbal added + target_model_schema = None #birbal added ): mlflow_client = mlflow_client or create_mlflow_client() experiment_renames = rename_utils.get_renames(experiment_renames) @@ -65,7 +67,9 @@ def import_models( model_renames, experiment_renames, verbose, - use_threads + use_threads, + target_model_catalog, #birbal added + target_model_schema #birbal added ) duration = round(time.time()-start_time, 1) dct = { "duration": duration, "experiments_import": exp_info, "models_import": model_res } @@ -132,7 +136,9 @@ def _import_models(mlflow_client, model_renames, experiment_renames, verbose, - use_threads + use_threads, + target_model_catalog = None, #birbal added + target_model_schema = None #birbal added ): max_workers = utils.get_threads(use_threads) start_time = time.time() @@ -150,8 +156,14 @@ def _import_models(mlflow_client, with ThreadPoolExecutor(max_workers=max_workers) as executor: for model_name in model_names: + _logger.info(f"model name BEFORE rename : '{model_name}'") #birbal added dir = os.path.join(models_dir, model_name) model_name = rename_utils.rename(model_name, model_renames, "model") + + if target_model_catalog is not None and target_model_schema is not None: #birbal added + model_name=rename_utils.build_full_model_name(target_model_catalog, target_model_schema, model_name) + _logger.info(f"model name AFTER rename : '{model_name}'") #birbal added + executor.submit(all_importer.import_model, model_name = model_name, input_dir = dir, diff --git a/mlflow_export_import/bulk/model_utils.py b/mlflow_export_import/bulk/model_utils.py index cde43a6c..1939d4b0 100644 --- a/mlflow_export_import/bulk/model_utils.py +++ b/mlflow_export_import/bulk/model_utils.py @@ -7,29 +7,62 @@ _logger = utils.getLogger(__name__) -def get_experiments_runs_of_models(client, model_names, show_experiments=False, show_runs=False): +def get_experiments_runs_of_models(client, model_names, task_index=None, num_tasks=None, show_experiments=False, show_runs=False): """ Get experiments and runs to to export. """ - model_names = bulk_utils.get_model_names(client, model_names) - _logger.info(f"{len(model_names)} Models:") + model_names = bulk_utils.get_model_names(client, model_names, task_index, num_tasks) + _logger.info(f"TOTAL MODELS TO EXPORT FOR TASK_INDEX={task_index} : {len(model_names)}") for model_name in model_names: _logger.info(f" {model_name}") exps_and_runs = {} for model_name in model_names: - versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") + versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal.Changed from "name='{model_name}'" to handle models name with single quote for vr in versions: try: run = client.get_run(vr.run_id) exps_and_runs.setdefault(run.info.experiment_id,[]).append(run.info.run_id) - except mlflow.exceptions.MlflowException as e: - if e.error_code == "RESOURCE_DOES_NOT_EXIST": - _logger.warning(f"run '{vr.run_id}' of version {vr.version} of model '{model_name}' does not exist") - else: - _logger.warning(f"run '{vr.run_id}' of version {vr.version} of model '{model_name}': Error.code: {e.error_code}. Error.message: {e.message}") + except Exception as e: #birbal added + _logger.warning(f"Error with run '{vr.run_id}' of version {vr.version} of model '{model_name}': Error: {e}") + if show_experiments: show_experiments_runs_of_models(exps_and_runs, show_runs) return exps_and_runs +def get_experiment_runs_dict_from_names(client, experiment_names): #birbal added entire function + experiment_runs_dict = {} + for name in experiment_names: + experiment = client.get_experiment_by_name(name) + if experiment is not None: + experiment_id = experiment.experiment_id + runs = client.search_runs(experiment_ids=[experiment_id], max_results=1000) + run_ids = [run.info.run_id for run in runs] + experiment_runs_dict[experiment_id] = run_ids + else: + _logger.info(f"Experiment not found: {name}... in bulk->model_utils.py") + + return experiment_runs_dict + + +def get_experiments_name_of_models(client, model_names): + """ Get experiments name to export. """ + model_names = bulk_utils.get_model_names(client, model_names) + experiment_name_list = [] + for model_name in model_names: + versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal. Fix for models name with single quote + for vr in versions: + try: + run = client.get_run(vr.run_id) + experiment_id = run.info.experiment_id + experiment = mlflow.get_experiment(experiment_id) + experiment_name = experiment.name + experiment_name_list.append(experiment_name) + except Exception as e: + _logger.warning(f"run '{vr.run_id}' of version {vr.version} of model '{model_name}': Error: {e}") + + return experiment_name_list + + + def show_experiments_runs_of_models(exps_and_runs, show_runs=False): _logger.info("Experiments for models:") for k,v in exps_and_runs.items(): diff --git a/mlflow_export_import/bulk/rename_utils.py b/mlflow_export_import/bulk/rename_utils.py index 4b381032..fc28060a 100644 --- a/mlflow_export_import/bulk/rename_utils.py +++ b/mlflow_export_import/bulk/rename_utils.py @@ -16,7 +16,7 @@ def read_rename_file(path): def rename(name, replacements, object_name="object"): if not replacements: - return name + return name ## birbal :: corrected to return name instead of None. returning None will cause failure for k,v in replacements.items(): if k != "" and name.startswith(k): new_name = name.replace(k,v) @@ -34,3 +34,9 @@ def get_renames(filename_or_dict): return filename_or_dict else: raise MlflowExportImportException(f"Unknown name replacement type '{type(filename_or_dict)}'", http_status_code=400) + +def build_full_model_name(catalog, schema, model_name): #birbal added + if model_name.count('.') == 2: + return model_name + else: + return f"{catalog}.{schema}.{model_name}" diff --git a/mlflow_export_import/client/client_utils.py b/mlflow_export_import/client/client_utils.py index d8dacbf8..f412bbbd 100644 --- a/mlflow_export_import/client/client_utils.py +++ b/mlflow_export_import/client/client_utils.py @@ -1,5 +1,6 @@ import mlflow from . http_client import HttpClient, MlflowHttpClient, DatabricksHttpClient +from mlflow_export_import.bulk import config def create_http_client(mlflow_client, model_name=None): @@ -23,14 +24,27 @@ def create_dbx_client(mlflow_client): return DatabricksHttpClient(creds.host, creds.token) -def create_mlflow_client(): +# def create_mlflow_client(): ##birbal . This is original block. commented out +# """ +# Create MLflowClient. If MLFLOW_TRACKING_URI is UC, then set MlflowClient.tracking_uri to the non-UC variant. +# """ +# registry_uri = mlflow.get_registry_uri() +# if registry_uri: +# tracking_uri = mlflow.get_tracking_uri() +# nonuc_tracking_uri = tracking_uri.replace("databricks-uc","databricks") # NOTE: legacy +# return mlflow.MlflowClient(nonuc_tracking_uri, registry_uri) +# else: +# return mlflow.MlflowClient() + + +def create_mlflow_client(): ##birbal added. Modified version of above. """ Create MLflowClient. If MLFLOW_TRACKING_URI is UC, then set MlflowClient.tracking_uri to the non-UC variant. """ - registry_uri = mlflow.get_registry_uri() - if registry_uri: - tracking_uri = mlflow.get_tracking_uri() - nonuc_tracking_uri = tracking_uri.replace("databricks-uc","databricks") # NOTE: legacy - return mlflow.MlflowClient(nonuc_tracking_uri, registry_uri) - else: - return mlflow.MlflowClient() + target_model_registry=config.target_model_registry # Birbal- this is set at Import_Registered_Models.py + + if not target_model_registry or target_model_registry.lower() == "workspace_registry": + mlflow.set_registry_uri('databricks') + elif target_model_registry.lower() == "unity_catalog": + mlflow.set_registry_uri('databricks-uc') + return mlflow.MlflowClient() \ No newline at end of file diff --git a/mlflow_export_import/common/checkpoint_thread.py b/mlflow_export_import/common/checkpoint_thread.py new file mode 100644 index 00000000..18e87ce6 --- /dev/null +++ b/mlflow_export_import/common/checkpoint_thread.py @@ -0,0 +1,119 @@ +import threading +import time +from datetime import datetime +import os +import pandas as pd +import logging +import pyarrow.dataset as ds +from mlflow_export_import.common import utils +from mlflow_export_import.common import filesystem as _fs +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() + +_logger = utils.getLogger(__name__) + +class CheckpointThread(threading.Thread): #birbal added + def __init__(self, queue, checkpoint_dir, interval=300, batch_size=100): + super().__init__() + self.queue = queue + self.checkpoint_dir = checkpoint_dir + self.interval = interval + self.batch_size = batch_size + self._stop_event = threading.Event() + self._buffer = [] + self._last_flush_time = time.time() + + + def run(self): + max_drain_batch = 50 # Max items to pull per loop iteration + while not self._stop_event.is_set() or not self.queue.empty(): + items_fetched = False + drain_count = 0 + + try: + while not self.queue.empty(): + _logger.debug(f"drain_count is {drain_count} and buffer len is {len(self._buffer)}") + item = self.queue.get() + self._buffer.append(item) + drain_count += 1 + if drain_count > max_drain_batch: + _logger.info(f" drain_count > max_drain_batch is TRUE") + items_fetched = True + break + + except Exception: + pass # Queue is empty or bounded + + if items_fetched: + _logger.info(f"[Checkpoint] Fetched {drain_count} items from queue.") + + time_since_last_flush = time.time() - self._last_flush_time + if len(self._buffer) >= self.batch_size or time_since_last_flush >= self.interval: + _logger.info(f"ready to flush to delta") + self.flush_to_delta() + self._buffer.clear() + self._last_flush_time = time.time() + + # Final flush + if self._buffer: + self.flush_to_delta() + self._buffer.clear() + + + + def flush_to_delta(self): + _logger.info(f"flush_to_delta called") + try: + df = pd.DataFrame(self._buffer) + if df.empty: + _logger.info(f"[Checkpoint] DataFrame is empty. Skipping write to {self.checkpoint_dir}") + return + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + file_path = os.path.join(self.checkpoint_dir, f"checkpoint_{timestamp}.parquet") + df.to_parquet(file_path, index=False) + _logger.info(f"[Checkpoint] Saved len(df) {len(df)} records to {file_path}") + + except Exception as e: + _logger.error(f"[Checkpoint] Failed to write to {self.checkpoint_dir}: {e}", exc_info=True) + + def stop(self): + self._stop_event.set() + _logger.info("STOP event called.") + + @staticmethod + def load_processed_objects(checkpoint_dir, object_type= None): + try: + dataset = ds.dataset(checkpoint_dir, format="parquet") + df = dataset.to_table().to_pandas() + result_list = [] + + if df.empty: + _logger.warning(f"[Checkpoint] Parquet data is empty in {checkpoint_dir}") + return {} + + if object_type == "experiments": + result_list = df["experiment_id"].dropna().unique().tolist() + + if object_type == "models": + result_list = df["model"].dropna().unique().tolist() + + return result_list + + except Exception as e: + _logger.warning(f"[Checkpoint] Failed to load checkpoint data from {checkpoint_dir}: {e}", exc_info=True) + return None + +def filter_unprocessed_objects(checkpoint_dir,object_type,to_be_processed_objects): #birbal added + processed_objects = CheckpointThread.load_processed_objects(checkpoint_dir,object_type) + if isinstance(to_be_processed_objects, dict): + unprocessed_objects = {k: v for k, v in to_be_processed_objects.items() if k not in processed_objects} + return unprocessed_objects + + if isinstance(to_be_processed_objects, list): + unprocessed_objects = list(set(to_be_processed_objects) - set(processed_objects)) + return unprocessed_objects + + return None + + \ No newline at end of file diff --git a/mlflow_export_import/common/logging_utils.py b/mlflow_export_import/common/logging_utils.py index b656d870..efd038c3 100644 --- a/mlflow_export_import/common/logging_utils.py +++ b/mlflow_export_import/common/logging_utils.py @@ -1,6 +1,8 @@ import os import yaml import logging.config +from mlflow_export_import.bulk import config #birbal added +log_path=config.log_path #birbal added _have_loaded_logging_config = False @@ -10,8 +12,10 @@ def get_logger(name): return logging.getLogger(name) config_path = os.environ.get("MLFLOW_EXPORT_IMPORT_LOG_CONFIG_FILE", None) - output_path = os.environ.get("MLFLOW_EXPORT_IMPORT_LOG_OUTPUT_FILE", None) - log_format = os.environ.get("MLFLOW_EXPORT_IMPORT_LOG_FORMAT", None) + output_path = os.environ.get("MLFLOW_EXPORT_IMPORT_LOG_OUTPUT_FILE", log_path) + log_format = os.environ.get("MLFLOW_EXPORT_IMPORT_LOG_FORMAT", "%(asctime)s - %(levelname)s - [%(name)s:%(lineno)d] - %(message)s") #birbal updated + + #print(f"logging_utils.get_logger: config_path: {config_path}") #print(f"logging_utils.get_logger: output_path: {output_path}") #print(f"logging_utils.get_logger: log_format: {log_format}") diff --git a/mlflow_export_import/common/mlflow_utils.py b/mlflow_export_import/common/mlflow_utils.py index d0840969..4f0eeccf 100644 --- a/mlflow_export_import/common/mlflow_utils.py +++ b/mlflow_export_import/common/mlflow_utils.py @@ -31,6 +31,9 @@ def set_experiment(mlflow_client, dbx_client, exp_name, tags=None): if not exp_name.startswith("/"): raise MlflowExportImportException(f"Cannot create experiment '{exp_name}'. Databricks experiment must start with '/'.") create_workspace_dir(dbx_client, os.path.dirname(exp_name)) + + else: ##birbal + _logger.error("utils.calling_databricks is false") try: if not tags: tags = {} tags = utils.create_mlflow_tags_for_databricks_import(tags) diff --git a/mlflow_export_import/common/model_utils.py b/mlflow_export_import/common/model_utils.py index 227a6310..a55bac18 100644 --- a/mlflow_export_import/common/model_utils.py +++ b/mlflow_export_import/common/model_utils.py @@ -26,6 +26,11 @@ def model_names_same_registry(name1, name2): not is_unity_catalog_model(name1) and not is_unity_catalog_model(name2) +def model_names_same_registry_nonucsrc_uctgt(name1, name2): + return \ + not is_unity_catalog_model(name1) and is_unity_catalog_model(name2) + + def create_model(client, model_name, model_dct, import_metadata): """ Creates a registered model if it does not exist, and returns the model in either case. @@ -38,6 +43,8 @@ def create_model(client, model_name, model_dct, import_metadata): client.create_registered_model(model_name) _logger.info(f"Created new registered model '{model_name}'") return True + except Exception as e: + _logger.info(f"except Exception trigger, error for '{model_name}': {e}") except RestException as e: if e.error_code != "RESOURCE_ALREADY_EXISTS": raise e @@ -50,7 +57,8 @@ def delete_model(client, model_name, sleep_time=5): Delete a registered model and all its versions. """ try: - versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") + # versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") + versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal added _logger.info(f"Deleting model '{model_name}' and its versions") for vr in versions: msg = utils.get_obj_key_values(vr, [ "name", "version", "current_stage", "status", "run_id" ]) @@ -60,8 +68,10 @@ def delete_model(client, model_name, sleep_time=5): time.sleep(sleep_time) # Wait until stage transition takes hold client.delete_model_version(model_name, vr.version) client.delete_registered_model(model_name) - except RestException: - pass + # except RestException: #birbal commented out + except Exception as e: + _logger.error(f"Error deleting modfel {model_name}. Error: {e}") + def list_model_versions(client, model_name, get_latest_versions=False): @@ -69,14 +79,16 @@ def list_model_versions(client, model_name, get_latest_versions=False): List 'all' or the 'latest' versions of registered model. """ if is_unity_catalog_model(model_name): - versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") + # versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") + versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal added # JIRA: ES-834105 - UC-ML MLflow search_registered_models and search_model_versions do not return tags and aliases - 2023-08-21 return [ client.get_model_version(vr.name, vr.version) for vr in versions ] else: if get_latest_versions: return client.get_latest_versions(model_name) else: - return list(SearchModelVersionsIterator(client, filter=f"name='{model_name}'")) + # return list(SearchModelVersionsIterator(client, filter=f"name='{model_name}'")) + return list(SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """)) #birbal added def search_model_versions(client, filter): @@ -201,11 +213,13 @@ def get_registered_model(mlflow_client, model_name, get_permissions=False): return model -def update_model_permissions(mlflow_client, dbx_client, model_name, perms): +def update_model_permissions(mlflow_client, dbx_client, model_name, perms, nonucsrc_uctgt = False): #birbal added nonucsrc_uctgt parameter if perms: _logger.info(f"Updating permissions for registered model '{model_name}'") - if is_unity_catalog_model(model_name): + if is_unity_catalog_model(model_name) and not nonucsrc_uctgt: #birbal added uc_permissions_utils.update_permissions(mlflow_client, model_name, perms) + elif is_unity_catalog_model(model_name) and nonucsrc_uctgt: #birbal added + uc_permissions_utils.update_permissions_nonucsrc_uctgt(mlflow_client, model_name, perms) else: _model = dbx_client.get("mlflow/databricks/registered-models/get", { "name": model_name }) _model = _model["registered_model_databricks"] diff --git a/mlflow_export_import/common/uc_permissions_utils.py b/mlflow_export_import/common/uc_permissions_utils.py index 2812e04d..0ed2a7ea 100644 --- a/mlflow_export_import/common/uc_permissions_utils.py +++ b/mlflow_export_import/common/uc_permissions_utils.py @@ -26,13 +26,23 @@ def get_effective_permissions(self, model_name): resource = f"unity-catalog/effective-permissions/function/{model_name}" return self.client.get(resource) - def update_permissions(self, model_name, changes): + # def update_permissions(self, model_name, changes): #birbal commented out entire func def + # """ + # https://docs.databricks.com/api/workspace/grants/update + # PATCH /api/2.1/unity-catalog/permissions/{securable_type}/{full_name} + # """ + # resource = f"unity-catalog/permissions/function/{model_name}" + # _logger.info(f"Updating {len(changes.get('changes',[]))} permissions for model '{model_name}'. Resource: {resource}") + # return self.client.patch(resource, changes) + + + def update_permissions(self, object_type, object_name , changes): #birbal modified the above block """ https://docs.databricks.com/api/workspace/grants/update PATCH /api/2.1/unity-catalog/permissions/{securable_type}/{full_name} """ - resource = f"unity-catalog/permissions/function/{model_name}" - _logger.info(f"Updating {len(changes.get('changes',[]))} permissions for model '{model_name}'. Resource: {resource}") + resource = f"unity-catalog/permissions/{object_type}/{object_name}" + _logger.info(f"Updating {len(changes.get('changes',[]))} permissions for {object_type} '{object_name}'. Resource: {resource}") return self.client.patch(resource, changes) @@ -51,7 +61,77 @@ def get_permissions(mlflow_client, model_name): return {} -def update_permissions(mlflow_client, model_name, perms, unroll_changes=True): + +def update_permissions_nonucsrc_uctgt(mlflow_client, model_name, perms): ##birbal added this entire func. + + try: + _logger.info(f"BEFORE perms is {perms}") + uc_client = UcPermissionsClient(mlflow_client) + model_perm_dict, catalog_perm_dict, schema_perm_dict = format_perms(perms) + _logger.info(f"AFTER model_perm_dict is {model_perm_dict}") + _logger.info(f"AFTER catalog_perm_dict is {catalog_perm_dict}") + _logger.info(f"AFTER schema_perm_dict is {schema_perm_dict}") + + catalog, schema, model = model_name.split(".") + + uc_client.update_permissions("catalog", catalog, catalog_perm_dict) + uc_client.update_permissions("schema", catalog+"."+schema, schema_perm_dict) + uc_client.update_permissions("function", model_name, model_perm_dict) + except Exception as e: + _logger.error(f"error with update_permissions for model '{model_name}'. Error: {e}") + + +def format_perms(perms): ##birbal added this entire func. + model_perm = [] + catalog_perm=[] + schema_perm=[] + + for acl in perms['permissions']['access_control_list']: + permission_type = "EXECUTE" + for perm in acl['all_permissions']: + if perm.get('permission_level') == 'CAN_MANAGE': + permission_type = "MANAGE" + break + if 'user_name' in acl: + model_perm.append({ + "add": [permission_type], + "principal": acl['user_name'] + }) + catalog_perm.append({ + "add": ["USE_CATALOG"], + "principal": acl['user_name'] + }) + schema_perm.append({ + "add": ["USE_SCHEMA"], + "principal": acl['user_name'] + }) + + if 'group_name' in acl: + group_name = acl['group_name'] + if group_name == "admins": + continue + model_perm.append({ + "add": [permission_type], + "principal": group_name + }) + catalog_perm.append({ + "add": ["USE_CATALOG"], + "principal": group_name + }) + schema_perm.append({ + "add": ["USE_SCHEMA"], + "principal": group_name + }) + + model_perm_dict = {"changes": model_perm} + catalog_perm_dict = {"changes": catalog_perm} + schema_perm_dict = {"changes": schema_perm} + return model_perm_dict,catalog_perm_dict,schema_perm_dict + + + + +def update_permissions(mlflow_client, model_name, perms, unroll_changes=True): uc_client = UcPermissionsClient(mlflow_client) changes = _mk_update_changes(perms) if unroll_changes: # NOTE: in order to prevent batch update to fail because one individual update failed diff --git a/mlflow_export_import/experiment/export_experiment.py b/mlflow_export_import/experiment/export_experiment.py index ac1656e4..308e7aad 100644 --- a/mlflow_export_import/experiment/export_experiment.py +++ b/mlflow_export_import/experiment/export_experiment.py @@ -35,7 +35,8 @@ def export_experiment( export_deleted_runs = False, check_nested_runs = False, notebook_formats = None, - mlflow_client = None + mlflow_client = None, + result_queue = None #birbal added ): """ :param: experiment_id_or_name: Experiment ID or name. @@ -81,7 +82,7 @@ def export_experiment( for run in runs: _export_run(mlflow_client, run, output_dir, ok_run_ids, failed_run_ids, - run_start_time, run_start_time_str, export_deleted_runs, notebook_formats) + run_start_time, run_start_time_str, export_deleted_runs, notebook_formats, result_queue) #birbal added result_queue num_runs_exported += 1 info_attr = { @@ -114,7 +115,7 @@ def export_experiment( def _export_run(mlflow_client, run, output_dir, ok_run_ids, failed_run_ids, run_start_time, run_start_time_str, - export_deleted_runs, notebook_formats + export_deleted_runs, notebook_formats, result_queue = None #birbal added result_queue ): if run_start_time and run.info.start_time < run_start_time: msg = { @@ -123,14 +124,15 @@ def _export_run(mlflow_client, run, output_dir, "start_time": fmt_ts_millis(run.info.start_time), "run_start_time": run_start_time_str } - _logger.info(f"Not exporting run: {msg}") + _logger.info(f"Not exporting run: {msg} as run.info.start_time < run_start_time ") #birbal updated return is_success = export_run( run_id = run.info.run_id, output_dir = os.path.join(output_dir, run.info.run_id), export_deleted_runs = export_deleted_runs, notebook_formats = notebook_formats, - mlflow_client = mlflow_client + mlflow_client = mlflow_client, + result_queue = result_queue #birbal added ) if is_success: ok_run_ids.append(run.info.run_id) diff --git a/mlflow_export_import/experiment/import_experiment.py b/mlflow_export_import/experiment/import_experiment.py index 72a14c5a..56047429 100644 --- a/mlflow_export_import/experiment/import_experiment.py +++ b/mlflow_export_import/experiment/import_experiment.py @@ -85,7 +85,8 @@ def import_experiment( input_dir = os.path.join(input_dir, src_run_id), dst_notebook_dir = dst_notebook_dir, import_source_tags = import_source_tags, - use_src_user_id = use_src_user_id + use_src_user_id = use_src_user_id, + exp = exp #birbal added ) dst_run_id = dst_run.info.run_id run_ids_map[src_run_id] = { "dst_run_id": dst_run_id, "src_parent_run_id": src_parent_run_id } diff --git a/mlflow_export_import/model/export_model.py b/mlflow_export_import/model/export_model.py index f965be7e..3d3a66f2 100644 --- a/mlflow_export_import/model/export_model.py +++ b/mlflow_export_import/model/export_model.py @@ -23,6 +23,7 @@ from mlflow_export_import.common.timestamp_utils import adjust_timestamps from mlflow_export_import.common import MlflowExportImportException from mlflow_export_import.run.export_run import export_run +import ast #birbal added _logger = utils.getLogger(__name__) @@ -47,7 +48,8 @@ def export_model( export_permissions = False, export_deleted_runs = False, notebook_formats = None, - mlflow_client = None + mlflow_client = None, + result_queue = None #birbal added ): """ :param model_name: Registered model name. @@ -74,31 +76,37 @@ def export_model( opts = Options(stages, versions, export_latest_versions, export_deleted_runs, export_version_model, export_permissions, notebook_formats) try: - _export_model(mlflow_client, model_name, output_dir, opts) + _export_model(mlflow_client, model_name, output_dir, opts, result_queue) #birbal added result_queue return True, model_name except RestException as e: - err_msg = { "model": model_name, "RestException": e.json } + err_msg = { "model": model_name, "RestException": str(e.json) } #birbal string casted if e.json.get("error_code") == "RESOURCE_DOES_NOT_EXIST": _logger.error({ **{"message": "Model does not exist"}, **err_msg}) else: _logger.error({**{"message": "Model cannot be exported"}, **err_msg}) import traceback traceback.print_exc() + err_msg["status"] = "failed" #birbal added + result_queue.put(err_msg) #birbal added return False, model_name except Exception as e: - _logger.error({ "model": model_name, "Exception": e }) + _logger.error({ "model": model_name, "Exception": e }) + err_msg = { "model": model_name, "status": "failed","Exception": str(e) } #birbal string casted + result_queue.put(err_msg) #birbal added import traceback traceback.print_exc() return False, model_name -def _export_model(mlflow_client, model_name, output_dir, opts): +def _export_model(mlflow_client, model_name, output_dir, opts, result_queue = None): #birbal added result_queue ori_versions = model_utils.list_model_versions(mlflow_client, model_name, opts.export_latest_versions) + _logger.info(f"TOTAL MODELS VERSIONS TO EXPORT: {len(ori_versions)}") #birbal added + msg = "latest" if opts.export_latest_versions else "all" _logger.info(f"Exporting model '{model_name}': found {len(ori_versions)} '{msg}' versions") model = model_utils.get_registered_model(mlflow_client, model_name, opts.export_permissions) - versions, failed_versions = _export_versions(mlflow_client, model, ori_versions, output_dir, opts) + versions, failed_versions = _export_versions(mlflow_client, model, ori_versions, output_dir, opts, result_queue) #birbal added result_queue _adjust_model(model, versions) info_attr = { @@ -110,12 +118,21 @@ def _export_model(mlflow_client, model_name, output_dir, opts): "export_latest_versions": opts.export_latest_versions, "export_permissions": opts.export_permissions } - _model = { "registered_model": model } - io_utils.write_export_file(output_dir, "model.json", __file__, _model, info_attr) - _logger.info(f"Exported {len(versions)}/{len(ori_versions)} '{msg}' versions for model '{model_name}'") + try: #birbal added + _model = { "registered_model": model } + io_utils.write_export_file(output_dir, "model.json", __file__, _model, info_attr) + _logger.info(f"Exported {len(versions)}/{len(ori_versions)} '{msg}' versions for model '{model_name}'") + except Exception as e: + ##birbal added this block to resolve ""Object of type ModelVersionDeploymentJobState is not JSON" error + model = str(model).replace("<", "\"").replace(">", "\"") + model = ast.literal_eval(model) + #birbal below end + _model = { "registered_model": model } + io_utils.write_export_file(output_dir, "model.json", __file__, _model, info_attr) + _logger.warning(f"Exported {len(versions)}/{len(ori_versions)} '{msg}' versions for model '{model_name}' AFTER applying the FIX(replaced < and > with double quote). Else it will throw this exception due to the presence of < and > in the dict value of key deployment_job_state. Exception : {str(e)} which will cause issues during MODEL IMPORT") -def _export_versions(mlflow_client, model_dct, versions, output_dir, opts): +def _export_versions(mlflow_client, model_dct, versions, output_dir, opts, result_queue = None): #birbal added result_queue aliases = model_dct.get("aliases", []) version_aliases = {} [ version_aliases.setdefault(x["version"], []).append(x["alias"]) for x in aliases ] # map of version => its aliases @@ -126,12 +143,12 @@ def _export_versions(mlflow_client, model_dct, versions, output_dir, opts): continue if len(opts.versions) > 0 and not vr.version in opts.versions: continue - _export_version(mlflow_client, vr, output_dir, version_aliases.get(vr.version,[]), output_versions, failed_versions, j, len(versions), opts) + _export_version(mlflow_client, vr, output_dir, version_aliases.get(vr.version,[]), output_versions, failed_versions, j, len(versions), opts, result_queue) #birbal added result_queue output_versions.sort(key=lambda x: x["version"], reverse=False) return output_versions, failed_versions -def _export_version(mlflow_client, vr, output_dir, aliases, output_versions, failed_versions, j, num_versions, opts): +def _export_version(mlflow_client, vr, output_dir, aliases, output_versions, failed_versions, j, num_versions, opts, result_queue = None): #birbal added result_queue _output_dir = os.path.join(output_dir, vr.run_id) msg = { "name": vr.name, "version": vr.version, "stage": vr.current_stage, "aliases": aliases } _logger.info(f"Exporting model verson {j+1}/{num_versions}: {msg} to '{_output_dir}'") @@ -148,7 +165,9 @@ def _export_version(mlflow_client, vr, output_dir, aliases, output_versions, fai export_deleted_runs = opts.export_deleted_runs, notebook_formats = opts.notebook_formats, mlflow_client = mlflow_client, - raise_exception = True + raise_exception = True, + result_queue = result_queue, #birbal added + vr = vr #birbal added ) if not run and not opts.export_deleted_runs: failed_msg = { "message": "deleted run", "version": vr_dct } @@ -158,7 +177,7 @@ def _export_version(mlflow_client, vr, output_dir, aliases, output_versions, fai output_versions.append(vr_dct) except RestException as e: - err_msg = { "model": vr.name, "version": vr.version, "run_id": vr.run_id, "RestException": e.json } + err_msg = { "model": vr.name, "version": vr.version, "run_id": vr.run_id, "RestException": str(e.json) } #birbal string casted if e.json.get("error_code") == "RESOURCE_DOES_NOT_EXIST": err_msg = { **{"message": "Version run probably does not exist"}, **err_msg} _logger.error(f"Version export failed (1): {err_msg}") @@ -171,6 +190,16 @@ def _export_version(mlflow_client, vr, output_dir, aliases, output_versions, fai failed_msg = { "version": vr_dct, "RestException": e.json } failed_versions.append(failed_msg) + err_msg["status"] = "failed" #birbal added + if result_queue: + result_queue.put(err_msg) #birbal added + + except Exception as e: + err_msg = { "model": vr.name, "version": vr.version, "run_id": vr.run_id, "status":"failed", "Exception": str(e) } #birbal string casted + if result_queue: + result_queue.put(err_msg) #birbal added + + def _add_metadata_to_version(mlflow_client, vr_dct, run): vr_dct["_run_artifact_uri"] = run.info.artifact_uri diff --git a/mlflow_export_import/model/import_model.py b/mlflow_export_import/model/import_model.py index 451f0e45..d8f894de 100644 --- a/mlflow_export_import/model/import_model.py +++ b/mlflow_export_import/model/import_model.py @@ -116,11 +116,17 @@ def _import_model(self, created_model = model_utils.create_model(self.mlflow_client, model_name, model_dct, True) perms = model_dct.get("permissions") if created_model and self.import_permissions and perms: - if model_utils.model_names_same_registry(model_dct["name"], model_name): - model_utils.update_model_permissions(self.mlflow_client, self.dbx_client, model_name, perms) - else: - _logger.warning(f'Cannot import permissions since models \'{model_dct["name"]}\' and \'{model_name}\' must be either both Unity Catalog model names or both Workspace model names.') - + try: #birbal added + if model_utils.model_names_same_registry(model_dct["name"], model_name): + model_utils.update_model_permissions(self.mlflow_client, self.dbx_client, model_name, perms) + elif model_utils.model_names_same_registry_nonucsrc_uctgt(model_dct["name"], model_name): #birbal added + model_utils.update_model_permissions(self.mlflow_client, self.dbx_client, model_name, perms, True) + else: + _logger.warning(f'Cannot import permissions since models \'{model_dct["name"]}\' and \'{model_name}\' must be either both Unity Catalog model names or both Workspace model names.') + except Exception as e: #birbal added + _logger.error(f"Error updating model permission for model {model_name} . Error: {e}") + else: ##birbal added + _logger.info(f"Model permission update skipped for model {model_name}") return model_dct @@ -157,14 +163,17 @@ def import_model(self, :param verbose: Verbose. :return: Model import manifest. """ + model_dct = self._import_model(model_name, input_dir, delete_model) + _logger.info("Importing versions:") for vr in model_dct.get("versions",[]): try: run_id = self._import_run(input_dir, experiment_name, vr) if run_id: self.import_version(model_name, vr, run_id) - except RestException as e: + # except RestException as e: #birbal commented out + except Exception as e: #birbal added msg = { "model": model_name, "version": vr["version"], "src_run_id": vr["run_id"], "experiment": experiment_name, "RestException": str(e) } _logger.error(f"Failed to import model version: {msg}") import traceback diff --git a/mlflow_export_import/model_version/import_model_version.py b/mlflow_export_import/model_version/import_model_version.py index 6d720a53..9428829a 100644 --- a/mlflow_export_import/model_version/import_model_version.py +++ b/mlflow_export_import/model_version/import_model_version.py @@ -103,7 +103,7 @@ def _import_model_version( src_vr, dst_run_id, dst_source, - import_stages_and_aliases = True, + import_stages_and_aliases = True, import_source_tags = False ): start_time = time.time() @@ -128,7 +128,7 @@ def _import_model_version( tags = tags ) - if import_stages_and_aliases: + if import_stages_and_aliases: for alias in src_vr.get("aliases",[]): mlflow_client.set_registered_model_alias(dst_vr.name, alias, dst_vr.version) diff --git a/mlflow_export_import/run/export_run.py b/mlflow_export_import/run/export_run.py index c0549888..f7a5a8e8 100644 --- a/mlflow_export_import/run/export_run.py +++ b/mlflow_export_import/run/export_run.py @@ -34,7 +34,9 @@ def export_run( skip_download_run_artifacts = False, notebook_formats = None, raise_exception = False, - mlflow_client = None + mlflow_client = None, + result_queue = None, #birbal addedbirbal added + vr = None # ): """ :param run_id: Run ID. @@ -62,7 +64,8 @@ def export_run( _logger.warning(f"Not exporting run '{run.info.run_id} because its lifecycle_stage is '{run.info.lifecycle_stage}'") return None experiment_id = run.info.experiment_id - msg = { "run_id": run.info.run_id, "lifecycle_stage": run.info.lifecycle_stage, "experiment_id": run.info.experiment_id } + experiment = mlflow_client.get_experiment(experiment_id) + msg = { "run_id": run.info.run_id, "experiment_id": run.info.experiment_id, "experiment_name": experiment.name} #birbal removed lifecycle_stage tags = run.data.tags tags = dict(sorted(tags.items())) @@ -101,19 +104,47 @@ def export_run( _logger.warning(f"No notebooks to export for run '{run_id}' since tag '{MLFLOW_DATABRICKS_NOTEBOOK_PATH}' is not set.") dur = format_seconds(time.time()-start_time) _logger.info(f"Exported run in {dur}: {msg}") + + msg["status"] = "success" #birbal added + if vr: + msg["model"] = vr.name #birbal added + msg["version"] = vr.version #birbal added + msg["stage"] = vr.current_stage #birbal added + if result_queue: + result_queue.put(msg) #birbal added return run except RestException as e: if raise_exception: raise e - err_msg = { "run_id": run_id, "experiment_id": experiment_id, "RestException": e.json } + # err_msg = { "run_id": run_id, "experiment_id": experiment_id, "RestException": e.json } #birbal commented out + err_msg = { "run_id": run_id, "experiment_id": experiment_id, "RestException": str(e.json) } #birbal string casted + + err_msg["status"] = "failed" #birbal added + if vr: + err_msg["model"] = vr.name #birbal added + err_msg["version"] = vr.version #birbal added + err_msg["stage"] = vr.current_stage #birbal added _logger.error(f"Run export failed (1): {err_msg}") + if result_queue: + result_queue.put(err_msg) #birbal added + return None except Exception as e: if raise_exception: raise e - err_msg = { "run_id": run_id, "experiment_id": experiment_id, "Exception": e } - _logger.error(f"Run export failed (2): {err_msg}") + # err_msg = { "run_id": run_id, "experiment_id": experiment_id, "Exception": e } #birbal commented out + err_msg = { "run_id": run_id, "experiment_id": experiment_id, "Exception": str(e) } #birbal string casted + + err_msg["status"] = "failed" #birbal added + if vr: + err_msg["model"] = vr.name #birbal added + err_msg["version"] = vr.version #birbal added + err_msg["stage"] = vr.current_stage #birbal added + _logger.error(f"Run export failed (2): {err_msg}") + if result_queue: + result_queue.put(err_msg) #birbal added + traceback.print_exc() return None diff --git a/mlflow_export_import/run/import_run.py b/mlflow_export_import/run/import_run.py index ad66a1ae..a56f156f 100644 --- a/mlflow_export_import/run/import_run.py +++ b/mlflow_export_import/run/import_run.py @@ -23,6 +23,9 @@ from mlflow_export_import.client.client_utils import create_mlflow_client, create_dbx_client, create_http_client from . import run_data_importer from . import run_utils +import mlflow.utils.databricks_utils as db_utils #birbal added +import requests #birbal added +import json _logger = utils.getLogger(__name__) @@ -33,7 +36,8 @@ def import_run( dst_notebook_dir = None, use_src_user_id = False, mlmodel_fix = True, - mlflow_client = None + mlflow_client = None, + exp = None ): """ Imports a run into the specified experiment. @@ -64,7 +68,10 @@ def _mk_ex(src_run_dct, dst_run_id, exp_name): _logger.info(f"Importing run from '{input_dir}'") - exp = mlflow_utils.set_experiment(mlflow_client, dbx_client, experiment_name) + # exp = mlflow_utils.set_experiment(mlflow_client, dbx_client, experiment_name) + if not exp: #birbal added + exp = mlflow_utils.set_experiment(mlflow_client, dbx_client, experiment_name) + src_run_path = os.path.join(input_dir, "run.json") src_run_dct = io_utils.read_file_mlflow(src_run_path) in_databricks = "DATABRICKS_RUNTIME_VERSION" in os.environ @@ -72,7 +79,7 @@ def _mk_ex(src_run_dct, dst_run_id, exp_name): run = mlflow_client.create_run(exp.experiment_id) run_id = run.info.run_id try: - run_data_importer.import_run_data( + run_data_importer.import_run_data( mlflow_client, src_run_dct, run_id, @@ -99,34 +106,33 @@ def _mk_ex(src_run_dct, dst_run_id, exp_name): traceback.print_exc() raise MlflowExportImportException(e, f"Importing run {run_id} of experiment '{exp.name}' failed") - if utils.calling_databricks() and dst_notebook_dir: - _upload_databricks_notebook(dbx_client, input_dir, src_run_dct, dst_notebook_dir) + if utils.calling_databricks(): #birbal added + _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dct, dst_notebook_dir,run_id) #birbal added.. passed mlflow_client res = (run, src_run_dct["tags"].get(MLFLOW_PARENT_RUN_ID, None)) _logger.info(f"Imported run '{run.info.run_id}' into experiment '{experiment_name}'") return res -def _upload_databricks_notebook(dbx_client, input_dir, src_run_dct, dst_notebook_dir): - run_id = src_run_dct["info"]["run_id"] +def _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dct, dst_notebook_dir,run_id): #birbal added tag_key = "mlflow.databricks.notebookPath" src_notebook_path = src_run_dct["tags"].get(tag_key,None) if not src_notebook_path: _logger.warning(f"No tag '{tag_key}' for run_id '{run_id}'") return notebook_name = os.path.basename(src_notebook_path) - + dst_notebook_dir = os.path.dirname(src_notebook_path) format = "source" - notebook_path = _fs.make_local_path(os.path.join(input_dir,"artifacts","notebooks",f"{notebook_name}.{format}")) - if not _fs.exists(notebook_path): - _logger.warning(f"Source '{notebook_path}' does not exist for run_id '{run_id}'") - return + + notebook_path = os.path.join(input_dir,"artifacts","notebooks",f"{notebook_name}.{format}") #birbal added with open(notebook_path, "r", encoding="utf-8") as f: content = f.read() - dst_notebook_path = os.path.join(dst_notebook_dir, notebook_name) + dst_notebook_path = src_notebook_path #birbal added + + content = base64.b64encode(content.encode()).decode("utf-8") - data = { + payload = { "path": dst_notebook_path, "language": "PYTHON", "format": format, @@ -136,10 +142,53 @@ def _upload_databricks_notebook(dbx_client, input_dir, src_run_dct, dst_notebook mlflow_utils.create_workspace_dir(dbx_client, dst_notebook_dir) try: _logger.info(f"Importing notebook '{dst_notebook_path}' for run {run_id}") - dbx_client._post("workspace/import", data) - except MlflowExportImportException as e: - _logger.warning(f"Cannot save notebook '{dst_notebook_path}'. {e}") + create_notebook(mlflow_client,payload,run_id) #birbal added + update_notebook_lineage(mlflow_client,run_id,dst_notebook_path) #birbal added + + except Exception as e: #birbal added + _logger.error(f"Error importing notebook '{dst_notebook_path}' for run_id {run_id}. Error - {e}") + + +def create_notebook(mlflow_client,payload,run_id): #birbal added this entire block + + creds = mlflow_client._tracking_client.store.get_host_creds() + host = creds.host + token = creds.token + + headers = { + "Authorization": f"Bearer {token}" + } + response = requests.post( + f"{host}/api/2.0/workspace/import", + headers=headers, + json=payload + ) + if response.status_code == 200: + _logger.info(f"Imported notebook for run_id {run_id} using workspace/import api") + else: + _logger.error(f"workspace/import api failed to import notebook for run_id {run_id}. response.text is {response.text}") + + + +def update_notebook_lineage(mlflow_client,run_id,dst_notebook_path): #birbal added this entire block + host=db_utils.get_workspace_url() + token=db_utils.get_databricks_host_creds().token + + HEADERS = { + 'Authorization': f'Bearer {token}' + } + get_url = f'{host}/api/2.0/workspace/get-status' + params = {'path': dst_notebook_path} + response = requests.get(get_url, headers=HEADERS, params=params) + response.raise_for_status() + notebook_object = response.json() + notebook_id = notebook_object.get("object_id") + + mlflow_client.set_tag(run_id, "mlflow.source.name", dst_notebook_path) + mlflow_client.set_tag(run_id, "mlflow.source.type", "NOTEBOOK") + mlflow_client.set_tag(run_id, "mlflow.databricks.notebookID", notebook_id) + mlflow_client.set_tag(run_id, "mlflow.databricks.workspaceURL", host) def _import_inputs(http_client, src_run_dct, run_id): inputs = src_run_dct.get("inputs") diff --git a/mlflow_export_import/run/run_data_importer.py b/mlflow_export_import/run/run_data_importer.py index bc6cec1a..e6876457 100644 --- a/mlflow_export_import/run/run_data_importer.py +++ b/mlflow_export_import/run/run_data_importer.py @@ -13,6 +13,7 @@ from mlflow_export_import.common.source_tags import mk_source_tags_mlflow_tag, mk_source_tags + def _log_data(run_dct, run_id, batch_size, get_data, log_data, args_get_data=None): metadata = get_data(run_dct, args_get_data) num_batches = int(math.ceil(len(metadata) / batch_size)) @@ -24,7 +25,6 @@ def _log_data(run_dct, run_id, batch_size, get_data, log_data, args_get_data=Non log_data(run_id, batch) res = res + batch - def _log_params(client, run_dct, run_id, batch_size): def get_data(run_dct, args): return [ Param(k,v) for k,v in run_dct["params"].items() ] @@ -59,6 +59,7 @@ def get_data(run_dct, args): tags = { **tags, **source_mlflow_tags, **source_info_tags } tags = utils.create_mlflow_tags_for_databricks_import(tags) # remove "mlflow" tags that cannot be imported into Databricks tags = [ RunTag(k,v) for k,v in tags.items() ] + if not in_databricks: utils.set_dst_user_id(tags, args["src_user_id"], args["use_src_user_id"]) return tags @@ -73,13 +74,14 @@ def log_data(run_id, tags): } _log_data(run_dct, run_id, batch_size, get_data, log_data, args_get) + def import_run_data(mlflow_client, run_dct, run_id, import_source_tags, src_user_id, use_src_user_id, in_databricks): from mlflow.utils.validation import MAX_PARAMS_TAGS_PER_BATCH, MAX_METRICS_PER_BATCH _log_params(mlflow_client, run_dct, run_id, MAX_PARAMS_TAGS_PER_BATCH) _log_metrics(mlflow_client, run_dct, run_id, MAX_METRICS_PER_BATCH) - _log_tags( + _log_tags( mlflow_client, run_dct, run_id, @@ -88,8 +90,7 @@ def import_run_data(mlflow_client, run_dct, run_id, import_source_tags, src_user in_databricks, src_user_id, use_src_user_id -) - + ) if __name__ == "__main__": import sys From d524d2eb6c478faec7532b0f898cce39eaea81da Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Mon, 30 Jun 2025 21:16:59 +0000 Subject: [PATCH 02/25] deleting .DS_Store --- .DS_Store | Bin 6148 -> 0 bytes databricks_notebooks/.DS_Store | Bin 6148 -> 0 bytes mlflow_export_import/.DS_Store | Bin 6148 -> 0 bytes 3 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store delete mode 100644 databricks_notebooks/.DS_Store delete mode 100644 mlflow_export_import/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index c432ecbfba68177465791335c9ba3d979dd3720d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKyH3ME5S)b+k!Vt+puAt;4^C0|0zSZy7$k&)Q38U}9ly=&!zhuVr9ivZ?%b_+ z?(8YNJ^-?MYp;L>fH_?epBkp7_thtM78ymcH8$Ae5nDXrFsc4Iq1+`}+~Jje%y00F zp4oe5dtO|#J#4#8-*>vy5<7Hwp!ax3!?F*&aJ(m_;X~}k*+}9OrlkU@Kq`<5qynjc z_H4EJg=5B4AQeajz7^2#L!m3yz`@Zz9Sn8^AkLUJzFDvMO3_;h>&~NFZvDV5NtQ8Zj6?3Ak`0}W(=$Q9w;NWO<=8aCw9|6@R KEfx3+1-<|VVI=ba diff --git a/databricks_notebooks/.DS_Store b/databricks_notebooks/.DS_Store deleted file mode 100644 index c546a39615299a97f594372e4906d38bb9a7068f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~F^SV?}7 zi2>O0bvpnffDPS=y@#0@^8pvUFyVfCUZ>0BF?o@;>VT*85wrc=7Nmd_kOERb3P^zk zDUiqbX1kzg(xXTLDXpcS=p~5F1H`^?Ok{*+NhK!Ls>QIRGu|q%FB}t-4vU+4PTg$Pp;+9`c#CvcpQup^ zNP%+&ZgaWx`u{|q>Hp73T1f#Za8(Lew|-u)_@t_>lgD|jZS)tq=X}xKI1dVkD96Mo i$6R Date: Mon, 30 Jun 2025 22:16:32 +0000 Subject: [PATCH 03/25] Change log level from DEBUG to INFO for console --- mlflow_export_import/common/default_logging_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlflow_export_import/common/default_logging_config.py b/mlflow_export_import/common/default_logging_config.py index 67a2782e..881c5d6f 100644 --- a/mlflow_export_import/common/default_logging_config.py +++ b/mlflow_export_import/common/default_logging_config.py @@ -9,7 +9,7 @@ "handlers": { "console": { "class": "logging.StreamHandler", - "level": "DEBUG", + "level": "INFO", "formatter": "simple", "stream": "ext://sys.stdout" }, @@ -22,7 +22,7 @@ }, "loggers": { "sampleLogger": { - "level": "DEBUG", + "level": "INFO", "handlers": [ "console" ], @@ -30,7 +30,7 @@ } }, "root": { - "level": "DEBUG", + "level": "INFO", "handlers": [ "console", "file" From b4433646bf85bd0bf6e949affbbe90a74c94acb9 Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Tue, 22 Jul 2025 20:57:18 +0000 Subject: [PATCH 04/25] test file. Need to remove --- mlflow_export_import/proj.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 mlflow_export_import/proj.py diff --git a/mlflow_export_import/proj.py b/mlflow_export_import/proj.py new file mode 100644 index 00000000..fda30490 --- /dev/null +++ b/mlflow_export_import/proj.py @@ -0,0 +1,23 @@ +from pyspark.sql import SparkSession +from pyspark.sql.functions import col, udf +from pyspark.sql.types import StringType + +import time + +df = spark.read.option("header", "true").csv("dbfs:/databricks-datasets/retail-org/customers/customers.csv") + +@udf(StringType()) +def to_upper_udf(name): + time.sleep(0.01) + return name.upper() if name else None + +df = df.withColumn("name_upper", to_upper_udf(col("customer_name"))) + +df = df.repartition(200) + +df.cache() +df = df.filter(col("state") == "CA") + +results = df.collect() + +df.write.mode("overwrite").csv("dbfs:/tmp/inefficient_output") \ No newline at end of file From 203846db63da4a25aa359ca09f5b4e9214982821 Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Wed, 6 Aug 2025 20:43:19 +0000 Subject: [PATCH 05/25] Code cleanup and few fixes --- databricks_notebooks/bulk/Export_All.py | 7 +- .../bulk/Export_Registered_Models.py | 83 +++++---- .../bulk/Import_Registered_Models.py | 2 +- .../bulk/master_Export_Registered_Models.py | 167 ++++++++---------- .../bulk/master_Export_all.py | 46 ++++- .../bulk/master_Import_Registered_Models.py | 13 +- mlflow_export_import/bulk/bulk_utils.py | 7 +- mlflow_export_import/bulk/export_all.py | 60 ++++--- .../common/ws_permissions_utils.py | 2 +- mlflow_export_import/model/export_model.py | 13 +- .../model_version/import_model_version.py | 47 ++--- mlflow_export_import/run/export_run.py | 14 +- mlflow_export_import/run/import_run.py | 3 +- 13 files changed, 246 insertions(+), 218 deletions(-) diff --git a/databricks_notebooks/bulk/Export_All.py b/databricks_notebooks/bulk/Export_All.py index 8df068e0..7ea00e6d 100644 --- a/databricks_notebooks/bulk/Export_All.py +++ b/databricks_notebooks/bulk/Export_All.py @@ -54,6 +54,9 @@ dbutils.widgets.text("jobrunid", "") jobrunid = dbutils.widgets.get("jobrunid") + +dbutils.widgets.text("model_file_name", "") +model_file_name = dbutils.widgets.get("model_file_name") if run_start_date=="": run_start_date = None @@ -66,6 +69,7 @@ print("num_tasks:", num_tasks) print("run_timestamp:", run_timestamp) print("jobrunid:", jobrunid) +print("model_file_name:", model_file_name) # COMMAND ---------- @@ -119,7 +123,8 @@ task_index = task_index, num_tasks = num_tasks, checkpoint_dir_experiment = checkpoint_dir_experiment, - checkpoint_dir_model = checkpoint_dir_model + checkpoint_dir_model = checkpoint_dir_model, + model_names = model_file_name ) # COMMAND ---------- diff --git a/databricks_notebooks/bulk/Export_Registered_Models.py b/databricks_notebooks/bulk/Export_Registered_Models.py index c44f84a2..71d83477 100644 --- a/databricks_notebooks/bulk/Export_Registered_Models.py +++ b/databricks_notebooks/bulk/Export_Registered_Models.py @@ -25,10 +25,11 @@ from mlflow_export_import.bulk import config import time +import os # COMMAND ---------- -models = dbutils.widgets.get("models") +model_file_name = dbutils.widgets.get("model_file_name") output_dir = dbutils.widgets.get("output_dir") output_dir = output_dir.replace("dbfs:","/dbfs") @@ -37,75 +38,63 @@ export_latest_versions = dbutils.widgets.get("export_latest_versions") == "true" -export_all_runs = dbutils.widgets.get("export_all_runs") == "true" - export_permissions = dbutils.widgets.get("export_permissions") == "true" export_deleted_runs = dbutils.widgets.get("export_deleted_runs") == "true" -export_version_model = dbutils.widgets.get("export_version_model") == "true" - -notebook_formats = dbutils.widgets.get("notebook_formats").split(",") - -use_threads = dbutils.widgets.get("use_threads") == "true" - task_index = int(dbutils.widgets.get("task_index")) num_tasks = int(dbutils.widgets.get("num_tasks")) run_timestamp = dbutils.widgets.get("run_timestamp") -# os.environ["OUTPUT_DIR"] = output_dir +dbutils.widgets.text("jobrunid", "") +jobrunid = dbutils.widgets.get("jobrunid") -print("models:", models) +print("model_file_name:", model_file_name) print("output_dir:", output_dir) print("stages:", stages) print("export_latest_versions:", export_latest_versions) -print("export_all_runs:", export_all_runs) print("export_permissions:", export_permissions) print("export_deleted_runs:", export_deleted_runs) -print("export_version_model:", export_version_model) -print("notebook_formats:", notebook_formats) -print("use_threads:", use_threads) - print("task_index:", task_index) print("num_tasks:", num_tasks) print("run_timestamp:", run_timestamp) +print("jobrunid:", jobrunid) # COMMAND ---------- -if task_index == -1 and num_tasks == -1: - task_index = None - num_tasks = None - output_dir = f"{output_dir}/{run_timestamp}" - dbfs_log_path = f"{output_dir}/Export_Registered_Models.log" -else: - output_dir = f"{output_dir}/{run_timestamp}/{task_index}" - dbfs_log_path = f"{output_dir}/Export_Registered_Models_{task_index}.log" - -print("output_dir:", output_dir) -print("dbfs_log_path:", dbfs_log_path) +log_path=f"/tmp/my.log" +log_path # COMMAND ---------- -if dbfs_log_path.startswith("/Workspace"): - dbfs_log_path=dbfs_log_path.replace("/Workspace","file:/Workspace") -dbfs_log_path = dbfs_log_path.replace("/dbfs","dbfs:") -dbfs_log_path +config.log_path=log_path # COMMAND ---------- -# assert_widget(models, "1. Models") -# assert_widget(output_dir, "2. Output directory") +checkpoint_dir_experiment = os.path.join(output_dir, run_timestamp,"checkpoint", "experiments") +try: + if not os.path.exists(checkpoint_dir_experiment): + os.makedirs(checkpoint_dir_experiment, exist_ok=True) + print(f"checkpoint_dir_experiment: created {checkpoint_dir_experiment}") +except Exception as e: + raise Exception(f"Failed to create directory {checkpoint_dir_experiment}: {e}") # COMMAND ---------- -log_path=f"/tmp/my.log" -log_path +checkpoint_dir_model = os.path.join(output_dir, run_timestamp,"checkpoint", "models") +try: + if not os.path.exists(checkpoint_dir_model): + os.makedirs(checkpoint_dir_model, exist_ok=True) + print(f"checkpoint_dir_model: created {checkpoint_dir_model}") +except Exception as e: + raise Exception(f"Failed to create directory {checkpoint_dir_model}: {e}") # COMMAND ---------- -config.log_path=log_path +output_dir = os.path.join(output_dir, run_timestamp, jobrunid, str(task_index)) +output_dir # COMMAND ---------- @@ -116,18 +105,20 @@ from mlflow_export_import.bulk.export_models import export_models export_models( - model_names = models, + model_names = model_file_name, output_dir = output_dir, stages = stages, export_latest_versions = export_latest_versions, - export_all_runs = export_all_runs, - export_version_model = export_version_model, + export_all_runs = True, + export_version_model = False, export_permissions = export_permissions, export_deleted_runs = export_deleted_runs, - notebook_formats = notebook_formats, - use_threads = use_threads, + notebook_formats = ['SOURCE'], + use_threads = True, task_index = task_index, - num_tasks = num_tasks + num_tasks = num_tasks, + checkpoint_dir_experiment = checkpoint_dir_experiment, + checkpoint_dir_model = checkpoint_dir_model ) @@ -141,6 +132,14 @@ # COMMAND ---------- +dbfs_log_path = f"{output_dir}/export_all_{task_index}.log" +if dbfs_log_path.startswith("/Workspace"): + dbfs_log_path=dbfs_log_path.replace("/Workspace","file:/Workspace") +dbfs_log_path = dbfs_log_path.replace("/dbfs","dbfs:") +dbfs_log_path + +# COMMAND ---------- + dbutils.fs.cp(f"file:{log_path}", dbfs_log_path) # COMMAND ---------- diff --git a/databricks_notebooks/bulk/Import_Registered_Models.py b/databricks_notebooks/bulk/Import_Registered_Models.py index d7186797..c2f60d04 100644 --- a/databricks_notebooks/bulk/Import_Registered_Models.py +++ b/databricks_notebooks/bulk/Import_Registered_Models.py @@ -155,7 +155,7 @@ # COMMAND ---------- -curr_timestamp = datetime.now().strftime("%Y-%m-%dT%H:%M:%S") +curr_timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") dbfs_log_path = f"{input_dir}/Import_Registered_Models_{task_index}_{curr_timestamp}.log" if dbfs_log_path.startswith("/Workspace"): diff --git a/databricks_notebooks/bulk/master_Export_Registered_Models.py b/databricks_notebooks/bulk/master_Export_Registered_Models.py index e81a33da..f5da68b0 100644 --- a/databricks_notebooks/bulk/master_Export_Registered_Models.py +++ b/databricks_notebooks/bulk/master_Export_Registered_Models.py @@ -4,10 +4,14 @@ # COMMAND ---------- -dbutils.widgets.text("01. Models", "") -models = dbutils.widgets.get("01. Models") +dbutils.widgets.removeAll() -dbutils.widgets.text("02. Output directory", "/Workspace/Users/birbal.das@databricks.com/logs") +# COMMAND ---------- + +dbutils.widgets.text("01. model_file_name", "") +model_file_name = dbutils.widgets.get("01. model_file_name") + +dbutils.widgets.text("02. Output directory", "/dbfs/mnt/") output_dir = dbutils.widgets.get("02. Output directory") output_dir = output_dir.replace("dbfs:","/dbfs") @@ -17,48 +21,52 @@ dbutils.widgets.dropdown("04. Export latest versions","no",["yes","no"]) export_latest_versions = dbutils.widgets.get("04. Export latest versions") == "yes" -dbutils.widgets.dropdown("05. Export all runs","no",["yes","no"]) -export_all_runs = dbutils.widgets.get("05. Export all runs") == "yes" - -dbutils.widgets.dropdown("06. Export permissions","no",["yes","no"]) -export_permissions = dbutils.widgets.get("06. Export permissions") == "yes" - -dbutils.widgets.dropdown("07. Export deleted runs","no",["yes","no"]) -export_deleted_runs = dbutils.widgets.get("07. Export deleted runs") == "yes" - -dbutils.widgets.dropdown("08. Export version MLflow model","no",["yes","no"]) # TODO -export_version_model = dbutils.widgets.get("08. Export version MLflow model") == "yes" +dbutils.widgets.dropdown("05. Export permissions","no",["yes","no"]) +export_permissions = dbutils.widgets.get("05. Export permissions") == "yes" -# notebook_formats = get_notebook_formats("09") +dbutils.widgets.dropdown("06. Export deleted runs","no",["yes","no"]) +export_deleted_runs = dbutils.widgets.get("06. Export deleted runs") == "yes" -dbutils.widgets.multiselect("09. Notebook formats", "SOURCE", [ "SOURCE", "DBC", "HTML", "JUPYTER" ]) -notebook_formats = dbutils.widgets.get("09. Notebook formats") - -dbutils.widgets.dropdown("10. Use threads","no",["yes","no"]) -use_threads = dbutils.widgets.get("10. Use threads") == "yes" - - -dbutils.widgets.text("11. num_tasks", "") -num_tasks = dbutils.widgets.get("11. num_tasks") +dbutils.widgets.text("07. num_tasks", "1") +num_tasks = dbutils.widgets.get("07. num_tasks") import os os.environ["OUTPUT_DIR"] = output_dir -print("models:", models) +print("model_file_name:", model_file_name) print("output_dir:", output_dir) print("stages:", stages) print("export_latest_versions:", export_latest_versions) -print("export_all_runs:", export_all_runs) print("export_permissions:", export_permissions) print("export_deleted_runs:", export_deleted_runs) -print("export_version_model:", export_version_model) -print("notebook_formats:", notebook_formats) -print("use_threads:", use_threads) print("num_tasks:", num_tasks) # COMMAND ---------- +if not output_dir: + raise ValueError("output_dir cannot be empty") +if not output_dir.startswith("/dbfs/mnt"): + raise ValueError("output_dir must start with /dbfs/mnt") +if not num_tasks: + raise ValueError("num_tasks cannot be empty") +if not num_tasks.isdigit(): + raise ValueError("num_tasks must be a number") + +# COMMAND ---------- + +if model_file_name: + if not model_file_name.endswith(".txt"): + raise ValueError("model_file_name must end with .txt if not empty") + if not model_file_name.startswith("/dbfs"): + raise ValueError("model_file_name must start with /dbfs if not empty") +else: + model_file_name = "all" + +model_file_name + +# COMMAND ---------- + DATABRICKS_INSTANCE=dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags().get('browserHostName').getOrElse(None) DATABRICKS_INSTANCE = f"https://{DATABRICKS_INSTANCE}" DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) @@ -66,77 +74,38 @@ driver_node_type = "Standard_D4ds_v5" worker_node_type = "Standard_D4ds_v5" -def create_multi_task_job_json(models, output_dir, stages, export_latest_versions, export_all_runs, export_permissions, export_deleted_runs, export_version_model, notebook_formats, use_threads, num_tasks): +def create_multi_task_job_json(): tasks = [] - - if models.lower() == "all" or models.endswith(".txt"): - for i in range(1, int(num_tasks)+1): - task = { - "task_key": f"task_{i}", - "description": f"Bir Task for param1 = {i}", - "new_cluster": { - "spark_version": "15.4.x-cpu-ml-scala2.12", - "node_type_id": worker_node_type, - "driver_node_type_id": driver_node_type, - "num_workers": 1, - "data_security_mode": "SINGLE_USER", - "runtime_engine": "STANDARD" - }, - "notebook_task": { - "notebook_path": "/Workspace/Users/birbal.das@databricks.com/bir-mlflow-export-import/databricks_notebooks/bulk/Export_Registered_Models", - "base_parameters": { - "models" : models, - "output_dir" : output_dir, - "stages" : stages, - "export_latest_versions" : export_latest_versions, - "export_all_runs" : export_all_runs, - "export_permissions" : export_permissions, - "export_deleted_runs" : export_deleted_runs, - "export_version_model" : export_version_model, - "notebook_formats" : notebook_formats, - "use_threads" : use_threads, - "task_index": i, - "num_tasks" : num_tasks, - "run_timestamp" : "{{job.start_time.iso_date}}-jobid-{{job.id}}-jobrunid-{{job.run_id}}" - } - } + for i in range(1, int(num_tasks)+1): + task = { + "task_key": f"task_{i}", + "description": f"Bir Task for param1 = {i}", + "new_cluster": { + "spark_version": "15.4.x-cpu-ml-scala2.12", + "node_type_id": worker_node_type, + "driver_node_type_id": driver_node_type, + "num_workers": 1, + "data_security_mode": "SINGLE_USER", + "runtime_engine": "STANDARD" + }, + "notebook_task": { + "notebook_path": "/Workspace/Users/birbal.das@databricks.com/AA_sephora/birnew-mlflow-export-import/databricks_notebooks/bulk/Export_Registered_Models", + "base_parameters": { + "model_file_name" : model_file_name, + "output_dir" : output_dir, + "stages" : stages, + "export_latest_versions" : export_latest_versions, + "export_permissions" : export_permissions, + "export_deleted_runs" : export_deleted_runs, + "task_index": i, + "num_tasks" : num_tasks, + "run_timestamp": "{{job.start_time.iso_date}}-ExportModels-jobid-{{job.id}}", + "jobrunid": "jobrunid-{{job.run_id}}" } - tasks.append(task) - else: - task = { - "task_key": f"task", - "description": f"Bir Task for param1 ", - "new_cluster": { - "spark_version": "15.4.x-cpu-ml-scala2.12", - "node_type_id": worker_node_type, - "driver_node_type_id": driver_node_type, - "num_workers": 1, - "data_security_mode": "SINGLE_USER", - "runtime_engine": "STANDARD" - }, - "notebook_task": { - "notebook_path": "/Workspace/Users/birbal.das@databricks.com/bir-mlflow-export-import/databricks_notebooks/bulk/Export_Registered_Models", - "base_parameters": { - "models" : models, - "output_dir" : output_dir, - "stages" : stages, - "export_latest_versions" : export_latest_versions, - "export_all_runs" : export_all_runs, - "export_permissions" : export_permissions, - "export_deleted_runs" : export_deleted_runs, - "export_version_model" : export_version_model, - "notebook_formats" : notebook_formats, - "use_threads" : use_threads, - "task_index": "-1", - "num_tasks" : "-1", - "run_timestamp" : "{{job.start_time.iso_date}}-jobid-{{job.id}}-jobrunid-{{job.run_id}}" - } - } } - tasks.append(task) - - + } + tasks.append(task) job_json = { "name": "Export_Registered_Models_job", @@ -147,7 +116,7 @@ def create_multi_task_job_json(models, output_dir, stages, export_latest_version return job_json def submit_databricks_job(): - job_payload = create_multi_task_job_json(models, output_dir, stages, export_latest_versions, export_all_runs, export_permissions, export_deleted_runs, export_version_model, notebook_formats, use_threads, num_tasks) + job_payload = create_multi_task_job_json() headers = { "Authorization": f"Bearer {DATABRICKS_TOKEN}", @@ -171,3 +140,7 @@ def submit_databricks_job(): # COMMAND ---------- submit_databricks_job() + +# COMMAND ---------- + + diff --git a/databricks_notebooks/bulk/master_Export_all.py b/databricks_notebooks/bulk/master_Export_all.py index 93c62557..abb5ee63 100644 --- a/databricks_notebooks/bulk/master_Export_all.py +++ b/databricks_notebooks/bulk/master_Export_all.py @@ -5,6 +5,10 @@ # COMMAND ---------- +dbutils.widgets.removeAll() + +# COMMAND ---------- + dbutils.widgets.text("1. Output directory", "") output_dir = dbutils.widgets.get("1. Output directory") output_dir = output_dir.replace("dbfs:","/dbfs") @@ -21,8 +25,11 @@ dbutils.widgets.dropdown("5. Export permissions","no",["yes","no"]) export_permissions = dbutils.widgets.get("5. Export permissions") == "yes" -dbutils.widgets.text("11. num_tasks", "") -num_tasks = dbutils.widgets.get("11. num_tasks") +dbutils.widgets.text("6. num_tasks", "") +num_tasks = dbutils.widgets.get("6. num_tasks") + +dbutils.widgets.text("7. model_file_name", "") +model_file_name = dbutils.widgets.get("7. model_file_name") print("output_dir:", output_dir) print("stages:", stages) @@ -30,6 +37,30 @@ print("run_start_date:", run_start_date) print("export_permissions:", export_permissions) print("num_tasks:", num_tasks) +print("model_file_name:", model_file_name) + +# COMMAND ---------- + +if not output_dir: + raise ValueError("output_dir cannot be empty") +if not output_dir.startswith("/dbfs/mnt"): + raise ValueError("output_dir must start with /dbfs/mnt") +if not num_tasks: + raise ValueError("num_tasks cannot be empty") +if not num_tasks.isdigit(): + raise ValueError("num_tasks must be a number") + +# COMMAND ---------- + +if model_file_name: + if not model_file_name.endswith(".txt"): + raise ValueError("model_file_name must end with .txt if not empty") + if not model_file_name.startswith("/dbfs"): + raise ValueError("model_file_name must start with /dbfs if not empty") +else: + model_file_name = "all" + +model_file_name # COMMAND ---------- @@ -55,7 +86,7 @@ def create_multi_task_job_json(): "runtime_engine": "STANDARD" }, "notebook_task": { - "notebook_path": "/Workspace/Users/birbal.das@databricks.com/mlflowimport/bir-mlflow-export-import/databricks_notebooks/bulk/Export_All", + "notebook_path": "/Workspace/Users/birbal.das@databricks.com/AA_sephora/birnew-mlflow-export-import/databricks_notebooks/bulk/Export_All", "base_parameters": { "output_dir": output_dir, "stages": stages, @@ -64,8 +95,9 @@ def create_multi_task_job_json(): "export_permissions": export_permissions, "task_index": i, "num_tasks": num_tasks, - "run_timestamp": "{{job.start_time.iso_date}}-Export-jobid-{{job.id}}", - "jobrunid": "jobrunid-{{job.run_id}}" + "run_timestamp": "{{job.start_time.iso_date}}-ExportAll-jobid-{{job.id}}", + "jobrunid": "jobrunid-{{job.run_id}}", + "model_file_name": model_file_name } } } @@ -104,3 +136,7 @@ def submit_databricks_job(): # COMMAND ---------- submit_databricks_job() + +# COMMAND ---------- + + diff --git a/databricks_notebooks/bulk/master_Import_Registered_Models.py b/databricks_notebooks/bulk/master_Import_Registered_Models.py index 774e2404..188f2a1a 100644 --- a/databricks_notebooks/bulk/master_Import_Registered_Models.py +++ b/databricks_notebooks/bulk/master_Import_Registered_Models.py @@ -81,13 +81,6 @@ # COMMAND ---------- -if input_dir.startswith("/Workspace"): - input_dir=input_dir.replace("/Workspace","file:/Workspace") - -input_dir - -# COMMAND ---------- - DATABRICKS_INSTANCE=dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags().get('browserHostName').getOrElse(None) DATABRICKS_INSTANCE = f"https://{DATABRICKS_INSTANCE}" DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) @@ -110,7 +103,7 @@ def create_multi_task_job_json(): "runtime_engine": "STANDARD" }, "notebook_task": { - "notebook_path": "/Workspace/Users/birbal.das@databricks.com/test_final/bir-mlflow-export-import/databricks_notebooks/bulk/Import_Registered_Models", + "notebook_path": "/Workspace/Users/birbal.das@databricks.com/AA_sephora/birnew-mlflow-export-import/databricks_notebooks/bulk/Import_Registered_Models", "base_parameters": { "input_dir": os.path.join(input_dir,str(i)), "target_model_registry": target_model_registry, @@ -159,3 +152,7 @@ def submit_databricks_job(): # COMMAND ---------- submit_databricks_job() + +# COMMAND ---------- + + diff --git a/mlflow_export_import/bulk/bulk_utils.py b/mlflow_export_import/bulk/bulk_utils.py index d5b336ba..47910332 100644 --- a/mlflow_export_import/bulk/bulk_utils.py +++ b/mlflow_export_import/bulk/bulk_utils.py @@ -26,8 +26,12 @@ def _get_list(names, func_list, task_index=None, num_tasks=None): #birbal update return [ x for x in func_list() if x.startswith(prefix) ] else: return names.split(",") - else: + + elif isinstance(names, dict): #birbal added return names + + else: + return get_subset_list(names, task_index, num_tasks) #birbal updated @@ -38,7 +42,6 @@ def list_entities(): return _get_list(experiment_ids, list_entities) -# def get_model_names(mlflow_client, model_names): def get_model_names(mlflow_client, model_names,task_index=None,num_tasks=None): #birbal updated def list_entities(): return [ model.name for model in SearchRegisteredModelsIterator(mlflow_client) ] diff --git a/mlflow_export_import/bulk/export_all.py b/mlflow_export_import/bulk/export_all.py index 2ba55b3c..a4088e9c 100644 --- a/mlflow_export_import/bulk/export_all.py +++ b/mlflow_export_import/bulk/export_all.py @@ -47,7 +47,8 @@ def export_all( task_index = None, num_tasks = None, checkpoint_dir_experiment = None, - checkpoint_dir_model = None + checkpoint_dir_model = None, + model_names = None ): mlflow_client = mlflow_client or create_mlflow_client() @@ -58,7 +59,8 @@ def export_all( start_time = time.time() res_models = export_models( mlflow_client = mlflow_client, - model_names = "all", + # model_names = "all", + model_names = model_names, output_dir = output_dir, stages = stages, export_latest_versions = export_latest_versions, @@ -75,41 +77,43 @@ def export_all( ) - all_exps = SearchExperimentsIterator(mlflow_client) - all_exps = list(set(all_exps)) - all_exp_names = [ exp.name for exp in all_exps ] - _logger.info(f"TOTAL WORKSPACE EXPERIMENT COUNT: {len(all_exp_names)}") + res_exps = None + if not model_names.endswith('.txt'): + all_exps = SearchExperimentsIterator(mlflow_client) + all_exps = list(set(all_exps)) + all_exp_names = [ exp.name for exp in all_exps ] + _logger.info(f"TOTAL WORKSPACE EXPERIMENT COUNT: {len(all_exp_names)}") - all_model_exp_names=get_experiments_name_of_models(mlflow_client,model_names = "all") + all_model_exp_names=get_experiments_name_of_models(mlflow_client,model_names = "all") - all_model_exp_names = list(set(all_model_exp_names)) - _logger.info(f"TOTAL WORKSPACE MODEL EXPERIMENT COUNT: {len(all_model_exp_names)}") + all_model_exp_names = list(set(all_model_exp_names)) + _logger.info(f"TOTAL WORKSPACE MODEL EXPERIMENT COUNT: {len(all_model_exp_names)}") - remaining_exp_names = list(set(all_exp_names) - set(all_model_exp_names)) - _logger.info(f"TOTAL WORKSPACE EXPERIMENT COUNT WITH NO MODEL: {len(remaining_exp_names)}") + remaining_exp_names = list(set(all_exp_names) - set(all_model_exp_names)) + _logger.info(f"TOTAL WORKSPACE EXPERIMENT COUNT WITH NO MODEL: {len(remaining_exp_names)}") - remaining_exp_names_subset = bulk_utils.get_subset_list(remaining_exp_names, task_index, num_tasks) #birbal added - _logger.info(f"TOTAL WORKSPACE EXPERIMENT COUNT WITH NO MODEL FOR TASK_INDEX={task_index}: {len(remaining_exp_names_subset)}") #birbal added + remaining_exp_names_subset = bulk_utils.get_subset_list(remaining_exp_names, task_index, num_tasks) #birbal added + _logger.info(f"TOTAL WORKSPACE EXPERIMENT COUNT WITH NO MODEL FOR TASK_INDEX={task_index}: {len(remaining_exp_names_subset)}") #birbal added - exps_and_runs = get_experiment_runs_dict_from_names(mlflow_client, remaining_exp_names_subset) #birbal added + exps_and_runs = get_experiment_runs_dict_from_names(mlflow_client, remaining_exp_names_subset) #birbal added - exps_and_runs = filter_unprocessed_objects(checkpoint_dir_experiment,"experiments",exps_and_runs) - _logger.info(f"AFTER FILTERING OUT THE PROCESSED EXPERIMENTS FROM CHECKPOINT, TOTAL REMAINING COUNT: {len(exps_and_runs)}") + exps_and_runs = filter_unprocessed_objects(checkpoint_dir_experiment,"experiments",exps_and_runs) + _logger.info(f"AFTER FILTERING OUT THE PROCESSED EXPERIMENTS FROM CHECKPOINT, TOTAL REMAINING COUNT: {len(exps_and_runs)}") - res_exps = export_experiments( - mlflow_client = mlflow_client, - experiments = exps_and_runs, #birbal added - output_dir = os.path.join(output_dir,"experiments"), - export_permissions = export_permissions, - run_start_time = run_start_time, - export_deleted_runs = export_deleted_runs, - notebook_formats = notebook_formats, - use_threads = use_threads, - task_index = task_index, #birbal added - checkpoint_dir_experiment = checkpoint_dir_experiment #birbal - ) + res_exps = export_experiments( + mlflow_client = mlflow_client, + experiments = exps_and_runs, #birbal added + output_dir = os.path.join(output_dir,"experiments"), + export_permissions = export_permissions, + run_start_time = run_start_time, + export_deleted_runs = export_deleted_runs, + notebook_formats = notebook_formats, + use_threads = use_threads, + task_index = task_index, #birbal added + checkpoint_dir_experiment = checkpoint_dir_experiment #birbal + ) duration = round(time.time() - start_time, 1) info_attr = { "options": { diff --git a/mlflow_export_import/common/ws_permissions_utils.py b/mlflow_export_import/common/ws_permissions_utils.py index 30de5828..b7f5e963 100644 --- a/mlflow_export_import/common/ws_permissions_utils.py +++ b/mlflow_export_import/common/ws_permissions_utils.py @@ -32,7 +32,7 @@ def _call_get(dbx_client, resource): try: return dbx_client.get(resource) except MlflowExportImportException as e: - _logger.error(e.kwargs) + _logger.error(f"Error while retrieving permissions with endpoint {resource}. Most probably due to notebook scoped experiment. Valid experiment type (mlflow.experimentType) should be MLFLOW_EXPERIMENT, not NOTEBOOK. Verify experimentType using SDK. Error: {e.kwargs}") #birbal added return {} diff --git a/mlflow_export_import/model/export_model.py b/mlflow_export_import/model/export_model.py index 3d3a66f2..c7e3fef9 100644 --- a/mlflow_export_import/model/export_model.py +++ b/mlflow_export_import/model/export_model.py @@ -90,23 +90,25 @@ def export_model( result_queue.put(err_msg) #birbal added return False, model_name except Exception as e: - _logger.error({ "model": model_name, "Exception": e }) + _logger.error({ "model": model_name, "Exception": str(e) }) err_msg = { "model": model_name, "status": "failed","Exception": str(e) } #birbal string casted result_queue.put(err_msg) #birbal added - import traceback - traceback.print_exc() + # import traceback + # traceback.print_exc() return False, model_name def _export_model(mlflow_client, model_name, output_dir, opts, result_queue = None): #birbal added result_queue ori_versions = model_utils.list_model_versions(mlflow_client, model_name, opts.export_latest_versions) - _logger.info(f"TOTAL MODELS VERSIONS TO EXPORT: {len(ori_versions)}") #birbal added + _logger.info(f"TOTAL MODELS VERSIONS TO EXPORT FOR MODEL {model_name}: {len(ori_versions)}") #birbal added msg = "latest" if opts.export_latest_versions else "all" _logger.info(f"Exporting model '{model_name}': found {len(ori_versions)} '{msg}' versions") model = model_utils.get_registered_model(mlflow_client, model_name, opts.export_permissions) + versions, failed_versions = _export_versions(mlflow_client, model, ori_versions, output_dir, opts, result_queue) #birbal added result_queue + _adjust_model(model, versions) info_attr = { @@ -118,6 +120,8 @@ def _export_model(mlflow_client, model_name, output_dir, opts, result_queue = No "export_latest_versions": opts.export_latest_versions, "export_permissions": opts.export_permissions } + + try: #birbal added _model = { "registered_model": model } io_utils.write_export_file(output_dir, "model.json", __file__, _model, info_attr) @@ -140,6 +144,7 @@ def _export_versions(mlflow_client, model_dct, versions, output_dir, opts, resul output_versions, failed_versions = ([], []) for j,vr in enumerate(versions): if not model_utils.is_unity_catalog_model(model_dct["name"]) and vr.current_stage and (len(opts.stages) > 0 and not vr.current_stage.lower() in opts.stages): + _logger.info(f"_export_version skipped") #birbal continue if len(opts.versions) > 0 and not vr.version in opts.versions: continue diff --git a/mlflow_export_import/model_version/import_model_version.py b/mlflow_export_import/model_version/import_model_version.py index 9428829a..7602a943 100644 --- a/mlflow_export_import/model_version/import_model_version.py +++ b/mlflow_export_import/model_version/import_model_version.py @@ -119,27 +119,32 @@ def _import_model_version( # The client's tracking_uri is not honored. Instead MlflowClient.create_model_version() # seems to use mlflow.tracking_uri internally to download run artifacts for UC models. _logger.info(f"Importing model version '{model_name}'") - with MlflowTrackingUriTweak(mlflow_client): - dst_vr = mlflow_client.create_model_version( - name = model_name, - source = dst_source, - run_id = dst_run_id, - description = src_vr.get("description"), - tags = tags - ) - - if import_stages_and_aliases: - for alias in src_vr.get("aliases",[]): - mlflow_client.set_registered_model_alias(dst_vr.name, alias, dst_vr.version) - - if not model_utils.is_unity_catalog_model(model_name): - src_current_stage = src_vr["current_stage"] - if src_current_stage and src_current_stage != "None": # fails for Databricks but not OSS - mlflow_client.transition_model_version_stage(model_name, dst_vr.version, src_current_stage) - - dur = format_seconds(time.time()-start_time) - _logger.info(f"Imported model version '{model_name}/{dst_vr.version}' in {dur}") - return mlflow_client.get_model_version(dst_vr.name, dst_vr.version) + + try: #birbal added + with MlflowTrackingUriTweak(mlflow_client): + dst_vr = mlflow_client.create_model_version( + name = model_name, + source = dst_source, + run_id = dst_run_id, + description = src_vr.get("description"), + tags = tags + ) + + if import_stages_and_aliases: + for alias in src_vr.get("aliases",[]): + mlflow_client.set_registered_model_alias(dst_vr.name, alias, dst_vr.version) + + if not model_utils.is_unity_catalog_model(model_name): + src_current_stage = src_vr["current_stage"] + if src_current_stage and src_current_stage != "None": # fails for Databricks but not OSS + mlflow_client.transition_model_version_stage(model_name, dst_vr.version, src_current_stage) + + dur = format_seconds(time.time()-start_time) + _logger.info(f"Imported model version '{model_name}/{dst_vr.version}' in {dur}") + return mlflow_client.get_model_version(dst_vr.name, dst_vr.version) + + except Exception as e: + _logger.error(f"Error creating model version of {model_name}. Error: {str(e)}") def _get_model_path(src_vr): diff --git a/mlflow_export_import/run/export_run.py b/mlflow_export_import/run/export_run.py index f7a5a8e8..dd6a9b90 100644 --- a/mlflow_export_import/run/export_run.py +++ b/mlflow_export_import/run/export_run.py @@ -21,7 +21,7 @@ from mlflow_export_import.client.client_utils import create_mlflow_client, create_dbx_client from mlflow_export_import.notebook.download_notebook import download_notebook -from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_NOTEBOOK_PATH +from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_NOTEBOOK_PATH, MLFLOW_DATABRICKS_NOTEBOOK_ID #birbal added MLFLOW_DATABRICKS_NOTEBOOK_ID MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID = "mlflow.databricks.notebookRevisionID" # NOTE: not in mlflow/utils/mlflow_tags.py _logger = utils.getLogger(__name__) @@ -58,6 +58,7 @@ def export_run( experiment_id = None try: run = mlflow_client.get_run(run_id) + dst_path = os.path.join(output_dir, "artifacts") msg = { "run_id": run.info.run_id, "dst_path": dst_path } if run.info.lifecycle_stage == "deleted" and not export_deleted_runs: @@ -94,6 +95,7 @@ def export_run( run_id = run.info.run_id, dst_path = _fs.mk_local_path(dst_path), tracking_uri = mlflow_client._tracking_client.tracking_uri) + notebook = tags.get(MLFLOW_DATABRICKS_NOTEBOOK_PATH) # export notebook as artifact @@ -101,7 +103,7 @@ def export_run( if len(notebook_formats) > 0: _export_notebook(dbx_client, output_dir, notebook, notebook_formats, run, fs) elif len(notebook_formats) > 0: - _logger.warning(f"No notebooks to export for run '{run_id}' since tag '{MLFLOW_DATABRICKS_NOTEBOOK_PATH}' is not set.") + _logger.error(f"No notebooks to export for run '{run_id}' since run tag '{MLFLOW_DATABRICKS_NOTEBOOK_PATH}' is not set.") dur = format_seconds(time.time()-start_time) _logger.info(f"Exported run in {dur}: {msg}") @@ -117,7 +119,6 @@ def export_run( except RestException as e: if raise_exception: raise e - # err_msg = { "run_id": run_id, "experiment_id": experiment_id, "RestException": e.json } #birbal commented out err_msg = { "run_id": run_id, "experiment_id": experiment_id, "RestException": str(e.json) } #birbal string casted err_msg["status"] = "failed" #birbal added @@ -133,7 +134,6 @@ def export_run( except Exception as e: if raise_exception: raise e - # err_msg = { "run_id": run_id, "experiment_id": experiment_id, "Exception": e } #birbal commented out err_msg = { "run_id": run_id, "experiment_id": experiment_id, "Exception": str(e) } #birbal string casted err_msg["status"] = "failed" #birbal added @@ -164,9 +164,9 @@ def _export_notebook(dbx_client, output_dir, notebook, notebook_formats, run, fs notebook_dir = os.path.join(output_dir, "artifacts", "notebooks") fs.mkdirs(notebook_dir) revision_id = run.data.tags.get(MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID) - if not revision_id: - _logger.warning(f"Cannot download notebook '{notebook}' for run '{run.info.run_id}' since tag '{MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID}' does not exist. Notebook is probably a Git Repo notebook.") - return + # if not revision_id: #birbal commented out. If not, it simply skips the notebook download due to missing notebook versionID which shouldn't be the case. + # _logger.warning(f"Cannot download notebook '{notebook}' for run '{run.info.run_id}' since tag '{MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID}' does not exist. Notebook is probably a Git Repo notebook.") + # return manifest = { MLFLOW_DATABRICKS_NOTEBOOK_PATH: run.data.tags[MLFLOW_DATABRICKS_NOTEBOOK_PATH], MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID: revision_id, diff --git a/mlflow_export_import/run/import_run.py b/mlflow_export_import/run/import_run.py index a56f156f..3dc033f0 100644 --- a/mlflow_export_import/run/import_run.py +++ b/mlflow_export_import/run/import_run.py @@ -133,7 +133,8 @@ def _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dc content = base64.b64encode(content.encode()).decode("utf-8") payload = { - "path": dst_notebook_path, + # "path": dst_notebook_path, + "path": dst_notebook_path + "_notebook", ##birbal added _notebook to fix issue with Notebook scoped experiment "language": "PYTHON", "format": format, "overwrite": True, From 28db530e37e816d804857b8fac4122ab1bdf2e4c Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Wed, 13 Aug 2025 22:19:38 +0000 Subject: [PATCH 06/25] fix --- databricks_notebooks/bulk/Common.py | 20 +++++++++++++++++++ databricks_notebooks/bulk/Export_All.py | 2 ++ .../bulk/Export_All_log_parsing.py | 4 ++-- .../bulk/Import_Registered_Models.py | 4 ++-- .../bulk/master_Export_all.py | 12 ++++------- .../bulk/master_Import_Registered_Models.py | 4 ++-- mlflow_export_import/bulk/model_utils.py | 6 ++++-- mlflow_export_import/common/model_utils.py | 12 +++++------ mlflow_export_import/model/import_model.py | 1 + 9 files changed, 43 insertions(+), 22 deletions(-) diff --git a/databricks_notebooks/bulk/Common.py b/databricks_notebooks/bulk/Common.py index 76201660..23ac7574 100644 --- a/databricks_notebooks/bulk/Common.py +++ b/databricks_notebooks/bulk/Common.py @@ -25,3 +25,23 @@ def get_notebook_formats(num): notebook_formats = notebook_formats.split(",") if "" in notebook_formats: notebook_formats.remove("") return notebook_formats + +# COMMAND ---------- + +import mlflow +display([{"mlflow_version": mlflow.__version__}]) + +# COMMAND ---------- + +# MAGIC %pip install -U mlflow-skinny +# MAGIC %pip install -U git+https:///github.com/mlflow/mlflow-export-import/#egg=mlflow-export-import +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +import mlflow +display([{"mlflow_version": mlflow.__version__}]) + +# COMMAND ---------- + + diff --git a/databricks_notebooks/bulk/Export_All.py b/databricks_notebooks/bulk/Export_All.py index 7ea00e6d..caf493bd 100644 --- a/databricks_notebooks/bulk/Export_All.py +++ b/databricks_notebooks/bulk/Export_All.py @@ -99,12 +99,14 @@ # COMMAND ---------- log_path=f"/tmp/my.log" +dbfs_log_path = f"{output_dir}/export_all_{task_index}.log" log_path # COMMAND ---------- config.log_path=log_path config.export_or_import="export" +config.target_model_registry="unity_catalog" ## birbal...remove # COMMAND ---------- diff --git a/databricks_notebooks/bulk/Export_All_log_parsing.py b/databricks_notebooks/bulk/Export_All_log_parsing.py index 7f21f905..b7ee62e2 100644 --- a/databricks_notebooks/bulk/Export_All_log_parsing.py +++ b/databricks_notebooks/bulk/Export_All_log_parsing.py @@ -1,5 +1,5 @@ # Databricks notebook source -spark.read.parquet("dbfs:/mnt/modelnonuc/2025-06-17-Export-jobid-34179827290231/checkpoint/models/*.parquet").createOrReplaceTempView("models") +spark.read.parquet("/checkpoint/models/*.parquet").createOrReplaceTempView("models") # COMMAND ---------- @@ -11,7 +11,7 @@ # COMMAND ---------- -spark.read.parquet("dbfs:/mnt/modelnonuc/2025-06-17-Export-jobid-34179827290231/checkpoint/experiments").createOrReplaceTempView("experiments") +spark.read.parquet("/checkpoint/experiments").createOrReplaceTempView("experiments") # COMMAND ---------- diff --git a/databricks_notebooks/bulk/Import_Registered_Models.py b/databricks_notebooks/bulk/Import_Registered_Models.py index c2f60d04..b3254caf 100644 --- a/databricks_notebooks/bulk/Import_Registered_Models.py +++ b/databricks_notebooks/bulk/Import_Registered_Models.py @@ -81,8 +81,8 @@ if not input_dir: raise ValueError("input_dir cannot be empty") -if not input_dir.startswith("/dbfs/mnt"): - raise ValueError("input_dir must start with /dbfs/mnt") +# if not input_dir.startswith("/dbfs/mnt"): +# raise ValueError("input_dir must start with /dbfs/mnt") if not task_index: raise ValueError("task_index cannot be empty") if not task_index.isdigit(): diff --git a/databricks_notebooks/bulk/master_Export_all.py b/databricks_notebooks/bulk/master_Export_all.py index abb5ee63..7e5125a0 100644 --- a/databricks_notebooks/bulk/master_Export_all.py +++ b/databricks_notebooks/bulk/master_Export_all.py @@ -5,10 +5,6 @@ # COMMAND ---------- -dbutils.widgets.removeAll() - -# COMMAND ---------- - dbutils.widgets.text("1. Output directory", "") output_dir = dbutils.widgets.get("1. Output directory") output_dir = output_dir.replace("dbfs:","/dbfs") @@ -43,8 +39,8 @@ if not output_dir: raise ValueError("output_dir cannot be empty") -if not output_dir.startswith("/dbfs/mnt"): - raise ValueError("output_dir must start with /dbfs/mnt") +# if not output_dir.startswith("/dbfs/mnt"): +# raise ValueError("output_dir must start with /dbfs/mnt") if not num_tasks: raise ValueError("num_tasks cannot be empty") if not num_tasks.isdigit(): @@ -55,8 +51,8 @@ if model_file_name: if not model_file_name.endswith(".txt"): raise ValueError("model_file_name must end with .txt if not empty") - if not model_file_name.startswith("/dbfs"): - raise ValueError("model_file_name must start with /dbfs if not empty") + # if not model_file_name.startswith("/dbfs"): + # raise ValueError("model_file_name must start with /dbfs if not empty") else: model_file_name = "all" diff --git a/databricks_notebooks/bulk/master_Import_Registered_Models.py b/databricks_notebooks/bulk/master_Import_Registered_Models.py index 188f2a1a..15873797 100644 --- a/databricks_notebooks/bulk/master_Import_Registered_Models.py +++ b/databricks_notebooks/bulk/master_Import_Registered_Models.py @@ -50,8 +50,8 @@ if not input_dir: raise ValueError("input_dir cannot be empty") -if not input_dir.startswith("/dbfs/mnt"): - raise ValueError("input_dir must start with /dbfs/mnt") +# if not input_dir.startswith("/dbfs/mnt"): +# raise ValueError("input_dir must start with /dbfs/mnt") if not num_tasks: raise ValueError("num_tasks cannot be empty") if not num_tasks.isdigit(): diff --git a/mlflow_export_import/bulk/model_utils.py b/mlflow_export_import/bulk/model_utils.py index 1939d4b0..c6699b4b 100644 --- a/mlflow_export_import/bulk/model_utils.py +++ b/mlflow_export_import/bulk/model_utils.py @@ -15,7 +15,8 @@ def get_experiments_runs_of_models(client, model_names, task_index=None, num_tas _logger.info(f" {model_name}") exps_and_runs = {} for model_name in model_names: - versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal.Changed from "name='{model_name}'" to handle models name with single quote + versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'" ) + # versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal.Changed from "name='{model_name}'" to handle models name with single quote for vr in versions: try: run = client.get_run(vr.run_id) @@ -48,7 +49,8 @@ def get_experiments_name_of_models(client, model_names): model_names = bulk_utils.get_model_names(client, model_names) experiment_name_list = [] for model_name in model_names: - versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal. Fix for models name with single quote + versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") + # versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal. Fix for models name with single quote for vr in versions: try: run = client.get_run(vr.run_id) diff --git a/mlflow_export_import/common/model_utils.py b/mlflow_export_import/common/model_utils.py index a55bac18..db03765c 100644 --- a/mlflow_export_import/common/model_utils.py +++ b/mlflow_export_import/common/model_utils.py @@ -57,8 +57,8 @@ def delete_model(client, model_name, sleep_time=5): Delete a registered model and all its versions. """ try: - # versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") - versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal added + versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") + # versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal added _logger.info(f"Deleting model '{model_name}' and its versions") for vr in versions: msg = utils.get_obj_key_values(vr, [ "name", "version", "current_stage", "status", "run_id" ]) @@ -79,16 +79,16 @@ def list_model_versions(client, model_name, get_latest_versions=False): List 'all' or the 'latest' versions of registered model. """ if is_unity_catalog_model(model_name): - # versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") - versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal added + versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") + # versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal added # JIRA: ES-834105 - UC-ML MLflow search_registered_models and search_model_versions do not return tags and aliases - 2023-08-21 return [ client.get_model_version(vr.name, vr.version) for vr in versions ] else: if get_latest_versions: return client.get_latest_versions(model_name) else: - # return list(SearchModelVersionsIterator(client, filter=f"name='{model_name}'")) - return list(SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """)) #birbal added + return list(SearchModelVersionsIterator(client, filter=f"name='{model_name}'")) + # return list(SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """)) #birbal added def search_model_versions(client, filter): diff --git a/mlflow_export_import/model/import_model.py b/mlflow_export_import/model/import_model.py index d8f894de..8bdb9a7d 100644 --- a/mlflow_export_import/model/import_model.py +++ b/mlflow_export_import/model/import_model.py @@ -269,6 +269,7 @@ def import_model(self, _logger.info(f"Importing {len(model_dct['versions'])} versions:") for vr in model_dct["versions"]: src_run_id = vr["run_id"] + _logger.info(f"self.run_info_map is {self.run_info_map}") ##birbal...need to remove dst_run_info = self.run_info_map.get(src_run_id, None) if not dst_run_info: msg = { "model": model_name, "version": vr["version"], "stage": vr["current_stage"], "run_id": src_run_id } From d7728cce0dc326b14407b04f8955a72fcf73d867 Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Fri, 15 Aug 2025 22:38:46 +0000 Subject: [PATCH 07/25] Fix- Revision id format change --- mlflow_export_import/notebook/download_notebook.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlflow_export_import/notebook/download_notebook.py b/mlflow_export_import/notebook/download_notebook.py index 31cdb7a0..1ada3829 100644 --- a/mlflow_export_import/notebook/download_notebook.py +++ b/mlflow_export_import/notebook/download_notebook.py @@ -28,14 +28,15 @@ def _download_notebook(notebook_workspace_path, output_dir, format, extension, r "format": format } if revision_id: - params ["revision"] = { "revision_timestamp": revision_id } # NOTE: not publicly documented + # params ["revision"] = { "revision_timestamp": revision_id } # NOTE: not publicly documented + params ["revision.revision_timestamp"] = revision_id # Birbal. Above longer supports due to change in format. Format change was done sometime in August 2025. notebook_name = os.path.basename(notebook_workspace_path) try: rsp = dbx_client._get("workspace/export", json.dumps(params)) notebook_path = os.path.join(output_dir, f"{notebook_name}.{extension}") io_utils.write_file(notebook_path, rsp.content) except MlflowExportImportException as e: - _logger.warning(f"Cannot download notebook '{notebook_workspace_path}'. {e}") + _logger.error(f"Cannot download notebook '{notebook_workspace_path}'. {e}") @click.command() From 6ce581529b1427943e0d1a92579e0a55ad7376f7 Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Sat, 16 Aug 2025 01:51:51 +0000 Subject: [PATCH 08/25] fix for notebook scoped experiment --- mlflow_export_import/run/import_run.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlflow_export_import/run/import_run.py b/mlflow_export_import/run/import_run.py index 3dc033f0..f6f8f855 100644 --- a/mlflow_export_import/run/import_run.py +++ b/mlflow_export_import/run/import_run.py @@ -128,13 +128,12 @@ def _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dc with open(notebook_path, "r", encoding="utf-8") as f: content = f.read() - dst_notebook_path = src_notebook_path #birbal added + dst_notebook_path = src_notebook_path + "_notebook" ##birbal added _notebook to fix issue with Notebook scoped experiment content = base64.b64encode(content.encode()).decode("utf-8") payload = { - # "path": dst_notebook_path, - "path": dst_notebook_path + "_notebook", ##birbal added _notebook to fix issue with Notebook scoped experiment + "path": dst_notebook_path, "language": "PYTHON", "format": format, "overwrite": True, @@ -190,6 +189,7 @@ def update_notebook_lineage(mlflow_client,run_id,dst_notebook_path): #birbal mlflow_client.set_tag(run_id, "mlflow.source.type", "NOTEBOOK") mlflow_client.set_tag(run_id, "mlflow.databricks.notebookID", notebook_id) mlflow_client.set_tag(run_id, "mlflow.databricks.workspaceURL", host) + mlflow_client.set_tag(run_id, "mlflow.databricks.notebookPath", dst_notebook_path) def _import_inputs(http_client, src_run_dct, run_id): inputs = src_run_dct.get("inputs") From 1490f57b2381d7487a1c23c737bc387b85f0f92f Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Sat, 16 Aug 2025 02:54:14 +0000 Subject: [PATCH 09/25] Skip notebook import if same workspace --- mlflow_export_import/run/import_run.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mlflow_export_import/run/import_run.py b/mlflow_export_import/run/import_run.py index f6f8f855..359539e6 100644 --- a/mlflow_export_import/run/import_run.py +++ b/mlflow_export_import/run/import_run.py @@ -107,8 +107,14 @@ def _mk_ex(src_run_dct, dst_run_id, exp_name): raise MlflowExportImportException(e, f"Importing run {run_id} of experiment '{exp.name}' failed") if utils.calling_databricks(): #birbal added - _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dct, dst_notebook_dir,run_id) #birbal added.. passed mlflow_client - + creds = mlflow_client._tracking_client.store.get_host_creds() + src_webappURL = src_run_dct["tags"]["mlflow.databricks.webappURL"] + _logger.info(f"src_webappURL izzzz {src_webappURL} and target webappURL is {creds.host}") + if creds.host != src_webappURL: + _logger.info(f"NOTEBOOK IMPORT STARTED") + _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dct, dst_notebook_dir,run_id) #birbal added.. passed mlflow_client + else: + _logger.info(f"NOTEBOOK IMPORT SKIPPED DUE TO SAME WORKSPACE") res = (run, src_run_dct["tags"].get(MLFLOW_PARENT_RUN_ID, None)) _logger.info(f"Imported run '{run.info.run_id}' into experiment '{experiment_name}'") return res From 1ab00732042496ed8a637730604358bb44b6bdaf Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Sat, 16 Aug 2025 04:29:56 +0000 Subject: [PATCH 10/25] clean up --- databricks_notebooks/bulk/Export_All.py | 24 +++++++--- .../bulk/Import_Registered_Models.py | 32 +++++++------ .../bulk/master_Export_all.py | 26 ++++++++-- .../bulk/master_Import_Registered_Models.py | 47 +++++++++++-------- 4 files changed, 86 insertions(+), 43 deletions(-) diff --git a/databricks_notebooks/bulk/Export_All.py b/databricks_notebooks/bulk/Export_All.py index caf493bd..de84cc5b 100644 --- a/databricks_notebooks/bulk/Export_All.py +++ b/databricks_notebooks/bulk/Export_All.py @@ -1,7 +1,8 @@ # Databricks notebook source -# MAGIC %md ## Export All +# MAGIC %md +# MAGIC ##Export All # MAGIC -# MAGIC Export all the MLflow registered models and all experiments of a tracking server. +# MAGIC ##Export all the MLflow registered models and all experiments of a tracking server. # MAGIC # MAGIC **Widgets** # MAGIC * `1. Output directory` - shared directory between source and destination workspaces. @@ -23,6 +24,7 @@ from mlflow_export_import.bulk import config import time import os +from datetime import datetime # COMMAND ---------- @@ -57,6 +59,12 @@ dbutils.widgets.text("model_file_name", "") model_file_name = dbutils.widgets.get("model_file_name") + +dbutils.widgets.text("source_model_registry", "") +source_model_registry = dbutils.widgets.get("source_model_registry") + +dbutils.widgets.dropdown("Cloud","azure",["azure","aws","gcp"]) +cloud = dbutils.widgets.get("Cloud") if run_start_date=="": run_start_date = None @@ -70,6 +78,7 @@ print("run_timestamp:", run_timestamp) print("jobrunid:", jobrunid) print("model_file_name:", model_file_name) +print("source_model_registry:", source_model_registry) # COMMAND ---------- @@ -98,15 +107,18 @@ # COMMAND ---------- -log_path=f"/tmp/my.log" -dbfs_log_path = f"{output_dir}/export_all_{task_index}.log" +log_path=f"/tmp/exportall_{task_index}.log" log_path # COMMAND ---------- +# curr_timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") +# log_path = f"{output_dir}/export_all_{task_index}_{curr_timestamp}.log" + +# COMMAND ---------- + config.log_path=log_path -config.export_or_import="export" -config.target_model_registry="unity_catalog" ## birbal...remove +config.target_model_registry=source_model_registry # COMMAND ---------- diff --git a/databricks_notebooks/bulk/Import_Registered_Models.py b/databricks_notebooks/bulk/Import_Registered_Models.py index b3254caf..7e88be25 100644 --- a/databricks_notebooks/bulk/Import_Registered_Models.py +++ b/databricks_notebooks/bulk/Import_Registered_Models.py @@ -101,20 +101,18 @@ # COMMAND ---------- -w = WorkspaceClient() -try: - catalog = w.catalogs.get(name=target_model_catalog) - print(f"Catalog '{target_model_catalog}' exists.") -except Exception as e: - raise ValueError(f"Error - {e}") - -# COMMAND ---------- - -try: - schema = w.schemas.get(full_name=f"{target_model_catalog}.{target_model_schema}") - print(f"Schema '{target_model_catalog}.{target_model_schema}' exists.") -except Exception as e: - raise ValueError(f"Error - {e}") +if target_model_registry == "unity_catalog": + w = WorkspaceClient() + try: + catalog = w.catalogs.get(name=target_model_catalog) + print(f"Catalog '{target_model_catalog}' exists.") + except Exception as e: + raise ValueError(f"Error - {e}") + try: + schema = w.schemas.get(full_name=f"{target_model_catalog}.{target_model_schema}") + print(f"Schema '{target_model_catalog}.{target_model_schema}' exists.") + except Exception as e: + raise ValueError(f"Error - {e}") # COMMAND ---------- @@ -130,6 +128,12 @@ # COMMAND ---------- +# curr_timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + +# log_path = f"{input_dir}/Import_Registered_Models_{task_index}_{curr_timestamp}.log" + +# COMMAND ---------- + config.log_path=log_path config.target_model_registry=target_model_registry diff --git a/databricks_notebooks/bulk/master_Export_all.py b/databricks_notebooks/bulk/master_Export_all.py index 7e5125a0..22c85c5a 100644 --- a/databricks_notebooks/bulk/master_Export_all.py +++ b/databricks_notebooks/bulk/master_Export_all.py @@ -26,6 +26,12 @@ dbutils.widgets.text("7. model_file_name", "") model_file_name = dbutils.widgets.get("7. model_file_name") + +dbutils.widgets.dropdown("8. Source model registry","unity_catalog",["unity_catalog","workspace_registry"]) +source_model_registry = dbutils.widgets.get("8. Source model registry") + +dbutils.widgets.dropdown("9. Cloud","azure",["azure","aws","gcp"]) +cloud = dbutils.widgets.get("9. Cloud") print("output_dir:", output_dir) print("stages:", stages) @@ -34,6 +40,8 @@ print("export_permissions:", export_permissions) print("num_tasks:", num_tasks) print("model_file_name:", model_file_name) +print("source_model_registry:", source_model_registry) +print("cloud:", cloud) # COMMAND ---------- @@ -64,8 +72,17 @@ DATABRICKS_INSTANCE = f"https://{DATABRICKS_INSTANCE}" DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) -driver_node_type = "Standard_D4ds_v5" -worker_node_type = "Standard_D4ds_v5" +if cloud == "azure": + driver_node_type = "Standard_D4ds_v5" + worker_node_type = "Standard_D4ds_v5" + +if cloud == "aws": + driver_node_type = "m4.xlarge" + worker_node_type = "m4.xlarge" + +if cloud == "gcp": + driver_node_type = "n1-standard-4" + worker_node_type = "n1-standard-4" def create_multi_task_job_json(): tasks = [] @@ -82,7 +99,7 @@ def create_multi_task_job_json(): "runtime_engine": "STANDARD" }, "notebook_task": { - "notebook_path": "/Workspace/Users/birbal.das@databricks.com/AA_sephora/birnew-mlflow-export-import/databricks_notebooks/bulk/Export_All", + "notebook_path": "/Workspace/Users/birbal.das@databricks.com/AA_sephora_notebook_export_fix/birnew-mlflow-export-import/databricks_notebooks/bulk/Export_All", "base_parameters": { "output_dir": output_dir, "stages": stages, @@ -93,7 +110,8 @@ def create_multi_task_job_json(): "num_tasks": num_tasks, "run_timestamp": "{{job.start_time.iso_date}}-ExportAll-jobid-{{job.id}}", "jobrunid": "jobrunid-{{job.run_id}}", - "model_file_name": model_file_name + "model_file_name": model_file_name, + "source_model_registry": source_model_registry } } } diff --git a/databricks_notebooks/bulk/master_Import_Registered_Models.py b/databricks_notebooks/bulk/master_Import_Registered_Models.py index 15873797..2f1a4aa6 100644 --- a/databricks_notebooks/bulk/master_Import_Registered_Models.py +++ b/databricks_notebooks/bulk/master_Import_Registered_Models.py @@ -36,6 +36,9 @@ dbutils.widgets.text("9. num_tasks", "") num_tasks = dbutils.widgets.get("9. num_tasks") +dbutils.widgets.dropdown("10. Cloud","azure",["azure","aws","gcp"]) +cloud = dbutils.widgets.get("10. Cloud") + print("input_dir:", input_dir) print("target_model_registry:", target_model_registry) print("target_model_catalog:", target_model_catalog) @@ -45,6 +48,7 @@ print("experiment_rename_file:", experiment_rename_file) print("import_permissions:", import_permissions) print("num_tasks:", num_tasks) +print("cloud:", cloud) # COMMAND ---------- @@ -64,20 +68,19 @@ # COMMAND ---------- -w = WorkspaceClient() -try: - catalog = w.catalogs.get(name=target_model_catalog) - print(f"Catalog '{target_model_catalog}' exists.") -except Exception as e: - raise ValueError(f"Error - {e}") - -# COMMAND ---------- +if target_model_registry == "unity_catalog": + w = WorkspaceClient() + try: + catalog = w.catalogs.get(name=target_model_catalog) + print(f"Catalog '{target_model_catalog}' exists.") + except Exception as e: + raise ValueError(f"Error - {e}") -try: - schema = w.schemas.get(full_name=f"{target_model_catalog}.{target_model_schema}") - print(f"Schema '{target_model_catalog}.{target_model_schema}' exists.") -except Exception as e: - raise ValueError(f"Error - {e}") + try: + schema = w.schemas.get(full_name=f"{target_model_catalog}.{target_model_schema}") + print(f"Schema '{target_model_catalog}.{target_model_schema}' exists.") + except Exception as e: + raise ValueError(f"Error - {e}") # COMMAND ---------- @@ -85,8 +88,18 @@ DATABRICKS_INSTANCE = f"https://{DATABRICKS_INSTANCE}" DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) -driver_node_type = "Standard_D4ds_v5" -worker_node_type = "Standard_D4ds_v5" + +if cloud == "azure": + driver_node_type = "Standard_D4ds_v5" + worker_node_type = "Standard_D4ds_v5" + +if cloud == "aws": + driver_node_type = "m4.xlarge" + worker_node_type = "m4.xlarge" + +if cloud == "gcp": + driver_node_type = "n1-standard-4" + worker_node_type = "n1-standard-4" def create_multi_task_job_json(): tasks = [] @@ -152,7 +165,3 @@ def submit_databricks_job(): # COMMAND ---------- submit_databricks_job() - -# COMMAND ---------- - - From 7ba1cab49983c58b5d079cfcb0eb7a558c0fb95a Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Sat, 16 Aug 2025 04:36:37 +0000 Subject: [PATCH 11/25] cleanup --- databricks_notebooks/bulk/Import_Registered_Models.py | 3 --- databricks_notebooks/bulk/master_Export_all.py | 4 ---- databricks_notebooks/bulk/master_Import_Registered_Models.py | 2 -- 3 files changed, 9 deletions(-) diff --git a/databricks_notebooks/bulk/Import_Registered_Models.py b/databricks_notebooks/bulk/Import_Registered_Models.py index 7e88be25..1ce5b793 100644 --- a/databricks_notebooks/bulk/Import_Registered_Models.py +++ b/databricks_notebooks/bulk/Import_Registered_Models.py @@ -81,8 +81,6 @@ if not input_dir: raise ValueError("input_dir cannot be empty") -# if not input_dir.startswith("/dbfs/mnt"): -# raise ValueError("input_dir must start with /dbfs/mnt") if not task_index: raise ValueError("task_index cannot be empty") if not task_index.isdigit(): @@ -129,7 +127,6 @@ # COMMAND ---------- # curr_timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - # log_path = f"{input_dir}/Import_Registered_Models_{task_index}_{curr_timestamp}.log" # COMMAND ---------- diff --git a/databricks_notebooks/bulk/master_Export_all.py b/databricks_notebooks/bulk/master_Export_all.py index 22c85c5a..eca1fedf 100644 --- a/databricks_notebooks/bulk/master_Export_all.py +++ b/databricks_notebooks/bulk/master_Export_all.py @@ -47,8 +47,6 @@ if not output_dir: raise ValueError("output_dir cannot be empty") -# if not output_dir.startswith("/dbfs/mnt"): -# raise ValueError("output_dir must start with /dbfs/mnt") if not num_tasks: raise ValueError("num_tasks cannot be empty") if not num_tasks.isdigit(): @@ -59,8 +57,6 @@ if model_file_name: if not model_file_name.endswith(".txt"): raise ValueError("model_file_name must end with .txt if not empty") - # if not model_file_name.startswith("/dbfs"): - # raise ValueError("model_file_name must start with /dbfs if not empty") else: model_file_name = "all" diff --git a/databricks_notebooks/bulk/master_Import_Registered_Models.py b/databricks_notebooks/bulk/master_Import_Registered_Models.py index 2f1a4aa6..d6531f82 100644 --- a/databricks_notebooks/bulk/master_Import_Registered_Models.py +++ b/databricks_notebooks/bulk/master_Import_Registered_Models.py @@ -54,8 +54,6 @@ if not input_dir: raise ValueError("input_dir cannot be empty") -# if not input_dir.startswith("/dbfs/mnt"): -# raise ValueError("input_dir must start with /dbfs/mnt") if not num_tasks: raise ValueError("num_tasks cannot be empty") if not num_tasks.isdigit(): From 7328750ccfbd4ff98f304acd9e83f9c7005f9502 Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Sat, 16 Aug 2025 05:18:58 +0000 Subject: [PATCH 12/25] cleanup --- mlflow_export_import/model/export_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlflow_export_import/model/export_model.py b/mlflow_export_import/model/export_model.py index c7e3fef9..e8ee7461 100644 --- a/mlflow_export_import/model/export_model.py +++ b/mlflow_export_import/model/export_model.py @@ -144,7 +144,7 @@ def _export_versions(mlflow_client, model_dct, versions, output_dir, opts, resul output_versions, failed_versions = ([], []) for j,vr in enumerate(versions): if not model_utils.is_unity_catalog_model(model_dct["name"]) and vr.current_stage and (len(opts.stages) > 0 and not vr.current_stage.lower() in opts.stages): - _logger.info(f"_export_version skipped") #birbal + _logger.warning(f"_export_version skipped") #birbal continue if len(opts.versions) > 0 and not vr.version in opts.versions: continue From d55326eb845374acf00b2bf94493181b5f0f196d Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Sat, 16 Aug 2025 08:13:14 +0000 Subject: [PATCH 13/25] cleanup --- mlflow_export_import/bulk/import_experiments.py | 2 ++ mlflow_export_import/bulk/import_models.py | 4 +++- mlflow_export_import/model/export_model.py | 2 +- mlflow_export_import/model/import_model.py | 1 + mlflow_export_import/run/import_run.py | 3 +++ 5 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mlflow_export_import/bulk/import_experiments.py b/mlflow_export_import/bulk/import_experiments.py index 2524a2ef..2587c74c 100644 --- a/mlflow_export_import/bulk/import_experiments.py +++ b/mlflow_export_import/bulk/import_experiments.py @@ -88,7 +88,9 @@ def _import_experiment(mlflow_client, - None if error happened """ try: + _logger.error(f"EXPERIMENT BEFORE RENAME {exp_name} ") # birbal exp_name = rename_utils.rename(exp_name, experiment_renames, "experiment") + _logger.error(f"EXPERIMENT AFTER RENAME {exp_name} ") # birbal run_info_map = import_experiment( mlflow_client = mlflow_client, experiment_name = exp_name, diff --git a/mlflow_export_import/bulk/import_models.py b/mlflow_export_import/bulk/import_models.py index b25a9ca6..d21ed406 100644 --- a/mlflow_export_import/bulk/import_models.py +++ b/mlflow_export_import/bulk/import_models.py @@ -44,13 +44,15 @@ def import_models( target_model_schema = None #birbal added ): mlflow_client = mlflow_client or create_mlflow_client() + experiment_renames_original = experiment_renames #birbal experiment_renames = rename_utils.get_renames(experiment_renames) model_renames = rename_utils.get_renames(model_renames) start_time = time.time() exp_run_info_map, exp_info = _import_experiments( mlflow_client, input_dir, - experiment_renames, + # experiment_renames, + experiment_renames_original, #birbal import_permissions, import_source_tags, use_src_user_id, diff --git a/mlflow_export_import/model/export_model.py b/mlflow_export_import/model/export_model.py index e8ee7461..4be3f64e 100644 --- a/mlflow_export_import/model/export_model.py +++ b/mlflow_export_import/model/export_model.py @@ -144,7 +144,7 @@ def _export_versions(mlflow_client, model_dct, versions, output_dir, opts, resul output_versions, failed_versions = ([], []) for j,vr in enumerate(versions): if not model_utils.is_unity_catalog_model(model_dct["name"]) and vr.current_stage and (len(opts.stages) > 0 and not vr.current_stage.lower() in opts.stages): - _logger.warning(f"_export_version skipped") #birbal + _logger.warning(f"MODEL VERSION EXPORT SKIPPED. Current model stage:{vr.current_stage} does not match with Input stages passed:{opts.stages} ") #birbal continue if len(opts.versions) > 0 and not vr.version in opts.versions: continue diff --git a/mlflow_export_import/model/import_model.py b/mlflow_export_import/model/import_model.py index 8bdb9a7d..daa62bba 100644 --- a/mlflow_export_import/model/import_model.py +++ b/mlflow_export_import/model/import_model.py @@ -277,6 +277,7 @@ def import_model(self, else: dst_run_id = dst_run_info.run_id exp_name = rename_utils.rename(vr["_experiment_name"], self.experiment_renames, "experiment") + _logger.error(f"RENAMED EXPERIMENT FROM {vr["_experiment_name"]} TO {exp_name}") # birbal try: with MlflowTrackingUriTweak(self.mlflow_client): mlflow.set_experiment(exp_name) diff --git a/mlflow_export_import/run/import_run.py b/mlflow_export_import/run/import_run.py index 359539e6..5b9b3f3e 100644 --- a/mlflow_export_import/run/import_run.py +++ b/mlflow_export_import/run/import_run.py @@ -190,12 +190,15 @@ def update_notebook_lineage(mlflow_client,run_id,dst_notebook_path): #birbal response.raise_for_status() notebook_object = response.json() notebook_id = notebook_object.get("object_id") + + creds = mlflow_client._tracking_client.store.get_host_creds() mlflow_client.set_tag(run_id, "mlflow.source.name", dst_notebook_path) mlflow_client.set_tag(run_id, "mlflow.source.type", "NOTEBOOK") mlflow_client.set_tag(run_id, "mlflow.databricks.notebookID", notebook_id) mlflow_client.set_tag(run_id, "mlflow.databricks.workspaceURL", host) mlflow_client.set_tag(run_id, "mlflow.databricks.notebookPath", dst_notebook_path) + mlflow_client.set_tag(run_id, "mlflow.databricks.webappURL", creds.host) def _import_inputs(http_client, src_run_dct, run_id): inputs = src_run_dct.get("inputs") From 83059e201397b6fb6c269545db80ca67aa69d8e2 Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Sun, 17 Aug 2025 22:57:35 +0000 Subject: [PATCH 14/25] cleanup --- mlflow_export_import/run/import_run.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlflow_export_import/run/import_run.py b/mlflow_export_import/run/import_run.py index 5b9b3f3e..c9d41e3d 100644 --- a/mlflow_export_import/run/import_run.py +++ b/mlflow_export_import/run/import_run.py @@ -106,10 +106,10 @@ def _mk_ex(src_run_dct, dst_run_id, exp_name): traceback.print_exc() raise MlflowExportImportException(e, f"Importing run {run_id} of experiment '{exp.name}' failed") - if utils.calling_databricks(): #birbal added - creds = mlflow_client._tracking_client.store.get_host_creds() - src_webappURL = src_run_dct["tags"]["mlflow.databricks.webappURL"] - _logger.info(f"src_webappURL izzzz {src_webappURL} and target webappURL is {creds.host}") + if utils.calling_databricks(): #birbal added entire block + creds = mlflow_client._tracking_client.store.get_host_creds() + src_webappURL = src_run_dct.get("tags", {}).get("mlflow.databricks.webappURL", None) + _logger.info(f"src_webappURL is {src_webappURL} and target webappURL is {creds.host}") if creds.host != src_webappURL: _logger.info(f"NOTEBOOK IMPORT STARTED") _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dct, dst_notebook_dir,run_id) #birbal added.. passed mlflow_client @@ -124,7 +124,7 @@ def _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dc tag_key = "mlflow.databricks.notebookPath" src_notebook_path = src_run_dct["tags"].get(tag_key,None) if not src_notebook_path: - _logger.warning(f"No tag '{tag_key}' for run_id '{run_id}'") + _logger.warning(f"No tag '{tag_key}' for run_id '{run_id}'. NOTEBOOK IMPORT SKIPPED") return notebook_name = os.path.basename(src_notebook_path) dst_notebook_dir = os.path.dirname(src_notebook_path) From ee618dd44bc41540fea360943f628221307b8393 Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Sun, 17 Aug 2025 23:30:32 +0000 Subject: [PATCH 15/25] check emptiness of experiments and models --- mlflow_export_import/bulk/import_experiments.py | 8 +++++++- mlflow_export_import/bulk/import_models.py | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mlflow_export_import/bulk/import_experiments.py b/mlflow_export_import/bulk/import_experiments.py index 2587c74c..b7e166d3 100644 --- a/mlflow_export_import/bulk/import_experiments.py +++ b/mlflow_export_import/bulk/import_experiments.py @@ -49,7 +49,13 @@ def import_experiments( experiment_renames = rename_utils.get_renames(experiment_renames) mlflow_client = mlflow_client or mlflow.MlflowClient() - dct = io_utils.read_file_mlflow(os.path.join(input_dir, "experiments.json")) + + try: #birbal + dct = io_utils.read_file_mlflow(os.path.join(input_dir, "experiments.json")) + except Exception as e: + _logger.info(f"'experiments.json' does not exist in {input_dir}. NO EXPERIMENTS TO IMPORT") + return [] + exps = dct["experiments"] _logger.info("Importing experiments:") for exp in exps: diff --git a/mlflow_export_import/bulk/import_models.py b/mlflow_export_import/bulk/import_models.py index d21ed406..8530f9e6 100644 --- a/mlflow_export_import/bulk/import_models.py +++ b/mlflow_export_import/bulk/import_models.py @@ -148,6 +148,11 @@ def _import_models(mlflow_client, models_dir = os.path.join(input_dir, "models") models = io_utils.read_file_mlflow(os.path.join(models_dir,"models.json")) model_names = models["models"] + + if len(model_names) == 0: + _logger.warning(f"No models found in {os.path.join(models_dir,"models.json")}. NO MODELS TO IMPORT") + return {} + all_importer = BulkModelImporter( mlflow_client = mlflow_client, run_info_map = run_info_map, From 507d85fc3b7bd690cb81609ced2c54f08323feff Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Mon, 18 Aug 2025 00:09:26 +0000 Subject: [PATCH 16/25] cleanup --- mlflow_export_import/bulk/import_experiments.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlflow_export_import/bulk/import_experiments.py b/mlflow_export_import/bulk/import_experiments.py index b7e166d3..3392486a 100644 --- a/mlflow_export_import/bulk/import_experiments.py +++ b/mlflow_export_import/bulk/import_experiments.py @@ -94,9 +94,9 @@ def _import_experiment(mlflow_client, - None if error happened """ try: - _logger.error(f"EXPERIMENT BEFORE RENAME {exp_name} ") # birbal + _logger.info(f"EXPERIMENT BEFORE RENAME {exp_name} ") # birbal exp_name = rename_utils.rename(exp_name, experiment_renames, "experiment") - _logger.error(f"EXPERIMENT AFTER RENAME {exp_name} ") # birbal + _logger.info(f"EXPERIMENT AFTER RENAME {exp_name} ") # birbal run_info_map = import_experiment( mlflow_client = mlflow_client, experiment_name = exp_name, From ee706d886ca4308194b41a1ed7107c6ae4c100fb Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Mon, 18 Aug 2025 07:03:59 +0000 Subject: [PATCH 17/25] fix and code cleanup --- .../bulk/Import_Registered_Models.py | 1 + mlflow_export_import/bulk/config.py | 3 ++- mlflow_export_import/model/import_model.py | 1 - .../model_version/import_model_version.py | 1 + mlflow_export_import/run/import_run.py | 25 +++++++++++++++++-- 5 files changed, 27 insertions(+), 4 deletions(-) diff --git a/databricks_notebooks/bulk/Import_Registered_Models.py b/databricks_notebooks/bulk/Import_Registered_Models.py index 1ce5b793..a4ae9cd7 100644 --- a/databricks_notebooks/bulk/Import_Registered_Models.py +++ b/databricks_notebooks/bulk/Import_Registered_Models.py @@ -133,6 +133,7 @@ config.log_path=log_path config.target_model_registry=target_model_registry +config.notebook_user_mapping_file="/dbfs/mnt/con1/dinner2/vol1/exportall_uc/notebookuserm.txt" # COMMAND ---------- diff --git a/mlflow_export_import/bulk/config.py b/mlflow_export_import/bulk/config.py index 4fdc20d1..215cbd21 100644 --- a/mlflow_export_import/bulk/config.py +++ b/mlflow_export_import/bulk/config.py @@ -1,2 +1,3 @@ log_path=None -target_model_registry=None \ No newline at end of file +target_model_registry=None +notebook_user_mapping_file=None \ No newline at end of file diff --git a/mlflow_export_import/model/import_model.py b/mlflow_export_import/model/import_model.py index daa62bba..69e8547b 100644 --- a/mlflow_export_import/model/import_model.py +++ b/mlflow_export_import/model/import_model.py @@ -269,7 +269,6 @@ def import_model(self, _logger.info(f"Importing {len(model_dct['versions'])} versions:") for vr in model_dct["versions"]: src_run_id = vr["run_id"] - _logger.info(f"self.run_info_map is {self.run_info_map}") ##birbal...need to remove dst_run_info = self.run_info_map.get(src_run_id, None) if not dst_run_info: msg = { "model": model_name, "version": vr["version"], "stage": vr["current_stage"], "run_id": src_run_id } diff --git a/mlflow_export_import/model_version/import_model_version.py b/mlflow_export_import/model_version/import_model_version.py index 7602a943..26d8f574 100644 --- a/mlflow_export_import/model_version/import_model_version.py +++ b/mlflow_export_import/model_version/import_model_version.py @@ -86,6 +86,7 @@ def import_model_version( model_path = _get_model_path(src_vr) dst_source = f"{dst_run.info.artifact_uri}/{model_path}" + dst_vr = _import_model_version( mlflow_client, model_name = model_name, diff --git a/mlflow_export_import/run/import_run.py b/mlflow_export_import/run/import_run.py index c9d41e3d..937d7fd0 100644 --- a/mlflow_export_import/run/import_run.py +++ b/mlflow_export_import/run/import_run.py @@ -26,6 +26,8 @@ import mlflow.utils.databricks_utils as db_utils #birbal added import requests #birbal added import json +from mlflow_export_import.bulk import config #birbal added +from mlflow_export_import.bulk import rename_utils #birbal added _logger = utils.getLogger(__name__) @@ -126,8 +128,28 @@ def _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dc if not src_notebook_path: _logger.warning(f"No tag '{tag_key}' for run_id '{run_id}'. NOTEBOOK IMPORT SKIPPED") return + notebook_name = os.path.basename(src_notebook_path) - dst_notebook_dir = os.path.dirname(src_notebook_path) + + try: #birbal added entire block to solve the issue where the source user doesn't exists in target workspace + dst_notebook_dir = os.path.dirname(src_notebook_path) + mlflow_utils.create_workspace_dir(dbx_client, dst_notebook_dir) + + except Exception as e: #birbal added + _logger.warning(f"Failed to create directory '{dst_notebook_dir}'. This is most probably because the user doesn't exist in target workspace. Checking notebook user mapping file...") + notebook_user_mapping_file=config.notebook_user_mapping_file + if notebook_user_mapping_file: + notebook_user_mapping_file = rename_utils.get_renames(notebook_user_mapping_file) + _logger.info(f"notebook_user_mapping_file is {notebook_user_mapping_file}") + _logger.info(f"src_notebook_path BEFORE RENAME {src_notebook_path}") + src_notebook_path = rename_utils.rename(src_notebook_path, notebook_user_mapping_file, "notebook") + _logger.info(f"src_notebook_path AFTER RENAME {src_notebook_path}") + dst_notebook_dir = os.path.dirname(src_notebook_path) + mlflow_utils.create_workspace_dir(dbx_client, dst_notebook_dir) + else: + _logger.error(f"Notebook couldn't be imported because the target directory '{dst_notebook_dir}' could not be created, and no notebook user mapping file was provided as input") + raise e + format = "source" notebook_path = os.path.join(input_dir,"artifacts","notebooks",f"{notebook_name}.{format}") #birbal added @@ -145,7 +167,6 @@ def _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dc "overwrite": True, "content": content } - mlflow_utils.create_workspace_dir(dbx_client, dst_notebook_dir) try: _logger.info(f"Importing notebook '{dst_notebook_path}' for run {run_id}") create_notebook(mlflow_client,payload,run_id) #birbal added From 1921586b28609316fff03180b52c919839ebf2fa Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Mon, 18 Aug 2025 08:07:46 +0000 Subject: [PATCH 18/25] cleanup --- databricks_notebooks/bulk/Import_Registered_Models.py | 7 ++++++- .../bulk/master_Import_Registered_Models.py | 10 ++++++++-- mlflow_export_import/model/import_model.py | 2 +- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/databricks_notebooks/bulk/Import_Registered_Models.py b/databricks_notebooks/bulk/Import_Registered_Models.py index a4ae9cd7..67bb9756 100644 --- a/databricks_notebooks/bulk/Import_Registered_Models.py +++ b/databricks_notebooks/bulk/Import_Registered_Models.py @@ -55,6 +55,10 @@ dbutils.widgets.text("task_index", "") task_index = dbutils.widgets.get("task_index") +dbutils.widgets.text("notebook_user_mapping_file","") +val = dbutils.widgets.get("notebook_user_mapping_file") +notebook_user_mapping_file = {} if val in ("null", None, "") else val + print("input_dir:", input_dir) print("target_model_registry:", target_model_registry) @@ -65,6 +69,7 @@ print("experiment_rename_file:", experiment_rename_file) print("import_permissions:", import_permissions) print("task_index:", task_index) +print("notebook_user_mapping_file:", notebook_user_mapping_file) # COMMAND ---------- @@ -133,7 +138,7 @@ config.log_path=log_path config.target_model_registry=target_model_registry -config.notebook_user_mapping_file="/dbfs/mnt/con1/dinner2/vol1/exportall_uc/notebookuserm.txt" +config.notebook_user_mapping_file=notebook_user_mapping_file # COMMAND ---------- diff --git a/databricks_notebooks/bulk/master_Import_Registered_Models.py b/databricks_notebooks/bulk/master_Import_Registered_Models.py index d6531f82..c2306e92 100644 --- a/databricks_notebooks/bulk/master_Import_Registered_Models.py +++ b/databricks_notebooks/bulk/master_Import_Registered_Models.py @@ -39,6 +39,10 @@ dbutils.widgets.dropdown("10. Cloud","azure",["azure","aws","gcp"]) cloud = dbutils.widgets.get("10. Cloud") +dbutils.widgets.text("11. Notebook user mapping file","") +val = dbutils.widgets.get("11. Notebook user mapping file") +notebook_user_mapping_file = val or None + print("input_dir:", input_dir) print("target_model_registry:", target_model_registry) print("target_model_catalog:", target_model_catalog) @@ -49,6 +53,7 @@ print("import_permissions:", import_permissions) print("num_tasks:", num_tasks) print("cloud:", cloud) +print("notebook_user_mapping_file:", notebook_user_mapping_file) # COMMAND ---------- @@ -114,7 +119,7 @@ def create_multi_task_job_json(): "runtime_engine": "STANDARD" }, "notebook_task": { - "notebook_path": "/Workspace/Users/birbal.das@databricks.com/AA_sephora/birnew-mlflow-export-import/databricks_notebooks/bulk/Import_Registered_Models", + "notebook_path": "/Workspace/Users/birbal.das@databricks.com/birnew-mlflow-export-import/databricks_notebooks/bulk/Import_Registered_Models", "base_parameters": { "input_dir": os.path.join(input_dir,str(i)), "target_model_registry": target_model_registry, @@ -124,7 +129,8 @@ def create_multi_task_job_json(): "model_rename_file": model_rename_file, "experiment_rename_file": experiment_rename_file, "import_permissions": import_permissions, - "task_index": str(i) + "task_index": str(i), + "notebook_user_mapping_file":notebook_user_mapping_file } } } diff --git a/mlflow_export_import/model/import_model.py b/mlflow_export_import/model/import_model.py index 69e8547b..a6564cba 100644 --- a/mlflow_export_import/model/import_model.py +++ b/mlflow_export_import/model/import_model.py @@ -276,7 +276,7 @@ def import_model(self, else: dst_run_id = dst_run_info.run_id exp_name = rename_utils.rename(vr["_experiment_name"], self.experiment_renames, "experiment") - _logger.error(f"RENAMED EXPERIMENT FROM {vr["_experiment_name"]} TO {exp_name}") # birbal + _logger.info(f"RENAMED EXPERIMENT FROM {vr["_experiment_name"]} TO {exp_name}") # birbal try: with MlflowTrackingUriTweak(self.mlflow_client): mlflow.set_experiment(exp_name) From 512d6863001589bfd3cb5258cd7dd0e99a02e350 Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Tue, 19 Aug 2025 07:05:49 +0000 Subject: [PATCH 19/25] cleanup --- .../bulk/Import_Registered_Models.py | 4 ++-- mlflow_export_import/bulk/config.py | 3 +-- .../bulk/import_experiments.py | 12 +++++++---- mlflow_export_import/bulk/import_models.py | 13 ++++++++---- mlflow_export_import/bulk/rename_utils.py | 3 +++ .../experiment/import_experiment.py | 6 ++++-- mlflow_export_import/run/import_run.py | 20 +++++++++---------- 7 files changed, 36 insertions(+), 25 deletions(-) diff --git a/databricks_notebooks/bulk/Import_Registered_Models.py b/databricks_notebooks/bulk/Import_Registered_Models.py index 67bb9756..d8707a1b 100644 --- a/databricks_notebooks/bulk/Import_Registered_Models.py +++ b/databricks_notebooks/bulk/Import_Registered_Models.py @@ -138,7 +138,6 @@ config.log_path=log_path config.target_model_registry=target_model_registry -config.notebook_user_mapping_file=notebook_user_mapping_file # COMMAND ---------- @@ -153,7 +152,8 @@ import_source_tags = False, ## Birbal:: Do not set to True. else it will import junk mlflow tags. Setting to False WILL import all source tags by default. use_threads = True, target_model_catalog = target_model_catalog, #birbal added - target_model_schema = target_model_schema #birbal added + target_model_schema = target_model_schema, #birbal added, + notebook_user_mapping_file = notebook_user_mapping_file #birbal added ) # COMMAND ---------- diff --git a/mlflow_export_import/bulk/config.py b/mlflow_export_import/bulk/config.py index 215cbd21..4fdc20d1 100644 --- a/mlflow_export_import/bulk/config.py +++ b/mlflow_export_import/bulk/config.py @@ -1,3 +1,2 @@ log_path=None -target_model_registry=None -notebook_user_mapping_file=None \ No newline at end of file +target_model_registry=None \ No newline at end of file diff --git a/mlflow_export_import/bulk/import_experiments.py b/mlflow_export_import/bulk/import_experiments.py index 3392486a..4fb10f18 100644 --- a/mlflow_export_import/bulk/import_experiments.py +++ b/mlflow_export_import/bulk/import_experiments.py @@ -30,7 +30,8 @@ def import_experiments( use_src_user_id = False, experiment_renames = None, use_threads = False, - mlflow_client = None + mlflow_client = None, + notebook_user_mapping = None #birbal ): """ :param input_dir: Source experiment directory. @@ -74,7 +75,8 @@ def import_experiments( import_permissions, import_source_tags, use_src_user_id, - experiment_renames + experiment_renames, + notebook_user_mapping #birbal ) futures.append([exp["id"], run_info_map]) return [ (f[0], f[1].result()) for f in futures ] # materialize the future @@ -86,7 +88,8 @@ def _import_experiment(mlflow_client, import_permissions, import_source_tags, use_src_user_id, - experiment_renames + experiment_renames, + notebook_user_mapping ): """ :return: @@ -103,7 +106,8 @@ def _import_experiment(mlflow_client, input_dir = input_dir, import_permissions = import_permissions, import_source_tags = import_source_tags, - use_src_user_id = use_src_user_id + use_src_user_id = use_src_user_id, + notebook_user_mapping = notebook_user_mapping #birbal ) return run_info_map except Exception as e: diff --git a/mlflow_export_import/bulk/import_models.py b/mlflow_export_import/bulk/import_models.py index 8530f9e6..008ec398 100644 --- a/mlflow_export_import/bulk/import_models.py +++ b/mlflow_export_import/bulk/import_models.py @@ -41,12 +41,14 @@ def import_models( use_threads = False, mlflow_client = None, target_model_catalog = None, #birbal added - target_model_schema = None #birbal added + target_model_schema = None, #birbal added + notebook_user_mapping_file = None #birbal added ): mlflow_client = mlflow_client or create_mlflow_client() experiment_renames_original = experiment_renames #birbal experiment_renames = rename_utils.get_renames(experiment_renames) model_renames = rename_utils.get_renames(model_renames) + notebook_user_mapping = rename_utils.get_renames(notebook_user_mapping_file) #birbal start_time = time.time() exp_run_info_map, exp_info = _import_experiments( mlflow_client, @@ -56,7 +58,8 @@ def import_models( import_permissions, import_source_tags, use_src_user_id, - use_threads + use_threads, + notebook_user_mapping #birbal ) run_info_map = _flatten_run_info_map(exp_run_info_map) model_res = _import_models( @@ -94,7 +97,8 @@ def _import_experiments(mlflow_client, import_permissions, import_source_tags, use_src_user_id, - use_threads + use_threads, + notebook_user_mapping ): start_time = time.time() @@ -105,7 +109,8 @@ def _import_experiments(mlflow_client, use_src_user_id = use_src_user_id, experiment_renames = experiment_renames, use_threads = use_threads, - mlflow_client = mlflow_client + mlflow_client = mlflow_client, + notebook_user_mapping = notebook_user_mapping #birbal ) duration = round(time.time()-start_time, 1) diff --git a/mlflow_export_import/bulk/rename_utils.py b/mlflow_export_import/bulk/rename_utils.py index fc28060a..a8bcc77e 100644 --- a/mlflow_export_import/bulk/rename_utils.py +++ b/mlflow_export_import/bulk/rename_utils.py @@ -18,6 +18,8 @@ def rename(name, replacements, object_name="object"): if not replacements: return name ## birbal :: corrected to return name instead of None. returning None will cause failure for k,v in replacements.items(): + if object_name == "notebook": #birbal added + k = k.removeprefix("/Workspace") if k != "" and name.startswith(k): new_name = name.replace(k,v) _logger.info(f"Renaming {object_name} '{name}' to '{new_name}'") @@ -25,6 +27,7 @@ def rename(name, replacements, object_name="object"): return name + def get_renames(filename_or_dict): if filename_or_dict is None: return None diff --git a/mlflow_export_import/experiment/import_experiment.py b/mlflow_export_import/experiment/import_experiment.py index 56047429..a062dced 100644 --- a/mlflow_export_import/experiment/import_experiment.py +++ b/mlflow_export_import/experiment/import_experiment.py @@ -33,7 +33,8 @@ def import_experiment( import_permissions = False, use_src_user_id = False, dst_notebook_dir = None, - mlflow_client = None + mlflow_client = None, + notebook_user_mapping = None #birbal ): """ :param experiment_name: Destination experiment name. @@ -86,7 +87,8 @@ def import_experiment( dst_notebook_dir = dst_notebook_dir, import_source_tags = import_source_tags, use_src_user_id = use_src_user_id, - exp = exp #birbal added + exp = exp, #birbal added + notebook_user_mapping = notebook_user_mapping #birbal ) dst_run_id = dst_run.info.run_id run_ids_map[src_run_id] = { "dst_run_id": dst_run_id, "src_parent_run_id": src_parent_run_id } diff --git a/mlflow_export_import/run/import_run.py b/mlflow_export_import/run/import_run.py index 937d7fd0..193711fe 100644 --- a/mlflow_export_import/run/import_run.py +++ b/mlflow_export_import/run/import_run.py @@ -26,7 +26,6 @@ import mlflow.utils.databricks_utils as db_utils #birbal added import requests #birbal added import json -from mlflow_export_import.bulk import config #birbal added from mlflow_export_import.bulk import rename_utils #birbal added _logger = utils.getLogger(__name__) @@ -39,7 +38,8 @@ def import_run( use_src_user_id = False, mlmodel_fix = True, mlflow_client = None, - exp = None + exp = None, + notebook_user_mapping = None #birbal ): """ Imports a run into the specified experiment. @@ -114,7 +114,7 @@ def _mk_ex(src_run_dct, dst_run_id, exp_name): _logger.info(f"src_webappURL is {src_webappURL} and target webappURL is {creds.host}") if creds.host != src_webappURL: _logger.info(f"NOTEBOOK IMPORT STARTED") - _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dct, dst_notebook_dir,run_id) #birbal added.. passed mlflow_client + _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dct, dst_notebook_dir,run_id,notebook_user_mapping) #birbal added.. passed mlflow_client else: _logger.info(f"NOTEBOOK IMPORT SKIPPED DUE TO SAME WORKSPACE") res = (run, src_run_dct["tags"].get(MLFLOW_PARENT_RUN_ID, None)) @@ -122,7 +122,7 @@ def _mk_ex(src_run_dct, dst_run_id, exp_name): return res -def _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dct, dst_notebook_dir,run_id): #birbal added +def _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dct, dst_notebook_dir,run_id,notebook_user_mapping): #birbal added tag_key = "mlflow.databricks.notebookPath" src_notebook_path = src_run_dct["tags"].get(tag_key,None) if not src_notebook_path: @@ -131,18 +131,16 @@ def _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dc notebook_name = os.path.basename(src_notebook_path) - try: #birbal added entire block to solve the issue where the source user doesn't exists in target workspace + try: #birbal added entire try/except block to solve the issue where the source user doesn't exists in target workspace dst_notebook_dir = os.path.dirname(src_notebook_path) mlflow_utils.create_workspace_dir(dbx_client, dst_notebook_dir) - except Exception as e: #birbal added + except Exception as e: _logger.warning(f"Failed to create directory '{dst_notebook_dir}'. This is most probably because the user doesn't exist in target workspace. Checking notebook user mapping file...") - notebook_user_mapping_file=config.notebook_user_mapping_file - if notebook_user_mapping_file: - notebook_user_mapping_file = rename_utils.get_renames(notebook_user_mapping_file) - _logger.info(f"notebook_user_mapping_file is {notebook_user_mapping_file}") + if notebook_user_mapping: + _logger.info(f"notebook_user_mapping is {notebook_user_mapping}") _logger.info(f"src_notebook_path BEFORE RENAME {src_notebook_path}") - src_notebook_path = rename_utils.rename(src_notebook_path, notebook_user_mapping_file, "notebook") + src_notebook_path = rename_utils.rename(src_notebook_path, notebook_user_mapping, "notebook") _logger.info(f"src_notebook_path AFTER RENAME {src_notebook_path}") dst_notebook_dir = os.path.dirname(src_notebook_path) mlflow_utils.create_workspace_dir(dbx_client, dst_notebook_dir) From a9c8e55659caabca6716ac19a08ccda246d756ec Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Tue, 19 Aug 2025 20:27:37 +0000 Subject: [PATCH 20/25] cleanup --- databricks_notebooks/bulk/Common.py | 16 ---------------- .../bulk/master_Import_Registered_Models.py | 2 +- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/databricks_notebooks/bulk/Common.py b/databricks_notebooks/bulk/Common.py index 23ac7574..5532fc4b 100644 --- a/databricks_notebooks/bulk/Common.py +++ b/databricks_notebooks/bulk/Common.py @@ -28,20 +28,4 @@ def get_notebook_formats(num): # COMMAND ---------- -import mlflow -display([{"mlflow_version": mlflow.__version__}]) - -# COMMAND ---------- - -# MAGIC %pip install -U mlflow-skinny -# MAGIC %pip install -U git+https:///github.com/mlflow/mlflow-export-import/#egg=mlflow-export-import -# MAGIC dbutils.library.restartPython() - -# COMMAND ---------- - -import mlflow -display([{"mlflow_version": mlflow.__version__}]) - -# COMMAND ---------- - diff --git a/databricks_notebooks/bulk/master_Import_Registered_Models.py b/databricks_notebooks/bulk/master_Import_Registered_Models.py index c2306e92..a4faadfe 100644 --- a/databricks_notebooks/bulk/master_Import_Registered_Models.py +++ b/databricks_notebooks/bulk/master_Import_Registered_Models.py @@ -119,7 +119,7 @@ def create_multi_task_job_json(): "runtime_engine": "STANDARD" }, "notebook_task": { - "notebook_path": "/Workspace/Users/birbal.das@databricks.com/birnew-mlflow-export-import/databricks_notebooks/bulk/Import_Registered_Models", + "notebook_path": "/Workspace/Users/birbal.das@databricks.com/AA_final/birnew-mlflow-export-import/databricks_notebooks/bulk/Import_Registered_Models", "base_parameters": { "input_dir": os.path.join(input_dir,str(i)), "target_model_registry": target_model_registry, From 4ca126383e16b39c22924cb8dcc39bdee83f70c6 Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Tue, 19 Aug 2025 21:39:32 +0000 Subject: [PATCH 21/25] cleanup --- mlflow_export_import/common/model_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlflow_export_import/common/model_utils.py b/mlflow_export_import/common/model_utils.py index db03765c..93e6ccbe 100644 --- a/mlflow_export_import/common/model_utils.py +++ b/mlflow_export_import/common/model_utils.py @@ -44,7 +44,8 @@ def create_model(client, model_name, model_dct, import_metadata): _logger.info(f"Created new registered model '{model_name}'") return True except Exception as e: - _logger.info(f"except Exception trigger, error for '{model_name}': {e}") + # _logger.info(f"except Exception trigger, error for '{model_name}': {e}") + _logger.error(f"FAILED TO CREATE MODEL: '{model_name}': ERROR- {e}") #birbal except RestException as e: if e.error_code != "RESOURCE_ALREADY_EXISTS": raise e From a2aadfaf113866e77c646aa93cd2941e060db07b Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Tue, 19 Aug 2025 21:57:17 +0000 Subject: [PATCH 22/25] cleanup --- databricks_notebooks/bulk/Export_All.py | 9 ++++----- databricks_notebooks/bulk/master_Export_all.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/databricks_notebooks/bulk/Export_All.py b/databricks_notebooks/bulk/Export_All.py index de84cc5b..c92edde9 100644 --- a/databricks_notebooks/bulk/Export_All.py +++ b/databricks_notebooks/bulk/Export_All.py @@ -28,8 +28,7 @@ # COMMAND ---------- - -output_dir = dbutils.widgets.get("output_dir") +dbutils.widgets.text("output_dir","") output_dir = dbutils.widgets.get("output_dir") output_dir = output_dir.replace("dbfs:","/dbfs") @@ -45,10 +44,10 @@ dbutils.widgets.dropdown("export_permissions","false",["true","false"]) export_permissions = dbutils.widgets.get("export_permissions") == "true" -dbutils.widgets.text("task_index", "") +dbutils.widgets.text("task_index", "1") task_index = int(dbutils.widgets.get("task_index")) -dbutils.widgets.text("num_tasks", "") +dbutils.widgets.text("num_tasks", "1") num_tasks = int(dbutils.widgets.get("num_tasks")) dbutils.widgets.text("run_timestamp", "") @@ -60,7 +59,7 @@ dbutils.widgets.text("model_file_name", "") model_file_name = dbutils.widgets.get("model_file_name") -dbutils.widgets.text("source_model_registry", "") +dbutils.widgets.dropdown("source_model_registry","unity_catalog",["unity_catalog","workspace_registry"]) source_model_registry = dbutils.widgets.get("source_model_registry") dbutils.widgets.dropdown("Cloud","azure",["azure","aws","gcp"]) diff --git a/databricks_notebooks/bulk/master_Export_all.py b/databricks_notebooks/bulk/master_Export_all.py index eca1fedf..f164da7f 100644 --- a/databricks_notebooks/bulk/master_Export_all.py +++ b/databricks_notebooks/bulk/master_Export_all.py @@ -73,8 +73,8 @@ worker_node_type = "Standard_D4ds_v5" if cloud == "aws": - driver_node_type = "m4.xlarge" - worker_node_type = "m4.xlarge" + driver_node_type = "rd-fleet.xlarge" + worker_node_type = "rd-fleet.xlarge" if cloud == "gcp": driver_node_type = "n1-standard-4" From 6e4b5145dbfcf930d8cf69455da6001a2ba704b7 Mon Sep 17 00:00:00 2001 From: "birbal.in25@gmail.com" Date: Wed, 20 Aug 2025 22:52:25 +0000 Subject: [PATCH 23/25] mlflow to 2.19.0 --- databricks_notebooks/bulk/Common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/databricks_notebooks/bulk/Common.py b/databricks_notebooks/bulk/Common.py index 5532fc4b..e620aeea 100644 --- a/databricks_notebooks/bulk/Common.py +++ b/databricks_notebooks/bulk/Common.py @@ -1,4 +1,10 @@ # Databricks notebook source +# MAGIC %pip install -U mlflow==2.19.0 +# MAGIC %pip install -U git+https:///github.com/mlflow/mlflow-export-import/#egg=mlflow-export-import +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + # MAGIC %pip install -U mlflow-skinny # MAGIC %pip install -U git+https:///github.com/mlflow/mlflow-export-import/#egg=mlflow-export-import # MAGIC dbutils.library.restartPython() From b5ba72372c4b0ad3dac5e964ecc6937fd5e5239d Mon Sep 17 00:00:00 2001 From: birbalin25 Date: Thu, 21 Aug 2025 22:27:18 +0000 Subject: [PATCH 24/25] cleanup --- databricks_notebooks/bulk/Common.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/databricks_notebooks/bulk/Common.py b/databricks_notebooks/bulk/Common.py index e620aeea..0fe94c3f 100644 --- a/databricks_notebooks/bulk/Common.py +++ b/databricks_notebooks/bulk/Common.py @@ -1,14 +1,18 @@ # Databricks notebook source -# MAGIC %pip install -U mlflow==2.19.0 +# MAGIC %pip install -U mlflow-skinny # MAGIC %pip install -U git+https:///github.com/mlflow/mlflow-export-import/#egg=mlflow-export-import # MAGIC dbutils.library.restartPython() # COMMAND ---------- -# MAGIC %pip install -U mlflow-skinny +# MAGIC %pip install -U mlflow==2.19.0 # MAGIC %pip install -U git+https:///github.com/mlflow/mlflow-export-import/#egg=mlflow-export-import # MAGIC dbutils.library.restartPython() +# COMMAND ---------- + + + # COMMAND ---------- import mlflow From 51fdf2debebaef98d846b3764cd03c48f77b492e Mon Sep 17 00:00:00 2001 From: birbalin25 Date: Thu, 21 Aug 2025 22:28:19 +0000 Subject: [PATCH 25/25] cleanup --- databricks_notebooks/bulk/Common.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/databricks_notebooks/bulk/Common.py b/databricks_notebooks/bulk/Common.py index 0fe94c3f..cb708fc1 100644 --- a/databricks_notebooks/bulk/Common.py +++ b/databricks_notebooks/bulk/Common.py @@ -1,7 +1,7 @@ # Databricks notebook source -# MAGIC %pip install -U mlflow-skinny -# MAGIC %pip install -U git+https:///github.com/mlflow/mlflow-export-import/#egg=mlflow-export-import -# MAGIC dbutils.library.restartPython() +# %pip install -U mlflow-skinny +# %pip install -U git+https:///github.com/mlflow/mlflow-export-import/#egg=mlflow-export-import +# dbutils.library.restartPython() # COMMAND ---------- @@ -9,10 +9,6 @@ # MAGIC %pip install -U git+https:///github.com/mlflow/mlflow-export-import/#egg=mlflow-export-import # MAGIC dbutils.library.restartPython() -# COMMAND ---------- - - - # COMMAND ---------- import mlflow