From 26aa0ffd9a99b23da7da0d6134bcfd6fab4ef5f5 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Mon, 14 Jul 2025 11:29:15 +0200 Subject: [PATCH 01/24] [Feature] Add support for loading datasets from local Minari cache --- test/test_libs.py | 54 ++++++++++++++++++++++++++++ torchrl/data/datasets/minari_data.py | 45 +++++++++++++++++++---- 2 files changed, 92 insertions(+), 7 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 1a92eb671c3..fb49f3778e3 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -32,6 +32,9 @@ import numpy as np import pytest import torch +from minari import DataCollector +import gymnasium as gym +import minari from packaging import version from tensordict import ( @@ -3452,7 +3455,58 @@ def fn(data): assert len(dataset) == 100 assert sample["data"].shape == torch.Size([32, 8]) assert sample["next", "data"].shape == torch.Size([32, 8]) + + + @pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari or Gym not available") + def test_local_minari_dataset_loading(self): + if not _minari_init(): + pytest.skip("Failed to initialize Minari datasets") + + dataset_id = "cartpole/test-local-v1" + + # Create dataset using Gym + DataCollector + env = gym.make("CartPole-v1") + env = DataCollector(env, record_infos=True) + for _ in range(50): + env.reset(seed=123) + while True: + action = env.action_space.sample() + obs, rew, terminated, truncated, info = env.step(action) + if terminated or truncated: + break + + env.create_dataset( + dataset_id=dataset_id, + algorithm_name="RandomPolicy", + code_permalink="https://github.com/Farama-Foundation/Minari", + author="Farama", + author_email="contact@farama.org", + eval_env="CartPole-v1" + ) + + # Load from local cache + data = MinariExperienceReplay( + dataset_id=dataset_id, + split_trajs=False, + batch_size=32, + download=False, + sampler=SamplerWithoutReplacement(drop_last=True), + prefetch=2, + load_from_local_minari=True, + ) + + t0 = time.time() + for i, sample in enumerate(data): + t1 = time.time() + torchrl_logger.info(f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms") + assert data.metadata["action_space"].is_in(sample["action"]), "Invalid action sample" + assert data.metadata["observation_space"].is_in(sample["observation"]), "Invalid observation sample" + t0 = time.time() + if i == 10: + break + + minari.delete_dataset(dataset_id="cartpole/test-local-v1") @pytest.mark.slow class TestRoboset: diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 3d0d241bd99..166eaefced6 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -88,6 +88,14 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay): it is assumed that any ``truncated`` or ``terminated`` signal is equivalent to the end of a trajectory. Defaults to ``False``. + load_from_local_minari (bool, optional): if ``True``, the dataset will be loaded directly + from the local Minari cache (typically located at ``~/.minari/datasets``), + bypassing any remote download. This is useful when working with custom + Minari datasets previously generated and stored locally, or when network + access should be avoided. If the dataset is not found in the expected + cache directory, a ``FileNotFoundError`` will be raised. + Defaults to ``False``. + Attributes: available_datasets: a list of accepted entries to be downloaded. @@ -167,6 +175,7 @@ def __init__( prefetch: int | None = None, transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, + load_from_local_minari: bool = False, ): self.dataset_id = dataset_id if root is None: @@ -175,7 +184,9 @@ def __init__( self.root = root self.split_trajs = split_trajs self.download = download - if self.download == "force" or (self.download and not self._is_downloaded()): + self.load_from_local_minari = load_from_local_minari + + if self.download == "force" or (self.download and not self._is_downloaded()) or self.load_from_local_minari: if self.download == "force": try: if os.path.exists(self.data_path_root): @@ -240,13 +251,34 @@ def _download_and_preproc(self): with tempfile.TemporaryDirectory() as tmpdir: os.environ["MINARI_DATASETS_PATH"] = tmpdir - minari.download_dataset(dataset_id=self.dataset_id) - parent_dir = Path(tmpdir) / self.dataset_id / "data" - td_data = TensorDict() total_steps = 0 - torchrl_logger.info("first read through data to create data structure...") - h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") + td_data = TensorDict() + + if self.load_from_local_minari: + # Load minari dataset from user's local Minari cache + + minari_cache_dir = os.path.expanduser("~/.minari/datasets") + os.environ["MINARI_DATASETS_PATH"] = minari_cache_dir + parent_dir = Path(minari_cache_dir) / self.dataset_id / "data" + h5_path = parent_dir / "main_data.hdf5" + + if not h5_path.exists(): + raise FileNotFoundError(f"{h5_path} does not exist in local Minari cache!") + + torchrl_logger.info( + f"loading dataset from local Minari cache at {h5_path}" + ) + h5_data = PersistentTensorDict.from_h5(h5_path) + + else: + minari.download_dataset(dataset_id=self.dataset_id) + + parent_dir = Path(tmpdir) / self.dataset_id / "data" + + torchrl_logger.info("first read through data to create data structure...") + h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") + # populate the tensordict episode_dict = {} for i, (episode_key, episode) in enumerate(h5_data.items()): @@ -365,7 +397,6 @@ def _download_and_preproc(self): json.dump(self.metadata, metadata_file) self._load_and_proc_metadata() return td_data - def _make_split(self): from torchrl.collectors.utils import split_trajectories From a56c5086bfbf9f5e67d997713a86f692d03fa82c Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Mon, 14 Jul 2025 13:12:13 +0200 Subject: [PATCH 02/24] [Refactor] Fixed linting errors --- test/test_libs.py | 30 ++++++++++++++++++---------- torchrl/data/datasets/minari_data.py | 21 +++++++++++++------ 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index fb49f3778e3..1ba5581ab93 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -29,12 +29,12 @@ from sys import platform from unittest import mock +import minari + import numpy as np import pytest import torch from minari import DataCollector -import gymnasium as gym -import minari from packaging import version from tensordict import ( @@ -3455,9 +3455,10 @@ def fn(data): assert len(dataset) == 100 assert sample["data"].shape == torch.Size([32, 8]) assert sample["next", "data"].shape == torch.Size([32, 8]) - - - @pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari or Gym not available") + + @pytest.mark.skipif( + not _has_minari or not _has_gymnasium, reason="Minari or Gym not available" + ) def test_local_minari_dataset_loading(self): if not _minari_init(): @@ -3466,7 +3467,7 @@ def test_local_minari_dataset_loading(self): dataset_id = "cartpole/test-local-v1" # Create dataset using Gym + DataCollector - env = gym.make("CartPole-v1") + env = gymnasium.make("CartPole-v1") env = DataCollector(env, record_infos=True) for _ in range(50): env.reset(seed=123) @@ -3482,7 +3483,7 @@ def test_local_minari_dataset_loading(self): code_permalink="https://github.com/Farama-Foundation/Minari", author="Farama", author_email="contact@farama.org", - eval_env="CartPole-v1" + eval_env="CartPole-v1", ) # Load from local cache @@ -3499,15 +3500,22 @@ def test_local_minari_dataset_loading(self): t0 = time.time() for i, sample in enumerate(data): t1 = time.time() - torchrl_logger.info(f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms") - assert data.metadata["action_space"].is_in(sample["action"]), "Invalid action sample" - assert data.metadata["observation_space"].is_in(sample["observation"]), "Invalid observation sample" + torchrl_logger.info( + f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms" + ) + assert data.metadata["action_space"].is_in( + sample["action"] + ), "Invalid action sample" + assert data.metadata["observation_space"].is_in( + sample["observation"] + ), "Invalid observation sample" t0 = time.time() if i == 10: break - + minari.delete_dataset(dataset_id="cartpole/test-local-v1") + @pytest.mark.slow class TestRoboset: def test_load(self): diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 166eaefced6..d8c63873497 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -93,7 +93,7 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay): bypassing any remote download. This is useful when working with custom Minari datasets previously generated and stored locally, or when network access should be avoided. If the dataset is not found in the expected - cache directory, a ``FileNotFoundError`` will be raised. + cache directory, a ``FileNotFoundError`` will be raised. Defaults to ``False``. @@ -186,7 +186,11 @@ def __init__( self.download = download self.load_from_local_minari = load_from_local_minari - if self.download == "force" or (self.download and not self._is_downloaded()) or self.load_from_local_minari: + if ( + self.download == "force" + or (self.download and not self._is_downloaded()) + or self.load_from_local_minari + ): if self.download == "force": try: if os.path.exists(self.data_path_root): @@ -264,8 +268,10 @@ def _download_and_preproc(self): h5_path = parent_dir / "main_data.hdf5" if not h5_path.exists(): - raise FileNotFoundError(f"{h5_path} does not exist in local Minari cache!") - + raise FileNotFoundError( + f"{h5_path} does not exist in local Minari cache!" + ) + torchrl_logger.info( f"loading dataset from local Minari cache at {h5_path}" ) @@ -273,10 +279,12 @@ def _download_and_preproc(self): else: minari.download_dataset(dataset_id=self.dataset_id) - + parent_dir = Path(tmpdir) / self.dataset_id / "data" - torchrl_logger.info("first read through data to create data structure...") + torchrl_logger.info( + "first read through data to create data structure..." + ) h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") # populate the tensordict @@ -397,6 +405,7 @@ def _download_and_preproc(self): json.dump(self.metadata, metadata_file) self._load_and_proc_metadata() return td_data + def _make_split(self): from torchrl.collectors.utils import split_trajectories From e90f4d0599c017f4ddf6d526020d2de4bb2e29f4 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Mon, 14 Jul 2025 16:19:28 +0200 Subject: [PATCH 03/24] [Refactor] moved minari import to local scope --- test/test_libs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 1ba5581ab93..894ed79ff43 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -29,7 +29,6 @@ from sys import platform from unittest import mock -import minari import numpy as np import pytest @@ -3460,7 +3459,8 @@ def fn(data): not _has_minari or not _has_gymnasium, reason="Minari or Gym not available" ) def test_local_minari_dataset_loading(self): - + import minari + if not _minari_init(): pytest.skip("Failed to initialize Minari datasets") From cc43f9a5d4491ca5bcef287893cc64946035491e Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Tue, 15 Jul 2025 15:44:41 +0200 Subject: [PATCH 04/24] [Refactor] Refactor Minari testing to use custom datasets --- test/test_libs.py | 192 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 173 insertions(+), 19 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 894ed79ff43..532140466f9 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -29,11 +29,9 @@ from sys import platform from unittest import mock - import numpy as np import pytest import torch -from minari import DataCollector from packaging import version from tensordict import ( @@ -3343,6 +3341,39 @@ def test_d4rl_iteration(self, task, split_trajs): _MINARI_DATASETS = [] +MUJOCO_ENVIRONMENTS = [ + "Hopper-v5", + "Pusher-v5", + "Humanoid-v5", + "InvertedDoublePendulum-v5", + "HalfCheetah-v5", + "Swimmer-v5", + "Walker2d-v5", + "Ant-v5", + "Reacher-v5", +] + +D4RL_ENVIRONMENTS = [ + "AntMaze_UMaze-v5", + "AdroitHandPen-v1", + "AntMaze_Medium-v4", + "AntMaze_Large_Diverse_GR-v4", + "AntMaze_Large-v4", + "AntMaze_Medium_Diverse_GR-v4", + "PointMaze_OpenDense-v3", + "PointMaze_UMaze-v3", + "PointMaze_LargeDense-v3", + "PointMaze_Medium-v3", + "PointMaze_UMazeDense-v3", + "PointMaze_MediumDense-v3", + "PointMaze_Large-v3", + "PointMaze_Open-v3", + "FrankaKitchen-v1", + "AdroitHandDoor-v1", + "AdroitHandHammer-v1", + "AdroitHandRelocate-v1", +] + def _minari_init(): """Initialize Minari datasets list. Returns True if already initialized.""" @@ -3375,30 +3406,148 @@ def _minari_init(): return False -# Initialize with placeholder values for parametrization -# These will be replaced with actual dataset names when the first Minari test runs -_MINARI_DATASETS = [str(i) for i in range(20)] +def get_random_minigrid_datasets(): + """ + Fetch 5 random Minigrid datasets from the Minari server. + """ + import minari + + all_minigrid = [ + dataset + for dataset in minari.list_remote_datasets( + latest_version=True, compatible_minari_version=True + ).keys() + if dataset.startswith("minigrid/") + ] + + if len(all_minigrid) < 5: + raise RuntimeError("Not enough minigrid datasets found on Minari server.") + indices = torch.randperm(len(all_minigrid))[:5] + return [all_minigrid[idx] for idx in indices] + + +def get_random_atari_envs(): + """ + Fetch 10 random Atari environments using ale_py and torch. + """ + import ale_py + import gymnasium as gym + + gym.register_envs(ale_py) + + env_specs = gym.envs.registry.values() + all_env_ids = [env_spec.id for env_spec in env_specs] + atari_env_ids = [env_id for env_id in all_env_ids if env_id.startswith("ALE")] + if len(atari_env_ids) < 10: + raise RuntimeError("Not enough Atari environments found.") + indices = torch.randperm(len(atari_env_ids))[:10] + return [atari_env_ids[idx] for idx in indices] + + +def custom_minari_init(custom_envs, num_episodes=5): + """ + Initialize custom Minari datasets for the given environments. + """ + import gymnasium + import gymnasium_robotics + from minari import DataCollector + + gymnasium.register_envs(gymnasium_robotics) + + custom_dataset_ids = [] + for env_id in custom_envs: + dataset_id = f"{env_id.lower()}/test-custom-local-v1" + env = gymnasium.make(env_id) + collector = DataCollector(env) + + for ep in range(num_episodes): + collector.reset(seed=123 + ep) + + while True: + action = collector.action_space.sample() + _, _, terminated, truncated, _ = collector.step(action) + if terminated or truncated: + break + + collector.create_dataset( + dataset_id=dataset_id, + algorithm_name="RandomPolicy", + code_permalink="https://github.com/Farama-Foundation/Minari", + author="Farama", + author_email="contact@farama.org", + eval_env=env_id, + ) + custom_dataset_ids.append(dataset_id) + + return custom_dataset_ids @pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari not found") @pytest.mark.slow class TestMinari: @pytest.mark.parametrize("split", [False, True]) - @pytest.mark.parametrize("dataset_idx", range(20)) + @pytest.mark.parametrize( + "dataset_idx", + range( + len(MUJOCO_ENVIRONMENTS) + + len(D4RL_ENVIRONMENTS) + + len(get_random_minigrid_datasets()) + + len(get_random_atari_envs()) + ), + ) def test_load(self, dataset_idx, split): - # Initialize Minari datasets if not already done - if not _minari_init(): - pytest.skip("Failed to initialize Minari datasets") + """ + Test loading from custom datasets for Mujoco and D4RL, + Minari remote datasets for Minigrid, and random Atari environments. + """ + import minari - # Get the actual dataset name from the initialized list - if dataset_idx >= len(_MINARI_DATASETS): - pytest.skip(f"Dataset index {dataset_idx} out of range") + custom_envs = MUJOCO_ENVIRONMENTS + D4RL_ENVIRONMENTS + num_custom = len(custom_envs) + minigrid_datasets = get_random_minigrid_datasets() + num_minigrid = len(minigrid_datasets) + atari_envs = get_random_atari_envs() + + if dataset_idx < num_custom: + # Custom dataset for Mujoco/D4RL + custom_dataset_ids = custom_minari_init( + [custom_envs[dataset_idx]], num_episodes=5 + ) + dataset_id = custom_dataset_ids[0] + data = MinariExperienceReplay( + dataset_id=dataset_id, + split_trajs=split, + batch_size=32, + load_from_local_minari=True, + ) + cleanup_needed = True + + elif dataset_idx < num_custom + num_minigrid: + # Minigrid datasets from Minari server + minigrid_idx = dataset_idx - num_custom + dataset_id = minigrid_datasets[minigrid_idx] + data = MinariExperienceReplay( + dataset_id=dataset_id, + batch_size=32, + split_trajs=split, + download="force", + ) + cleanup_needed = False + + else: + # Atari environment datasets + atari_idx = dataset_idx - num_custom - num_minigrid + env_id = atari_envs[atari_idx] + custom_dataset_ids = custom_minari_init([env_id], num_episodes=5) + dataset_id = custom_dataset_ids[0] + data = MinariExperienceReplay( + dataset_id=dataset_id, + split_trajs=split, + batch_size=32, + load_from_local_minari=True, + ) + cleanup_needed = True - selected_dataset = _MINARI_DATASETS[dataset_idx] - torchrl_logger.info(f"dataset {selected_dataset}") - data = MinariExperienceReplay( - selected_dataset, batch_size=32, split_trajs=split - ) t0 = time.time() for i, sample in enumerate(data): t1 = time.time() @@ -3409,6 +3558,10 @@ def test_load(self, dataset_idx, split): if i == 10: break + # Clean up custom datasets after running local dataset tests + if cleanup_needed: + minari.delete_dataset(dataset_id=dataset_id) + def test_minari_preproc(self, tmpdir): dataset = MinariExperienceReplay( "D4RL/pointmaze/large-v2", @@ -3459,8 +3612,9 @@ def fn(data): not _has_minari or not _has_gymnasium, reason="Minari or Gym not available" ) def test_local_minari_dataset_loading(self): - import minari - + import minari + from minari import DataCollector + if not _minari_init(): pytest.skip("Failed to initialize Minari datasets") From 23aececfc7f474514bc8288a511e56f8263feaaa Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Fri, 18 Jul 2025 15:47:14 +0200 Subject: [PATCH 05/24] [Fix] eliminate global imports for test_libs --- test/test_libs.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 532140466f9..d4ac6bf37e6 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3488,12 +3488,7 @@ class TestMinari: @pytest.mark.parametrize("split", [False, True]) @pytest.mark.parametrize( "dataset_idx", - range( - len(MUJOCO_ENVIRONMENTS) - + len(D4RL_ENVIRONMENTS) - + len(get_random_minigrid_datasets()) - + len(get_random_atari_envs()) - ), + range(100) ) def test_load(self, dataset_idx, split): """ @@ -3504,9 +3499,20 @@ def test_load(self, dataset_idx, split): custom_envs = MUJOCO_ENVIRONMENTS + D4RL_ENVIRONMENTS num_custom = len(custom_envs) - minigrid_datasets = get_random_minigrid_datasets() + try: + minigrid_datasets = get_random_minigrid_datasets() + except Exception: + minigrid_datasets = [] num_minigrid = len(minigrid_datasets) - atari_envs = get_random_atari_envs() + try: + atari_envs = get_random_atari_envs() + except Exception: + atari_envs = [] + num_atari = len(atari_envs) + total_datasets = num_custom + num_minigrid + num_atari + + if dataset_idx >= total_datasets: + pytest.skip("Index out of range for available datasets") if dataset_idx < num_custom: # Custom dataset for Mujoco/D4RL @@ -3669,7 +3675,6 @@ def test_local_minari_dataset_loading(self): minari.delete_dataset(dataset_id="cartpole/test-local-v1") - @pytest.mark.slow class TestRoboset: def test_load(self): From 9321adbae9df609388e85fd0e62ae5eeb0c12b0a Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Fri, 18 Jul 2025 16:00:27 +0200 Subject: [PATCH 06/24] [Refactor] Reduce dataset index range in Minari tests --- test/test_libs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_libs.py b/test/test_libs.py index d4ac6bf37e6..b3609bb7ccd 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3488,7 +3488,8 @@ class TestMinari: @pytest.mark.parametrize("split", [False, True]) @pytest.mark.parametrize( "dataset_idx", - range(100) + # Only use a static upper bound; do not call any function that imports minari globally. + range(50) ) def test_load(self, dataset_idx, split): """ From 0ab8f813d616a9843086894ec05b923e63ab5efe Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Mon, 21 Jul 2025 23:51:13 +0200 Subject: [PATCH 07/24] [Enhancement] Update minari environment.yml and refactored minari test --- .github/unittest/linux_libs/scripts_minari/environment.yml | 2 ++ pytest.ini | 1 + 2 files changed, 3 insertions(+) diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml index ce6bdba3c24..98ddb6aec79 100644 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -21,3 +21,5 @@ dependencies: - hydra-core - minari[gcs,hdf5,hf] - gymnasium<1.0.0 + - ale-py + - gymnasium-robotics diff --git a/pytest.ini b/pytest.ini index 39fe36617a1..07bde00eb7d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,6 +4,7 @@ addopts = -ra # Make tracebacks shorter --tb=native + --runslow markers = unity_editor testpaths = From 01965f4a0c9ec34cdf468417d7d137fe3a68b56d Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Mon, 21 Jul 2025 23:52:52 +0200 Subject: [PATCH 08/24] [Refactor] Reduced datasets for minari tests --- pytest.ini | 1 - test/test_libs.py | 23 ++++++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/pytest.ini b/pytest.ini index 07bde00eb7d..39fe36617a1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,7 +4,6 @@ addopts = -ra # Make tracebacks shorter --tb=native - --runslow markers = unity_editor testpaths = diff --git a/test/test_libs.py b/test/test_libs.py index b3609bb7ccd..3ab8a467454 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3420,15 +3420,14 @@ def get_random_minigrid_datasets(): if dataset.startswith("minigrid/") ] - if len(all_minigrid) < 5: - raise RuntimeError("Not enough minigrid datasets found on Minari server.") - indices = torch.randperm(len(all_minigrid))[:5] + # 3 random datasets + indices = torch.randperm(len(all_minigrid))[:3] return [all_minigrid[idx] for idx in indices] def get_random_atari_envs(): """ - Fetch 10 random Atari environments using ale_py and torch. + Fetch 3 random Atari environments using ale_py and torch. """ import ale_py import gymnasium as gym @@ -3438,9 +3437,9 @@ def get_random_atari_envs(): env_specs = gym.envs.registry.values() all_env_ids = [env_spec.id for env_spec in env_specs] atari_env_ids = [env_id for env_id in all_env_ids if env_id.startswith("ALE")] - if len(atari_env_ids) < 10: + if len(atari_env_ids) < 3: raise RuntimeError("Not enough Atari environments found.") - indices = torch.randperm(len(atari_env_ids))[:10] + indices = torch.randperm(len(atari_env_ids))[:3] return [atari_env_ids[idx] for idx in indices] @@ -3489,7 +3488,7 @@ class TestMinari: @pytest.mark.parametrize( "dataset_idx", # Only use a static upper bound; do not call any function that imports minari globally. - range(50) + range(7) ) def test_load(self, dataset_idx, split): """ @@ -3498,8 +3497,14 @@ def test_load(self, dataset_idx, split): """ import minari + num_custom_to_select = 4 custom_envs = MUJOCO_ENVIRONMENTS + D4RL_ENVIRONMENTS - num_custom = len(custom_envs) + + # Randomly select a subset of custom environments + indices = torch.randperm(len(custom_envs))[:num_custom_to_select] + custom_envs_subset = [custom_envs[i] for i in indices] + + num_custom = len(custom_envs_subset) try: minigrid_datasets = get_random_minigrid_datasets() except Exception: @@ -3518,7 +3523,7 @@ def test_load(self, dataset_idx, split): if dataset_idx < num_custom: # Custom dataset for Mujoco/D4RL custom_dataset_ids = custom_minari_init( - [custom_envs[dataset_idx]], num_episodes=5 + [custom_envs_subset[dataset_idx]], num_episodes=5 ) dataset_id = custom_dataset_ids[0] data = MinariExperienceReplay( From dcd06b2ffeb94a0eed5dc81fc12801eb264bf29e Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Tue, 22 Jul 2025 08:48:59 +0200 Subject: [PATCH 09/24] [Enhancement] Update environment.yml minari dependencies --- .github/unittest/linux_libs/scripts_minari/environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml index 98ddb6aec79..e45a1b2ab8e 100644 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -23,3 +23,4 @@ dependencies: - gymnasium<1.0.0 - ale-py - gymnasium-robotics + - minari[create] From 88b3dbe0de93bb0d57778eb9105f05fbf8a1dcc9 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Tue, 22 Jul 2025 09:03:26 +0200 Subject: [PATCH 10/24] [Enhancement] Added minari environment.yml dependencies --- .github/unittest/linux_libs/scripts_minari/environment.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml index e45a1b2ab8e..ef5dc520245 100644 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -24,3 +24,6 @@ dependencies: - ale-py - gymnasium-robotics - minari[create] + - jax + - mujoco + - mujoco-py \ No newline at end of file From 598300841a770d309d918d03e03152fae40ce29c Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Tue, 22 Jul 2025 10:21:05 +0200 Subject: [PATCH 11/24] [Enhancement] Update dataset fetching to retrieve 2 random datasets and added requirements --- .../linux_libs/scripts_minari/environment.yml | 3 ++- test/test_libs.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml index ef5dc520245..6d03c4e0275 100644 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -26,4 +26,5 @@ dependencies: - minari[create] - jax - mujoco - - mujoco-py \ No newline at end of file + - mujoco-py + - minigrid \ No newline at end of file diff --git a/test/test_libs.py b/test/test_libs.py index 3ab8a467454..971f95a3d3d 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3408,7 +3408,7 @@ def _minari_init(): def get_random_minigrid_datasets(): """ - Fetch 5 random Minigrid datasets from the Minari server. + Fetch 2 random Minigrid datasets from the Minari server. """ import minari @@ -3421,13 +3421,13 @@ def get_random_minigrid_datasets(): ] # 3 random datasets - indices = torch.randperm(len(all_minigrid))[:3] + indices = torch.randperm(len(all_minigrid))[:2] return [all_minigrid[idx] for idx in indices] def get_random_atari_envs(): """ - Fetch 3 random Atari environments using ale_py and torch. + Fetch 2 random Atari environments using ale_py and torch. """ import ale_py import gymnasium as gym @@ -3437,9 +3437,9 @@ def get_random_atari_envs(): env_specs = gym.envs.registry.values() all_env_ids = [env_spec.id for env_spec in env_specs] atari_env_ids = [env_id for env_id in all_env_ids if env_id.startswith("ALE")] - if len(atari_env_ids) < 3: + if len(atari_env_ids) < 2: raise RuntimeError("Not enough Atari environments found.") - indices = torch.randperm(len(atari_env_ids))[:3] + indices = torch.randperm(len(atari_env_ids))[:2] return [atari_env_ids[idx] for idx in indices] @@ -3488,7 +3488,7 @@ class TestMinari: @pytest.mark.parametrize( "dataset_idx", # Only use a static upper bound; do not call any function that imports minari globally. - range(7) + range(8) ) def test_load(self, dataset_idx, split): """ @@ -3510,10 +3510,12 @@ def test_load(self, dataset_idx, split): except Exception: minigrid_datasets = [] num_minigrid = len(minigrid_datasets) + try: atari_envs = get_random_atari_envs() except Exception: atari_envs = [] + num_atari = len(atari_envs) total_datasets = num_custom + num_minigrid + num_atari From c532070bf5c4329e61423056d61a7bb48229d42c Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Wed, 23 Jul 2025 09:46:17 +0200 Subject: [PATCH 12/24] [Fix] Resolved mujoco-py version --- .github/unittest/linux_libs/scripts_minari/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml index 6d03c4e0275..5b9e34437ee 100644 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -26,5 +26,5 @@ dependencies: - minari[create] - jax - mujoco - - mujoco-py + - mujoco-py<2.2,>=2.1 - minigrid \ No newline at end of file From c5b12b9d9f2c59deb250b7074adb082b6e73897d Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Mon, 14 Jul 2025 11:29:15 +0200 Subject: [PATCH 13/24] [Feature] Add support for loading datasets from local Minari cache --- test/test_libs.py | 54 ++++++++++++++++++++++++++++ torchrl/data/datasets/minari_data.py | 45 +++++++++++++++++++---- 2 files changed, 92 insertions(+), 7 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 1a92eb671c3..fb49f3778e3 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -32,6 +32,9 @@ import numpy as np import pytest import torch +from minari import DataCollector +import gymnasium as gym +import minari from packaging import version from tensordict import ( @@ -3452,7 +3455,58 @@ def fn(data): assert len(dataset) == 100 assert sample["data"].shape == torch.Size([32, 8]) assert sample["next", "data"].shape == torch.Size([32, 8]) + + + @pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari or Gym not available") + def test_local_minari_dataset_loading(self): + if not _minari_init(): + pytest.skip("Failed to initialize Minari datasets") + + dataset_id = "cartpole/test-local-v1" + + # Create dataset using Gym + DataCollector + env = gym.make("CartPole-v1") + env = DataCollector(env, record_infos=True) + for _ in range(50): + env.reset(seed=123) + while True: + action = env.action_space.sample() + obs, rew, terminated, truncated, info = env.step(action) + if terminated or truncated: + break + + env.create_dataset( + dataset_id=dataset_id, + algorithm_name="RandomPolicy", + code_permalink="https://github.com/Farama-Foundation/Minari", + author="Farama", + author_email="contact@farama.org", + eval_env="CartPole-v1" + ) + + # Load from local cache + data = MinariExperienceReplay( + dataset_id=dataset_id, + split_trajs=False, + batch_size=32, + download=False, + sampler=SamplerWithoutReplacement(drop_last=True), + prefetch=2, + load_from_local_minari=True, + ) + + t0 = time.time() + for i, sample in enumerate(data): + t1 = time.time() + torchrl_logger.info(f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms") + assert data.metadata["action_space"].is_in(sample["action"]), "Invalid action sample" + assert data.metadata["observation_space"].is_in(sample["observation"]), "Invalid observation sample" + t0 = time.time() + if i == 10: + break + + minari.delete_dataset(dataset_id="cartpole/test-local-v1") @pytest.mark.slow class TestRoboset: diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 3d0d241bd99..166eaefced6 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -88,6 +88,14 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay): it is assumed that any ``truncated`` or ``terminated`` signal is equivalent to the end of a trajectory. Defaults to ``False``. + load_from_local_minari (bool, optional): if ``True``, the dataset will be loaded directly + from the local Minari cache (typically located at ``~/.minari/datasets``), + bypassing any remote download. This is useful when working with custom + Minari datasets previously generated and stored locally, or when network + access should be avoided. If the dataset is not found in the expected + cache directory, a ``FileNotFoundError`` will be raised. + Defaults to ``False``. + Attributes: available_datasets: a list of accepted entries to be downloaded. @@ -167,6 +175,7 @@ def __init__( prefetch: int | None = None, transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, + load_from_local_minari: bool = False, ): self.dataset_id = dataset_id if root is None: @@ -175,7 +184,9 @@ def __init__( self.root = root self.split_trajs = split_trajs self.download = download - if self.download == "force" or (self.download and not self._is_downloaded()): + self.load_from_local_minari = load_from_local_minari + + if self.download == "force" or (self.download and not self._is_downloaded()) or self.load_from_local_minari: if self.download == "force": try: if os.path.exists(self.data_path_root): @@ -240,13 +251,34 @@ def _download_and_preproc(self): with tempfile.TemporaryDirectory() as tmpdir: os.environ["MINARI_DATASETS_PATH"] = tmpdir - minari.download_dataset(dataset_id=self.dataset_id) - parent_dir = Path(tmpdir) / self.dataset_id / "data" - td_data = TensorDict() total_steps = 0 - torchrl_logger.info("first read through data to create data structure...") - h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") + td_data = TensorDict() + + if self.load_from_local_minari: + # Load minari dataset from user's local Minari cache + + minari_cache_dir = os.path.expanduser("~/.minari/datasets") + os.environ["MINARI_DATASETS_PATH"] = minari_cache_dir + parent_dir = Path(minari_cache_dir) / self.dataset_id / "data" + h5_path = parent_dir / "main_data.hdf5" + + if not h5_path.exists(): + raise FileNotFoundError(f"{h5_path} does not exist in local Minari cache!") + + torchrl_logger.info( + f"loading dataset from local Minari cache at {h5_path}" + ) + h5_data = PersistentTensorDict.from_h5(h5_path) + + else: + minari.download_dataset(dataset_id=self.dataset_id) + + parent_dir = Path(tmpdir) / self.dataset_id / "data" + + torchrl_logger.info("first read through data to create data structure...") + h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") + # populate the tensordict episode_dict = {} for i, (episode_key, episode) in enumerate(h5_data.items()): @@ -365,7 +397,6 @@ def _download_and_preproc(self): json.dump(self.metadata, metadata_file) self._load_and_proc_metadata() return td_data - def _make_split(self): from torchrl.collectors.utils import split_trajectories From 5e0fca8d8f972882d055a5c529218d901c463834 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Mon, 14 Jul 2025 13:12:13 +0200 Subject: [PATCH 14/24] [Refactor] Fixed linting errors --- test/test_libs.py | 30 ++++++++++++++++++---------- torchrl/data/datasets/minari_data.py | 21 +++++++++++++------ 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index fb49f3778e3..1ba5581ab93 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -29,12 +29,12 @@ from sys import platform from unittest import mock +import minari + import numpy as np import pytest import torch from minari import DataCollector -import gymnasium as gym -import minari from packaging import version from tensordict import ( @@ -3455,9 +3455,10 @@ def fn(data): assert len(dataset) == 100 assert sample["data"].shape == torch.Size([32, 8]) assert sample["next", "data"].shape == torch.Size([32, 8]) - - - @pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari or Gym not available") + + @pytest.mark.skipif( + not _has_minari or not _has_gymnasium, reason="Minari or Gym not available" + ) def test_local_minari_dataset_loading(self): if not _minari_init(): @@ -3466,7 +3467,7 @@ def test_local_minari_dataset_loading(self): dataset_id = "cartpole/test-local-v1" # Create dataset using Gym + DataCollector - env = gym.make("CartPole-v1") + env = gymnasium.make("CartPole-v1") env = DataCollector(env, record_infos=True) for _ in range(50): env.reset(seed=123) @@ -3482,7 +3483,7 @@ def test_local_minari_dataset_loading(self): code_permalink="https://github.com/Farama-Foundation/Minari", author="Farama", author_email="contact@farama.org", - eval_env="CartPole-v1" + eval_env="CartPole-v1", ) # Load from local cache @@ -3499,15 +3500,22 @@ def test_local_minari_dataset_loading(self): t0 = time.time() for i, sample in enumerate(data): t1 = time.time() - torchrl_logger.info(f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms") - assert data.metadata["action_space"].is_in(sample["action"]), "Invalid action sample" - assert data.metadata["observation_space"].is_in(sample["observation"]), "Invalid observation sample" + torchrl_logger.info( + f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms" + ) + assert data.metadata["action_space"].is_in( + sample["action"] + ), "Invalid action sample" + assert data.metadata["observation_space"].is_in( + sample["observation"] + ), "Invalid observation sample" t0 = time.time() if i == 10: break - + minari.delete_dataset(dataset_id="cartpole/test-local-v1") + @pytest.mark.slow class TestRoboset: def test_load(self): diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 166eaefced6..d8c63873497 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -93,7 +93,7 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay): bypassing any remote download. This is useful when working with custom Minari datasets previously generated and stored locally, or when network access should be avoided. If the dataset is not found in the expected - cache directory, a ``FileNotFoundError`` will be raised. + cache directory, a ``FileNotFoundError`` will be raised. Defaults to ``False``. @@ -186,7 +186,11 @@ def __init__( self.download = download self.load_from_local_minari = load_from_local_minari - if self.download == "force" or (self.download and not self._is_downloaded()) or self.load_from_local_minari: + if ( + self.download == "force" + or (self.download and not self._is_downloaded()) + or self.load_from_local_minari + ): if self.download == "force": try: if os.path.exists(self.data_path_root): @@ -264,8 +268,10 @@ def _download_and_preproc(self): h5_path = parent_dir / "main_data.hdf5" if not h5_path.exists(): - raise FileNotFoundError(f"{h5_path} does not exist in local Minari cache!") - + raise FileNotFoundError( + f"{h5_path} does not exist in local Minari cache!" + ) + torchrl_logger.info( f"loading dataset from local Minari cache at {h5_path}" ) @@ -273,10 +279,12 @@ def _download_and_preproc(self): else: minari.download_dataset(dataset_id=self.dataset_id) - + parent_dir = Path(tmpdir) / self.dataset_id / "data" - torchrl_logger.info("first read through data to create data structure...") + torchrl_logger.info( + "first read through data to create data structure..." + ) h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") # populate the tensordict @@ -397,6 +405,7 @@ def _download_and_preproc(self): json.dump(self.metadata, metadata_file) self._load_and_proc_metadata() return td_data + def _make_split(self): from torchrl.collectors.utils import split_trajectories From d7c9aae762d79a7ff9165668a5c5d96cf5735b58 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Mon, 14 Jul 2025 16:19:28 +0200 Subject: [PATCH 15/24] [Refactor] moved minari import to local scope --- test/test_libs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 1ba5581ab93..894ed79ff43 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -29,7 +29,6 @@ from sys import platform from unittest import mock -import minari import numpy as np import pytest @@ -3460,7 +3459,8 @@ def fn(data): not _has_minari or not _has_gymnasium, reason="Minari or Gym not available" ) def test_local_minari_dataset_loading(self): - + import minari + if not _minari_init(): pytest.skip("Failed to initialize Minari datasets") From 52f94bdb47b91017f6ace33c58e83be789b9f9c0 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Tue, 15 Jul 2025 15:44:41 +0200 Subject: [PATCH 16/24] [Refactor] Refactor Minari testing to use custom datasets --- test/test_libs.py | 192 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 173 insertions(+), 19 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 894ed79ff43..532140466f9 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -29,11 +29,9 @@ from sys import platform from unittest import mock - import numpy as np import pytest import torch -from minari import DataCollector from packaging import version from tensordict import ( @@ -3343,6 +3341,39 @@ def test_d4rl_iteration(self, task, split_trajs): _MINARI_DATASETS = [] +MUJOCO_ENVIRONMENTS = [ + "Hopper-v5", + "Pusher-v5", + "Humanoid-v5", + "InvertedDoublePendulum-v5", + "HalfCheetah-v5", + "Swimmer-v5", + "Walker2d-v5", + "Ant-v5", + "Reacher-v5", +] + +D4RL_ENVIRONMENTS = [ + "AntMaze_UMaze-v5", + "AdroitHandPen-v1", + "AntMaze_Medium-v4", + "AntMaze_Large_Diverse_GR-v4", + "AntMaze_Large-v4", + "AntMaze_Medium_Diverse_GR-v4", + "PointMaze_OpenDense-v3", + "PointMaze_UMaze-v3", + "PointMaze_LargeDense-v3", + "PointMaze_Medium-v3", + "PointMaze_UMazeDense-v3", + "PointMaze_MediumDense-v3", + "PointMaze_Large-v3", + "PointMaze_Open-v3", + "FrankaKitchen-v1", + "AdroitHandDoor-v1", + "AdroitHandHammer-v1", + "AdroitHandRelocate-v1", +] + def _minari_init(): """Initialize Minari datasets list. Returns True if already initialized.""" @@ -3375,30 +3406,148 @@ def _minari_init(): return False -# Initialize with placeholder values for parametrization -# These will be replaced with actual dataset names when the first Minari test runs -_MINARI_DATASETS = [str(i) for i in range(20)] +def get_random_minigrid_datasets(): + """ + Fetch 5 random Minigrid datasets from the Minari server. + """ + import minari + + all_minigrid = [ + dataset + for dataset in minari.list_remote_datasets( + latest_version=True, compatible_minari_version=True + ).keys() + if dataset.startswith("minigrid/") + ] + + if len(all_minigrid) < 5: + raise RuntimeError("Not enough minigrid datasets found on Minari server.") + indices = torch.randperm(len(all_minigrid))[:5] + return [all_minigrid[idx] for idx in indices] + + +def get_random_atari_envs(): + """ + Fetch 10 random Atari environments using ale_py and torch. + """ + import ale_py + import gymnasium as gym + + gym.register_envs(ale_py) + + env_specs = gym.envs.registry.values() + all_env_ids = [env_spec.id for env_spec in env_specs] + atari_env_ids = [env_id for env_id in all_env_ids if env_id.startswith("ALE")] + if len(atari_env_ids) < 10: + raise RuntimeError("Not enough Atari environments found.") + indices = torch.randperm(len(atari_env_ids))[:10] + return [atari_env_ids[idx] for idx in indices] + + +def custom_minari_init(custom_envs, num_episodes=5): + """ + Initialize custom Minari datasets for the given environments. + """ + import gymnasium + import gymnasium_robotics + from minari import DataCollector + + gymnasium.register_envs(gymnasium_robotics) + + custom_dataset_ids = [] + for env_id in custom_envs: + dataset_id = f"{env_id.lower()}/test-custom-local-v1" + env = gymnasium.make(env_id) + collector = DataCollector(env) + + for ep in range(num_episodes): + collector.reset(seed=123 + ep) + + while True: + action = collector.action_space.sample() + _, _, terminated, truncated, _ = collector.step(action) + if terminated or truncated: + break + + collector.create_dataset( + dataset_id=dataset_id, + algorithm_name="RandomPolicy", + code_permalink="https://github.com/Farama-Foundation/Minari", + author="Farama", + author_email="contact@farama.org", + eval_env=env_id, + ) + custom_dataset_ids.append(dataset_id) + + return custom_dataset_ids @pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari not found") @pytest.mark.slow class TestMinari: @pytest.mark.parametrize("split", [False, True]) - @pytest.mark.parametrize("dataset_idx", range(20)) + @pytest.mark.parametrize( + "dataset_idx", + range( + len(MUJOCO_ENVIRONMENTS) + + len(D4RL_ENVIRONMENTS) + + len(get_random_minigrid_datasets()) + + len(get_random_atari_envs()) + ), + ) def test_load(self, dataset_idx, split): - # Initialize Minari datasets if not already done - if not _minari_init(): - pytest.skip("Failed to initialize Minari datasets") + """ + Test loading from custom datasets for Mujoco and D4RL, + Minari remote datasets for Minigrid, and random Atari environments. + """ + import minari - # Get the actual dataset name from the initialized list - if dataset_idx >= len(_MINARI_DATASETS): - pytest.skip(f"Dataset index {dataset_idx} out of range") + custom_envs = MUJOCO_ENVIRONMENTS + D4RL_ENVIRONMENTS + num_custom = len(custom_envs) + minigrid_datasets = get_random_minigrid_datasets() + num_minigrid = len(minigrid_datasets) + atari_envs = get_random_atari_envs() + + if dataset_idx < num_custom: + # Custom dataset for Mujoco/D4RL + custom_dataset_ids = custom_minari_init( + [custom_envs[dataset_idx]], num_episodes=5 + ) + dataset_id = custom_dataset_ids[0] + data = MinariExperienceReplay( + dataset_id=dataset_id, + split_trajs=split, + batch_size=32, + load_from_local_minari=True, + ) + cleanup_needed = True + + elif dataset_idx < num_custom + num_minigrid: + # Minigrid datasets from Minari server + minigrid_idx = dataset_idx - num_custom + dataset_id = minigrid_datasets[minigrid_idx] + data = MinariExperienceReplay( + dataset_id=dataset_id, + batch_size=32, + split_trajs=split, + download="force", + ) + cleanup_needed = False + + else: + # Atari environment datasets + atari_idx = dataset_idx - num_custom - num_minigrid + env_id = atari_envs[atari_idx] + custom_dataset_ids = custom_minari_init([env_id], num_episodes=5) + dataset_id = custom_dataset_ids[0] + data = MinariExperienceReplay( + dataset_id=dataset_id, + split_trajs=split, + batch_size=32, + load_from_local_minari=True, + ) + cleanup_needed = True - selected_dataset = _MINARI_DATASETS[dataset_idx] - torchrl_logger.info(f"dataset {selected_dataset}") - data = MinariExperienceReplay( - selected_dataset, batch_size=32, split_trajs=split - ) t0 = time.time() for i, sample in enumerate(data): t1 = time.time() @@ -3409,6 +3558,10 @@ def test_load(self, dataset_idx, split): if i == 10: break + # Clean up custom datasets after running local dataset tests + if cleanup_needed: + minari.delete_dataset(dataset_id=dataset_id) + def test_minari_preproc(self, tmpdir): dataset = MinariExperienceReplay( "D4RL/pointmaze/large-v2", @@ -3459,8 +3612,9 @@ def fn(data): not _has_minari or not _has_gymnasium, reason="Minari or Gym not available" ) def test_local_minari_dataset_loading(self): - import minari - + import minari + from minari import DataCollector + if not _minari_init(): pytest.skip("Failed to initialize Minari datasets") From 4965ba3056fa2cd9c04970b4e76bc56cfc8943f6 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Fri, 18 Jul 2025 15:47:14 +0200 Subject: [PATCH 17/24] [Fix] eliminate global imports for test_libs --- test/test_libs.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 532140466f9..d4ac6bf37e6 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3488,12 +3488,7 @@ class TestMinari: @pytest.mark.parametrize("split", [False, True]) @pytest.mark.parametrize( "dataset_idx", - range( - len(MUJOCO_ENVIRONMENTS) - + len(D4RL_ENVIRONMENTS) - + len(get_random_minigrid_datasets()) - + len(get_random_atari_envs()) - ), + range(100) ) def test_load(self, dataset_idx, split): """ @@ -3504,9 +3499,20 @@ def test_load(self, dataset_idx, split): custom_envs = MUJOCO_ENVIRONMENTS + D4RL_ENVIRONMENTS num_custom = len(custom_envs) - minigrid_datasets = get_random_minigrid_datasets() + try: + minigrid_datasets = get_random_minigrid_datasets() + except Exception: + minigrid_datasets = [] num_minigrid = len(minigrid_datasets) - atari_envs = get_random_atari_envs() + try: + atari_envs = get_random_atari_envs() + except Exception: + atari_envs = [] + num_atari = len(atari_envs) + total_datasets = num_custom + num_minigrid + num_atari + + if dataset_idx >= total_datasets: + pytest.skip("Index out of range for available datasets") if dataset_idx < num_custom: # Custom dataset for Mujoco/D4RL @@ -3669,7 +3675,6 @@ def test_local_minari_dataset_loading(self): minari.delete_dataset(dataset_id="cartpole/test-local-v1") - @pytest.mark.slow class TestRoboset: def test_load(self): From a78f660d68a1e7c73dd42949801dde3f37d6b935 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Fri, 18 Jul 2025 16:00:27 +0200 Subject: [PATCH 18/24] [Refactor] Reduce dataset index range in Minari tests --- test/test_libs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_libs.py b/test/test_libs.py index d4ac6bf37e6..b3609bb7ccd 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3488,7 +3488,8 @@ class TestMinari: @pytest.mark.parametrize("split", [False, True]) @pytest.mark.parametrize( "dataset_idx", - range(100) + # Only use a static upper bound; do not call any function that imports minari globally. + range(50) ) def test_load(self, dataset_idx, split): """ From eb99021d3d52cffb662155ad087596480993869e Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Mon, 21 Jul 2025 23:51:13 +0200 Subject: [PATCH 19/24] [Enhancement] Update minari environment.yml and refactored minari test --- .github/unittest/linux_libs/scripts_minari/environment.yml | 2 ++ pytest.ini | 1 + 2 files changed, 3 insertions(+) diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml index ce6bdba3c24..98ddb6aec79 100644 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -21,3 +21,5 @@ dependencies: - hydra-core - minari[gcs,hdf5,hf] - gymnasium<1.0.0 + - ale-py + - gymnasium-robotics diff --git a/pytest.ini b/pytest.ini index 39fe36617a1..07bde00eb7d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,6 +4,7 @@ addopts = -ra # Make tracebacks shorter --tb=native + --runslow markers = unity_editor testpaths = From 40d54dfe0d7befcb04f509ccde0d1dd458f0df36 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Mon, 21 Jul 2025 23:52:52 +0200 Subject: [PATCH 20/24] [Refactor] Reduced datasets for minari tests --- pytest.ini | 1 - test/test_libs.py | 23 ++++++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/pytest.ini b/pytest.ini index 07bde00eb7d..39fe36617a1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,7 +4,6 @@ addopts = -ra # Make tracebacks shorter --tb=native - --runslow markers = unity_editor testpaths = diff --git a/test/test_libs.py b/test/test_libs.py index b3609bb7ccd..3ab8a467454 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3420,15 +3420,14 @@ def get_random_minigrid_datasets(): if dataset.startswith("minigrid/") ] - if len(all_minigrid) < 5: - raise RuntimeError("Not enough minigrid datasets found on Minari server.") - indices = torch.randperm(len(all_minigrid))[:5] + # 3 random datasets + indices = torch.randperm(len(all_minigrid))[:3] return [all_minigrid[idx] for idx in indices] def get_random_atari_envs(): """ - Fetch 10 random Atari environments using ale_py and torch. + Fetch 3 random Atari environments using ale_py and torch. """ import ale_py import gymnasium as gym @@ -3438,9 +3437,9 @@ def get_random_atari_envs(): env_specs = gym.envs.registry.values() all_env_ids = [env_spec.id for env_spec in env_specs] atari_env_ids = [env_id for env_id in all_env_ids if env_id.startswith("ALE")] - if len(atari_env_ids) < 10: + if len(atari_env_ids) < 3: raise RuntimeError("Not enough Atari environments found.") - indices = torch.randperm(len(atari_env_ids))[:10] + indices = torch.randperm(len(atari_env_ids))[:3] return [atari_env_ids[idx] for idx in indices] @@ -3489,7 +3488,7 @@ class TestMinari: @pytest.mark.parametrize( "dataset_idx", # Only use a static upper bound; do not call any function that imports minari globally. - range(50) + range(7) ) def test_load(self, dataset_idx, split): """ @@ -3498,8 +3497,14 @@ def test_load(self, dataset_idx, split): """ import minari + num_custom_to_select = 4 custom_envs = MUJOCO_ENVIRONMENTS + D4RL_ENVIRONMENTS - num_custom = len(custom_envs) + + # Randomly select a subset of custom environments + indices = torch.randperm(len(custom_envs))[:num_custom_to_select] + custom_envs_subset = [custom_envs[i] for i in indices] + + num_custom = len(custom_envs_subset) try: minigrid_datasets = get_random_minigrid_datasets() except Exception: @@ -3518,7 +3523,7 @@ def test_load(self, dataset_idx, split): if dataset_idx < num_custom: # Custom dataset for Mujoco/D4RL custom_dataset_ids = custom_minari_init( - [custom_envs[dataset_idx]], num_episodes=5 + [custom_envs_subset[dataset_idx]], num_episodes=5 ) dataset_id = custom_dataset_ids[0] data = MinariExperienceReplay( From d81b32f39e5530633880abc0987facd7c0d25c33 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 23 Jul 2025 18:58:09 -0700 Subject: [PATCH 21/24] conda -> uv --- .../linux_libs/scripts_minari/environment.yml | 25 ---------- .../linux_libs/scripts_minari/install.sh | 16 +++---- .../linux_libs/scripts_minari/post_process.sh | 3 +- .../scripts_minari/requirements.txt | 20 ++++++++ .../linux_libs/scripts_minari/run_all.sh | 16 +++++++ .../linux_libs/scripts_minari/run_test.sh | 5 +- .../linux_libs/scripts_minari/setup_env.sh | 47 ++++++++++--------- .github/workflows/test-linux-libs.yml | 5 +- packaging/build_wheels.sh | 2 +- test/test_libs.py | 13 ++--- 10 files changed, 79 insertions(+), 73 deletions(-) delete mode 100644 .github/unittest/linux_libs/scripts_minari/environment.yml create mode 100644 .github/unittest/linux_libs/scripts_minari/requirements.txt create mode 100755 .github/unittest/linux_libs/scripts_minari/run_all.sh diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml deleted file mode 100644 index 98ddb6aec79..00000000000 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ /dev/null @@ -1,25 +0,0 @@ -channels: - - pytorch - - defaults -dependencies: - - pip - - pip: - - hypothesis - - future - - cloudpickle - - pytest - - pytest-cov - - pytest-mock - - pytest-instafail - - pytest-rerunfailures - - pytest-error-for-skips - - pytest-asyncio - - expecttest - - pybind11[global] - - pyyaml - - scipy - - hydra-core - - minari[gcs,hdf5,hf] - - gymnasium<1.0.0 - - ale-py - - gymnasium-robotics diff --git a/.github/unittest/linux_libs/scripts_minari/install.sh b/.github/unittest/linux_libs/scripts_minari/install.sh index 0094f045932..0fc3e9076db 100755 --- a/.github/unittest/linux_libs/scripts_minari/install.sh +++ b/.github/unittest/linux_libs/scripts_minari/install.sh @@ -7,8 +7,7 @@ unset PYTORCH_VERSION set -e -eval "$(./conda/bin/conda shell.bash hook)" -conda activate ./env +# Note: This script is sourced by run_all.sh, so the environment is already active if [ "${CU_VERSION:-}" == cpu ] ; then version="cpu" @@ -22,22 +21,21 @@ else version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" fi - # submodules git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with cu128" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + uv pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U + uv pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch --index-url https://download.pytorch.org/whl/cpu + uv pip install torch --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch --index-url https://download.pytorch.org/whl/cu128 + uv pip install torch --index-url https://download.pytorch.org/whl/cu128 fi else printf "Failed to install pytorch" @@ -46,9 +44,9 @@ fi # install tensordict if [[ "$RELEASE" == 0 ]]; then - pip3 install git+https://github.com/pytorch/tensordict.git + uv pip install git+https://github.com/pytorch/tensordict.git else - pip3 install tensordict + uv pip install tensordict fi # smoke test diff --git a/.github/unittest/linux_libs/scripts_minari/post_process.sh b/.github/unittest/linux_libs/scripts_minari/post_process.sh index e97bf2a7b1b..a9a59242630 100755 --- a/.github/unittest/linux_libs/scripts_minari/post_process.sh +++ b/.github/unittest/linux_libs/scripts_minari/post_process.sh @@ -2,5 +2,4 @@ set -e -eval "$(./conda/bin/conda shell.bash hook)" -conda activate ./env +# Note: This script is sourced by run_all.sh, so the environment is already active diff --git a/.github/unittest/linux_libs/scripts_minari/requirements.txt b/.github/unittest/linux_libs/scripts_minari/requirements.txt new file mode 100644 index 00000000000..8a54f765e12 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/requirements.txt @@ -0,0 +1,20 @@ +hypothesis +future +cloudpickle +pytest +pytest-cov +pytest-mock +pytest-instafail +pytest-rerunfailures +pytest-error-for-skips +pytest-asyncio +expecttest +pybind11[global] +pyyaml +scipy +hydra-core +minari[gcs,hdf5,hf,create] +gymnasium<1.0.0 +ale-py +gymnasium-robotics +mujoco \ No newline at end of file diff --git a/.github/unittest/linux_libs/scripts_minari/run_all.sh b/.github/unittest/linux_libs/scripts_minari/run_all.sh new file mode 100755 index 00000000000..2fb2505f7f8 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_minari/run_all.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +set -euxo pipefail + +# Run all minari test scripts in sequence, sourcing each one to maintain environment state +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +echo "Running minari tests with uv-based setup..." + +# Source each script in sequence to maintain environment state +source "${this_dir}/setup_env.sh" +source "${this_dir}/install.sh" +source "${this_dir}/run_test.sh" +source "${this_dir}/post_process.sh" + +echo "Minari tests completed successfully!" \ No newline at end of file diff --git a/.github/unittest/linux_libs/scripts_minari/run_test.sh b/.github/unittest/linux_libs/scripts_minari/run_test.sh index 6282e7d57eb..e75868fb41e 100755 --- a/.github/unittest/linux_libs/scripts_minari/run_test.sh +++ b/.github/unittest/linux_libs/scripts_minari/run_test.sh @@ -2,8 +2,7 @@ set -e -eval "$(./conda/bin/conda shell.bash hook)" -conda activate ./env +# Note: This script is sourced by run_all.sh, so the environment is already active apt-get update && apt-get remove swig -y && apt-get install -y git gcc patchelf libosmesa6-dev libgl1-mesa-glx libglfw3 swig3.0 cmake ln -s /usr/bin/swig3.0 /usr/bin/swig @@ -18,8 +17,6 @@ root_dir="$(git rev-parse --show-toplevel)" env_dir="${root_dir}/env" lib_dir="${env_dir}/lib" -conda deactivate && conda activate ./env - # this workflow only tests the libs python -c "import minari" diff --git a/.github/unittest/linux_libs/scripts_minari/setup_env.sh b/.github/unittest/linux_libs/scripts_minari/setup_env.sh index a7afd646646..a7133a9fc47 100755 --- a/.github/unittest/linux_libs/scripts_minari/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_minari/setup_env.sh @@ -33,37 +33,40 @@ apt-get upgrade -y libstdc++6 this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" root_dir="$(git rev-parse --show-toplevel)" -conda_dir="${root_dir}/conda" env_dir="${root_dir}/env" cd "${root_dir}" -case "$(uname -s)" in - Darwin*) os=MacOSX;; - *) os=Linux -esac - -# 1. Install conda at ./conda -if [ ! -d "${conda_dir}" ]; then - printf "* Installing conda\n" - wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" - bash ./miniconda.sh -b -f -p "${conda_dir}" +# Install uv if not already installed +if ! command -v uv &> /dev/null; then + printf "* Installing uv\n" + # Try different Python commands + if command -v python3 &> /dev/null; then + python3 -m pip install uv + elif command -v python &> /dev/null; then + python -m pip install uv + else + # Fallback to curl installation + curl -LsSf https://astral.sh/uv/install.sh | sh + export PATH="$HOME/.local/bin:$PATH" + fi fi -eval "$(${conda_dir}/bin/conda shell.bash hook)" -# 2. Create test environment at ./env +# Create virtual environment using uv printf "python: ${PYTHON_VERSION}\n" if [ ! -d "${env_dir}" ]; then - printf "* Creating a test environment\n" - conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" + printf "* Creating a test environment with uv\n" + uv venv "${env_dir}" --python "${PYTHON_VERSION}" fi -conda activate "${env_dir}" -# 3. Install Conda dependencies -printf "* Installing dependencies (except PyTorch)\n" -echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" -cat "${this_dir}/environment.yml" +# Activate the virtual environment +source "${env_dir}/bin/activate" -pip3 install pip --upgrade +# Upgrade pip +uv pip install --upgrade pip -conda env update --file "${this_dir}/environment.yml" --prune +# Install dependencies from requirements.txt (we'll create this) +printf "* Installing dependencies (except PyTorch)\n" +if [ -f "${this_dir}/requirements.txt" ]; then + uv pip install -r "${this_dir}/requirements.txt" +fi diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 4118c3e0d29..9026ba84d34 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -466,10 +466,7 @@ jobs: export BATCHED_PIPE_TIMEOUT=60 export TD_GET_DEFAULTS_TO_NONE=1 - bash .github/unittest/linux_libs/scripts_minari/setup_env.sh - bash .github/unittest/linux_libs/scripts_minari/install.sh - bash .github/unittest/linux_libs/scripts_minari/run_test.sh - bash .github/unittest/linux_libs/scripts_minari/post_process.sh + bash .github/unittest/linux_libs/scripts_minari/run_all.sh unittests-openx: strategy: diff --git a/packaging/build_wheels.sh b/packaging/build_wheels.sh index fd3c228fd5b..0f4ea255452 100755 --- a/packaging/build_wheels.sh +++ b/packaging/build_wheels.sh @@ -7,7 +7,7 @@ script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" export BUILD_TYPE=wheel setup_env setup_wheel_python -pip_install numpy pyyaml future ninja +pip_install numpy pyyaml future ninja "pybind11[global]" pip_install --upgrade setuptools setup_pip_pytorch_version python setup.py clean diff --git a/test/test_libs.py b/test/test_libs.py index 3ab8a467454..a0f626178f1 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3343,13 +3343,13 @@ def test_d4rl_iteration(self, task, split_trajs): MUJOCO_ENVIRONMENTS = [ "Hopper-v5", - "Pusher-v5", + "Pusher-v4", "Humanoid-v5", "InvertedDoublePendulum-v5", "HalfCheetah-v5", "Swimmer-v5", "Walker2d-v5", - "Ant-v5", + "ALE/Ant-v5", "Reacher-v5", ] @@ -3421,7 +3421,7 @@ def get_random_minigrid_datasets(): ] # 3 random datasets - indices = torch.randperm(len(all_minigrid))[:3] + indices = torch.randperm(len(all_minigrid))[:3] return [all_minigrid[idx] for idx in indices] @@ -3488,7 +3488,7 @@ class TestMinari: @pytest.mark.parametrize( "dataset_idx", # Only use a static upper bound; do not call any function that imports minari globally. - range(7) + range(4), ) def test_load(self, dataset_idx, split): """ @@ -3499,11 +3499,11 @@ def test_load(self, dataset_idx, split): num_custom_to_select = 4 custom_envs = MUJOCO_ENVIRONMENTS + D4RL_ENVIRONMENTS - + # Randomly select a subset of custom environments indices = torch.randperm(len(custom_envs))[:num_custom_to_select] custom_envs_subset = [custom_envs[i] for i in indices] - + num_custom = len(custom_envs_subset) try: minigrid_datasets = get_random_minigrid_datasets() @@ -3681,6 +3681,7 @@ def test_local_minari_dataset_loading(self): minari.delete_dataset(dataset_id="cartpole/test-local-v1") + @pytest.mark.slow class TestRoboset: def test_load(self): From 2ddce39449247a8d0f9a8946bae6f1e993c9fe50 Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Thu, 24 Jul 2025 14:19:24 +0200 Subject: [PATCH 22/24] Updated gym version --- .github/unittest/linux_libs/scripts_minari/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml index 5b9e34437ee..39bf8246b21 100644 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -20,7 +20,7 @@ dependencies: - scipy - hydra-core - minari[gcs,hdf5,hf] - - gymnasium<1.0.0 + - gymnasium>=1.2.0 - ale-py - gymnasium-robotics - minari[create] From 51c01674bdf9e18d5a8407397a91c1c3e44583bb Mon Sep 17 00:00:00 2001 From: Ibinarriaga8 <202206789@alu.comillas.edu> Date: Thu, 24 Jul 2025 14:35:16 +0200 Subject: [PATCH 23/24] Updated requirements --- .github/unittest/linux_libs/scripts_minari/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_minari/requirements.txt b/.github/unittest/linux_libs/scripts_minari/requirements.txt index 8a54f765e12..a6e1f45bfec 100644 --- a/.github/unittest/linux_libs/scripts_minari/requirements.txt +++ b/.github/unittest/linux_libs/scripts_minari/requirements.txt @@ -14,7 +14,7 @@ pyyaml scipy hydra-core minari[gcs,hdf5,hf,create] -gymnasium<1.0.0 +gymnasium>=1.2.0 ale-py gymnasium-robotics mujoco \ No newline at end of file From be40b1c08c56c415996c63811b8d2e859a4c1c4b Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 24 Jul 2025 15:45:57 +0100 Subject: [PATCH 24/24] Fixes --- .../linux_libs/scripts_minari/environment.yml | 2 +- .../scripts_minari/requirements.txt | 2 +- .../linux_libs/scripts_minari/run_all.sh | 2 +- test/_utils_internal.py | 8 +- test/test_libs.py | 126 ++++---- torchrl/data/datasets/minari_data.py | 300 ++++++++++-------- version.txt | 2 +- 7 files changed, 241 insertions(+), 201 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml index 39bf8246b21..e1362721b12 100644 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -27,4 +27,4 @@ dependencies: - jax - mujoco - mujoco-py<2.2,>=2.1 - - minigrid \ No newline at end of file + - minigrid diff --git a/.github/unittest/linux_libs/scripts_minari/requirements.txt b/.github/unittest/linux_libs/scripts_minari/requirements.txt index a6e1f45bfec..ae21314f263 100644 --- a/.github/unittest/linux_libs/scripts_minari/requirements.txt +++ b/.github/unittest/linux_libs/scripts_minari/requirements.txt @@ -17,4 +17,4 @@ minari[gcs,hdf5,hf,create] gymnasium>=1.2.0 ale-py gymnasium-robotics -mujoco \ No newline at end of file +mujoco diff --git a/.github/unittest/linux_libs/scripts_minari/run_all.sh b/.github/unittest/linux_libs/scripts_minari/run_all.sh index 2fb2505f7f8..f0921978df8 100755 --- a/.github/unittest/linux_libs/scripts_minari/run_all.sh +++ b/.github/unittest/linux_libs/scripts_minari/run_all.sh @@ -13,4 +13,4 @@ source "${this_dir}/install.sh" source "${this_dir}/run_test.sh" source "${this_dir}/post_process.sh" -echo "Minari tests completed successfully!" \ No newline at end of file +echo "Minari tests completed successfully!" diff --git a/test/_utils_internal.py b/test/_utils_internal.py index d0bc4242040..2f8b67d91a8 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -13,6 +13,7 @@ import unittest import warnings from functools import wraps +from typing import Callable import pytest import torch @@ -214,7 +215,12 @@ def generate_seeds(seed, repeat): # Decorator to retry upon certain Exceptions. -def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False): +def retry( + ExceptionToCheck: type[Exception], + tries: int = 3, + delay: int = 3, + skip_after_retries: bool = False, +) -> Callable[[Callable], Callable]: def deco_retry(f): @wraps(f) def f_retry(*args, **kwargs): diff --git a/test/test_libs.py b/test/test_libs.py index 1cf94e95e3a..054d6ca0240 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3408,16 +3408,16 @@ def test_d4rl_iteration(self, task, split_trajs): ] -def _minari_init(): +def _minari_init() -> tuple[bool, Exception | None]: """Initialize Minari datasets list. Returns True if already initialized.""" global _MINARI_DATASETS if _MINARI_DATASETS and not all( isinstance(x, str) and x.isdigit() for x in _MINARI_DATASETS ): - return True # Already initialized with real dataset names + return True, None # Already initialized with real dataset names if not _has_minari or not _has_gymnasium: - return False + return False, ImportError("Minari or Gymnasium not found") try: import minari @@ -3434,9 +3434,9 @@ def _minari_init(): assert len(keys) > 5, keys _MINARI_DATASETS[:] = keys # Replace the placeholder values - return True - except Exception: - return False + return True, None + except Exception as err: + return False, err def get_random_minigrid_datasets(): @@ -3607,6 +3607,7 @@ def test_load(self, dataset_idx, split): if cleanup_needed: minari.delete_dataset(dataset_id=dataset_id) + @retry(Exception, tries=3, delay=1) def test_minari_preproc(self, tmpdir): dataset = MinariExperienceReplay( "D4RL/pointmaze/large-v2", @@ -3656,63 +3657,70 @@ def fn(data): @pytest.mark.skipif( not _has_minari or not _has_gymnasium, reason="Minari or Gym not available" ) - def test_local_minari_dataset_loading(self): - import minari - from minari import DataCollector - - if not _minari_init(): - pytest.skip("Failed to initialize Minari datasets") - - dataset_id = "cartpole/test-local-v1" - - # Create dataset using Gym + DataCollector - env = gymnasium.make("CartPole-v1") - env = DataCollector(env, record_infos=True) - for _ in range(50): - env.reset(seed=123) - while True: - action = env.action_space.sample() - obs, rew, terminated, truncated, info = env.step(action) - if terminated or truncated: - break - - env.create_dataset( - dataset_id=dataset_id, - algorithm_name="RandomPolicy", - code_permalink="https://github.com/Farama-Foundation/Minari", - author="Farama", - author_email="contact@farama.org", - eval_env="CartPole-v1", - ) - - # Load from local cache - data = MinariExperienceReplay( - dataset_id=dataset_id, - split_trajs=False, - batch_size=32, - download=False, - sampler=SamplerWithoutReplacement(drop_last=True), - prefetch=2, - load_from_local_minari=True, - ) + def test_local_minari_dataset_loading(self, tmpdir): + MINARI_DATASETS_PATH = os.environ.get("MINARI_DATASETS_PATH") + os.environ["MINARI_DATASETS_PATH"] = str(tmpdir) + try: + import minari + from minari import DataCollector + + success, err = _minari_init() + if not success: + pytest.skip(f"Failed to initialize Minari datasets: {err}") + + dataset_id = "cartpole/test-local-v1" + + # Create dataset using Gym + DataCollector + env = gymnasium.make("CartPole-v1") + env = DataCollector(env, record_infos=True) + for _ in range(50): + env.reset(seed=123) + while True: + action = env.action_space.sample() + obs, rew, terminated, truncated, info = env.step(action) + if terminated or truncated: + break + + env.create_dataset( + dataset_id=dataset_id, + algorithm_name="RandomPolicy", + code_permalink="https://github.com/Farama-Foundation/Minari", + author="Farama", + author_email="contact@farama.org", + eval_env="CartPole-v1", + ) - t0 = time.time() - for i, sample in enumerate(data): - t1 = time.time() - torchrl_logger.info( - f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms" + # Load from local cache + data = MinariExperienceReplay( + dataset_id=dataset_id, + split_trajs=False, + batch_size=32, + download=False, + sampler=SamplerWithoutReplacement(drop_last=True), + prefetch=2, + load_from_local_minari=True, ) - assert data.metadata["action_space"].is_in( - sample["action"] - ), "Invalid action sample" - assert data.metadata["observation_space"].is_in( - sample["observation"] - ), "Invalid observation sample" + t0 = time.time() - if i == 10: - break + for i, sample in enumerate(data): + t1 = time.time() + torchrl_logger.info( + f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms" + ) + assert data.metadata["action_space"].is_in( + sample["action"] + ), "Invalid action sample" + assert data.metadata["observation_space"].is_in( + sample["observation"] + ), "Invalid observation sample" + t0 = time.time() + if i == 10: + break - minari.delete_dataset(dataset_id="cartpole/test-local-v1") + minari.delete_dataset(dataset_id="cartpole/test-local-v1") + finally: + if MINARI_DATASETS_PATH: + os.environ["MINARI_DATASETS_PATH"] = MINARI_DATASETS_PATH @pytest.mark.slow diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index d8c63873497..3d0e660d392 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -253,158 +253,184 @@ def _download_and_preproc(self): if _has_tqdm: from tqdm import tqdm - with tempfile.TemporaryDirectory() as tmpdir: - os.environ["MINARI_DATASETS_PATH"] = tmpdir - - total_steps = 0 - td_data = TensorDict() + prev_minari_datasets_path_save = prev_minari_datasets_path = os.environ.get( + "MINARI_DATASETS_PATH" + ) + try: + if prev_minari_datasets_path is None: + prev_minari_datasets_path = os.path.expanduser("~/.minari/datasets") + with tempfile.TemporaryDirectory() as tmpdir: - if self.load_from_local_minari: - # Load minari dataset from user's local Minari cache + total_steps = 0 + td_data = TensorDict() - minari_cache_dir = os.path.expanduser("~/.minari/datasets") - os.environ["MINARI_DATASETS_PATH"] = minari_cache_dir - parent_dir = Path(minari_cache_dir) / self.dataset_id / "data" - h5_path = parent_dir / "main_data.hdf5" + if self.load_from_local_minari: + # Load minari dataset from user's local Minari cache - if not h5_path.exists(): - raise FileNotFoundError( - f"{h5_path} does not exist in local Minari cache!" + parent_dir = ( + Path(prev_minari_datasets_path) / self.dataset_id / "data" ) + h5_path = parent_dir / "main_data.hdf5" - torchrl_logger.info( - f"loading dataset from local Minari cache at {h5_path}" - ) - h5_data = PersistentTensorDict.from_h5(h5_path) + if not h5_path.exists(): + raise FileNotFoundError( + f"{h5_path} does not exist in local Minari cache!" + ) + + torchrl_logger.info( + f"loading dataset from local Minari cache at {h5_path}" + ) + h5_data = PersistentTensorDict.from_h5(h5_path) - else: - minari.download_dataset(dataset_id=self.dataset_id) + else: + # temporarily change the minari cache path + prev_minari_datasets_path_save2 = os.environ.get( + "MINARI_DATASETS_PATH" + ) + os.environ["MINARI_DATASETS_PATH"] = tmpdir + try: + minari.download_dataset(dataset_id=self.dataset_id) + finally: + if prev_minari_datasets_path_save2 is not None: + os.environ[ + "MINARI_DATASETS_PATH" + ] = prev_minari_datasets_path_save2 + + parent_dir = Path(tmpdir) / self.dataset_id / "data" + + torchrl_logger.info( + "first read through data to create data structure..." + ) + h5_data = PersistentTensorDict.from_h5( + parent_dir / "main_data.hdf5" + ) - parent_dir = Path(tmpdir) / self.dataset_id / "data" + # populate the tensordict + episode_dict = {} + for i, (episode_key, episode) in enumerate(h5_data.items()): + episode_num = int(episode_key[len("episode_") :]) + episode_len = episode["actions"].shape[0] + episode_dict[episode_num] = (episode_key, episode_len) + # Get the total number of steps for the dataset + total_steps += episode_len + if i == 0: + td_data.set("episode", 0) + for key, val in episode.items(): + match = _NAME_MATCH[key] + if key in ("observations", "state", "infos"): + if ( + not val.shape + ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: + if val.is_empty(): + continue + val = _patch_info(val) + td_data.set(("next", match), torch.zeros_like(val[0])) + td_data.set(match, torch.zeros_like(val[0])) + if key not in ("terminations", "truncations", "rewards"): + td_data.set(match, torch.zeros_like(val[0])) + else: + td_data.set( + ("next", match), + torch.zeros_like(val[0].unsqueeze(-1)), + ) + # give it the proper size + td_data["next", "done"] = ( + td_data["next", "truncated"] | td_data["next", "terminated"] + ) + if "terminated" in td_data.keys(): + td_data["done"] = td_data["truncated"] | td_data["terminated"] + td_data = td_data.expand(total_steps) + # save to designated location torchrl_logger.info( - "first read through data to create data structure..." + f"creating tensordict data in {self.data_path_root}: " ) - h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") - - # populate the tensordict - episode_dict = {} - for i, (episode_key, episode) in enumerate(h5_data.items()): - episode_num = int(episode_key[len("episode_") :]) - episode_len = episode["actions"].shape[0] - episode_dict[episode_num] = (episode_key, episode_len) - # Get the total number of steps for the dataset - total_steps += episode_len - if i == 0: - td_data.set("episode", 0) - for key, val in episode.items(): - match = _NAME_MATCH[key] - if key in ("observations", "state", "infos"): - if ( - not val.shape - ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: - if val.is_empty(): - continue - val = _patch_info(val) - td_data.set(("next", match), torch.zeros_like(val[0])) - td_data.set(match, torch.zeros_like(val[0])) - if key not in ("terminations", "truncations", "rewards"): - td_data.set(match, torch.zeros_like(val[0])) - else: - td_data.set( - ("next", match), - torch.zeros_like(val[0].unsqueeze(-1)), - ) + td_data = td_data.memmap_like(self.data_path_root) + torchrl_logger.info(f"tensordict structure: {td_data}") - # give it the proper size - td_data["next", "done"] = ( - td_data["next", "truncated"] | td_data["next", "terminated"] - ) - if "terminated" in td_data.keys(): - td_data["done"] = td_data["truncated"] | td_data["terminated"] - td_data = td_data.expand(total_steps) - # save to designated location - torchrl_logger.info(f"creating tensordict data in {self.data_path_root}: ") - td_data = td_data.memmap_like(self.data_path_root) - torchrl_logger.info(f"tensordict structure: {td_data}") - - torchrl_logger.info(f"Reading data from {max(*episode_dict) + 1} episodes") - index = 0 - with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar: - # iterate over episodes and populate the tensordict - for episode_num in sorted(episode_dict): - episode_key, steps = episode_dict[episode_num] - episode = h5_data.get(episode_key) - idx = slice(index, (index + steps)) - data_view = td_data[idx] - data_view.fill_("episode", episode_num) - for key, val in episode.items(): - match = _NAME_MATCH[key] - if key in ( - "observations", - "state", - "infos", - ): - if not val.shape or steps != val.shape[0] - 1: - if val.is_empty(): - continue - val = _patch_info(val) - if steps != val.shape[0] - 1: - raise RuntimeError( - f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." - ) - data_view["next", match].copy_(val[1:]) - data_view[match].copy_(val[:-1]) - elif key not in ("terminations", "truncations", "rewards"): - if steps is None: - steps = val.shape[0] - else: - if steps != val.shape[0]: + torchrl_logger.info( + f"Reading data from {max(*episode_dict) + 1} episodes" + ) + index = 0 + with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar: + # iterate over episodes and populate the tensordict + for episode_num in sorted(episode_dict): + episode_key, steps = episode_dict[episode_num] + episode = h5_data.get(episode_key) + idx = slice(index, (index + steps)) + data_view = td_data[idx] + data_view.fill_("episode", episode_num) + for key, val in episode.items(): + match = _NAME_MATCH[key] + if key in ( + "observations", + "state", + "infos", + ): + if not val.shape or steps != val.shape[0] - 1: + if val.is_empty(): + continue + val = _patch_info(val) + if steps != val.shape[0] - 1: raise RuntimeError( - f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." ) - data_view[match].copy_(val) - else: - if steps is None: - steps = val.shape[0] + data_view["next", match].copy_(val[1:]) + data_view[match].copy_(val[:-1]) + elif key not in ("terminations", "truncations", "rewards"): + if steps is None: + steps = val.shape[0] + else: + if steps != val.shape[0]: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." + ) + data_view[match].copy_(val) else: - if steps != val.shape[0]: - raise RuntimeError( - f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." - ) - data_view[("next", match)].copy_(val.unsqueeze(-1)) - data_view["next", "done"].copy_( - data_view["next", "terminated"] | data_view["next", "truncated"] - ) - if "done" in data_view.keys(): - data_view["done"].copy_( - data_view["terminated"] | data_view["truncated"] + if steps is None: + steps = val.shape[0] + else: + if steps != val.shape[0]: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." + ) + data_view[("next", match)].copy_(val.unsqueeze(-1)) + data_view["next", "done"].copy_( + data_view["next", "terminated"] + | data_view["next", "truncated"] ) - if pbar is not None: - pbar.update(steps) - pbar.set_description( - f"index={index} - episode num {episode_num}" - ) - index += steps - h5_data.close() - # Add a "done" entry - if self.split_trajs: - with td_data.unlock_(): - from torchrl.collectors.utils import split_trajectories - - td_data = split_trajectories(td_data).memmap_(self.data_path) - with open(self.metadata_path, "w") as metadata_file: - dataset = minari.load_dataset(self.dataset_id) - self.metadata = asdict(dataset.spec) - self.metadata["observation_space"] = _spec_to_dict( - self.metadata["observation_space"] - ) - self.metadata["action_space"] = _spec_to_dict( - self.metadata["action_space"] - ) - json.dump(self.metadata, metadata_file) - self._load_and_proc_metadata() - return td_data + if "done" in data_view.keys(): + data_view["done"].copy_( + data_view["terminated"] | data_view["truncated"] + ) + if pbar is not None: + pbar.update(steps) + pbar.set_description( + f"index={index} - episode num {episode_num}" + ) + index += steps + h5_data.close() + # Add a "done" entry + if self.split_trajs: + with td_data.unlock_(): + from torchrl.collectors.utils import split_trajectories + + td_data = split_trajectories(td_data).memmap_(self.data_path) + with open(self.metadata_path, "w") as metadata_file: + dataset = minari.load_dataset(self.dataset_id) + self.metadata = asdict(dataset.spec) + self.metadata["observation_space"] = _spec_to_dict( + self.metadata["observation_space"] + ) + self.metadata["action_space"] = _spec_to_dict( + self.metadata["action_space"] + ) + json.dump(self.metadata, metadata_file) + self._load_and_proc_metadata() + return td_data + finally: + if prev_minari_datasets_path_save is not None: + os.environ["MINARI_DATASETS_PATH"] = prev_minari_datasets_path_save def _make_split(self): from torchrl.collectors.utils import split_trajectories diff --git a/version.txt b/version.txt index ac39a106c48..4ca8929845d 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.9.0 +2025.7.24