diff --git a/test/test_libs.py b/test/test_libs.py index 1a92eb671c3..b3609bb7ccd 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3341,6 +3341,39 @@ def test_d4rl_iteration(self, task, split_trajs): _MINARI_DATASETS = [] +MUJOCO_ENVIRONMENTS = [ + "Hopper-v5", + "Pusher-v5", + "Humanoid-v5", + "InvertedDoublePendulum-v5", + "HalfCheetah-v5", + "Swimmer-v5", + "Walker2d-v5", + "Ant-v5", + "Reacher-v5", +] + +D4RL_ENVIRONMENTS = [ + "AntMaze_UMaze-v5", + "AdroitHandPen-v1", + "AntMaze_Medium-v4", + "AntMaze_Large_Diverse_GR-v4", + "AntMaze_Large-v4", + "AntMaze_Medium_Diverse_GR-v4", + "PointMaze_OpenDense-v3", + "PointMaze_UMaze-v3", + "PointMaze_LargeDense-v3", + "PointMaze_Medium-v3", + "PointMaze_UMazeDense-v3", + "PointMaze_MediumDense-v3", + "PointMaze_Large-v3", + "PointMaze_Open-v3", + "FrankaKitchen-v1", + "AdroitHandDoor-v1", + "AdroitHandHammer-v1", + "AdroitHandRelocate-v1", +] + def _minari_init(): """Initialize Minari datasets list. Returns True if already initialized.""" @@ -3373,30 +3406,155 @@ def _minari_init(): return False -# Initialize with placeholder values for parametrization -# These will be replaced with actual dataset names when the first Minari test runs -_MINARI_DATASETS = [str(i) for i in range(20)] +def get_random_minigrid_datasets(): + """ + Fetch 5 random Minigrid datasets from the Minari server. + """ + import minari + + all_minigrid = [ + dataset + for dataset in minari.list_remote_datasets( + latest_version=True, compatible_minari_version=True + ).keys() + if dataset.startswith("minigrid/") + ] + + if len(all_minigrid) < 5: + raise RuntimeError("Not enough minigrid datasets found on Minari server.") + indices = torch.randperm(len(all_minigrid))[:5] + return [all_minigrid[idx] for idx in indices] + + +def get_random_atari_envs(): + """ + Fetch 10 random Atari environments using ale_py and torch. + """ + import ale_py + import gymnasium as gym + + gym.register_envs(ale_py) + + env_specs = gym.envs.registry.values() + all_env_ids = [env_spec.id for env_spec in env_specs] + atari_env_ids = [env_id for env_id in all_env_ids if env_id.startswith("ALE")] + if len(atari_env_ids) < 10: + raise RuntimeError("Not enough Atari environments found.") + indices = torch.randperm(len(atari_env_ids))[:10] + return [atari_env_ids[idx] for idx in indices] + + +def custom_minari_init(custom_envs, num_episodes=5): + """ + Initialize custom Minari datasets for the given environments. + """ + import gymnasium + import gymnasium_robotics + from minari import DataCollector + + gymnasium.register_envs(gymnasium_robotics) + + custom_dataset_ids = [] + for env_id in custom_envs: + dataset_id = f"{env_id.lower()}/test-custom-local-v1" + env = gymnasium.make(env_id) + collector = DataCollector(env) + + for ep in range(num_episodes): + collector.reset(seed=123 + ep) + + while True: + action = collector.action_space.sample() + _, _, terminated, truncated, _ = collector.step(action) + if terminated or truncated: + break + + collector.create_dataset( + dataset_id=dataset_id, + algorithm_name="RandomPolicy", + code_permalink="https://github.com/Farama-Foundation/Minari", + author="Farama", + author_email="contact@farama.org", + eval_env=env_id, + ) + custom_dataset_ids.append(dataset_id) + + return custom_dataset_ids @pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari not found") @pytest.mark.slow class TestMinari: @pytest.mark.parametrize("split", [False, True]) - @pytest.mark.parametrize("dataset_idx", range(20)) + @pytest.mark.parametrize( + "dataset_idx", + # Only use a static upper bound; do not call any function that imports minari globally. + range(50) + ) def test_load(self, dataset_idx, split): - # Initialize Minari datasets if not already done - if not _minari_init(): - pytest.skip("Failed to initialize Minari datasets") + """ + Test loading from custom datasets for Mujoco and D4RL, + Minari remote datasets for Minigrid, and random Atari environments. + """ + import minari - # Get the actual dataset name from the initialized list - if dataset_idx >= len(_MINARI_DATASETS): - pytest.skip(f"Dataset index {dataset_idx} out of range") + custom_envs = MUJOCO_ENVIRONMENTS + D4RL_ENVIRONMENTS + num_custom = len(custom_envs) + try: + minigrid_datasets = get_random_minigrid_datasets() + except Exception: + minigrid_datasets = [] + num_minigrid = len(minigrid_datasets) + try: + atari_envs = get_random_atari_envs() + except Exception: + atari_envs = [] + num_atari = len(atari_envs) + total_datasets = num_custom + num_minigrid + num_atari + + if dataset_idx >= total_datasets: + pytest.skip("Index out of range for available datasets") + + if dataset_idx < num_custom: + # Custom dataset for Mujoco/D4RL + custom_dataset_ids = custom_minari_init( + [custom_envs[dataset_idx]], num_episodes=5 + ) + dataset_id = custom_dataset_ids[0] + data = MinariExperienceReplay( + dataset_id=dataset_id, + split_trajs=split, + batch_size=32, + load_from_local_minari=True, + ) + cleanup_needed = True + + elif dataset_idx < num_custom + num_minigrid: + # Minigrid datasets from Minari server + minigrid_idx = dataset_idx - num_custom + dataset_id = minigrid_datasets[minigrid_idx] + data = MinariExperienceReplay( + dataset_id=dataset_id, + batch_size=32, + split_trajs=split, + download="force", + ) + cleanup_needed = False + + else: + # Atari environment datasets + atari_idx = dataset_idx - num_custom - num_minigrid + env_id = atari_envs[atari_idx] + custom_dataset_ids = custom_minari_init([env_id], num_episodes=5) + dataset_id = custom_dataset_ids[0] + data = MinariExperienceReplay( + dataset_id=dataset_id, + split_trajs=split, + batch_size=32, + load_from_local_minari=True, + ) + cleanup_needed = True - selected_dataset = _MINARI_DATASETS[dataset_idx] - torchrl_logger.info(f"dataset {selected_dataset}") - data = MinariExperienceReplay( - selected_dataset, batch_size=32, split_trajs=split - ) t0 = time.time() for i, sample in enumerate(data): t1 = time.time() @@ -3407,6 +3565,10 @@ def test_load(self, dataset_idx, split): if i == 10: break + # Clean up custom datasets after running local dataset tests + if cleanup_needed: + minari.delete_dataset(dataset_id=dataset_id) + def test_minari_preproc(self, tmpdir): dataset = MinariExperienceReplay( "D4RL/pointmaze/large-v2", @@ -3453,6 +3615,66 @@ def fn(data): assert sample["data"].shape == torch.Size([32, 8]) assert sample["next", "data"].shape == torch.Size([32, 8]) + @pytest.mark.skipif( + not _has_minari or not _has_gymnasium, reason="Minari or Gym not available" + ) + def test_local_minari_dataset_loading(self): + import minari + from minari import DataCollector + + if not _minari_init(): + pytest.skip("Failed to initialize Minari datasets") + + dataset_id = "cartpole/test-local-v1" + + # Create dataset using Gym + DataCollector + env = gymnasium.make("CartPole-v1") + env = DataCollector(env, record_infos=True) + for _ in range(50): + env.reset(seed=123) + while True: + action = env.action_space.sample() + obs, rew, terminated, truncated, info = env.step(action) + if terminated or truncated: + break + + env.create_dataset( + dataset_id=dataset_id, + algorithm_name="RandomPolicy", + code_permalink="https://github.com/Farama-Foundation/Minari", + author="Farama", + author_email="contact@farama.org", + eval_env="CartPole-v1", + ) + + # Load from local cache + data = MinariExperienceReplay( + dataset_id=dataset_id, + split_trajs=False, + batch_size=32, + download=False, + sampler=SamplerWithoutReplacement(drop_last=True), + prefetch=2, + load_from_local_minari=True, + ) + + t0 = time.time() + for i, sample in enumerate(data): + t1 = time.time() + torchrl_logger.info( + f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms" + ) + assert data.metadata["action_space"].is_in( + sample["action"] + ), "Invalid action sample" + assert data.metadata["observation_space"].is_in( + sample["observation"] + ), "Invalid observation sample" + t0 = time.time() + if i == 10: + break + + minari.delete_dataset(dataset_id="cartpole/test-local-v1") @pytest.mark.slow class TestRoboset: diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 3d0d241bd99..d8c63873497 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -88,6 +88,14 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay): it is assumed that any ``truncated`` or ``terminated`` signal is equivalent to the end of a trajectory. Defaults to ``False``. + load_from_local_minari (bool, optional): if ``True``, the dataset will be loaded directly + from the local Minari cache (typically located at ``~/.minari/datasets``), + bypassing any remote download. This is useful when working with custom + Minari datasets previously generated and stored locally, or when network + access should be avoided. If the dataset is not found in the expected + cache directory, a ``FileNotFoundError`` will be raised. + Defaults to ``False``. + Attributes: available_datasets: a list of accepted entries to be downloaded. @@ -167,6 +175,7 @@ def __init__( prefetch: int | None = None, transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, + load_from_local_minari: bool = False, ): self.dataset_id = dataset_id if root is None: @@ -175,7 +184,13 @@ def __init__( self.root = root self.split_trajs = split_trajs self.download = download - if self.download == "force" or (self.download and not self._is_downloaded()): + self.load_from_local_minari = load_from_local_minari + + if ( + self.download == "force" + or (self.download and not self._is_downloaded()) + or self.load_from_local_minari + ): if self.download == "force": try: if os.path.exists(self.data_path_root): @@ -240,13 +255,38 @@ def _download_and_preproc(self): with tempfile.TemporaryDirectory() as tmpdir: os.environ["MINARI_DATASETS_PATH"] = tmpdir - minari.download_dataset(dataset_id=self.dataset_id) - parent_dir = Path(tmpdir) / self.dataset_id / "data" - td_data = TensorDict() total_steps = 0 - torchrl_logger.info("first read through data to create data structure...") - h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") + td_data = TensorDict() + + if self.load_from_local_minari: + # Load minari dataset from user's local Minari cache + + minari_cache_dir = os.path.expanduser("~/.minari/datasets") + os.environ["MINARI_DATASETS_PATH"] = minari_cache_dir + parent_dir = Path(minari_cache_dir) / self.dataset_id / "data" + h5_path = parent_dir / "main_data.hdf5" + + if not h5_path.exists(): + raise FileNotFoundError( + f"{h5_path} does not exist in local Minari cache!" + ) + + torchrl_logger.info( + f"loading dataset from local Minari cache at {h5_path}" + ) + h5_data = PersistentTensorDict.from_h5(h5_path) + + else: + minari.download_dataset(dataset_id=self.dataset_id) + + parent_dir = Path(tmpdir) / self.dataset_id / "data" + + torchrl_logger.info( + "first read through data to create data structure..." + ) + h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") + # populate the tensordict episode_dict = {} for i, (episode_key, episode) in enumerate(h5_data.items()):