From 3f60ed2092a2fc6739bb73444bf8c59c8da50467 Mon Sep 17 00:00:00 2001 From: Arjun Dinesh Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Sat, 28 Jun 2025 14:09:02 +0530 Subject: [PATCH] feat: fallback to `load_from_disk` when loading saved directories in `load_dataset` ### Related Issue Fixes #7503 ### What does this PR do? This PR introduces a fallback mechanism in `load_dataset()` that detects when the input `path` points to a dataset previously saved using `save_to_disk()`, and automatically redirects to `load_from_disk(path)`. Previously, calling `load_dataset("/path/to/saved/dataset")` would misinterpret the local structure and return incorrect metadata rows. Now: ```python # Before: unexpected result ds = load_dataset("my_saved_dataset") # Misinterprets metadata # After: correct behavior ds = load_dataset("my_saved_dataset") # Auto-switches to load_from_disk() --- src/datasets/load.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/datasets/load.py b/src/datasets/load.py index bc2b0e679b6..f343bed0a58 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -63,6 +63,7 @@ from .fingerprint import Hasher from .info import DatasetInfo, DatasetInfosDict from .iterable_dataset import IterableDataset +from .load_from_disk import load_from_disk from .naming import camelcase_to_snakecase, snakecase_to_camelcase from .packaged_modules import ( _EXTENSION_TO_MODULE, @@ -1362,6 +1363,18 @@ def load_dataset( >>> ds = load_dataset('imagefolder', data_dir='/path/to/images', split='train') ``` """ + # Fallback: auto-detect save_to_disk-style folders + if ( + os.path.isdir(path) + and os.path.exists(os.path.join(path, "dataset_info.json")) # for Dataset + or os.path.exists(os.path.join(path, "dataset_dict.json")) # for DatasetDict + ): + logger.warning( + "Detected a directory saved via `save_to_disk()`. Redirecting to `load_from_disk('%s')` for compatibility.", + path, + ) + return load_from_disk(path) + if "trust_remote_code" in config_kwargs: if config_kwargs.pop("trust_remote_code"): logger.error(