Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/pytorch/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in
uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App

tensorboard >=2.11, <2.21.0 # for `TensorBoardLogger`
mlflow >=3.0.0, <4.0.0 # for `MLFlowLogger`

torch-tensorrt; platform_system == "Linux" and python_version >= "3.12"
5 changes: 2 additions & 3 deletions src/lightning/pytorch/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,9 @@ def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
params = _flatten_dict(params)

from mlflow.entities import Param
from mlflow.utils.validation import MAX_PARAM_VAL_LENGTH

# Truncate parameter values to 250 characters.
# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]
params_list = [Param(key=k, value=str(v)[:MAX_PARAM_VAL_LENGTH]) for k, v in params.items()]

# Log in chunks of 100 parameters (the maximum allowed by MLflow).
for idx in range(0, len(params_list), 100):
Expand Down
7 changes: 4 additions & 3 deletions tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,16 +317,17 @@ def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path):

@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_with_long_param_value(mlflow_mock, tmp_path):
"""Test that long parameter values are truncated to 250 characters."""
"""Test that long parameter values are truncated using MLflow's MAX_PARAM_VAL_LENGTH."""
from mlflow.utils.validation import MAX_PARAM_VAL_LENGTH

def _check_value_length(value, *args, **kwargs):
assert len(value) <= 250
assert len(value) <= MAX_PARAM_VAL_LENGTH

mlflow_mock.entities.Param.side_effect = _check_value_length

logger = MLFlowLogger("test", save_dir=str(tmp_path))

params = {"test": "test_param" * 50}
params = {"test": "test_param" * 1000}
logger.log_hyperparams(params)

# assert_called_once_with() won't properly check the parameter value.
Expand Down
Loading