diff --git a/databricks_notebooks/bulk/Common.py b/databricks_notebooks/bulk/Common.py index 76201660..cb708fc1 100644 --- a/databricks_notebooks/bulk/Common.py +++ b/databricks_notebooks/bulk/Common.py @@ -1,5 +1,11 @@ # Databricks notebook source -# MAGIC %pip install -U mlflow-skinny +# %pip install -U mlflow-skinny +# %pip install -U git+https:///github.com/mlflow/mlflow-export-import/#egg=mlflow-export-import +# dbutils.library.restartPython() + +# COMMAND ---------- + +# MAGIC %pip install -U mlflow==2.19.0 # MAGIC %pip install -U git+https:///github.com/mlflow/mlflow-export-import/#egg=mlflow-export-import # MAGIC dbutils.library.restartPython() @@ -25,3 +31,7 @@ def get_notebook_formats(num): notebook_formats = notebook_formats.split(",") if "" in notebook_formats: notebook_formats.remove("") return notebook_formats + +# COMMAND ---------- + + diff --git a/databricks_notebooks/bulk/Export_All.py b/databricks_notebooks/bulk/Export_All.py index 6e1de41f..c92edde9 100644 --- a/databricks_notebooks/bulk/Export_All.py +++ b/databricks_notebooks/bulk/Export_All.py @@ -1,7 +1,8 @@ # Databricks notebook source -# MAGIC %md ## Export All +# MAGIC %md +# MAGIC ##Export All # MAGIC -# MAGIC Export all the MLflow registered models and all experiments of a tracking server. +# MAGIC ##Export all the MLflow registered models and all experiments of a tracking server. # MAGIC # MAGIC **Widgets** # MAGIC * `1. Output directory` - shared directory between source and destination workspaces. @@ -20,32 +21,49 @@ # COMMAND ---------- -dbutils.widgets.text("1. Output directory", "") -output_dir = dbutils.widgets.get("1. Output directory") +from mlflow_export_import.bulk import config +import time +import os +from datetime import datetime + +# COMMAND ---------- + +dbutils.widgets.text("output_dir","") +output_dir = dbutils.widgets.get("output_dir") output_dir = output_dir.replace("dbfs:","/dbfs") -dbutils.widgets.multiselect("2. Stages", "Production", ["Production","Staging","Archived","None"]) -stages = dbutils.widgets.get("2. Stages") +dbutils.widgets.multiselect("stages", "Production", ["Production","Staging","Archived","None"]) +stages = dbutils.widgets.get("stages") + +dbutils.widgets.dropdown("export_latest_versions","false",["true","false"]) +export_latest_versions = dbutils.widgets.get("export_latest_versions") == "true" + +dbutils.widgets.text("run_start_date", "") +run_start_date = dbutils.widgets.get("run_start_date") -dbutils.widgets.dropdown("3. Export latest versions","no",["yes","no"]) -export_latest_versions = dbutils.widgets.get("3. Export latest versions") == "yes" +dbutils.widgets.dropdown("export_permissions","false",["true","false"]) +export_permissions = dbutils.widgets.get("export_permissions") == "true" -dbutils.widgets.text("4. Run start date", "") -run_start_date = dbutils.widgets.get("4. Run start date") +dbutils.widgets.text("task_index", "1") +task_index = int(dbutils.widgets.get("task_index")) -dbutils.widgets.dropdown("5. Export permissions","no",["yes","no"]) -export_permissions = dbutils.widgets.get("5. Export permissions") == "yes" +dbutils.widgets.text("num_tasks", "1") +num_tasks = int(dbutils.widgets.get("num_tasks")) -dbutils.widgets.dropdown("6. Export deleted runs","no",["yes","no"]) -export_deleted_runs = dbutils.widgets.get("6. Export deleted runs") == "yes" +dbutils.widgets.text("run_timestamp", "") +run_timestamp = dbutils.widgets.get("run_timestamp") -dbutils.widgets.dropdown("7. Export version MLflow model","no",["yes","no"]) # TODO -export_version_model = dbutils.widgets.get("7. Export version MLflow model") == "yes" +dbutils.widgets.text("jobrunid", "") +jobrunid = dbutils.widgets.get("jobrunid") -notebook_formats = get_notebook_formats(8) +dbutils.widgets.text("model_file_name", "") +model_file_name = dbutils.widgets.get("model_file_name") -dbutils.widgets.dropdown("9. Use threads","no",["yes","no"]) -use_threads = dbutils.widgets.get("9. Use threads") == "yes" +dbutils.widgets.dropdown("source_model_registry","unity_catalog",["unity_catalog","workspace_registry"]) +source_model_registry = dbutils.widgets.get("source_model_registry") + +dbutils.widgets.dropdown("Cloud","azure",["azure","aws","gcp"]) +cloud = dbutils.widgets.get("Cloud") if run_start_date=="": run_start_date = None @@ -54,14 +72,52 @@ print("export_latest_versions:", export_latest_versions) print("run_start_date:", run_start_date) print("export_permissions:", export_permissions) -print("export_deleted_runs:", export_deleted_runs) -print("export_version_model:", export_version_model) -print("notebook_formats:", notebook_formats) -print("use_threads:", use_threads) +print("task_index:", task_index) +print("num_tasks:", num_tasks) +print("run_timestamp:", run_timestamp) +print("jobrunid:", jobrunid) +print("model_file_name:", model_file_name) +print("source_model_registry:", source_model_registry) # COMMAND ---------- -assert_widget(output_dir, "1. Output directory") +checkpoint_dir_experiment = os.path.join(output_dir, run_timestamp,"checkpoint", "experiments") +try: + if not os.path.exists(checkpoint_dir_experiment): + os.makedirs(checkpoint_dir_experiment, exist_ok=True) + print(f"checkpoint_dir_experiment: created {checkpoint_dir_experiment}") +except Exception as e: + raise Exception(f"Failed to create directory {checkpoint_dir_experiment}: {e}") + +# COMMAND ---------- + +checkpoint_dir_model = os.path.join(output_dir, run_timestamp,"checkpoint", "models") +try: + if not os.path.exists(checkpoint_dir_model): + os.makedirs(checkpoint_dir_model, exist_ok=True) + print(f"checkpoint_dir_model: created {checkpoint_dir_model}") +except Exception as e: + raise Exception(f"Failed to create directory {checkpoint_dir_model}: {e}") + +# COMMAND ---------- + +output_dir = os.path.join(output_dir, run_timestamp, jobrunid, str(task_index)) +output_dir + +# COMMAND ---------- + +log_path=f"/tmp/exportall_{task_index}.log" +log_path + +# COMMAND ---------- + +# curr_timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") +# log_path = f"{output_dir}/export_all_{task_index}_{curr_timestamp}.log" + +# COMMAND ---------- + +config.log_path=log_path +config.target_model_registry=source_model_registry # COMMAND ---------- @@ -73,8 +129,38 @@ export_latest_versions = export_latest_versions, run_start_time = run_start_date, export_permissions = export_permissions, - export_deleted_runs = export_deleted_runs, - export_version_model = export_version_model, - notebook_formats = notebook_formats, - use_threads = use_threads + export_deleted_runs = False, + export_version_model = False, + notebook_formats = ['SOURCE'], + use_threads = True, + task_index = task_index, + num_tasks = num_tasks, + checkpoint_dir_experiment = checkpoint_dir_experiment, + checkpoint_dir_model = checkpoint_dir_model, + model_names = model_file_name ) + +# COMMAND ---------- + +time.sleep(10) + +# COMMAND ---------- + +# MAGIC %sh cat /tmp/my.log + +# COMMAND ---------- + +dbfs_log_path = f"{output_dir}/export_all_{task_index}.log" +if dbfs_log_path.startswith("/Workspace"): + dbfs_log_path=dbfs_log_path.replace("/Workspace","file:/Workspace") +dbfs_log_path = dbfs_log_path.replace("/dbfs","dbfs:") +dbfs_log_path + +# COMMAND ---------- + + +dbutils.fs.cp(f"file:{log_path}", dbfs_log_path) + +# COMMAND ---------- + +print(dbutils.fs.head(dbfs_log_path)) diff --git a/databricks_notebooks/bulk/Export_All_log_parsing.py b/databricks_notebooks/bulk/Export_All_log_parsing.py new file mode 100644 index 00000000..b7ee62e2 --- /dev/null +++ b/databricks_notebooks/bulk/Export_All_log_parsing.py @@ -0,0 +1,61 @@ +# Databricks notebook source +spark.read.parquet("/checkpoint/models/*.parquet").createOrReplaceTempView("models") + + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC select * from models +# MAGIC -- select count(distinct(model)) from models -- model=1202 +# MAGIC -- select count(distinct(experiment_id)) from models -- experiment=771 + +# COMMAND ---------- + +spark.read.parquet("/checkpoint/experiments").createOrReplaceTempView("experiments") + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC select * from experiments +# MAGIC -- select count(distinct(experiment_id)) from experiments --1774 + +# COMMAND ---------- + + + +# COMMAND ---------- + +from pyspark.sql.functions import regexp_extract, col + +log_df = spark.read.text("dbfs:/mnt/modelnonuc/2025-06-17-Export-jobid-34179827290231/jobrunid-548033559076165/*/export_all_*.log") + +# Define regex pattern +pattern = r"^(\d{2}-\w{3}-\d{2} \d{2}:\d{2}:\d{2}) - (\w+) - \[([^\]:]+):(\d+)\] - (.*)$" + +# Parse fields using regex +parsed_df = log_df.select( + regexp_extract('value', pattern, 1).alias('timestamp'), + regexp_extract('value', pattern, 2).alias('level'), + regexp_extract('value', pattern, 3).alias('module'), + regexp_extract('value', pattern, 4).alias('line_no'), + regexp_extract('value', pattern, 5).alias('message') +) + +parsed_df.createOrReplaceTempView("df") +display(parsed_df) + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC +# MAGIC select line_no,count(*),first(module) from df where level="ERROR" group by line_no + +# COMMAND ---------- + +# MAGIC %sql +# MAGIC +# MAGIC select * from df where level="ERROR" and line_no=78 + +# COMMAND ---------- + + diff --git a/databricks_notebooks/bulk/Export_Registered_Models.py b/databricks_notebooks/bulk/Export_Registered_Models.py index 4915f896..71d83477 100644 --- a/databricks_notebooks/bulk/Export_Registered_Models.py +++ b/databricks_notebooks/bulk/Export_Registered_Models.py @@ -23,57 +23,78 @@ # COMMAND ---------- -dbutils.widgets.text("01. Models", "") -models = dbutils.widgets.get("01. Models") +from mlflow_export_import.bulk import config +import time +import os -dbutils.widgets.text("02. Output directory", "dbfs:/mnt/andre-work/exim/experiments") -output_dir = dbutils.widgets.get("02. Output directory") -output_dir = output_dir.replace("dbfs:","/dbfs") +# COMMAND ---------- -dbutils.widgets.multiselect("03. Stages", "Production", ["Production","Staging","Archived","None"]) -stages = dbutils.widgets.get("03. Stages") +model_file_name = dbutils.widgets.get("model_file_name") -dbutils.widgets.dropdown("04. Export latest versions","no",["yes","no"]) -export_latest_versions = dbutils.widgets.get("04. Export latest versions") == "yes" +output_dir = dbutils.widgets.get("output_dir") +output_dir = output_dir.replace("dbfs:","/dbfs") -dbutils.widgets.dropdown("05. Export all runs","no",["yes","no"]) -export_all_runs = dbutils.widgets.get("05. Export all runs") == "yes" +stages = dbutils.widgets.get("stages") -dbutils.widgets.dropdown("06. Export permissions","no",["yes","no"]) -export_permissions = dbutils.widgets.get("06. Export permissions") == "yes" +export_latest_versions = dbutils.widgets.get("export_latest_versions") == "true" -dbutils.widgets.dropdown("07. Export deleted runs","no",["yes","no"]) -export_deleted_runs = dbutils.widgets.get("07. Export deleted runs") == "yes" +export_permissions = dbutils.widgets.get("export_permissions") == "true" -dbutils.widgets.dropdown("08. Export version MLflow model","no",["yes","no"]) # TODO -export_version_model = dbutils.widgets.get("08. Export version MLflow model") == "yes" +export_deleted_runs = dbutils.widgets.get("export_deleted_runs") == "true" -notebook_formats = get_notebook_formats("09") +task_index = int(dbutils.widgets.get("task_index")) -dbutils.widgets.dropdown("10. Use threads","no",["yes","no"]) -use_threads = dbutils.widgets.get("10. Use threads") == "yes" +num_tasks = int(dbutils.widgets.get("num_tasks")) -export_notebook_revision = False -export_all_runs = False +run_timestamp = dbutils.widgets.get("run_timestamp") -import os -os.environ["OUTPUT_DIR"] = output_dir +dbutils.widgets.text("jobrunid", "") +jobrunid = dbutils.widgets.get("jobrunid") -print("models:", models) +print("model_file_name:", model_file_name) print("output_dir:", output_dir) print("stages:", stages) print("export_latest_versions:", export_latest_versions) -print("export_all_runs:", export_all_runs) print("export_permissions:", export_permissions) print("export_deleted_runs:", export_deleted_runs) -print("export_version_model:", export_version_model) -print("notebook_formats:", notebook_formats) -print("use_threads:", use_threads) +print("task_index:", task_index) +print("num_tasks:", num_tasks) +print("run_timestamp:", run_timestamp) +print("jobrunid:", jobrunid) # COMMAND ---------- -assert_widget(models, "1. Models") -assert_widget(output_dir, "2. Output directory") +log_path=f"/tmp/my.log" +log_path + +# COMMAND ---------- + +config.log_path=log_path + +# COMMAND ---------- + +checkpoint_dir_experiment = os.path.join(output_dir, run_timestamp,"checkpoint", "experiments") +try: + if not os.path.exists(checkpoint_dir_experiment): + os.makedirs(checkpoint_dir_experiment, exist_ok=True) + print(f"checkpoint_dir_experiment: created {checkpoint_dir_experiment}") +except Exception as e: + raise Exception(f"Failed to create directory {checkpoint_dir_experiment}: {e}") + +# COMMAND ---------- + +checkpoint_dir_model = os.path.join(output_dir, run_timestamp,"checkpoint", "models") +try: + if not os.path.exists(checkpoint_dir_model): + os.makedirs(checkpoint_dir_model, exist_ok=True) + print(f"checkpoint_dir_model: created {checkpoint_dir_model}") +except Exception as e: + raise Exception(f"Failed to create directory {checkpoint_dir_model}: {e}") + +# COMMAND ---------- + +output_dir = os.path.join(output_dir, run_timestamp, jobrunid, str(task_index)) +output_dir # COMMAND ---------- @@ -84,20 +105,49 @@ from mlflow_export_import.bulk.export_models import export_models export_models( - model_names = models, + model_names = model_file_name, output_dir = output_dir, stages = stages, export_latest_versions = export_latest_versions, - export_all_runs = export_all_runs, - export_version_model = export_version_model, + export_all_runs = True, + export_version_model = False, export_permissions = export_permissions, export_deleted_runs = export_deleted_runs, - notebook_formats = notebook_formats, - use_threads = use_threads + notebook_formats = ['SOURCE'], + use_threads = True, + task_index = task_index, + num_tasks = num_tasks, + checkpoint_dir_experiment = checkpoint_dir_experiment, + checkpoint_dir_model = checkpoint_dir_model + ) # COMMAND ---------- +time.sleep(10) + +# COMMAND ---------- + +# MAGIC %sh cat /tmp/my.log + +# COMMAND ---------- + +dbfs_log_path = f"{output_dir}/export_all_{task_index}.log" +if dbfs_log_path.startswith("/Workspace"): + dbfs_log_path=dbfs_log_path.replace("/Workspace","file:/Workspace") +dbfs_log_path = dbfs_log_path.replace("/dbfs","dbfs:") +dbfs_log_path + +# COMMAND ---------- + +dbutils.fs.cp(f"file:{log_path}", dbfs_log_path) + +# COMMAND ---------- + +print(dbutils.fs.head(dbfs_log_path)) + +# COMMAND ---------- + # MAGIC %md ### Display exported files # COMMAND ---------- @@ -125,3 +175,7 @@ # COMMAND ---------- # MAGIC %sh cat $OUTPUT_DIR/experiments/experiments.json + +# COMMAND ---------- + + diff --git a/databricks_notebooks/bulk/Import_Registered_Models.py b/databricks_notebooks/bulk/Import_Registered_Models.py index 3048b498..d8707a1b 100644 --- a/databricks_notebooks/bulk/Import_Registered_Models.py +++ b/databricks_notebooks/bulk/Import_Registered_Models.py @@ -18,41 +18,126 @@ # COMMAND ---------- -dbutils.widgets.text("1. Input directory", "") -input_dir = dbutils.widgets.get("1. Input directory") +from mlflow_export_import.bulk import config +import time +from datetime import datetime +from databricks.sdk import WorkspaceClient + +# COMMAND ---------- + +dbutils.widgets.text("input_dir", "") +input_dir = dbutils.widgets.get("input_dir") input_dir = input_dir.replace("dbfs:","/dbfs") -dbutils.widgets.dropdown("2. Delete model","no",["yes","no"]) -delete_model = dbutils.widgets.get("2. Delete model") == "yes" +dbutils.widgets.dropdown("target_model_registry","unity_catalog",["unity_catalog","workspace_registry"]) +target_model_registry = dbutils.widgets.get("target_model_registry") + +dbutils.widgets.text("target_model_catalog", "") +target_model_catalog = dbutils.widgets.get("target_model_catalog") + +dbutils.widgets.text("target_model_schema", "") +target_model_schema = dbutils.widgets.get("target_model_schema") + +dbutils.widgets.dropdown("delete_model","false",["true","false"]) +delete_model = dbutils.widgets.get("delete_model") == "true" -dbutils.widgets.text("3. Model rename file","") -val = dbutils.widgets.get("3. Model rename file") -model_rename_file = val or None +dbutils.widgets.text("model_rename_file","") +val = dbutils.widgets.get("model_rename_file") +model_rename_file = {} if val in ("null", None, "") else val -dbutils.widgets.text("4. Experiment rename file","") -val = dbutils.widgets.get("4. Experiment rename file") -experiment_rename_file = val or None +dbutils.widgets.text("experiment_rename_file","") +val = dbutils.widgets.get("experiment_rename_file") +experiment_rename_file = {} if val in ("null", None, "") else val -dbutils.widgets.dropdown("5. Import permissions","no",["yes","no"]) -import_permissions = dbutils.widgets.get("5. Import permissions") == "yes" +dbutils.widgets.dropdown("import_permissions","false",["true","false"]) +import_permissions = dbutils.widgets.get("import_permissions") == "true" -dbutils.widgets.dropdown("6. Import source tags","no",["yes","no"]) -import_source_tags = dbutils.widgets.get("6. Import source tags") == "yes" +dbutils.widgets.text("task_index", "") +task_index = dbutils.widgets.get("task_index") + +dbutils.widgets.text("notebook_user_mapping_file","") +val = dbutils.widgets.get("notebook_user_mapping_file") +notebook_user_mapping_file = {} if val in ("null", None, "") else val -dbutils.widgets.dropdown("6. Use threads","no",["yes","no"]) -use_threads = dbutils.widgets.get("6. Use threads") == "yes" print("input_dir:", input_dir) +print("target_model_registry:", target_model_registry) +print("target_model_catalog:", target_model_catalog) +print("target_model_schema:", target_model_schema) print("delete_model:", delete_model) print("model_rename_file: ", model_rename_file) print("experiment_rename_file:", experiment_rename_file) print("import_permissions:", import_permissions) -print("import_source_tags:", import_source_tags) -print("use_threads:", use_threads) +print("task_index:", task_index) +print("notebook_user_mapping_file:", notebook_user_mapping_file) + +# COMMAND ---------- + +print(f"experiment_rename_file is {experiment_rename_file}") +print(f"experiment_rename_file type is {type(experiment_rename_file)}") + +print(f"model_rename_file is {model_rename_file}") +print(f"model_rename_file type is {type(model_rename_file)}") + +print(f"delete_model is {delete_model}") +print(f"import_permissions is {import_permissions}") + +# COMMAND ---------- + +if not input_dir: + raise ValueError("input_dir cannot be empty") +if not task_index: + raise ValueError("task_index cannot be empty") +if not task_index.isdigit(): + raise ValueError("task_index must be a number") + +# COMMAND ---------- + +if target_model_registry == "workspace_registry": + target_model_catalog = None + target_model_schema = None + +# COMMAND ---------- + +if target_model_registry == "unity_catalog" and (not target_model_catalog or not target_model_schema): + raise ValueError("target_model_catalog and target_model_schema cannot be blank when target_model_registry is 'unity_catalog'") + +# COMMAND ---------- + +if target_model_registry == "unity_catalog": + w = WorkspaceClient() + try: + catalog = w.catalogs.get(name=target_model_catalog) + print(f"Catalog '{target_model_catalog}' exists.") + except Exception as e: + raise ValueError(f"Error - {e}") + try: + schema = w.schemas.get(full_name=f"{target_model_catalog}.{target_model_schema}") + print(f"Schema '{target_model_catalog}.{target_model_schema}' exists.") + except Exception as e: + raise ValueError(f"Error - {e}") # COMMAND ---------- -assert_widget(input_dir, "1. Input directory") +if input_dir.startswith("/Workspace"): + input_dir=input_dir.replace("/Workspace","file:/Workspace") + +input_dir + +# COMMAND ---------- + +log_path=f"/tmp/Import_Registered_Models_{task_index}.log" +log_path + +# COMMAND ---------- + +# curr_timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") +# log_path = f"{input_dir}/Import_Registered_Models_{task_index}_{curr_timestamp}.log" + +# COMMAND ---------- + +config.log_path=log_path +config.target_model_registry=target_model_registry # COMMAND ---------- @@ -64,6 +149,32 @@ model_renames = model_rename_file, experiment_renames = experiment_rename_file, import_permissions = import_permissions, - import_source_tags = import_source_tags, - use_threads = use_threads + import_source_tags = False, ## Birbal:: Do not set to True. else it will import junk mlflow tags. Setting to False WILL import all source tags by default. + use_threads = True, + target_model_catalog = target_model_catalog, #birbal added + target_model_schema = target_model_schema, #birbal added, + notebook_user_mapping_file = notebook_user_mapping_file #birbal added ) + +# COMMAND ---------- + +time.sleep(10) + +# COMMAND ---------- + +curr_timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + +dbfs_log_path = f"{input_dir}/Import_Registered_Models_{task_index}_{curr_timestamp}.log" +if dbfs_log_path.startswith("/Workspace"): + dbfs_log_path=dbfs_log_path.replace("/Workspace","file:/Workspace") +dbfs_log_path = dbfs_log_path.replace("/dbfs","dbfs:") +dbfs_log_path + +# COMMAND ---------- + + +dbutils.fs.cp(f"file:{log_path}", dbfs_log_path) + +# COMMAND ---------- + +print(dbutils.fs.head(dbfs_log_path)) diff --git a/databricks_notebooks/bulk/master_Export_Registered_Models.py b/databricks_notebooks/bulk/master_Export_Registered_Models.py new file mode 100644 index 00000000..f5da68b0 --- /dev/null +++ b/databricks_notebooks/bulk/master_Export_Registered_Models.py @@ -0,0 +1,146 @@ +# Databricks notebook source +import requests +import json + +# COMMAND ---------- + +dbutils.widgets.removeAll() + +# COMMAND ---------- + +dbutils.widgets.text("01. model_file_name", "") +model_file_name = dbutils.widgets.get("01. model_file_name") + +dbutils.widgets.text("02. Output directory", "/dbfs/mnt/") +output_dir = dbutils.widgets.get("02. Output directory") +output_dir = output_dir.replace("dbfs:","/dbfs") + +dbutils.widgets.multiselect("03. Stages", "Production", ["Production","Staging","Archived","None"]) +stages = dbutils.widgets.get("03. Stages") + +dbutils.widgets.dropdown("04. Export latest versions","no",["yes","no"]) +export_latest_versions = dbutils.widgets.get("04. Export latest versions") == "yes" + +dbutils.widgets.dropdown("05. Export permissions","no",["yes","no"]) +export_permissions = dbutils.widgets.get("05. Export permissions") == "yes" + +dbutils.widgets.dropdown("06. Export deleted runs","no",["yes","no"]) +export_deleted_runs = dbutils.widgets.get("06. Export deleted runs") == "yes" + +dbutils.widgets.text("07. num_tasks", "1") +num_tasks = dbutils.widgets.get("07. num_tasks") + + +import os +os.environ["OUTPUT_DIR"] = output_dir + +print("model_file_name:", model_file_name) +print("output_dir:", output_dir) +print("stages:", stages) +print("export_latest_versions:", export_latest_versions) +print("export_permissions:", export_permissions) +print("export_deleted_runs:", export_deleted_runs) +print("num_tasks:", num_tasks) + +# COMMAND ---------- + +if not output_dir: + raise ValueError("output_dir cannot be empty") +if not output_dir.startswith("/dbfs/mnt"): + raise ValueError("output_dir must start with /dbfs/mnt") +if not num_tasks: + raise ValueError("num_tasks cannot be empty") +if not num_tasks.isdigit(): + raise ValueError("num_tasks must be a number") + +# COMMAND ---------- + +if model_file_name: + if not model_file_name.endswith(".txt"): + raise ValueError("model_file_name must end with .txt if not empty") + if not model_file_name.startswith("/dbfs"): + raise ValueError("model_file_name must start with /dbfs if not empty") +else: + model_file_name = "all" + +model_file_name + +# COMMAND ---------- + +DATABRICKS_INSTANCE=dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags().get('browserHostName').getOrElse(None) +DATABRICKS_INSTANCE = f"https://{DATABRICKS_INSTANCE}" +DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) + +driver_node_type = "Standard_D4ds_v5" +worker_node_type = "Standard_D4ds_v5" + +def create_multi_task_job_json(): + tasks = [] + + for i in range(1, int(num_tasks)+1): + task = { + "task_key": f"task_{i}", + "description": f"Bir Task for param1 = {i}", + "new_cluster": { + "spark_version": "15.4.x-cpu-ml-scala2.12", + "node_type_id": worker_node_type, + "driver_node_type_id": driver_node_type, + "num_workers": 1, + "data_security_mode": "SINGLE_USER", + "runtime_engine": "STANDARD" + }, + "notebook_task": { + "notebook_path": "/Workspace/Users/birbal.das@databricks.com/AA_sephora/birnew-mlflow-export-import/databricks_notebooks/bulk/Export_Registered_Models", + "base_parameters": { + "model_file_name" : model_file_name, + "output_dir" : output_dir, + "stages" : stages, + "export_latest_versions" : export_latest_versions, + "export_permissions" : export_permissions, + "export_deleted_runs" : export_deleted_runs, + "task_index": i, + "num_tasks" : num_tasks, + "run_timestamp": "{{job.start_time.iso_date}}-ExportModels-jobid-{{job.id}}", + "jobrunid": "jobrunid-{{job.run_id}}" + } + } + } + tasks.append(task) + + job_json = { + "name": "Export_Registered_Models_job", + "tasks": tasks, + "format": "MULTI_TASK" + } + + return job_json + +def submit_databricks_job(): + job_payload = create_multi_task_job_json() + + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + response = requests.post( + f"{DATABRICKS_INSTANCE}/api/2.2/jobs/create", + headers=headers, + data=json.dumps(job_payload) + ) + + if response.status_code == 200: + print("Job submitted successfully.") + print("Response:", response.json()) + else: + print("Error submitting job:", response.status_code, response.text) + + + +# COMMAND ---------- + +submit_databricks_job() + +# COMMAND ---------- + + diff --git a/databricks_notebooks/bulk/master_Export_all.py b/databricks_notebooks/bulk/master_Export_all.py new file mode 100644 index 00000000..f164da7f --- /dev/null +++ b/databricks_notebooks/bulk/master_Export_all.py @@ -0,0 +1,152 @@ +# Databricks notebook source +import requests +import json +from datetime import datetime + +# COMMAND ---------- + +dbutils.widgets.text("1. Output directory", "") +output_dir = dbutils.widgets.get("1. Output directory") +output_dir = output_dir.replace("dbfs:","/dbfs") + +dbutils.widgets.multiselect("2. Stages", "Production", ["Production","Staging","Archived","None"]) +stages = dbutils.widgets.get("2. Stages") + +dbutils.widgets.dropdown("3. Export latest versions","no",["yes","no"]) +export_latest_versions = dbutils.widgets.get("3. Export latest versions") == "yes" + +dbutils.widgets.text("4. Run start date", "") +run_start_date = dbutils.widgets.get("4. Run start date") + +dbutils.widgets.dropdown("5. Export permissions","no",["yes","no"]) +export_permissions = dbutils.widgets.get("5. Export permissions") == "yes" + +dbutils.widgets.text("6. num_tasks", "") +num_tasks = dbutils.widgets.get("6. num_tasks") + +dbutils.widgets.text("7. model_file_name", "") +model_file_name = dbutils.widgets.get("7. model_file_name") + +dbutils.widgets.dropdown("8. Source model registry","unity_catalog",["unity_catalog","workspace_registry"]) +source_model_registry = dbutils.widgets.get("8. Source model registry") + +dbutils.widgets.dropdown("9. Cloud","azure",["azure","aws","gcp"]) +cloud = dbutils.widgets.get("9. Cloud") + +print("output_dir:", output_dir) +print("stages:", stages) +print("export_latest_versions:", export_latest_versions) +print("run_start_date:", run_start_date) +print("export_permissions:", export_permissions) +print("num_tasks:", num_tasks) +print("model_file_name:", model_file_name) +print("source_model_registry:", source_model_registry) +print("cloud:", cloud) + +# COMMAND ---------- + +if not output_dir: + raise ValueError("output_dir cannot be empty") +if not num_tasks: + raise ValueError("num_tasks cannot be empty") +if not num_tasks.isdigit(): + raise ValueError("num_tasks must be a number") + +# COMMAND ---------- + +if model_file_name: + if not model_file_name.endswith(".txt"): + raise ValueError("model_file_name must end with .txt if not empty") +else: + model_file_name = "all" + +model_file_name + +# COMMAND ---------- + +DATABRICKS_INSTANCE=dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags().get('browserHostName').getOrElse(None) +DATABRICKS_INSTANCE = f"https://{DATABRICKS_INSTANCE}" +DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) + +if cloud == "azure": + driver_node_type = "Standard_D4ds_v5" + worker_node_type = "Standard_D4ds_v5" + +if cloud == "aws": + driver_node_type = "rd-fleet.xlarge" + worker_node_type = "rd-fleet.xlarge" + +if cloud == "gcp": + driver_node_type = "n1-standard-4" + worker_node_type = "n1-standard-4" + +def create_multi_task_job_json(): + tasks = [] + for i in range(1, int(num_tasks)+1): + task = { + "task_key": f"task_{i}", + "description": f"Bir Task for param1 = {i}", + "new_cluster": { + "spark_version": "15.4.x-cpu-ml-scala2.12", + "node_type_id": worker_node_type, + "driver_node_type_id": driver_node_type, + "num_workers": 1, + "data_security_mode": "SINGLE_USER", + "runtime_engine": "STANDARD" + }, + "notebook_task": { + "notebook_path": "/Workspace/Users/birbal.das@databricks.com/AA_sephora_notebook_export_fix/birnew-mlflow-export-import/databricks_notebooks/bulk/Export_All", + "base_parameters": { + "output_dir": output_dir, + "stages": stages, + "export_latest_versions": export_latest_versions, + "run_start_date": run_start_date, + "export_permissions": export_permissions, + "task_index": i, + "num_tasks": num_tasks, + "run_timestamp": "{{job.start_time.iso_date}}-ExportAll-jobid-{{job.id}}", + "jobrunid": "jobrunid-{{job.run_id}}", + "model_file_name": model_file_name, + "source_model_registry": source_model_registry + } + } + } + tasks.append(task) + + job_json = { + "name": "Export_All_Models", + "tasks": tasks, + "format": "MULTI_TASK" + } + + return job_json + +def submit_databricks_job(): + job_payload = create_multi_task_job_json() + + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + response = requests.post( + f"{DATABRICKS_INSTANCE}/api/2.2/jobs/create", + headers=headers, + data=json.dumps(job_payload) + ) + + if response.status_code == 200: + print("Job submitted successfully.") + print("Response:", response.json()) + else: + print("Error submitting job:", response.status_code, response.text) + + + +# COMMAND ---------- + +submit_databricks_job() + +# COMMAND ---------- + + diff --git a/databricks_notebooks/bulk/master_Import_Registered_Models.py b/databricks_notebooks/bulk/master_Import_Registered_Models.py new file mode 100644 index 00000000..a4faadfe --- /dev/null +++ b/databricks_notebooks/bulk/master_Import_Registered_Models.py @@ -0,0 +1,171 @@ +# Databricks notebook source +import requests +import json +import os +from databricks.sdk import WorkspaceClient + +# COMMAND ---------- + +dbutils.widgets.text("1. Input directory", "") +input_dir = dbutils.widgets.get("1. Input directory") +input_dir = input_dir.replace("dbfs:","/dbfs") + +dbutils.widgets.dropdown("2. Target model registry","unity_catalog",["unity_catalog","workspace_registry"]) +target_model_registry = dbutils.widgets.get("2. Target model registry") + +dbutils.widgets.text("3. Target catalog for model", "") +target_model_catalog = dbutils.widgets.get("3. Target catalog for model") + +dbutils.widgets.text("4. Target schema for model", "") +target_model_schema = dbutils.widgets.get("4. Target schema for model") + +dbutils.widgets.dropdown("5. Delete model","no",["yes","no"]) +delete_model = dbutils.widgets.get("5. Delete model") == "yes" + +dbutils.widgets.text("6. Model rename file","") +val = dbutils.widgets.get("6. Model rename file") +model_rename_file = val or None + +dbutils.widgets.text("7. Experiment rename file","") +val = dbutils.widgets.get("7. Experiment rename file") +experiment_rename_file = val or None + +dbutils.widgets.dropdown("8. Import permissions","no",["yes","no"]) +import_permissions = dbutils.widgets.get("8. Import permissions") == "yes" + +dbutils.widgets.text("9. num_tasks", "") +num_tasks = dbutils.widgets.get("9. num_tasks") + +dbutils.widgets.dropdown("10. Cloud","azure",["azure","aws","gcp"]) +cloud = dbutils.widgets.get("10. Cloud") + +dbutils.widgets.text("11. Notebook user mapping file","") +val = dbutils.widgets.get("11. Notebook user mapping file") +notebook_user_mapping_file = val or None + +print("input_dir:", input_dir) +print("target_model_registry:", target_model_registry) +print("target_model_catalog:", target_model_catalog) +print("target_model_schema:", target_model_schema) +print("delete_model:", delete_model) +print("model_rename_file:", model_rename_file) +print("experiment_rename_file:", experiment_rename_file) +print("import_permissions:", import_permissions) +print("num_tasks:", num_tasks) +print("cloud:", cloud) +print("notebook_user_mapping_file:", notebook_user_mapping_file) + +# COMMAND ---------- + +if not input_dir: + raise ValueError("input_dir cannot be empty") +if not num_tasks: + raise ValueError("num_tasks cannot be empty") +if not num_tasks.isdigit(): + raise ValueError("num_tasks must be a number") + +# COMMAND ---------- + +if target_model_registry == "unity_catalog" and (not target_model_catalog or not target_model_schema): + raise ValueError("target_model_catalog and target_model_schema cannot be blank when target_model_registry is 'unity_catalog'") + +# COMMAND ---------- + +if target_model_registry == "unity_catalog": + w = WorkspaceClient() + try: + catalog = w.catalogs.get(name=target_model_catalog) + print(f"Catalog '{target_model_catalog}' exists.") + except Exception as e: + raise ValueError(f"Error - {e}") + + try: + schema = w.schemas.get(full_name=f"{target_model_catalog}.{target_model_schema}") + print(f"Schema '{target_model_catalog}.{target_model_schema}' exists.") + except Exception as e: + raise ValueError(f"Error - {e}") + +# COMMAND ---------- + +DATABRICKS_INSTANCE=dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags().get('browserHostName').getOrElse(None) +DATABRICKS_INSTANCE = f"https://{DATABRICKS_INSTANCE}" +DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) + + +if cloud == "azure": + driver_node_type = "Standard_D4ds_v5" + worker_node_type = "Standard_D4ds_v5" + +if cloud == "aws": + driver_node_type = "m4.xlarge" + worker_node_type = "m4.xlarge" + +if cloud == "gcp": + driver_node_type = "n1-standard-4" + worker_node_type = "n1-standard-4" + +def create_multi_task_job_json(): + tasks = [] + for i in range(1, int(num_tasks)+1): + task = { + "task_key": f"task_{i}", + "description": f"Import task for task_index = {i}", + "new_cluster": { + "spark_version": "15.4.x-cpu-ml-scala2.12", + "node_type_id": worker_node_type, + "driver_node_type_id": driver_node_type, + "num_workers": 1, + "data_security_mode": "SINGLE_USER", + "runtime_engine": "STANDARD" + }, + "notebook_task": { + "notebook_path": "/Workspace/Users/birbal.das@databricks.com/AA_final/birnew-mlflow-export-import/databricks_notebooks/bulk/Import_Registered_Models", + "base_parameters": { + "input_dir": os.path.join(input_dir,str(i)), + "target_model_registry": target_model_registry, + "target_model_catalog": target_model_catalog, + "target_model_schema": target_model_schema, + "delete_model": delete_model, + "model_rename_file": model_rename_file, + "experiment_rename_file": experiment_rename_file, + "import_permissions": import_permissions, + "task_index": str(i), + "notebook_user_mapping_file":notebook_user_mapping_file + } + } + } + tasks.append(task) + + job_json = { + "name": "Import_Registered_Models_job", + "tasks": tasks, + "format": "MULTI_TASK" + } + + return job_json + +def submit_databricks_job(): + job_payload = create_multi_task_job_json() + + headers = { + "Authorization": f"Bearer {DATABRICKS_TOKEN}", + "Content-Type": "application/json" + } + + response = requests.post( + f"{DATABRICKS_INSTANCE}/api/2.2/jobs/create", + headers=headers, + data=json.dumps(job_payload) + ) + + if response.status_code == 200: + print("Job submitted successfully.") + print("Response:", response.json()) + else: + print("Error submitting job:", response.status_code, response.text) + + + +# COMMAND ---------- + +submit_databricks_job() diff --git a/mlflow_export_import/bulk/bulk_utils.py b/mlflow_export_import/bulk/bulk_utils.py index ac2ee9d5..47910332 100644 --- a/mlflow_export_import/bulk/bulk_utils.py +++ b/mlflow_export_import/bulk/bulk_utils.py @@ -1,8 +1,10 @@ from mlflow_export_import.common.iterators import SearchRegisteredModelsIterator from mlflow_export_import.common.iterators import SearchExperimentsIterator +from mlflow_export_import.common import utils #birbal added +_logger = utils.getLogger(__name__) #birbal added -def _get_list(names, func_list): +def _get_list(names, func_list, task_index=None, num_tasks=None): #birbal updated """ Returns a list of entities specified by the 'names' filter. :param names: Filter of desired list of entities. Can be: "all", comma-delimited string, list of entities or trailing wildcard "*". @@ -11,14 +13,27 @@ def _get_list(names, func_list): """ if isinstance(names, str): if names == "all": - return func_list() + if task_index is None or num_tasks is None: + return func_list() + else: + all_items=func_list() + _logger.info(f"TOTAL MODEL IN THE WORKSPACE REGISTRY IS {len(all_items)}") + return get_subset_list(all_items, task_index, num_tasks) + + elif names.endswith("*"): prefix = names[:-1] return [ x for x in func_list() if x.startswith(prefix) ] else: return names.split(",") - else: + + elif isinstance(names, dict): #birbal added return names + + else: + return get_subset_list(names, task_index, num_tasks) #birbal updated + + def get_experiment_ids(mlflow_client, experiment_ids): @@ -27,7 +42,23 @@ def list_entities(): return _get_list(experiment_ids, list_entities) -def get_model_names(mlflow_client, model_names): +def get_model_names(mlflow_client, model_names,task_index=None,num_tasks=None): #birbal updated def list_entities(): return [ model.name for model in SearchRegisteredModelsIterator(mlflow_client) ] - return _get_list(model_names, list_entities) + return _get_list(model_names, list_entities, task_index, num_tasks) #birbal updated + + +def get_subset_list(fulllist, task_index, num_tasks): + fulllist.sort() + total_items = len(fulllist) + base_size, remainder = divmod(total_items, num_tasks) + + if task_index <= remainder: + start = (base_size + 1) * (task_index - 1) + end = start + base_size + 1 + else: + start = (base_size + 1) * remainder + base_size * (task_index - remainder - 1) + end = start + base_size + + return fulllist[start:end] + diff --git a/mlflow_export_import/bulk/config.py b/mlflow_export_import/bulk/config.py new file mode 100644 index 00000000..4fdc20d1 --- /dev/null +++ b/mlflow_export_import/bulk/config.py @@ -0,0 +1,2 @@ +log_path=None +target_model_registry=None \ No newline at end of file diff --git a/mlflow_export_import/bulk/export_all.py b/mlflow_export_import/bulk/export_all.py index 1ed800c1..a4088e9c 100644 --- a/mlflow_export_import/bulk/export_all.py +++ b/mlflow_export_import/bulk/export_all.py @@ -23,6 +23,10 @@ from mlflow_export_import.client.client_utils import create_mlflow_client from mlflow_export_import.bulk.export_models import export_models from mlflow_export_import.bulk.export_experiments import export_experiments +from mlflow_export_import.bulk import bulk_utils +from mlflow_export_import.bulk.model_utils import get_experiments_name_of_models +from mlflow_export_import.common.checkpoint_thread import CheckpointThread, filter_unprocessed_objects #birbal added +from mlflow_export_import.bulk.model_utils import get_experiment_runs_dict_from_names #birbal added ALL_STAGES = "Production,Staging,Archived,None" @@ -39,13 +43,24 @@ def export_all( export_permissions = False, notebook_formats = None, use_threads = False, - mlflow_client = None + mlflow_client = None, + task_index = None, + num_tasks = None, + checkpoint_dir_experiment = None, + checkpoint_dir_model = None, + model_names = None ): mlflow_client = mlflow_client or create_mlflow_client() + + ### + ### + + start_time = time.time() res_models = export_models( mlflow_client = mlflow_client, - model_names = "all", + # model_names = "all", + model_names = model_names, output_dir = output_dir, stages = stages, export_latest_versions = export_latest_versions, @@ -54,25 +69,51 @@ def export_all( export_permissions = export_permissions, export_version_model = export_version_model, notebook_formats = notebook_formats, - use_threads = use_threads + use_threads = use_threads, + task_index = task_index, + num_tasks = num_tasks, + checkpoint_dir_experiment = checkpoint_dir_experiment, + checkpoint_dir_model = checkpoint_dir_model + ) - # Only import those experiments not exported by above export_models() - exported_exp_names = res_models["experiments"]["experiment_names"] - all_exps = SearchExperimentsIterator(mlflow_client) - all_exp_names = [ exp.name for exp in all_exps ] - remaining_exp_names = list(set(all_exp_names) - set(exported_exp_names)) + res_exps = None + if not model_names.endswith('.txt'): + all_exps = SearchExperimentsIterator(mlflow_client) + all_exps = list(set(all_exps)) + all_exp_names = [ exp.name for exp in all_exps ] + _logger.info(f"TOTAL WORKSPACE EXPERIMENT COUNT: {len(all_exp_names)}") - res_exps = export_experiments( - mlflow_client = mlflow_client, - experiments = remaining_exp_names, - output_dir = os.path.join(output_dir,"experiments"), - export_permissions = export_permissions, - run_start_time = run_start_time, - export_deleted_runs = export_deleted_runs, - notebook_formats = notebook_formats, - use_threads = use_threads - ) + all_model_exp_names=get_experiments_name_of_models(mlflow_client,model_names = "all") + + + all_model_exp_names = list(set(all_model_exp_names)) + _logger.info(f"TOTAL WORKSPACE MODEL EXPERIMENT COUNT: {len(all_model_exp_names)}") + + remaining_exp_names = list(set(all_exp_names) - set(all_model_exp_names)) + _logger.info(f"TOTAL WORKSPACE EXPERIMENT COUNT WITH NO MODEL: {len(remaining_exp_names)}") + + + remaining_exp_names_subset = bulk_utils.get_subset_list(remaining_exp_names, task_index, num_tasks) #birbal added + _logger.info(f"TOTAL WORKSPACE EXPERIMENT COUNT WITH NO MODEL FOR TASK_INDEX={task_index}: {len(remaining_exp_names_subset)}") #birbal added + + exps_and_runs = get_experiment_runs_dict_from_names(mlflow_client, remaining_exp_names_subset) #birbal added + + exps_and_runs = filter_unprocessed_objects(checkpoint_dir_experiment,"experiments",exps_and_runs) + _logger.info(f"AFTER FILTERING OUT THE PROCESSED EXPERIMENTS FROM CHECKPOINT, TOTAL REMAINING COUNT: {len(exps_and_runs)}") + + res_exps = export_experiments( + mlflow_client = mlflow_client, + experiments = exps_and_runs, #birbal added + output_dir = os.path.join(output_dir,"experiments"), + export_permissions = export_permissions, + run_start_time = run_start_time, + export_deleted_runs = export_deleted_runs, + notebook_formats = notebook_formats, + use_threads = use_threads, + task_index = task_index, #birbal added + checkpoint_dir_experiment = checkpoint_dir_experiment #birbal + ) duration = round(time.time() - start_time, 1) info_attr = { "options": { diff --git a/mlflow_export_import/bulk/export_experiments.py b/mlflow_export_import/bulk/export_experiments.py index 8b9712e2..e788a16f 100644 --- a/mlflow_export_import/bulk/export_experiments.py +++ b/mlflow_export_import/bulk/export_experiments.py @@ -24,6 +24,8 @@ from mlflow_export_import.common import filesystem as _fs from mlflow_export_import.bulk import bulk_utils from mlflow_export_import.experiment.export_experiment import export_experiment +from mlflow_export_import.common.checkpoint_thread import CheckpointThread #birbal added +from queue import Queue #birbal added _logger = utils.getLogger(__name__) @@ -35,7 +37,9 @@ def export_experiments( export_deleted_runs = False, notebook_formats = None, use_threads = False, - mlflow_client = None + mlflow_client = None, + task_index = None, #birbal added + checkpoint_dir_experiment = None #birbal added ): """ :param experiments: Can be either: @@ -50,7 +54,7 @@ def export_experiments( mlflow_client = mlflow_client or mlflow.MlflowClient() start_time = time.time() max_workers = utils.get_threads(use_threads) - experiments_arg = _convert_dict_keys_to_list(experiments) + experiments_arg = _convert_dict_keys_to_list(experiments.keys()) #birbal added if isinstance(experiments,str) and experiments.endswith(".txt"): with open(experiments, "r", encoding="utf-8") as f: @@ -61,6 +65,7 @@ def export_experiments( else: export_all_runs = not isinstance(experiments, dict) experiments = bulk_utils.get_experiment_ids(mlflow_client, experiments) + if export_all_runs: table_data = experiments columns = ["Experiment Name or ID"] @@ -73,87 +78,112 @@ def export_experiments( table_data.append(["Total",num_runs]) columns = ["Experiment ID", "# Runs"] utils.show_table("Experiments",table_data,columns) - _logger.info("") + + if len(experiments) == 0: + _logger.info(f"NO EXPERIMENTS TO PROCESS") + return + + ######## birbal new block + result_queue = Queue() + checkpoint_thread = CheckpointThread(result_queue, checkpoint_dir_experiment, interval=300, batch_size=100) + _logger.info(f"checkpoint_thread started for experiments") + checkpoint_thread.start() + ######## ok_runs = 0 failed_runs = 0 export_results = [] futures = [] - with ThreadPoolExecutor(max_workers=max_workers) as executor: - for exp_id_or_name in experiments: - run_ids = experiments_dct.get(exp_id_or_name, None) - future = executor.submit(_export_experiment, - mlflow_client, - exp_id_or_name, - output_dir, - export_permissions, - notebook_formats, - export_results, - run_start_time, - export_deleted_runs, - run_ids - ) - futures.append(future) - duration = round(time.time() - start_time, 1) - ok_runs = 0 - failed_runs = 0 - experiment_names = [] - for future in futures: - result = future.result() - ok_runs += result.ok_runs - failed_runs += result.failed_runs - experiment_names.append(result.name) - - total_runs = ok_runs + failed_runs - duration = round(time.time() - start_time, 1) - - info_attr = { - "experiment_names": experiment_names, - "options": { - "experiments": experiments_arg, - "output_dir": output_dir, - "export_permissions": export_permissions, - "run_start_time": run_start_time, - "export_deleted_runs": export_deleted_runs, - "notebook_formats": notebook_formats, - "use_threads": use_threads - }, - "status": { - "duration": duration, - "experiments": len(experiments), - "total_runs": total_runs, - "ok_runs": ok_runs, - "failed_runs": failed_runs + + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for exp_id_or_name in experiments: + run_ids = experiments_dct.get(exp_id_or_name, []) + future = executor.submit(_export_experiment, + mlflow_client, + exp_id_or_name, + output_dir, + export_permissions, + notebook_formats, + export_results, + run_start_time, + export_deleted_runs, + run_ids, + result_queue #birbal added + ) + futures.append(future) + duration = round(time.time() - start_time, 1) + ok_runs = 0 + failed_runs = 0 + experiment_names = [] + for future in futures: + result = future.result() + ok_runs += result.ok_runs + failed_runs += result.failed_runs + experiment_names.append(result.name) + + total_runs = ok_runs + failed_runs + duration = round(time.time() - start_time, 1) + + info_attr = { + "experiment_names": experiment_names, + "options": { + "experiments": experiments_arg, + "output_dir": output_dir, + "export_permissions": export_permissions, + "run_start_time": run_start_time, + "export_deleted_runs": export_deleted_runs, + "notebook_formats": notebook_formats, + "use_threads": use_threads + }, + "status": { + "duration": duration, + "experiments": len(experiments), + "total_runs": total_runs, + "ok_runs": ok_runs, + "failed_runs": failed_runs + } } - } - mlflow_attr = { "experiments": export_results } + mlflow_attr = { "experiments": export_results } - # NOTE: Make sure we don't overwrite existing experiments.json generated by export_models when being called by export_all. - # Merge this existing experiments.json with the new built by export_experiments. - path = _fs.mk_local_path(os.path.join(output_dir, "experiments.json")) - if os.path.exists(path): - from mlflow_export_import.bulk.experiments_merge_utils import merge_mlflow, merge_info - root = io_utils.read_file(path) - mlflow_attr = merge_mlflow(io_utils.get_mlflow(root), mlflow_attr) - info_attr = merge_info(io_utils.get_info(root), info_attr) - info_attr["note"] = "Merged by export_all from export_models and export_experiments" + # NOTE: Make sure we don't overwrite existing experiments.json generated by export_models when being called by export_all. + # Merge this existing experiments.json with the new built by export_experiments. + path = _fs.mk_local_path(os.path.join(output_dir, "experiments.json")) + if os.path.exists(path): + from mlflow_export_import.bulk.experiments_merge_utils import merge_mlflow, merge_info + root = io_utils.read_file(path) + mlflow_attr = merge_mlflow(io_utils.get_mlflow(root), mlflow_attr) + info_attr = merge_info(io_utils.get_info(root), info_attr) + info_attr["note"] = "Merged by export_all from export_models and export_experiments" - io_utils.write_export_file(output_dir, "experiments.json", __file__, mlflow_attr, info_attr) + io_utils.write_export_file(output_dir, "experiments.json", __file__, mlflow_attr, info_attr) - _logger.info(f"{len(experiments)} experiments exported") - _logger.info(f"{ok_runs}/{total_runs} runs succesfully exported") - if failed_runs > 0: - _logger.info(f"{failed_runs}/{total_runs} runs failed") - _logger.info(f"Duration for {len(experiments)} experiments export: {duration} seconds") + _logger.info(f"{len(experiments)} experiments exported") + _logger.info(f"{ok_runs}/{total_runs} runs succesfully exported") + if failed_runs > 0: + _logger.info(f"{failed_runs}/{total_runs} runs failed") + _logger.info(f"Duration for {len(experiments)} experiments export: {duration} seconds") - return info_attr + return info_attr + + except Exception as e: #birbal added + _logger.error(f"_export_experiment failed: {e}", exc_info=True) + + finally: #birbal added + checkpoint_thread.stop() + checkpoint_thread.join() + _logger.info("Checkpoint thread flushed and terminated for experiments") def _export_experiment(mlflow_client, exp_id_or_name, output_dir, export_permissions, notebook_formats, export_results, - run_start_time, export_deleted_runs, run_ids): + run_start_time, export_deleted_runs, run_ids, result_queue = None): #birbal added result_queue = None ok_runs = -1; failed_runs = -1 exp_name = exp_id_or_name try: + if not run_ids: + _logger.error(f"no runs to export for experiment {exp_id_or_name}. Throwing exception to capture in checkpoint file") + raise Exception(f"no runs to export for experiment {exp_id_or_name}") + exp = mlflow_utils.get_experiment(mlflow_client, exp_id_or_name) exp_name = exp.name exp_output_dir = os.path.join(output_dir, exp.experiment_id) @@ -166,7 +196,8 @@ def _export_experiment(mlflow_client, exp_id_or_name, output_dir, export_permiss run_start_time = run_start_time, export_deleted_runs = export_deleted_runs, notebook_formats = notebook_formats, - mlflow_client = mlflow_client + mlflow_client = mlflow_client, + result_queue = result_queue #birbal added ) duration = round(time.time() - start_time, 1) result = { @@ -181,14 +212,23 @@ def _export_experiment(mlflow_client, exp_id_or_name, output_dir, export_permiss except RestException as e: mlflow_utils.dump_exception(e) - err_msg = { **{ "message": "Cannot export experiment", "experiment": exp_name }, ** mlflow_utils.mk_msg_RestException(e) } + err_msg = { **{ "message": "Cannot export experiment", "experiment_id": exp_id_or_name }, ** mlflow_utils.mk_msg_RestException(str(e)) } #birbal type casted _logger.error(err_msg) + err_msg["status"] = "failed" #birbal added + result_queue.put(err_msg) #birbal added except MlflowExportImportException as e: - err_msg = { "message": "Cannot export experiment", "experiment": exp_name, "MlflowExportImportException": e.kwargs } + err_msg = { "message": "Cannot export experiment", "experiment_id": exp_id_or_name, "MlflowExportImportException": str(e.kwargs) } #birbal string casted _logger.error(err_msg) + + err_msg["status"] = "failed" #birbal added + result_queue.put(err_msg) #birbal added except Exception as e: - err_msg = { "message": "Cannot export experiment", "experiment": exp_name, "Exception": e } + err_msg = { "message": "Cannot export experiment", "experiment_id": exp_id_or_name, "Exception": str(e) } #birbal string casted _logger.error(err_msg) + + err_msg["status"] = "failed" #birbal added + result_queue.put(err_msg) #birbal added + return Result(exp_name, ok_runs, failed_runs) diff --git a/mlflow_export_import/bulk/export_models.py b/mlflow_export_import/bulk/export_models.py index 83100ea3..90262a1d 100644 --- a/mlflow_export_import/bulk/export_models.py +++ b/mlflow_export_import/bulk/export_models.py @@ -25,6 +25,8 @@ from mlflow_export_import.bulk import export_experiments from mlflow_export_import.bulk.model_utils import get_experiments_runs_of_models from mlflow_export_import.bulk import bulk_utils +from mlflow_export_import.common.checkpoint_thread import CheckpointThread, filter_unprocessed_objects #birbal added +from queue import Queue #birbal added _logger = utils.getLogger(__name__) @@ -40,7 +42,11 @@ def export_models( export_version_model = False, notebook_formats = None, use_threads = False, - mlflow_client = None + mlflow_client = None, + task_index = None, #birbal + num_tasks = None, #birbal + checkpoint_dir_experiment = None, #birbal + checkpoint_dir_model = None #birbal ): """ :param: model_names: Can be either: @@ -55,19 +61,34 @@ def export_models( model_names = f.read().splitlines() mlflow_client = mlflow_client or create_mlflow_client() - exps_and_runs = get_experiments_runs_of_models(mlflow_client, model_names) - exp_ids = exps_and_runs.keys() + exps_and_runs = get_experiments_runs_of_models(mlflow_client, model_names, task_index, num_tasks) ##birbal return dict of key=exp_id and value=list[run_id] + + total_run_ids = sum(len(run_id_list) for run_id_list in exps_and_runs.values()) #birbal added + _logger.info(f"TOTAL MODEL EXPERIMENTS TO EXPORT FOR TASK_INDEX={task_index}: {len(exps_and_runs)} AND TOTAL RUN_IDs TO EXPORT: {total_run_ids}") #birbal added + start_time = time.time() out_dir = os.path.join(output_dir, "experiments") - exps_to_export = exp_ids if export_all_runs else exps_and_runs + + ######Birbal block + exps_and_runs = filter_unprocessed_objects(checkpoint_dir_experiment,"experiments",exps_and_runs) + _logger.info(f"AFTER FILTERING OUT THE PROCESSED EXPERIMENTS FROM CHECKPOINT, REMAINING EXPERIMENTS COUNT TO BE PROCESSED: {len(exps_and_runs)} ") #birbal added + ###### + + # if len(exps_and_runs) == 0: + # _logger.info("NO MODEL EXPERIMENTS TO EXPORT") + # return + + res_exps = export_experiments.export_experiments( mlflow_client = mlflow_client, - experiments = exps_to_export, + experiments = exps_and_runs, #birbal added output_dir = out_dir, export_permissions = export_permissions, export_deleted_runs = export_deleted_runs, notebook_formats = notebook_formats, - use_threads = use_threads + use_threads = use_threads, + task_index = task_index, #birbal added + checkpoint_dir_experiment = checkpoint_dir_experiment #birbal added ) res_models = _export_models( mlflow_client, @@ -79,7 +100,10 @@ def export_models( export_latest_versions = export_latest_versions, export_version_model = export_version_model, export_permissions = export_permissions, - export_deleted_runs = export_deleted_runs + export_deleted_runs = export_deleted_runs, + task_index = task_index, #birbal + num_tasks = num_tasks, #birbal + checkpoint_dir_model = checkpoint_dir_model #birbal ) duration = round(time.time()-start_time, 1) _logger.info(f"Duration for total registered models and versions' runs export: {duration} seconds") @@ -112,60 +136,84 @@ def _export_models( export_latest_versions = False, export_version_model = False, export_permissions = False, - export_deleted_runs = False + export_deleted_runs = False, + task_index = None, + num_tasks = None, + checkpoint_dir_model = None ): max_workers = utils.get_threads(use_threads) start_time = time.time() - model_names = bulk_utils.get_model_names(mlflow_client, model_names) + model_names = bulk_utils.get_model_names(mlflow_client, model_names, task_index, num_tasks) + _logger.info(f"TOTAL MODELS TO EXPORT: {len(model_names)}") #birbal added _logger.info("Models to export:") for model_name in model_names: _logger.info(f" {model_name}") futures = [] - with ThreadPoolExecutor(max_workers=max_workers) as executor: - for model_name in model_names: - dir = os.path.join(output_dir, model_name) - future = executor.submit(export_model, - model_name = model_name, - output_dir = dir, - stages = stages, - export_latest_versions = export_latest_versions, - export_version_model = export_version_model, - export_permissions = export_permissions, - export_deleted_runs = export_deleted_runs, - notebook_formats = notebook_formats, - mlflow_client = mlflow_client, - ) - futures.append(future) - ok_models = [] ; failed_models = [] - for future in futures: - result = future.result() - if result[0]: ok_models.append(result[1]) - else: failed_models.append(result[1]) - duration = round(time.time()-start_time, 1) - info_attr = { - "model_names": model_names, - "stages": stages, - "export_latest_versions": export_latest_versions, - "notebook_formats": notebook_formats, - "use_threads": use_threads, - "output_dir": output_dir, - "num_total_models": len(model_names), - "num_ok_models": len(ok_models), - "num_failed_models": len(failed_models), - "duration": duration, - "failed_models": failed_models - } - mlflow_attr = { - "models": ok_models, - } - io_utils.write_export_file(output_dir, "models.json", __file__, mlflow_attr, info_attr) + ######## birbal new block + model_names = filter_unprocessed_objects(checkpoint_dir_model,"models",model_names) + _logger.info(f"AFTER FILTERING OUT THE PROCESSED MODELS FROM CHECKPOINT, TOTAL REMAINING COUNT: {len(model_names)}") + result_queue = Queue() + checkpoint_thread = CheckpointThread(result_queue, checkpoint_dir_model, interval=300, batch_size=100) + _logger.info(f"checkpoint_thread started for models") + checkpoint_thread.start() + ######## - _logger.info(f"{len(model_names)} models exported") - _logger.info(f"Duration for registered models export: {duration} seconds") + try: #birbal added + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for model_name in model_names: + dir = os.path.join(output_dir, model_name) + future = executor.submit(export_model, + model_name = model_name, + output_dir = dir, + stages = stages, + export_latest_versions = export_latest_versions, + export_version_model = export_version_model, + export_permissions = export_permissions, + export_deleted_runs = export_deleted_runs, + notebook_formats = notebook_formats, + mlflow_client = mlflow_client, + result_queue = result_queue #birbal added + ) + futures.append(future) + ok_models = [] ; failed_models = [] + for future in futures: + result = future.result() + if result[0]: ok_models.append(result[1]) + else: failed_models.append(result[1]) + duration = round(time.time()-start_time, 1) - return info_attr + info_attr = { + "model_names": model_names, + "stages": stages, + "export_latest_versions": export_latest_versions, + "notebook_formats": notebook_formats, + "use_threads": use_threads, + "output_dir": output_dir, + "num_total_models": len(model_names), + "num_ok_models": len(ok_models), + "num_failed_models": len(failed_models), + "duration": duration, + "failed_models": failed_models + } + mlflow_attr = { + "models": ok_models, + } + io_utils.write_export_file(output_dir, "models.json", __file__, mlflow_attr, info_attr) + + _logger.info(f"{len(model_names)} models exported") + _logger.info(f"Duration for registered models export: {duration} seconds") + + return info_attr + + except Exception as e: #birbal added + _logger.error(f"export_model failed: {e}") + + finally: #birbal added + checkpoint_thread.stop() + checkpoint_thread.join() + _logger.info("Checkpoint thread flushed and terminated for models") @click.command() diff --git a/mlflow_export_import/bulk/import_experiments.py b/mlflow_export_import/bulk/import_experiments.py index 100ea9b8..4fb10f18 100644 --- a/mlflow_export_import/bulk/import_experiments.py +++ b/mlflow_export_import/bulk/import_experiments.py @@ -30,7 +30,8 @@ def import_experiments( use_src_user_id = False, experiment_renames = None, use_threads = False, - mlflow_client = None + mlflow_client = None, + notebook_user_mapping = None #birbal ): """ :param input_dir: Source experiment directory. @@ -47,12 +48,19 @@ def import_experiments( """ experiment_renames = rename_utils.get_renames(experiment_renames) + mlflow_client = mlflow_client or mlflow.MlflowClient() - dct = io_utils.read_file_mlflow(os.path.join(input_dir, "experiments.json")) + + try: #birbal + dct = io_utils.read_file_mlflow(os.path.join(input_dir, "experiments.json")) + except Exception as e: + _logger.info(f"'experiments.json' does not exist in {input_dir}. NO EXPERIMENTS TO IMPORT") + return [] + exps = dct["experiments"] _logger.info("Importing experiments:") for exp in exps: - _logger.info(f" Importing experiment: {exp}") + _logger.info(f"Importing experiment: {exp}") max_workers = utils.get_threads(use_threads) futures = [] @@ -67,7 +75,8 @@ def import_experiments( import_permissions, import_source_tags, use_src_user_id, - experiment_renames + experiment_renames, + notebook_user_mapping #birbal ) futures.append([exp["id"], run_info_map]) return [ (f[0], f[1].result()) for f in futures ] # materialize the future @@ -79,7 +88,8 @@ def _import_experiment(mlflow_client, import_permissions, import_source_tags, use_src_user_id, - experiment_renames + experiment_renames, + notebook_user_mapping ): """ :return: @@ -87,14 +97,17 @@ def _import_experiment(mlflow_client, - None if error happened """ try: + _logger.info(f"EXPERIMENT BEFORE RENAME {exp_name} ") # birbal exp_name = rename_utils.rename(exp_name, experiment_renames, "experiment") + _logger.info(f"EXPERIMENT AFTER RENAME {exp_name} ") # birbal run_info_map = import_experiment( mlflow_client = mlflow_client, experiment_name = exp_name, input_dir = input_dir, import_permissions = import_permissions, import_source_tags = import_source_tags, - use_src_user_id = use_src_user_id + use_src_user_id = use_src_user_id, + notebook_user_mapping = notebook_user_mapping #birbal ) return run_info_map except Exception as e: diff --git a/mlflow_export_import/bulk/import_models.py b/mlflow_export_import/bulk/import_models.py index 0859ddc9..008ec398 100644 --- a/mlflow_export_import/bulk/import_models.py +++ b/mlflow_export_import/bulk/import_models.py @@ -39,20 +39,27 @@ def import_models( model_renames = None, verbose = False, use_threads = False, - mlflow_client = None + mlflow_client = None, + target_model_catalog = None, #birbal added + target_model_schema = None, #birbal added + notebook_user_mapping_file = None #birbal added ): mlflow_client = mlflow_client or create_mlflow_client() + experiment_renames_original = experiment_renames #birbal experiment_renames = rename_utils.get_renames(experiment_renames) model_renames = rename_utils.get_renames(model_renames) + notebook_user_mapping = rename_utils.get_renames(notebook_user_mapping_file) #birbal start_time = time.time() exp_run_info_map, exp_info = _import_experiments( mlflow_client, input_dir, - experiment_renames, + # experiment_renames, + experiment_renames_original, #birbal import_permissions, import_source_tags, use_src_user_id, - use_threads + use_threads, + notebook_user_mapping #birbal ) run_info_map = _flatten_run_info_map(exp_run_info_map) model_res = _import_models( @@ -65,7 +72,9 @@ def import_models( model_renames, experiment_renames, verbose, - use_threads + use_threads, + target_model_catalog, #birbal added + target_model_schema #birbal added ) duration = round(time.time()-start_time, 1) dct = { "duration": duration, "experiments_import": exp_info, "models_import": model_res } @@ -88,7 +97,8 @@ def _import_experiments(mlflow_client, import_permissions, import_source_tags, use_src_user_id, - use_threads + use_threads, + notebook_user_mapping ): start_time = time.time() @@ -99,7 +109,8 @@ def _import_experiments(mlflow_client, use_src_user_id = use_src_user_id, experiment_renames = experiment_renames, use_threads = use_threads, - mlflow_client = mlflow_client + mlflow_client = mlflow_client, + notebook_user_mapping = notebook_user_mapping #birbal ) duration = round(time.time()-start_time, 1) @@ -132,7 +143,9 @@ def _import_models(mlflow_client, model_renames, experiment_renames, verbose, - use_threads + use_threads, + target_model_catalog = None, #birbal added + target_model_schema = None #birbal added ): max_workers = utils.get_threads(use_threads) start_time = time.time() @@ -140,6 +153,11 @@ def _import_models(mlflow_client, models_dir = os.path.join(input_dir, "models") models = io_utils.read_file_mlflow(os.path.join(models_dir,"models.json")) model_names = models["models"] + + if len(model_names) == 0: + _logger.warning(f"No models found in {os.path.join(models_dir,"models.json")}. NO MODELS TO IMPORT") + return {} + all_importer = BulkModelImporter( mlflow_client = mlflow_client, run_info_map = run_info_map, @@ -150,8 +168,14 @@ def _import_models(mlflow_client, with ThreadPoolExecutor(max_workers=max_workers) as executor: for model_name in model_names: + _logger.info(f"model name BEFORE rename : '{model_name}'") #birbal added dir = os.path.join(models_dir, model_name) model_name = rename_utils.rename(model_name, model_renames, "model") + + if target_model_catalog is not None and target_model_schema is not None: #birbal added + model_name=rename_utils.build_full_model_name(target_model_catalog, target_model_schema, model_name) + _logger.info(f"model name AFTER rename : '{model_name}'") #birbal added + executor.submit(all_importer.import_model, model_name = model_name, input_dir = dir, diff --git a/mlflow_export_import/bulk/model_utils.py b/mlflow_export_import/bulk/model_utils.py index cde43a6c..c6699b4b 100644 --- a/mlflow_export_import/bulk/model_utils.py +++ b/mlflow_export_import/bulk/model_utils.py @@ -7,29 +7,64 @@ _logger = utils.getLogger(__name__) -def get_experiments_runs_of_models(client, model_names, show_experiments=False, show_runs=False): +def get_experiments_runs_of_models(client, model_names, task_index=None, num_tasks=None, show_experiments=False, show_runs=False): """ Get experiments and runs to to export. """ - model_names = bulk_utils.get_model_names(client, model_names) - _logger.info(f"{len(model_names)} Models:") + model_names = bulk_utils.get_model_names(client, model_names, task_index, num_tasks) + _logger.info(f"TOTAL MODELS TO EXPORT FOR TASK_INDEX={task_index} : {len(model_names)}") for model_name in model_names: _logger.info(f" {model_name}") exps_and_runs = {} for model_name in model_names: - versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") + versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'" ) + # versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal.Changed from "name='{model_name}'" to handle models name with single quote for vr in versions: try: run = client.get_run(vr.run_id) exps_and_runs.setdefault(run.info.experiment_id,[]).append(run.info.run_id) - except mlflow.exceptions.MlflowException as e: - if e.error_code == "RESOURCE_DOES_NOT_EXIST": - _logger.warning(f"run '{vr.run_id}' of version {vr.version} of model '{model_name}' does not exist") - else: - _logger.warning(f"run '{vr.run_id}' of version {vr.version} of model '{model_name}': Error.code: {e.error_code}. Error.message: {e.message}") + except Exception as e: #birbal added + _logger.warning(f"Error with run '{vr.run_id}' of version {vr.version} of model '{model_name}': Error: {e}") + if show_experiments: show_experiments_runs_of_models(exps_and_runs, show_runs) return exps_and_runs +def get_experiment_runs_dict_from_names(client, experiment_names): #birbal added entire function + experiment_runs_dict = {} + for name in experiment_names: + experiment = client.get_experiment_by_name(name) + if experiment is not None: + experiment_id = experiment.experiment_id + runs = client.search_runs(experiment_ids=[experiment_id], max_results=1000) + run_ids = [run.info.run_id for run in runs] + experiment_runs_dict[experiment_id] = run_ids + else: + _logger.info(f"Experiment not found: {name}... in bulk->model_utils.py") + + return experiment_runs_dict + + +def get_experiments_name_of_models(client, model_names): + """ Get experiments name to export. """ + model_names = bulk_utils.get_model_names(client, model_names) + experiment_name_list = [] + for model_name in model_names: + versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") + # versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal. Fix for models name with single quote + for vr in versions: + try: + run = client.get_run(vr.run_id) + experiment_id = run.info.experiment_id + experiment = mlflow.get_experiment(experiment_id) + experiment_name = experiment.name + experiment_name_list.append(experiment_name) + except Exception as e: + _logger.warning(f"run '{vr.run_id}' of version {vr.version} of model '{model_name}': Error: {e}") + + return experiment_name_list + + + def show_experiments_runs_of_models(exps_and_runs, show_runs=False): _logger.info("Experiments for models:") for k,v in exps_and_runs.items(): diff --git a/mlflow_export_import/bulk/rename_utils.py b/mlflow_export_import/bulk/rename_utils.py index 4b381032..a8bcc77e 100644 --- a/mlflow_export_import/bulk/rename_utils.py +++ b/mlflow_export_import/bulk/rename_utils.py @@ -16,8 +16,10 @@ def read_rename_file(path): def rename(name, replacements, object_name="object"): if not replacements: - return name + return name ## birbal :: corrected to return name instead of None. returning None will cause failure for k,v in replacements.items(): + if object_name == "notebook": #birbal added + k = k.removeprefix("/Workspace") if k != "" and name.startswith(k): new_name = name.replace(k,v) _logger.info(f"Renaming {object_name} '{name}' to '{new_name}'") @@ -25,6 +27,7 @@ def rename(name, replacements, object_name="object"): return name + def get_renames(filename_or_dict): if filename_or_dict is None: return None @@ -34,3 +37,9 @@ def get_renames(filename_or_dict): return filename_or_dict else: raise MlflowExportImportException(f"Unknown name replacement type '{type(filename_or_dict)}'", http_status_code=400) + +def build_full_model_name(catalog, schema, model_name): #birbal added + if model_name.count('.') == 2: + return model_name + else: + return f"{catalog}.{schema}.{model_name}" diff --git a/mlflow_export_import/client/client_utils.py b/mlflow_export_import/client/client_utils.py index d8dacbf8..f412bbbd 100644 --- a/mlflow_export_import/client/client_utils.py +++ b/mlflow_export_import/client/client_utils.py @@ -1,5 +1,6 @@ import mlflow from . http_client import HttpClient, MlflowHttpClient, DatabricksHttpClient +from mlflow_export_import.bulk import config def create_http_client(mlflow_client, model_name=None): @@ -23,14 +24,27 @@ def create_dbx_client(mlflow_client): return DatabricksHttpClient(creds.host, creds.token) -def create_mlflow_client(): +# def create_mlflow_client(): ##birbal . This is original block. commented out +# """ +# Create MLflowClient. If MLFLOW_TRACKING_URI is UC, then set MlflowClient.tracking_uri to the non-UC variant. +# """ +# registry_uri = mlflow.get_registry_uri() +# if registry_uri: +# tracking_uri = mlflow.get_tracking_uri() +# nonuc_tracking_uri = tracking_uri.replace("databricks-uc","databricks") # NOTE: legacy +# return mlflow.MlflowClient(nonuc_tracking_uri, registry_uri) +# else: +# return mlflow.MlflowClient() + + +def create_mlflow_client(): ##birbal added. Modified version of above. """ Create MLflowClient. If MLFLOW_TRACKING_URI is UC, then set MlflowClient.tracking_uri to the non-UC variant. """ - registry_uri = mlflow.get_registry_uri() - if registry_uri: - tracking_uri = mlflow.get_tracking_uri() - nonuc_tracking_uri = tracking_uri.replace("databricks-uc","databricks") # NOTE: legacy - return mlflow.MlflowClient(nonuc_tracking_uri, registry_uri) - else: - return mlflow.MlflowClient() + target_model_registry=config.target_model_registry # Birbal- this is set at Import_Registered_Models.py + + if not target_model_registry or target_model_registry.lower() == "workspace_registry": + mlflow.set_registry_uri('databricks') + elif target_model_registry.lower() == "unity_catalog": + mlflow.set_registry_uri('databricks-uc') + return mlflow.MlflowClient() \ No newline at end of file diff --git a/mlflow_export_import/common/checkpoint_thread.py b/mlflow_export_import/common/checkpoint_thread.py new file mode 100644 index 00000000..18e87ce6 --- /dev/null +++ b/mlflow_export_import/common/checkpoint_thread.py @@ -0,0 +1,119 @@ +import threading +import time +from datetime import datetime +import os +import pandas as pd +import logging +import pyarrow.dataset as ds +from mlflow_export_import.common import utils +from mlflow_export_import.common import filesystem as _fs +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() + +_logger = utils.getLogger(__name__) + +class CheckpointThread(threading.Thread): #birbal added + def __init__(self, queue, checkpoint_dir, interval=300, batch_size=100): + super().__init__() + self.queue = queue + self.checkpoint_dir = checkpoint_dir + self.interval = interval + self.batch_size = batch_size + self._stop_event = threading.Event() + self._buffer = [] + self._last_flush_time = time.time() + + + def run(self): + max_drain_batch = 50 # Max items to pull per loop iteration + while not self._stop_event.is_set() or not self.queue.empty(): + items_fetched = False + drain_count = 0 + + try: + while not self.queue.empty(): + _logger.debug(f"drain_count is {drain_count} and buffer len is {len(self._buffer)}") + item = self.queue.get() + self._buffer.append(item) + drain_count += 1 + if drain_count > max_drain_batch: + _logger.info(f" drain_count > max_drain_batch is TRUE") + items_fetched = True + break + + except Exception: + pass # Queue is empty or bounded + + if items_fetched: + _logger.info(f"[Checkpoint] Fetched {drain_count} items from queue.") + + time_since_last_flush = time.time() - self._last_flush_time + if len(self._buffer) >= self.batch_size or time_since_last_flush >= self.interval: + _logger.info(f"ready to flush to delta") + self.flush_to_delta() + self._buffer.clear() + self._last_flush_time = time.time() + + # Final flush + if self._buffer: + self.flush_to_delta() + self._buffer.clear() + + + + def flush_to_delta(self): + _logger.info(f"flush_to_delta called") + try: + df = pd.DataFrame(self._buffer) + if df.empty: + _logger.info(f"[Checkpoint] DataFrame is empty. Skipping write to {self.checkpoint_dir}") + return + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + file_path = os.path.join(self.checkpoint_dir, f"checkpoint_{timestamp}.parquet") + df.to_parquet(file_path, index=False) + _logger.info(f"[Checkpoint] Saved len(df) {len(df)} records to {file_path}") + + except Exception as e: + _logger.error(f"[Checkpoint] Failed to write to {self.checkpoint_dir}: {e}", exc_info=True) + + def stop(self): + self._stop_event.set() + _logger.info("STOP event called.") + + @staticmethod + def load_processed_objects(checkpoint_dir, object_type= None): + try: + dataset = ds.dataset(checkpoint_dir, format="parquet") + df = dataset.to_table().to_pandas() + result_list = [] + + if df.empty: + _logger.warning(f"[Checkpoint] Parquet data is empty in {checkpoint_dir}") + return {} + + if object_type == "experiments": + result_list = df["experiment_id"].dropna().unique().tolist() + + if object_type == "models": + result_list = df["model"].dropna().unique().tolist() + + return result_list + + except Exception as e: + _logger.warning(f"[Checkpoint] Failed to load checkpoint data from {checkpoint_dir}: {e}", exc_info=True) + return None + +def filter_unprocessed_objects(checkpoint_dir,object_type,to_be_processed_objects): #birbal added + processed_objects = CheckpointThread.load_processed_objects(checkpoint_dir,object_type) + if isinstance(to_be_processed_objects, dict): + unprocessed_objects = {k: v for k, v in to_be_processed_objects.items() if k not in processed_objects} + return unprocessed_objects + + if isinstance(to_be_processed_objects, list): + unprocessed_objects = list(set(to_be_processed_objects) - set(processed_objects)) + return unprocessed_objects + + return None + + \ No newline at end of file diff --git a/mlflow_export_import/common/default_logging_config.py b/mlflow_export_import/common/default_logging_config.py index 67a2782e..881c5d6f 100644 --- a/mlflow_export_import/common/default_logging_config.py +++ b/mlflow_export_import/common/default_logging_config.py @@ -9,7 +9,7 @@ "handlers": { "console": { "class": "logging.StreamHandler", - "level": "DEBUG", + "level": "INFO", "formatter": "simple", "stream": "ext://sys.stdout" }, @@ -22,7 +22,7 @@ }, "loggers": { "sampleLogger": { - "level": "DEBUG", + "level": "INFO", "handlers": [ "console" ], @@ -30,7 +30,7 @@ } }, "root": { - "level": "DEBUG", + "level": "INFO", "handlers": [ "console", "file" diff --git a/mlflow_export_import/common/logging_utils.py b/mlflow_export_import/common/logging_utils.py index b656d870..efd038c3 100644 --- a/mlflow_export_import/common/logging_utils.py +++ b/mlflow_export_import/common/logging_utils.py @@ -1,6 +1,8 @@ import os import yaml import logging.config +from mlflow_export_import.bulk import config #birbal added +log_path=config.log_path #birbal added _have_loaded_logging_config = False @@ -10,8 +12,10 @@ def get_logger(name): return logging.getLogger(name) config_path = os.environ.get("MLFLOW_EXPORT_IMPORT_LOG_CONFIG_FILE", None) - output_path = os.environ.get("MLFLOW_EXPORT_IMPORT_LOG_OUTPUT_FILE", None) - log_format = os.environ.get("MLFLOW_EXPORT_IMPORT_LOG_FORMAT", None) + output_path = os.environ.get("MLFLOW_EXPORT_IMPORT_LOG_OUTPUT_FILE", log_path) + log_format = os.environ.get("MLFLOW_EXPORT_IMPORT_LOG_FORMAT", "%(asctime)s - %(levelname)s - [%(name)s:%(lineno)d] - %(message)s") #birbal updated + + #print(f"logging_utils.get_logger: config_path: {config_path}") #print(f"logging_utils.get_logger: output_path: {output_path}") #print(f"logging_utils.get_logger: log_format: {log_format}") diff --git a/mlflow_export_import/common/mlflow_utils.py b/mlflow_export_import/common/mlflow_utils.py index d0840969..4f0eeccf 100644 --- a/mlflow_export_import/common/mlflow_utils.py +++ b/mlflow_export_import/common/mlflow_utils.py @@ -31,6 +31,9 @@ def set_experiment(mlflow_client, dbx_client, exp_name, tags=None): if not exp_name.startswith("/"): raise MlflowExportImportException(f"Cannot create experiment '{exp_name}'. Databricks experiment must start with '/'.") create_workspace_dir(dbx_client, os.path.dirname(exp_name)) + + else: ##birbal + _logger.error("utils.calling_databricks is false") try: if not tags: tags = {} tags = utils.create_mlflow_tags_for_databricks_import(tags) diff --git a/mlflow_export_import/common/model_utils.py b/mlflow_export_import/common/model_utils.py index 227a6310..93e6ccbe 100644 --- a/mlflow_export_import/common/model_utils.py +++ b/mlflow_export_import/common/model_utils.py @@ -26,6 +26,11 @@ def model_names_same_registry(name1, name2): not is_unity_catalog_model(name1) and not is_unity_catalog_model(name2) +def model_names_same_registry_nonucsrc_uctgt(name1, name2): + return \ + not is_unity_catalog_model(name1) and is_unity_catalog_model(name2) + + def create_model(client, model_name, model_dct, import_metadata): """ Creates a registered model if it does not exist, and returns the model in either case. @@ -38,6 +43,9 @@ def create_model(client, model_name, model_dct, import_metadata): client.create_registered_model(model_name) _logger.info(f"Created new registered model '{model_name}'") return True + except Exception as e: + # _logger.info(f"except Exception trigger, error for '{model_name}': {e}") + _logger.error(f"FAILED TO CREATE MODEL: '{model_name}': ERROR- {e}") #birbal except RestException as e: if e.error_code != "RESOURCE_ALREADY_EXISTS": raise e @@ -51,6 +59,7 @@ def delete_model(client, model_name, sleep_time=5): """ try: versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") + # versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal added _logger.info(f"Deleting model '{model_name}' and its versions") for vr in versions: msg = utils.get_obj_key_values(vr, [ "name", "version", "current_stage", "status", "run_id" ]) @@ -60,8 +69,10 @@ def delete_model(client, model_name, sleep_time=5): time.sleep(sleep_time) # Wait until stage transition takes hold client.delete_model_version(model_name, vr.version) client.delete_registered_model(model_name) - except RestException: - pass + # except RestException: #birbal commented out + except Exception as e: + _logger.error(f"Error deleting modfel {model_name}. Error: {e}") + def list_model_versions(client, model_name, get_latest_versions=False): @@ -70,6 +81,7 @@ def list_model_versions(client, model_name, get_latest_versions=False): """ if is_unity_catalog_model(model_name): versions = SearchModelVersionsIterator(client, filter=f"name='{model_name}'") + # versions = SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """) #birbal added # JIRA: ES-834105 - UC-ML MLflow search_registered_models and search_model_versions do not return tags and aliases - 2023-08-21 return [ client.get_model_version(vr.name, vr.version) for vr in versions ] else: @@ -77,6 +89,7 @@ def list_model_versions(client, model_name, get_latest_versions=False): return client.get_latest_versions(model_name) else: return list(SearchModelVersionsIterator(client, filter=f"name='{model_name}'")) + # return list(SearchModelVersionsIterator(client, filter=f""" name="{model_name}" """)) #birbal added def search_model_versions(client, filter): @@ -201,11 +214,13 @@ def get_registered_model(mlflow_client, model_name, get_permissions=False): return model -def update_model_permissions(mlflow_client, dbx_client, model_name, perms): +def update_model_permissions(mlflow_client, dbx_client, model_name, perms, nonucsrc_uctgt = False): #birbal added nonucsrc_uctgt parameter if perms: _logger.info(f"Updating permissions for registered model '{model_name}'") - if is_unity_catalog_model(model_name): + if is_unity_catalog_model(model_name) and not nonucsrc_uctgt: #birbal added uc_permissions_utils.update_permissions(mlflow_client, model_name, perms) + elif is_unity_catalog_model(model_name) and nonucsrc_uctgt: #birbal added + uc_permissions_utils.update_permissions_nonucsrc_uctgt(mlflow_client, model_name, perms) else: _model = dbx_client.get("mlflow/databricks/registered-models/get", { "name": model_name }) _model = _model["registered_model_databricks"] diff --git a/mlflow_export_import/common/uc_permissions_utils.py b/mlflow_export_import/common/uc_permissions_utils.py index 2812e04d..0ed2a7ea 100644 --- a/mlflow_export_import/common/uc_permissions_utils.py +++ b/mlflow_export_import/common/uc_permissions_utils.py @@ -26,13 +26,23 @@ def get_effective_permissions(self, model_name): resource = f"unity-catalog/effective-permissions/function/{model_name}" return self.client.get(resource) - def update_permissions(self, model_name, changes): + # def update_permissions(self, model_name, changes): #birbal commented out entire func def + # """ + # https://docs.databricks.com/api/workspace/grants/update + # PATCH /api/2.1/unity-catalog/permissions/{securable_type}/{full_name} + # """ + # resource = f"unity-catalog/permissions/function/{model_name}" + # _logger.info(f"Updating {len(changes.get('changes',[]))} permissions for model '{model_name}'. Resource: {resource}") + # return self.client.patch(resource, changes) + + + def update_permissions(self, object_type, object_name , changes): #birbal modified the above block """ https://docs.databricks.com/api/workspace/grants/update PATCH /api/2.1/unity-catalog/permissions/{securable_type}/{full_name} """ - resource = f"unity-catalog/permissions/function/{model_name}" - _logger.info(f"Updating {len(changes.get('changes',[]))} permissions for model '{model_name}'. Resource: {resource}") + resource = f"unity-catalog/permissions/{object_type}/{object_name}" + _logger.info(f"Updating {len(changes.get('changes',[]))} permissions for {object_type} '{object_name}'. Resource: {resource}") return self.client.patch(resource, changes) @@ -51,7 +61,77 @@ def get_permissions(mlflow_client, model_name): return {} -def update_permissions(mlflow_client, model_name, perms, unroll_changes=True): + +def update_permissions_nonucsrc_uctgt(mlflow_client, model_name, perms): ##birbal added this entire func. + + try: + _logger.info(f"BEFORE perms is {perms}") + uc_client = UcPermissionsClient(mlflow_client) + model_perm_dict, catalog_perm_dict, schema_perm_dict = format_perms(perms) + _logger.info(f"AFTER model_perm_dict is {model_perm_dict}") + _logger.info(f"AFTER catalog_perm_dict is {catalog_perm_dict}") + _logger.info(f"AFTER schema_perm_dict is {schema_perm_dict}") + + catalog, schema, model = model_name.split(".") + + uc_client.update_permissions("catalog", catalog, catalog_perm_dict) + uc_client.update_permissions("schema", catalog+"."+schema, schema_perm_dict) + uc_client.update_permissions("function", model_name, model_perm_dict) + except Exception as e: + _logger.error(f"error with update_permissions for model '{model_name}'. Error: {e}") + + +def format_perms(perms): ##birbal added this entire func. + model_perm = [] + catalog_perm=[] + schema_perm=[] + + for acl in perms['permissions']['access_control_list']: + permission_type = "EXECUTE" + for perm in acl['all_permissions']: + if perm.get('permission_level') == 'CAN_MANAGE': + permission_type = "MANAGE" + break + if 'user_name' in acl: + model_perm.append({ + "add": [permission_type], + "principal": acl['user_name'] + }) + catalog_perm.append({ + "add": ["USE_CATALOG"], + "principal": acl['user_name'] + }) + schema_perm.append({ + "add": ["USE_SCHEMA"], + "principal": acl['user_name'] + }) + + if 'group_name' in acl: + group_name = acl['group_name'] + if group_name == "admins": + continue + model_perm.append({ + "add": [permission_type], + "principal": group_name + }) + catalog_perm.append({ + "add": ["USE_CATALOG"], + "principal": group_name + }) + schema_perm.append({ + "add": ["USE_SCHEMA"], + "principal": group_name + }) + + model_perm_dict = {"changes": model_perm} + catalog_perm_dict = {"changes": catalog_perm} + schema_perm_dict = {"changes": schema_perm} + return model_perm_dict,catalog_perm_dict,schema_perm_dict + + + + +def update_permissions(mlflow_client, model_name, perms, unroll_changes=True): uc_client = UcPermissionsClient(mlflow_client) changes = _mk_update_changes(perms) if unroll_changes: # NOTE: in order to prevent batch update to fail because one individual update failed diff --git a/mlflow_export_import/common/ws_permissions_utils.py b/mlflow_export_import/common/ws_permissions_utils.py index 30de5828..b7f5e963 100644 --- a/mlflow_export_import/common/ws_permissions_utils.py +++ b/mlflow_export_import/common/ws_permissions_utils.py @@ -32,7 +32,7 @@ def _call_get(dbx_client, resource): try: return dbx_client.get(resource) except MlflowExportImportException as e: - _logger.error(e.kwargs) + _logger.error(f"Error while retrieving permissions with endpoint {resource}. Most probably due to notebook scoped experiment. Valid experiment type (mlflow.experimentType) should be MLFLOW_EXPERIMENT, not NOTEBOOK. Verify experimentType using SDK. Error: {e.kwargs}") #birbal added return {} diff --git a/mlflow_export_import/experiment/export_experiment.py b/mlflow_export_import/experiment/export_experiment.py index ac1656e4..308e7aad 100644 --- a/mlflow_export_import/experiment/export_experiment.py +++ b/mlflow_export_import/experiment/export_experiment.py @@ -35,7 +35,8 @@ def export_experiment( export_deleted_runs = False, check_nested_runs = False, notebook_formats = None, - mlflow_client = None + mlflow_client = None, + result_queue = None #birbal added ): """ :param: experiment_id_or_name: Experiment ID or name. @@ -81,7 +82,7 @@ def export_experiment( for run in runs: _export_run(mlflow_client, run, output_dir, ok_run_ids, failed_run_ids, - run_start_time, run_start_time_str, export_deleted_runs, notebook_formats) + run_start_time, run_start_time_str, export_deleted_runs, notebook_formats, result_queue) #birbal added result_queue num_runs_exported += 1 info_attr = { @@ -114,7 +115,7 @@ def export_experiment( def _export_run(mlflow_client, run, output_dir, ok_run_ids, failed_run_ids, run_start_time, run_start_time_str, - export_deleted_runs, notebook_formats + export_deleted_runs, notebook_formats, result_queue = None #birbal added result_queue ): if run_start_time and run.info.start_time < run_start_time: msg = { @@ -123,14 +124,15 @@ def _export_run(mlflow_client, run, output_dir, "start_time": fmt_ts_millis(run.info.start_time), "run_start_time": run_start_time_str } - _logger.info(f"Not exporting run: {msg}") + _logger.info(f"Not exporting run: {msg} as run.info.start_time < run_start_time ") #birbal updated return is_success = export_run( run_id = run.info.run_id, output_dir = os.path.join(output_dir, run.info.run_id), export_deleted_runs = export_deleted_runs, notebook_formats = notebook_formats, - mlflow_client = mlflow_client + mlflow_client = mlflow_client, + result_queue = result_queue #birbal added ) if is_success: ok_run_ids.append(run.info.run_id) diff --git a/mlflow_export_import/experiment/import_experiment.py b/mlflow_export_import/experiment/import_experiment.py index 72a14c5a..a062dced 100644 --- a/mlflow_export_import/experiment/import_experiment.py +++ b/mlflow_export_import/experiment/import_experiment.py @@ -33,7 +33,8 @@ def import_experiment( import_permissions = False, use_src_user_id = False, dst_notebook_dir = None, - mlflow_client = None + mlflow_client = None, + notebook_user_mapping = None #birbal ): """ :param experiment_name: Destination experiment name. @@ -85,7 +86,9 @@ def import_experiment( input_dir = os.path.join(input_dir, src_run_id), dst_notebook_dir = dst_notebook_dir, import_source_tags = import_source_tags, - use_src_user_id = use_src_user_id + use_src_user_id = use_src_user_id, + exp = exp, #birbal added + notebook_user_mapping = notebook_user_mapping #birbal ) dst_run_id = dst_run.info.run_id run_ids_map[src_run_id] = { "dst_run_id": dst_run_id, "src_parent_run_id": src_parent_run_id } diff --git a/mlflow_export_import/model/export_model.py b/mlflow_export_import/model/export_model.py index f965be7e..4be3f64e 100644 --- a/mlflow_export_import/model/export_model.py +++ b/mlflow_export_import/model/export_model.py @@ -23,6 +23,7 @@ from mlflow_export_import.common.timestamp_utils import adjust_timestamps from mlflow_export_import.common import MlflowExportImportException from mlflow_export_import.run.export_run import export_run +import ast #birbal added _logger = utils.getLogger(__name__) @@ -47,7 +48,8 @@ def export_model( export_permissions = False, export_deleted_runs = False, notebook_formats = None, - mlflow_client = None + mlflow_client = None, + result_queue = None #birbal added ): """ :param model_name: Registered model name. @@ -74,31 +76,39 @@ def export_model( opts = Options(stages, versions, export_latest_versions, export_deleted_runs, export_version_model, export_permissions, notebook_formats) try: - _export_model(mlflow_client, model_name, output_dir, opts) + _export_model(mlflow_client, model_name, output_dir, opts, result_queue) #birbal added result_queue return True, model_name except RestException as e: - err_msg = { "model": model_name, "RestException": e.json } + err_msg = { "model": model_name, "RestException": str(e.json) } #birbal string casted if e.json.get("error_code") == "RESOURCE_DOES_NOT_EXIST": _logger.error({ **{"message": "Model does not exist"}, **err_msg}) else: _logger.error({**{"message": "Model cannot be exported"}, **err_msg}) import traceback traceback.print_exc() + err_msg["status"] = "failed" #birbal added + result_queue.put(err_msg) #birbal added return False, model_name except Exception as e: - _logger.error({ "model": model_name, "Exception": e }) - import traceback - traceback.print_exc() + _logger.error({ "model": model_name, "Exception": str(e) }) + err_msg = { "model": model_name, "status": "failed","Exception": str(e) } #birbal string casted + result_queue.put(err_msg) #birbal added + # import traceback + # traceback.print_exc() return False, model_name -def _export_model(mlflow_client, model_name, output_dir, opts): +def _export_model(mlflow_client, model_name, output_dir, opts, result_queue = None): #birbal added result_queue ori_versions = model_utils.list_model_versions(mlflow_client, model_name, opts.export_latest_versions) + _logger.info(f"TOTAL MODELS VERSIONS TO EXPORT FOR MODEL {model_name}: {len(ori_versions)}") #birbal added + msg = "latest" if opts.export_latest_versions else "all" _logger.info(f"Exporting model '{model_name}': found {len(ori_versions)} '{msg}' versions") model = model_utils.get_registered_model(mlflow_client, model_name, opts.export_permissions) - versions, failed_versions = _export_versions(mlflow_client, model, ori_versions, output_dir, opts) + + versions, failed_versions = _export_versions(mlflow_client, model, ori_versions, output_dir, opts, result_queue) #birbal added result_queue + _adjust_model(model, versions) info_attr = { @@ -110,12 +120,23 @@ def _export_model(mlflow_client, model_name, output_dir, opts): "export_latest_versions": opts.export_latest_versions, "export_permissions": opts.export_permissions } - _model = { "registered_model": model } - io_utils.write_export_file(output_dir, "model.json", __file__, _model, info_attr) - _logger.info(f"Exported {len(versions)}/{len(ori_versions)} '{msg}' versions for model '{model_name}'") -def _export_versions(mlflow_client, model_dct, versions, output_dir, opts): + try: #birbal added + _model = { "registered_model": model } + io_utils.write_export_file(output_dir, "model.json", __file__, _model, info_attr) + _logger.info(f"Exported {len(versions)}/{len(ori_versions)} '{msg}' versions for model '{model_name}'") + except Exception as e: + ##birbal added this block to resolve ""Object of type ModelVersionDeploymentJobState is not JSON" error + model = str(model).replace("<", "\"").replace(">", "\"") + model = ast.literal_eval(model) + #birbal below end + _model = { "registered_model": model } + io_utils.write_export_file(output_dir, "model.json", __file__, _model, info_attr) + _logger.warning(f"Exported {len(versions)}/{len(ori_versions)} '{msg}' versions for model '{model_name}' AFTER applying the FIX(replaced < and > with double quote). Else it will throw this exception due to the presence of < and > in the dict value of key deployment_job_state. Exception : {str(e)} which will cause issues during MODEL IMPORT") + + +def _export_versions(mlflow_client, model_dct, versions, output_dir, opts, result_queue = None): #birbal added result_queue aliases = model_dct.get("aliases", []) version_aliases = {} [ version_aliases.setdefault(x["version"], []).append(x["alias"]) for x in aliases ] # map of version => its aliases @@ -123,15 +144,16 @@ def _export_versions(mlflow_client, model_dct, versions, output_dir, opts): output_versions, failed_versions = ([], []) for j,vr in enumerate(versions): if not model_utils.is_unity_catalog_model(model_dct["name"]) and vr.current_stage and (len(opts.stages) > 0 and not vr.current_stage.lower() in opts.stages): + _logger.warning(f"MODEL VERSION EXPORT SKIPPED. Current model stage:{vr.current_stage} does not match with Input stages passed:{opts.stages} ") #birbal continue if len(opts.versions) > 0 and not vr.version in opts.versions: continue - _export_version(mlflow_client, vr, output_dir, version_aliases.get(vr.version,[]), output_versions, failed_versions, j, len(versions), opts) + _export_version(mlflow_client, vr, output_dir, version_aliases.get(vr.version,[]), output_versions, failed_versions, j, len(versions), opts, result_queue) #birbal added result_queue output_versions.sort(key=lambda x: x["version"], reverse=False) return output_versions, failed_versions -def _export_version(mlflow_client, vr, output_dir, aliases, output_versions, failed_versions, j, num_versions, opts): +def _export_version(mlflow_client, vr, output_dir, aliases, output_versions, failed_versions, j, num_versions, opts, result_queue = None): #birbal added result_queue _output_dir = os.path.join(output_dir, vr.run_id) msg = { "name": vr.name, "version": vr.version, "stage": vr.current_stage, "aliases": aliases } _logger.info(f"Exporting model verson {j+1}/{num_versions}: {msg} to '{_output_dir}'") @@ -148,7 +170,9 @@ def _export_version(mlflow_client, vr, output_dir, aliases, output_versions, fai export_deleted_runs = opts.export_deleted_runs, notebook_formats = opts.notebook_formats, mlflow_client = mlflow_client, - raise_exception = True + raise_exception = True, + result_queue = result_queue, #birbal added + vr = vr #birbal added ) if not run and not opts.export_deleted_runs: failed_msg = { "message": "deleted run", "version": vr_dct } @@ -158,7 +182,7 @@ def _export_version(mlflow_client, vr, output_dir, aliases, output_versions, fai output_versions.append(vr_dct) except RestException as e: - err_msg = { "model": vr.name, "version": vr.version, "run_id": vr.run_id, "RestException": e.json } + err_msg = { "model": vr.name, "version": vr.version, "run_id": vr.run_id, "RestException": str(e.json) } #birbal string casted if e.json.get("error_code") == "RESOURCE_DOES_NOT_EXIST": err_msg = { **{"message": "Version run probably does not exist"}, **err_msg} _logger.error(f"Version export failed (1): {err_msg}") @@ -171,6 +195,16 @@ def _export_version(mlflow_client, vr, output_dir, aliases, output_versions, fai failed_msg = { "version": vr_dct, "RestException": e.json } failed_versions.append(failed_msg) + err_msg["status"] = "failed" #birbal added + if result_queue: + result_queue.put(err_msg) #birbal added + + except Exception as e: + err_msg = { "model": vr.name, "version": vr.version, "run_id": vr.run_id, "status":"failed", "Exception": str(e) } #birbal string casted + if result_queue: + result_queue.put(err_msg) #birbal added + + def _add_metadata_to_version(mlflow_client, vr_dct, run): vr_dct["_run_artifact_uri"] = run.info.artifact_uri diff --git a/mlflow_export_import/model/import_model.py b/mlflow_export_import/model/import_model.py index 451f0e45..a6564cba 100644 --- a/mlflow_export_import/model/import_model.py +++ b/mlflow_export_import/model/import_model.py @@ -116,11 +116,17 @@ def _import_model(self, created_model = model_utils.create_model(self.mlflow_client, model_name, model_dct, True) perms = model_dct.get("permissions") if created_model and self.import_permissions and perms: - if model_utils.model_names_same_registry(model_dct["name"], model_name): - model_utils.update_model_permissions(self.mlflow_client, self.dbx_client, model_name, perms) - else: - _logger.warning(f'Cannot import permissions since models \'{model_dct["name"]}\' and \'{model_name}\' must be either both Unity Catalog model names or both Workspace model names.') - + try: #birbal added + if model_utils.model_names_same_registry(model_dct["name"], model_name): + model_utils.update_model_permissions(self.mlflow_client, self.dbx_client, model_name, perms) + elif model_utils.model_names_same_registry_nonucsrc_uctgt(model_dct["name"], model_name): #birbal added + model_utils.update_model_permissions(self.mlflow_client, self.dbx_client, model_name, perms, True) + else: + _logger.warning(f'Cannot import permissions since models \'{model_dct["name"]}\' and \'{model_name}\' must be either both Unity Catalog model names or both Workspace model names.') + except Exception as e: #birbal added + _logger.error(f"Error updating model permission for model {model_name} . Error: {e}") + else: ##birbal added + _logger.info(f"Model permission update skipped for model {model_name}") return model_dct @@ -157,14 +163,17 @@ def import_model(self, :param verbose: Verbose. :return: Model import manifest. """ + model_dct = self._import_model(model_name, input_dir, delete_model) + _logger.info("Importing versions:") for vr in model_dct.get("versions",[]): try: run_id = self._import_run(input_dir, experiment_name, vr) if run_id: self.import_version(model_name, vr, run_id) - except RestException as e: + # except RestException as e: #birbal commented out + except Exception as e: #birbal added msg = { "model": model_name, "version": vr["version"], "src_run_id": vr["run_id"], "experiment": experiment_name, "RestException": str(e) } _logger.error(f"Failed to import model version: {msg}") import traceback @@ -267,6 +276,7 @@ def import_model(self, else: dst_run_id = dst_run_info.run_id exp_name = rename_utils.rename(vr["_experiment_name"], self.experiment_renames, "experiment") + _logger.info(f"RENAMED EXPERIMENT FROM {vr["_experiment_name"]} TO {exp_name}") # birbal try: with MlflowTrackingUriTweak(self.mlflow_client): mlflow.set_experiment(exp_name) diff --git a/mlflow_export_import/model_version/import_model_version.py b/mlflow_export_import/model_version/import_model_version.py index 6d720a53..26d8f574 100644 --- a/mlflow_export_import/model_version/import_model_version.py +++ b/mlflow_export_import/model_version/import_model_version.py @@ -86,6 +86,7 @@ def import_model_version( model_path = _get_model_path(src_vr) dst_source = f"{dst_run.info.artifact_uri}/{model_path}" + dst_vr = _import_model_version( mlflow_client, model_name = model_name, @@ -103,7 +104,7 @@ def _import_model_version( src_vr, dst_run_id, dst_source, - import_stages_and_aliases = True, + import_stages_and_aliases = True, import_source_tags = False ): start_time = time.time() @@ -119,27 +120,32 @@ def _import_model_version( # The client's tracking_uri is not honored. Instead MlflowClient.create_model_version() # seems to use mlflow.tracking_uri internally to download run artifacts for UC models. _logger.info(f"Importing model version '{model_name}'") - with MlflowTrackingUriTweak(mlflow_client): - dst_vr = mlflow_client.create_model_version( - name = model_name, - source = dst_source, - run_id = dst_run_id, - description = src_vr.get("description"), - tags = tags - ) - - if import_stages_and_aliases: - for alias in src_vr.get("aliases",[]): - mlflow_client.set_registered_model_alias(dst_vr.name, alias, dst_vr.version) - - if not model_utils.is_unity_catalog_model(model_name): - src_current_stage = src_vr["current_stage"] - if src_current_stage and src_current_stage != "None": # fails for Databricks but not OSS - mlflow_client.transition_model_version_stage(model_name, dst_vr.version, src_current_stage) - - dur = format_seconds(time.time()-start_time) - _logger.info(f"Imported model version '{model_name}/{dst_vr.version}' in {dur}") - return mlflow_client.get_model_version(dst_vr.name, dst_vr.version) + + try: #birbal added + with MlflowTrackingUriTweak(mlflow_client): + dst_vr = mlflow_client.create_model_version( + name = model_name, + source = dst_source, + run_id = dst_run_id, + description = src_vr.get("description"), + tags = tags + ) + + if import_stages_and_aliases: + for alias in src_vr.get("aliases",[]): + mlflow_client.set_registered_model_alias(dst_vr.name, alias, dst_vr.version) + + if not model_utils.is_unity_catalog_model(model_name): + src_current_stage = src_vr["current_stage"] + if src_current_stage and src_current_stage != "None": # fails for Databricks but not OSS + mlflow_client.transition_model_version_stage(model_name, dst_vr.version, src_current_stage) + + dur = format_seconds(time.time()-start_time) + _logger.info(f"Imported model version '{model_name}/{dst_vr.version}' in {dur}") + return mlflow_client.get_model_version(dst_vr.name, dst_vr.version) + + except Exception as e: + _logger.error(f"Error creating model version of {model_name}. Error: {str(e)}") def _get_model_path(src_vr): diff --git a/mlflow_export_import/notebook/download_notebook.py b/mlflow_export_import/notebook/download_notebook.py index 31cdb7a0..1ada3829 100644 --- a/mlflow_export_import/notebook/download_notebook.py +++ b/mlflow_export_import/notebook/download_notebook.py @@ -28,14 +28,15 @@ def _download_notebook(notebook_workspace_path, output_dir, format, extension, r "format": format } if revision_id: - params ["revision"] = { "revision_timestamp": revision_id } # NOTE: not publicly documented + # params ["revision"] = { "revision_timestamp": revision_id } # NOTE: not publicly documented + params ["revision.revision_timestamp"] = revision_id # Birbal. Above longer supports due to change in format. Format change was done sometime in August 2025. notebook_name = os.path.basename(notebook_workspace_path) try: rsp = dbx_client._get("workspace/export", json.dumps(params)) notebook_path = os.path.join(output_dir, f"{notebook_name}.{extension}") io_utils.write_file(notebook_path, rsp.content) except MlflowExportImportException as e: - _logger.warning(f"Cannot download notebook '{notebook_workspace_path}'. {e}") + _logger.error(f"Cannot download notebook '{notebook_workspace_path}'. {e}") @click.command() diff --git a/mlflow_export_import/proj.py b/mlflow_export_import/proj.py new file mode 100644 index 00000000..fda30490 --- /dev/null +++ b/mlflow_export_import/proj.py @@ -0,0 +1,23 @@ +from pyspark.sql import SparkSession +from pyspark.sql.functions import col, udf +from pyspark.sql.types import StringType + +import time + +df = spark.read.option("header", "true").csv("dbfs:/databricks-datasets/retail-org/customers/customers.csv") + +@udf(StringType()) +def to_upper_udf(name): + time.sleep(0.01) + return name.upper() if name else None + +df = df.withColumn("name_upper", to_upper_udf(col("customer_name"))) + +df = df.repartition(200) + +df.cache() +df = df.filter(col("state") == "CA") + +results = df.collect() + +df.write.mode("overwrite").csv("dbfs:/tmp/inefficient_output") \ No newline at end of file diff --git a/mlflow_export_import/run/export_run.py b/mlflow_export_import/run/export_run.py index c0549888..dd6a9b90 100644 --- a/mlflow_export_import/run/export_run.py +++ b/mlflow_export_import/run/export_run.py @@ -21,7 +21,7 @@ from mlflow_export_import.client.client_utils import create_mlflow_client, create_dbx_client from mlflow_export_import.notebook.download_notebook import download_notebook -from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_NOTEBOOK_PATH +from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_NOTEBOOK_PATH, MLFLOW_DATABRICKS_NOTEBOOK_ID #birbal added MLFLOW_DATABRICKS_NOTEBOOK_ID MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID = "mlflow.databricks.notebookRevisionID" # NOTE: not in mlflow/utils/mlflow_tags.py _logger = utils.getLogger(__name__) @@ -34,7 +34,9 @@ def export_run( skip_download_run_artifacts = False, notebook_formats = None, raise_exception = False, - mlflow_client = None + mlflow_client = None, + result_queue = None, #birbal addedbirbal added + vr = None # ): """ :param run_id: Run ID. @@ -56,13 +58,15 @@ def export_run( experiment_id = None try: run = mlflow_client.get_run(run_id) + dst_path = os.path.join(output_dir, "artifacts") msg = { "run_id": run.info.run_id, "dst_path": dst_path } if run.info.lifecycle_stage == "deleted" and not export_deleted_runs: _logger.warning(f"Not exporting run '{run.info.run_id} because its lifecycle_stage is '{run.info.lifecycle_stage}'") return None experiment_id = run.info.experiment_id - msg = { "run_id": run.info.run_id, "lifecycle_stage": run.info.lifecycle_stage, "experiment_id": run.info.experiment_id } + experiment = mlflow_client.get_experiment(experiment_id) + msg = { "run_id": run.info.run_id, "experiment_id": run.info.experiment_id, "experiment_name": experiment.name} #birbal removed lifecycle_stage tags = run.data.tags tags = dict(sorted(tags.items())) @@ -91,6 +95,7 @@ def export_run( run_id = run.info.run_id, dst_path = _fs.mk_local_path(dst_path), tracking_uri = mlflow_client._tracking_client.tracking_uri) + notebook = tags.get(MLFLOW_DATABRICKS_NOTEBOOK_PATH) # export notebook as artifact @@ -98,22 +103,48 @@ def export_run( if len(notebook_formats) > 0: _export_notebook(dbx_client, output_dir, notebook, notebook_formats, run, fs) elif len(notebook_formats) > 0: - _logger.warning(f"No notebooks to export for run '{run_id}' since tag '{MLFLOW_DATABRICKS_NOTEBOOK_PATH}' is not set.") + _logger.error(f"No notebooks to export for run '{run_id}' since run tag '{MLFLOW_DATABRICKS_NOTEBOOK_PATH}' is not set.") dur = format_seconds(time.time()-start_time) _logger.info(f"Exported run in {dur}: {msg}") + + msg["status"] = "success" #birbal added + if vr: + msg["model"] = vr.name #birbal added + msg["version"] = vr.version #birbal added + msg["stage"] = vr.current_stage #birbal added + if result_queue: + result_queue.put(msg) #birbal added return run except RestException as e: if raise_exception: raise e - err_msg = { "run_id": run_id, "experiment_id": experiment_id, "RestException": e.json } + err_msg = { "run_id": run_id, "experiment_id": experiment_id, "RestException": str(e.json) } #birbal string casted + + err_msg["status"] = "failed" #birbal added + if vr: + err_msg["model"] = vr.name #birbal added + err_msg["version"] = vr.version #birbal added + err_msg["stage"] = vr.current_stage #birbal added _logger.error(f"Run export failed (1): {err_msg}") + if result_queue: + result_queue.put(err_msg) #birbal added + return None except Exception as e: if raise_exception: raise e - err_msg = { "run_id": run_id, "experiment_id": experiment_id, "Exception": e } - _logger.error(f"Run export failed (2): {err_msg}") + err_msg = { "run_id": run_id, "experiment_id": experiment_id, "Exception": str(e) } #birbal string casted + + err_msg["status"] = "failed" #birbal added + if vr: + err_msg["model"] = vr.name #birbal added + err_msg["version"] = vr.version #birbal added + err_msg["stage"] = vr.current_stage #birbal added + _logger.error(f"Run export failed (2): {err_msg}") + if result_queue: + result_queue.put(err_msg) #birbal added + traceback.print_exc() return None @@ -133,9 +164,9 @@ def _export_notebook(dbx_client, output_dir, notebook, notebook_formats, run, fs notebook_dir = os.path.join(output_dir, "artifacts", "notebooks") fs.mkdirs(notebook_dir) revision_id = run.data.tags.get(MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID) - if not revision_id: - _logger.warning(f"Cannot download notebook '{notebook}' for run '{run.info.run_id}' since tag '{MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID}' does not exist. Notebook is probably a Git Repo notebook.") - return + # if not revision_id: #birbal commented out. If not, it simply skips the notebook download due to missing notebook versionID which shouldn't be the case. + # _logger.warning(f"Cannot download notebook '{notebook}' for run '{run.info.run_id}' since tag '{MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID}' does not exist. Notebook is probably a Git Repo notebook.") + # return manifest = { MLFLOW_DATABRICKS_NOTEBOOK_PATH: run.data.tags[MLFLOW_DATABRICKS_NOTEBOOK_PATH], MLFLOW_DATABRICKS_NOTEBOOK_REVISION_ID: revision_id, diff --git a/mlflow_export_import/run/import_run.py b/mlflow_export_import/run/import_run.py index ad66a1ae..193711fe 100644 --- a/mlflow_export_import/run/import_run.py +++ b/mlflow_export_import/run/import_run.py @@ -23,6 +23,10 @@ from mlflow_export_import.client.client_utils import create_mlflow_client, create_dbx_client, create_http_client from . import run_data_importer from . import run_utils +import mlflow.utils.databricks_utils as db_utils #birbal added +import requests #birbal added +import json +from mlflow_export_import.bulk import rename_utils #birbal added _logger = utils.getLogger(__name__) @@ -33,7 +37,9 @@ def import_run( dst_notebook_dir = None, use_src_user_id = False, mlmodel_fix = True, - mlflow_client = None + mlflow_client = None, + exp = None, + notebook_user_mapping = None #birbal ): """ Imports a run into the specified experiment. @@ -64,7 +70,10 @@ def _mk_ex(src_run_dct, dst_run_id, exp_name): _logger.info(f"Importing run from '{input_dir}'") - exp = mlflow_utils.set_experiment(mlflow_client, dbx_client, experiment_name) + # exp = mlflow_utils.set_experiment(mlflow_client, dbx_client, experiment_name) + if not exp: #birbal added + exp = mlflow_utils.set_experiment(mlflow_client, dbx_client, experiment_name) + src_run_path = os.path.join(input_dir, "run.json") src_run_dct = io_utils.read_file_mlflow(src_run_path) in_databricks = "DATABRICKS_RUNTIME_VERSION" in os.environ @@ -72,7 +81,7 @@ def _mk_ex(src_run_dct, dst_run_id, exp_name): run = mlflow_client.create_run(exp.experiment_id) run_id = run.info.run_id try: - run_data_importer.import_run_data( + run_data_importer.import_run_data( mlflow_client, src_run_dct, run_id, @@ -99,47 +108,116 @@ def _mk_ex(src_run_dct, dst_run_id, exp_name): traceback.print_exc() raise MlflowExportImportException(e, f"Importing run {run_id} of experiment '{exp.name}' failed") - if utils.calling_databricks() and dst_notebook_dir: - _upload_databricks_notebook(dbx_client, input_dir, src_run_dct, dst_notebook_dir) - + if utils.calling_databricks(): #birbal added entire block + creds = mlflow_client._tracking_client.store.get_host_creds() + src_webappURL = src_run_dct.get("tags", {}).get("mlflow.databricks.webappURL", None) + _logger.info(f"src_webappURL is {src_webappURL} and target webappURL is {creds.host}") + if creds.host != src_webappURL: + _logger.info(f"NOTEBOOK IMPORT STARTED") + _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dct, dst_notebook_dir,run_id,notebook_user_mapping) #birbal added.. passed mlflow_client + else: + _logger.info(f"NOTEBOOK IMPORT SKIPPED DUE TO SAME WORKSPACE") res = (run, src_run_dct["tags"].get(MLFLOW_PARENT_RUN_ID, None)) _logger.info(f"Imported run '{run.info.run_id}' into experiment '{experiment_name}'") return res -def _upload_databricks_notebook(dbx_client, input_dir, src_run_dct, dst_notebook_dir): - run_id = src_run_dct["info"]["run_id"] +def _upload_databricks_notebook(mlflow_client, dbx_client, input_dir, src_run_dct, dst_notebook_dir,run_id,notebook_user_mapping): #birbal added tag_key = "mlflow.databricks.notebookPath" src_notebook_path = src_run_dct["tags"].get(tag_key,None) if not src_notebook_path: - _logger.warning(f"No tag '{tag_key}' for run_id '{run_id}'") + _logger.warning(f"No tag '{tag_key}' for run_id '{run_id}'. NOTEBOOK IMPORT SKIPPED") return + notebook_name = os.path.basename(src_notebook_path) + try: #birbal added entire try/except block to solve the issue where the source user doesn't exists in target workspace + dst_notebook_dir = os.path.dirname(src_notebook_path) + mlflow_utils.create_workspace_dir(dbx_client, dst_notebook_dir) + + except Exception as e: + _logger.warning(f"Failed to create directory '{dst_notebook_dir}'. This is most probably because the user doesn't exist in target workspace. Checking notebook user mapping file...") + if notebook_user_mapping: + _logger.info(f"notebook_user_mapping is {notebook_user_mapping}") + _logger.info(f"src_notebook_path BEFORE RENAME {src_notebook_path}") + src_notebook_path = rename_utils.rename(src_notebook_path, notebook_user_mapping, "notebook") + _logger.info(f"src_notebook_path AFTER RENAME {src_notebook_path}") + dst_notebook_dir = os.path.dirname(src_notebook_path) + mlflow_utils.create_workspace_dir(dbx_client, dst_notebook_dir) + else: + _logger.error(f"Notebook couldn't be imported because the target directory '{dst_notebook_dir}' could not be created, and no notebook user mapping file was provided as input") + raise e + format = "source" - notebook_path = _fs.make_local_path(os.path.join(input_dir,"artifacts","notebooks",f"{notebook_name}.{format}")) - if not _fs.exists(notebook_path): - _logger.warning(f"Source '{notebook_path}' does not exist for run_id '{run_id}'") - return + + notebook_path = os.path.join(input_dir,"artifacts","notebooks",f"{notebook_name}.{format}") #birbal added with open(notebook_path, "r", encoding="utf-8") as f: content = f.read() - dst_notebook_path = os.path.join(dst_notebook_dir, notebook_name) + dst_notebook_path = src_notebook_path + "_notebook" ##birbal added _notebook to fix issue with Notebook scoped experiment + + content = base64.b64encode(content.encode()).decode("utf-8") - data = { + payload = { "path": dst_notebook_path, "language": "PYTHON", "format": format, "overwrite": True, "content": content } - mlflow_utils.create_workspace_dir(dbx_client, dst_notebook_dir) try: _logger.info(f"Importing notebook '{dst_notebook_path}' for run {run_id}") - dbx_client._post("workspace/import", data) - except MlflowExportImportException as e: - _logger.warning(f"Cannot save notebook '{dst_notebook_path}'. {e}") + create_notebook(mlflow_client,payload,run_id) #birbal added + update_notebook_lineage(mlflow_client,run_id,dst_notebook_path) #birbal added + + except Exception as e: #birbal added + _logger.error(f"Error importing notebook '{dst_notebook_path}' for run_id {run_id}. Error - {e}") + + +def create_notebook(mlflow_client,payload,run_id): #birbal added this entire block + + creds = mlflow_client._tracking_client.store.get_host_creds() + host = creds.host + token = creds.token + + headers = { + "Authorization": f"Bearer {token}" + } + response = requests.post( + f"{host}/api/2.0/workspace/import", + headers=headers, + json=payload + ) + if response.status_code == 200: + _logger.info(f"Imported notebook for run_id {run_id} using workspace/import api") + else: + _logger.error(f"workspace/import api failed to import notebook for run_id {run_id}. response.text is {response.text}") + + + +def update_notebook_lineage(mlflow_client,run_id,dst_notebook_path): #birbal added this entire block + host=db_utils.get_workspace_url() + token=db_utils.get_databricks_host_creds().token + + HEADERS = { + 'Authorization': f'Bearer {token}' + } + + get_url = f'{host}/api/2.0/workspace/get-status' + params = {'path': dst_notebook_path} + response = requests.get(get_url, headers=HEADERS, params=params) + response.raise_for_status() + notebook_object = response.json() + notebook_id = notebook_object.get("object_id") + + creds = mlflow_client._tracking_client.store.get_host_creds() + mlflow_client.set_tag(run_id, "mlflow.source.name", dst_notebook_path) + mlflow_client.set_tag(run_id, "mlflow.source.type", "NOTEBOOK") + mlflow_client.set_tag(run_id, "mlflow.databricks.notebookID", notebook_id) + mlflow_client.set_tag(run_id, "mlflow.databricks.workspaceURL", host) + mlflow_client.set_tag(run_id, "mlflow.databricks.notebookPath", dst_notebook_path) + mlflow_client.set_tag(run_id, "mlflow.databricks.webappURL", creds.host) def _import_inputs(http_client, src_run_dct, run_id): inputs = src_run_dct.get("inputs") diff --git a/mlflow_export_import/run/run_data_importer.py b/mlflow_export_import/run/run_data_importer.py index bc6cec1a..e6876457 100644 --- a/mlflow_export_import/run/run_data_importer.py +++ b/mlflow_export_import/run/run_data_importer.py @@ -13,6 +13,7 @@ from mlflow_export_import.common.source_tags import mk_source_tags_mlflow_tag, mk_source_tags + def _log_data(run_dct, run_id, batch_size, get_data, log_data, args_get_data=None): metadata = get_data(run_dct, args_get_data) num_batches = int(math.ceil(len(metadata) / batch_size)) @@ -24,7 +25,6 @@ def _log_data(run_dct, run_id, batch_size, get_data, log_data, args_get_data=Non log_data(run_id, batch) res = res + batch - def _log_params(client, run_dct, run_id, batch_size): def get_data(run_dct, args): return [ Param(k,v) for k,v in run_dct["params"].items() ] @@ -59,6 +59,7 @@ def get_data(run_dct, args): tags = { **tags, **source_mlflow_tags, **source_info_tags } tags = utils.create_mlflow_tags_for_databricks_import(tags) # remove "mlflow" tags that cannot be imported into Databricks tags = [ RunTag(k,v) for k,v in tags.items() ] + if not in_databricks: utils.set_dst_user_id(tags, args["src_user_id"], args["use_src_user_id"]) return tags @@ -73,13 +74,14 @@ def log_data(run_id, tags): } _log_data(run_dct, run_id, batch_size, get_data, log_data, args_get) + def import_run_data(mlflow_client, run_dct, run_id, import_source_tags, src_user_id, use_src_user_id, in_databricks): from mlflow.utils.validation import MAX_PARAMS_TAGS_PER_BATCH, MAX_METRICS_PER_BATCH _log_params(mlflow_client, run_dct, run_id, MAX_PARAMS_TAGS_PER_BATCH) _log_metrics(mlflow_client, run_dct, run_id, MAX_METRICS_PER_BATCH) - _log_tags( + _log_tags( mlflow_client, run_dct, run_id, @@ -88,8 +90,7 @@ def import_run_data(mlflow_client, run_dct, run_id, import_source_tags, src_user in_databricks, src_user_id, use_src_user_id -) - + ) if __name__ == "__main__": import sys