diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 37adaef7979..43d9ad1a525 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: - libcst == 0.4.7 - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 6.0.0 hooks: - id: flake8 args: [--config=setup.cfg] diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index e9d08822239..319eb19d0dc 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -144,6 +144,8 @@ using the following components: :template: rl_template.rst + CompressedListStorage + CompressedListStorageCheckpointer FlatStorageCheckpointer H5StorageCheckpointer ImmutableDatasetWriter @@ -191,6 +193,70 @@ were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/be | :class:`LazyMemmapStorage` | 3.44x | +-------------------------------+-----------+ +Compressed Storage for Memory Efficiency +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For applications where memory usage or memory bandwidth is a primary concern—especially when storing or transferring large sensory observations such as images, audio, or text—the :class:`~torchrl.data.replay_buffers.storages.CompressedListStorage` provides significant memory savings through compression. + +**Key features:** + +- **Memory Efficiency:** Achieves substantial memory savings via compression. +- **Data Integrity:** Maintains full data fidelity through lossless compression. +- **Flexible Compression:** Uses zstd compression by default, with support for custom compression algorithms. +- **TensorDict Support:** Seamlessly integrates with TensorDict structures. +- **Checkpointing:** Fully supports saving and loading compressed data. +- **Batched GPU Compression/Decompression:** Enables efficient replay buffer sampling directly from VRAM. + +The `CompressedListStorage` compresses data when storing and decompresses when retrieving, achieving compression ratios of 95x–122x for Atari images while maintaining full data fidelity. +We see these results in the Atari Learning Environment (ALE) from a rollout in Pong with a random policy for an episode at each compression level: + ++-------------------------------+--------+--------+--------+--------+--------+ +| Compression level of zstd | 1 | 3 | 8 | 12 | 22 | ++===============================+========+========+========+========+========+ +| Compression ratio in ALE Pong | 95x | 99x | 106x | 111x | 122x | ++-------------------------------+--------+--------+--------+--------+--------+ + +Example usage: + + >>> import torch + >>> from torchrl.data import ReplayBuffer, CompressedListStorage + >>> from tensordict import TensorDict + >>> + >>> # Create a compressed storage for image data + >>> storage = CompressedListStorage(max_size=1000, compression_level=3) + >>> rb = ReplayBuffer(storage=storage, batch_size=32) + >>> + >>> # Add image data + >>> images = torch.randn(100, 3, 84, 84) # Atari-like frames + >>> data = TensorDict({"obs": images}, batch_size=[100]) + >>> rb.extend(data) + >>> + >>> # Sample data (automatically decompressed) + >>> sample = rb.sample(32) + >>> print(sample["obs"].shape) # torch.Size([32, 3, 84, 84]) + +The compression level can be adjusted from 1 (fast, less compression) to 22 (slow, more compression), +with level 3 being a good default for most use cases. + +For custom compression algorithms: + + >>> def my_compress(tensor): + ... return tensor.to(torch.uint8) # Simple example + >>> + >>> def my_decompress(compressed_tensor, metadata): + ... return compressed_tensor.to(metadata["dtype"]) + >>> + >>> storage = CompressedListStorage( + ... max_size=1000, + ... compression_fn=my_compress, + ... decompression_fn=my_decompress + ... ) + +.. note:: The CompressedListStorage uses `zstd` for python versions of at least 3.14 and defaults to zlib otherwise. + +.. note:: Batched GPU compression relies on `nvidia.nvcomp`, please see example code + `examples/replay-buffers/compressed_replay_buffer.py `_. + Sharing replay buffers across processes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/replay-buffers/compressed_replay_buffer.py b/examples/replay-buffers/compressed_replay_buffer.py new file mode 100644 index 00000000000..108165c2b51 --- /dev/null +++ b/examples/replay-buffers/compressed_replay_buffer.py @@ -0,0 +1,469 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Demonstrating the use of compressing a rollout of Atari transitions on the GPU and batch decompressing them on the GPU. +This example may be helpful in the multi-environment case, or when multiple agents and environments are vmapped on the GPU. +Additionally, we can batch our decompression on the GPU in one go using the collate function. + +Below are the results of running this example with different compression levels on an Atari rollout of Pong. ++---------------------+--------+--------+--------+--------+--------+ +| Compressor Level | 1 | 3 | 8 | 12 | 23 | ++=====================+========+========+========+========+========+ +| Compression Ratio | 95x | 99x | 106x | 111x | 122x | ++---------------------+--------+--------+--------+--------+--------+ + +""" + +from __future__ import annotations + +import importlib + +import sys + +import time +from typing import Any, NamedTuple + +import gymnasium as gym + +import numpy as np +import torch +from tensordict import TensorDict +from torchrl import torchrl_logger as logger +from torchrl.data import CompressedListStorage, ListStorage, ReplayBuffer + +# check if nvidia.nvcomp is available +has_nvcomp = importlib.util.find_spec("nvidia.nvcomp") is not None +if not has_nvcomp: + raise ImportError( + "Please pip install nvidia-nvcomp to use this example with GPU compression." + ) +else: + import nvidia.nvcomp as nvcomp + + +class AtariTransition(NamedTuple): + observations: np.uint8 + actions: np.uint8 + next_observations: np.uint8 + rewards: np.float32 + terminated: np.bool + truncated: np.bool + info: dict[str, Any] + + +def setup_atari_environment(seed: int = 42) -> gym.Env: + import ale_py + + gym.register_envs(ale_py) + env = gym.make("ALE/Pong-v5", frameskip=1) + env = gym.wrappers.AtariPreprocessing(env, frame_skip=5) + env = gym.wrappers.RecordEpisodeStatistics(env) + env = gym.wrappers.TransformReward(env, np.sign) + env = gym.wrappers.FrameStackObservation(env, 4) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + +def run_rollout_benchmark( + rb: ReplayBuffer, + env: gym.Env, + calculate_compression_ratio_fn: callable, + create_and_add_transition_fn: callable, + compress_obs: callable, + num_transitions: int = 2000, +): + """Run a rollout benchmark collecting transitions and measuring steps per second and compression ratios.""" + compression_ratios = [] + terminated = truncated = True + next_obs = compressed_next_obs = None + + start_time = time.time() + for _ in range(num_transitions): + if terminated or truncated: + obs, _ = env.reset() + compressed_obs = compress_obs(obs) + else: + obs = next_obs + compressed_obs = compressed_next_obs + + # perform some fake inference with the obs + obs = torch.from_numpy(obs).cuda(non_blocking=True) + action = env.action_space.sample() + + next_obs, reward, terminated, truncated, info = env.step(action) + + # Compress next observation + compressed_next_obs = compress_obs(next_obs) + compression_ratios.append( + calculate_compression_ratio_fn(next_obs, compressed_next_obs) + ) + + # Create and add transition + create_and_add_transition_fn( + rb, + compressed_obs, + action, + compressed_next_obs, + reward, + terminated, + truncated, + info, + ) + + rollout_time = time.time() - start_time + return rollout_time, compression_ratios + + +def run_sampling_benchmark(rb: ReplayBuffer, num_samples=100, batch_size=32) -> float: + """Run a sampling replaybuffer benchmark measuring decompression speed.""" + start_time = time.time() + for _ in range(num_samples): + rb.sample(batch_size) + sampling_time = time.time() - start_time + return sampling_time + + +def get_cpu_codec(level=1): + """Returns compression and decompression functions for CPU.""" + if sys.version_info >= (3, 14): + from compression import zstd + + def compress_fn(data): + return zstd.compress(data, level) + + return compress_fn, zstd.decompress + else: + try: + import zstd + + def compress_fn(data): + return zstd.compress(data, level) + + return compress_fn, zstd.decompress + except ImportError: + raise ImportError( + "Please `pip install zstd` to use this example with CPU compression." + ) + + +def get_gpu_codec(level=1): + """Returns compression and decompression functions for GPU using NVIDIA NVCOMP. + + See the python API docs here: https://docs.nvidia.com/cuda/nvcomp/py_api.html + """ + # RAW = Does not add header with nvCOMP metadata, so that the codec can read compressed data from the CPU library + bitstream_kind = nvcomp.BitstreamKind.RAW + # Note: NVCOMP may not support all compression levels the same way as CPU zstd + codec = nvcomp.Codec(algorithm="Zstd", bitstream_kind=bitstream_kind) + + def compressor_fn(data: nvcomp.Array) -> nvcomp.Array: + return codec.encode(data) + + def decompressor_fn(compressed_data: nvcomp.Array) -> nvcomp.Array: + return codec.decode(compressed_data, data_type="|u1") + + return compressor_fn, decompressor_fn + + +def make_batch_decompressing_replay_buffer(decompressor_fn) -> ReplayBuffer: + """ + Creates a ReplayBuffer with batched decompression on the GPU. + """ + storage = ListStorage( + max_size=1000, + device="cuda", + ) + + def collate_compressed_data_and_batch_decompress( + data: list[AtariTransition], + ) -> list[AtariTransition]: + """We collate the compressed data together so that we can decompress it in a single batch operation.""" + transitions = data + + # gather compressed data + compressed_obs = [transition.observations for transition in transitions] + compressed_next_obs = [ + transition.next_observations for transition in transitions + ] + + # optional checks + assert all(isinstance(arr, nvcomp.Array) for arr in compressed_obs) + assert all(isinstance(arr, nvcomp.Array) for arr in compressed_next_obs) + + # batched decompress is faster + decompressed_data = decompressor_fn(compressed_obs + compressed_next_obs) + + # gather decompressed data + decompressed_obses = decompressed_data[: len(compressed_obs)] + decompressed_next_obses = decompressed_data[len(compressed_obs) :] + + # repack data + for i, (transition, obs, next_obs) in enumerate( + zip(transitions, decompressed_obses, decompressed_next_obses) + ): + transitions[i] = transition._replace( + observations=torch.from_dlpack(obs).view(4, 84, 84), + next_observations=torch.from_dlpack(next_obs).view(4, 84, 84), + ) + + return transitions + + return ReplayBuffer( + storage=storage, + batch_size=32, + collate_fn=collate_compressed_data_and_batch_decompress, + ) + + +def cpu_compress_to_gpu_decompress(level=1): + env = setup_atari_environment(seed=0) + + compressor_fn, _ = get_cpu_codec(level) + _, decompressor_fn = get_gpu_codec(level) + + obs, _ = env.reset(seed=0) + compressed_obs = compressor_fn(obs.tobytes()) + decompressed_obs = decompressor_fn(compressed_obs) + pt_obs = torch.from_dlpack(decompressed_obs).clone().view(4, 84, 84) + assert np.allclose(obs, pt_obs.cpu().numpy()) + + rb = make_batch_decompressing_replay_buffer(decompressor_fn) + + def calculate_compression_ratio(obs, compressed_next_obs): + return len(obs.tobytes()) / len(compressor_fn(obs.tobytes())) + + def compress_obs(obs): + return compressor_fn(obs.tobytes()) + + def create_and_add_transition( + rb, + compressed_obs, + action, + compressed_next_obs, + reward, + terminated, + truncated, + info, + ): + # Convert compressed bytes to nvcomp arrays for GPU storage + compressed_obs_data = nvcomp.as_array(compressed_obs).cuda(synchronize=False) + compressed_nv_next_obs = nvcomp.as_array(compressed_next_obs).cuda( + synchronize=False + ) + + transition = AtariTransition( + compressed_obs_data, + action, + compressed_nv_next_obs, + reward, + terminated, + truncated, + info, + ) + rb.add(transition) + + torch.cuda.synchronize() + rollout_time, compression_ratios = run_rollout_benchmark( + rb, + env, + calculate_compression_ratio, + create_and_add_transition, + compress_obs, + 2000, + ) + + torch.cuda.synchronize() + sample_time = run_sampling_benchmark(rb, 100, 32) + + output = [ + "\nListStorage + ReplayBuffer (CPU compress, GPU decompress, storage on GPU) Example:", + f"avg_compression_ratio={np.array(compression_ratios).mean():0.0f}", + f"rollout with zstd, @ transitions/s={2000 / rollout_time:0.0f}", + "batch sampling and decompression with zstd @ transitions/s={:0.0f}".format( + (100 * 32) / sample_time + ), + ] + + logger.info("\n\t".join(output)) + + +def cpu_only(level=1): + env = setup_atari_environment(seed=0) + + compressor_fn, decompressor_fn = get_cpu_codec(level) + + # Test compression/decompression works correctly + obs, _ = env.reset(seed=0) + compressed_obs = compressor_fn(obs.tobytes()) + decompressed_obs = decompressor_fn(compressed_obs) + recovered_obs = np.frombuffer(decompressed_obs, dtype=np.uint8).reshape(obs.shape) + assert np.allclose(obs, recovered_obs) + + def compress_from_torch(data: torch.Tensor) -> bytes: + """ + Convert a tensor to a byte stream for compression. + """ + return compressor_fn(data.cpu().numpy().tobytes()) + + def decompress_from_bytes(data: bytes, metadata: dict) -> torch.Tensor: + """ + Convert a byte stream back to a tensor. + """ + decompressed_data = bytearray(decompressor_fn(data)) + dtype = metadata.get("dtype", torch.float32) + device = metadata.get("device", "cpu") + shape = metadata.get("shape", ()) + + return ( + torch.frombuffer( + decompressed_data, + dtype=dtype, + ) + .view(shape) + .to(device) + ) + + storage = CompressedListStorage( + max_size=1000, + compression_level=level, # Use the passed compression level + device="cpu", + compression_fn=compress_from_torch, + decompression_fn=decompress_from_bytes, + ) + + rb = ReplayBuffer(storage=storage, batch_size=32) + + def calculate_compression_ratio(obs, compressed_obs): + # For cpu_only, the CompressedListStorage handles compression internally + # so we calculate the ratio based on the original observation size vs compressed bytes + original_size = obs.nbytes + compressed_size = len(compressor_fn(obs.tobytes())) + return original_size / compressed_size + + def compress_obs(obs): + return torch.from_numpy(obs).clone() + + def create_and_add_transition( + rb, + compressed_obs, + action, + compressed_next_obs, + reward, + terminated, + truncated, + info, + ): + transition_tuple = AtariTransition( + observations=compressed_obs, + actions=torch.tensor(action), + next_observations=compressed_next_obs, + rewards=torch.tensor(reward, dtype=torch.float32), + terminated=torch.tensor(terminated), + truncated=torch.tensor(truncated), + info=info, + ) + transition = TensorDict.from_namedtuple(transition_tuple, batch_size=[]) + rb.add(transition) + + # Run rollout benchmark + rollout_time, compression_ratios = run_rollout_benchmark( + rb, + env, + calculate_compression_ratio, + create_and_add_transition, + compress_obs, + 2000, + ) + + sample_time = run_sampling_benchmark(rb, 100, 32) + + output = [ + "\nCompressedListStorage + ReplayBuffer (CPU compress, CPU decompress, storage on CPU) Example:", + f"avg_compression_ratio={np.array(compression_ratios).mean():0.0f}", + f"rollout with zstd, @ transitions/s={2000 / rollout_time:0.0f}", + "batch sampling and decompression with zstd @ transitions/s={:0.0f}".format( + (100 * 32) / sample_time + ), + ] + + logger.info("\n\t".join(output)) + + +def gpu_only(level=1): + env = setup_atari_environment(seed=0) + + compressor_fn, decompressor_fn = get_gpu_codec(level) + + obs, _ = env.reset(seed=0) + nv_obs = nvcomp.as_array(obs).cuda(synchronize=False) + compressed_obs = compressor_fn(nv_obs) + decompressed_obs = decompressor_fn(compressed_obs) + pt_obs = torch.from_dlpack(decompressed_obs).clone().view(4, 84, 84) + assert np.allclose(obs, pt_obs.cpu().numpy()) + + rb = make_batch_decompressing_replay_buffer(decompressor_fn) + + # State for tracking GPU observations between transitions + def calculate_compression_ratio(obs, compressed_obs): + nv_obs_temp = nvcomp.as_array(obs).cuda(synchronize=False) + return nv_obs_temp.buffer_size / compressed_obs.buffer_size + + def compress_obs(obs): + nv_obs = nvcomp.as_array(obs).cuda(synchronize=False) + return compressor_fn(nv_obs) + + def create_and_add_transition( + rb, + compressed_obs, + action, + compressed_next_obs, + reward, + terminated, + truncated, + info, + ): + transition = AtariTransition( + compressed_obs, + action, + compressed_next_obs, + reward, + terminated, + truncated, + info, + ) + rb.add(transition) + + torch.cuda.synchronize() + rollout_time, compression_ratios = run_rollout_benchmark( + rb, + env, + calculate_compression_ratio, + create_and_add_transition, + compress_obs, + 2000, + ) + + torch.cuda.synchronize() + sample_time = run_sampling_benchmark(rb, 100, 32) + + output = [ + "\nListStorage + ReplayBuffer (GPU compress, GPU decompress, storage on GPU) Example:", + f"avg_compression_ratio={np.array(compression_ratios).mean():0.0f}", + f"rollout with zstd, @ transitions/s={2000 / rollout_time:0.0f}", + "batch sampling and decompression with zstd @ transitions/s={:0.0f}".format( + (100 * 32) / sample_time + ), + ] + + logger.info("\n\t".join(output)) + + +if __name__ == "__main__": + for level in [1, 3, 8, 12, 22]: + print(f"Running with compression level {level}...") + cpu_only(level) + gpu_only(level) + cpu_compress_to_gpu_decompress(level) diff --git a/examples/replay-buffers/compressed_replay_buffer_checkpoint.py b/examples/replay-buffers/compressed_replay_buffer_checkpoint.py new file mode 100644 index 00000000000..9242c2ad4cd --- /dev/null +++ b/examples/replay-buffers/compressed_replay_buffer_checkpoint.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +""" +Example demonstrating the improved CompressedListStorage with memmap functionality. + +This example shows how to use the new checkpointing capabilities that leverage +memory-mapped storage for efficient disk I/O and memory usage. +""" + +import tempfile +from pathlib import Path + +import torch +from tensordict import TensorDict + +from torchrl.data import CompressedListStorage, ReplayBuffer + + +def main(): + """Demonstrate compressed storage with memmap checkpointing.""" + + # Create a compressed storage with high compression level + storage = CompressedListStorage(max_size=1000, compression_level=6) + + # Create some sample data with different shapes and types + print("Creating sample data...") + + # Add tensor data + tensor_data = torch.randn(10, 3, 84, 84, dtype=torch.float32) # Image-like data + storage.set(0, tensor_data) + + # Add TensorDict data with mixed content + td_data = TensorDict( + { + "obs": torch.randn(5, 4, 84, 84, dtype=torch.float32), + "action": torch.randint(0, 18, (5,), dtype=torch.long), + "reward": torch.randn(5, dtype=torch.float32), + "done": torch.zeros(5, dtype=torch.bool), + "meta": "some metadata string", + }, + batch_size=[5], + ) + storage.set(1, td_data) + + # Add another tensor with different shape + tensor_data2 = torch.randn(8, 64, dtype=torch.float32) + storage.set(2, tensor_data2) + + print(f"Storage length: {len(storage)}") + print(f"Storage contains index 0: {storage.contains(0)}") + print(f"Storage contains index 3: {storage.contains(3)}") + + # Demonstrate data retrieval + print("\nRetrieving data...") + retrieved_tensor = storage.get(0) + retrieved_td = storage.get(1) + retrieved_tensor2 = storage.get(2) + + print(f"Retrieved tensor shape: {retrieved_tensor.shape}") + print(f"Retrieved TensorDict keys: {list(retrieved_td.keys())}") + print(f"Retrieved tensor2 shape: {retrieved_tensor2.shape}") + + # Verify data integrity + assert torch.allclose(tensor_data, retrieved_tensor, atol=1e-6) + assert torch.allclose(td_data["obs"], retrieved_td["obs"], atol=1e-6) + assert torch.allclose(tensor_data2, retrieved_tensor2, atol=1e-6) + print("✓ Data integrity verified!") + + # Demonstrate memmap checkpointing + print("\nDemonstrating memmap checkpointing...") + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "compressed_storage_checkpoint" + + # Save to disk using memmap + print(f"Saving to {checkpoint_path}...") + storage.dumps(checkpoint_path) + + # Check what files were created + print("Files created:") + for file_path in checkpoint_path.rglob("*"): + if file_path.is_file(): + size_mb = file_path.stat().st_size / (1024 * 1024) + print(f" {file_path.relative_to(checkpoint_path)}: {size_mb:.2f} MB") + + # Create new storage and load from checkpoint + print("\nLoading from checkpoint...") + new_storage = CompressedListStorage(max_size=1000, compression_level=6) + new_storage.loads(checkpoint_path) + + # Verify data integrity after checkpointing + print("Verifying data integrity after checkpointing...") + new_retrieved_tensor = new_storage.get(0) + new_retrieved_td = new_storage.get(1) + new_retrieved_tensor2 = new_storage.get(2) + + assert torch.allclose(tensor_data, new_retrieved_tensor, atol=1e-6) + assert torch.allclose(td_data["obs"], new_retrieved_td["obs"], atol=1e-6) + assert torch.allclose(tensor_data2, new_retrieved_tensor2, atol=1e-6) + print("✓ Data integrity after checkpointing verified!") + + print(f"New storage length: {len(new_storage)}") + + # Demonstrate with ReplayBuffer + print("\nDemonstrating with ReplayBuffer...") + + rb = ReplayBuffer(storage=CompressedListStorage(max_size=100, compression_level=4)) + + # Add some data to the replay buffer + for _ in range(5): + data = TensorDict( + { + "obs": torch.randn(3, 84, 84, dtype=torch.float32), + "action": torch.randint(0, 18, (3,), dtype=torch.long), + "reward": torch.randn(3, dtype=torch.float32), + }, + batch_size=[3], + ) + rb.extend(data) + + print(f"ReplayBuffer size: {len(rb)}") + + # Sample from the buffer + sample = rb.sample(2) + print(f"Sampled data keys: {list(sample.keys())}") + print(f"Sampled obs shape: {sample['obs'].shape}") + + # Checkpoint the replay buffer + with tempfile.TemporaryDirectory() as tmpdir: + rb_checkpoint_path = Path(tmpdir) / "rb_checkpoint" + print(f"\nCheckpointing ReplayBuffer to {rb_checkpoint_path}...") + rb.dumps(rb_checkpoint_path) + + # Create new replay buffer and load + new_rb = ReplayBuffer( + storage=CompressedListStorage(max_size=100, compression_level=4) + ) + new_rb.loads(rb_checkpoint_path) + + print(f"New ReplayBuffer size: {len(new_rb)}") + + # Verify sampling works + new_sample = new_rb.sample(2) + assert new_sample["obs"].shape == sample["obs"].shape + print("✓ ReplayBuffer checkpointing verified!") + + print("\n🎉 All demonstrations completed successfully!") + print("\nKey benefits of the new memmap implementation:") + print("1. Efficient disk I/O using memory-mapped storage") + print("2. Reduced memory usage for large datasets") + print("3. Fast loading and saving of compressed data") + print("4. Support for heterogeneous data structures") + print("5. Seamless integration with ReplayBuffer") + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg index 0649a97497f..98d45350291 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ max-line-length = 120 [flake8] # note: we ignore all 501s (line too long) anyway as they're taken care of by black max-line-length = 79 -ignore = E203, E402, W503, W504, E501 +ignore = E203, E402, W503, W504, E501, E231 per-file-ignores = __init__.py: F401, F403, F405 ./hubconf.py: F401 diff --git a/test/test_rb.py b/test/test_rb.py index ca2bb121d65..3c58d3e4a0e 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -11,10 +11,13 @@ import os import pickle import sys +import tempfile from functools import partial +from pathlib import Path from unittest import mock import numpy as np + import pytest import torch from packaging import version @@ -35,6 +38,7 @@ from torchrl.collectors import RandomPolicy, SyncDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( + CompressedListStorage, FlatStorageCheckpointer, MultiStep, NestedStorageCheckpointer, @@ -184,7 +188,6 @@ @pytest.mark.parametrize("size", [3, 5, 100]) class TestComposableBuffers: def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False): - if storage is not None: storage = storage(size, compilable=compilable) @@ -327,10 +330,14 @@ def test_cursor_position(self, rb_type, sampler, writer, storage, size, datatype writer.extend(batch1) return - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): writer.extend(batch1) # Added less data than storage max size @@ -378,10 +385,14 @@ def test_extend(self, rb_type, sampler, writer, storage, size, datatype): length = min(rb._storage.max_size, len(rb) + data_shape) if writer is TensorDictMaxValueWriter: data["next", "reward"][-length:] = 1_000_000 - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) length = len(rb) if is_tensor_collection(data): @@ -419,10 +430,14 @@ def data_iter(): and size < len(data2) and isinstance(rb._storage, TensorStorage) ) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data2) @pytest.mark.skipif( @@ -525,10 +540,14 @@ def test_sample(self, rb_type, sampler, writer, storage, size, datatype): ): rb.extend(data) return - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) rb_sample = rb.sample() # if not isinstance(new_data, (torch.Tensor, TensorDictBase)): @@ -545,7 +564,6 @@ def data_iter_func(maxval, data=data): rb_sample_iter = data_iter_func(rb._batch_size, rb_sample) for single_sample in rb_sample_iter: - if is_tensor_collection(data) or isinstance(data, torch.Tensor): data_iter = data else: @@ -595,10 +613,14 @@ def test_index(self, rb_type, sampler, writer, storage, size, datatype): ): rb.extend(data) return - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) d1 = rb[2] d2 = rb._storage[2] @@ -615,7 +637,6 @@ def test_index(self, rb_type, sampler, writer, storage, size, datatype): assert b def test_pickable(self, rb_type, sampler, writer, storage, size, datatype): - rb = self._get_rb( rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) @@ -874,7 +895,6 @@ def extend_and_sample(data): @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) def test_extend_lazystack(self, storage_type): - rb = ReplayBuffer( storage=storage_type(6), batch_size=2, @@ -1501,7 +1521,6 @@ def test_rng_dumps(self, tmpdir): @pytest.mark.parametrize("size", [3, 5, 100]) @pytest.mark.parametrize("prefetch", [0]) class TestBuffers: - default_constr = { ReplayBuffer: ReplayBuffer, PrioritizedReplayBuffer: functools.partial( @@ -1568,10 +1587,14 @@ def test_cursor_position2(self, rbtype, storage, size, prefetch): cond = ( OLD_TORCH and size < len(batch1) and isinstance(rb._storage, TensorStorage) ) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(batch1) # Added fewer data than storage max size @@ -1626,10 +1649,14 @@ def test_extend(self, rbtype, storage, size, prefetch): rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) data = self._get_data(rbtype, size=5) cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) length = len(rb) for d in data[-length:]: @@ -1653,10 +1680,14 @@ def test_sample(self, rbtype, storage, size, prefetch): rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) data = self._get_data(rbtype, size=5) cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) new_data = rb.sample() if not isinstance(new_data, (torch.Tensor, TensorDictBase)): @@ -1683,10 +1714,14 @@ def test_index(self, rbtype, storage, size, prefetch): rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) data = self._get_data(rbtype, size=5) cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) d1 = rb[2] d2 = rb._storage[2] @@ -2563,7 +2598,7 @@ def test_slice_sampler( break else: raise AssertionError( - f"Not all items can be sampled: {set(range(100))-count_unique} are missing" + f"Not all items can be sampled: {set(range(100)) - count_unique} are missing" ) if strict_length: @@ -2770,7 +2805,6 @@ def test_slice_sampler_left_right_ndim(self): assert curr_eps.unique().numel() == 1 def test_slice_sampler_strictlength(self): - torch.manual_seed(0) data = TensorDict( @@ -3430,7 +3464,6 @@ def _robust_stack(tensor_list): def test_rb( self, storage_type, sampler_type, data_type, p, num_buffer_sampled, batch_size ): - storages = [self._make_storage(storage_type, data_type) for _ in range(3)] collate_fn = self._make_collate(storage_type) data = [self._make_data(data_type) for _ in range(3)] @@ -4027,6 +4060,430 @@ def test_ray_rb_iter(self): rb.close() +class TestCompressedListStorage: + """Test cases for CompressedListStorage.""" + + def test_compressed_storage_initialization(self): + """Test that CompressedListStorage initializes correctly.""" + storage = CompressedListStorage(max_size=100, compression_level=3) + assert storage.max_size == 100 + assert storage.compression_level == 3 + assert len(storage) == 0 + + @pytest.mark.parametrize( + "test_tensor", + [ + torch.rand(1), # 0D scalar + torch.randn(84, dtype=torch.float32), # 1D tensor + torch.randn(84, 84, dtype=torch.float32), # 2D tensor + torch.randn(1, 84, 84, dtype=torch.float32), # 3D tensor + torch.randn(32, 84, 84, dtype=torch.float32), # 3D tensor + ], + ) + def test_compressed_storage_tensor(self, test_tensor): + """Test compression and decompression of tensor data of various shapes.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Store tensor + storage.set(0, test_tensor) + + # Retrieve tensor + retrieved_tensor = storage.get(0) + + # Verify data integrity + assert ( + test_tensor.shape == retrieved_tensor.shape + ), f"Expected shape {test_tensor.shape}, got {retrieved_tensor.shape}" + assert ( + test_tensor.dtype == retrieved_tensor.dtype + ), f"Expected dtype {test_tensor.dtype}, got {retrieved_tensor.dtype}" + assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) + + def test_compressed_storage_tensordict(self): + """Test compression and decompression of TensorDict data.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Create test TensorDict + test_td = TensorDict( + { + "obs": torch.randn(3, 84, 84, dtype=torch.float32), + "action": torch.tensor([1, 2, 3]), + "reward": torch.randn(3), + "done": torch.tensor([False, True, False]), + }, + batch_size=[3], + ) + + # Store TensorDict + storage.set(0, test_td) + + # Retrieve TensorDict + retrieved_td = storage.get(0) + + # Verify data integrity + assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6) + assert torch.allclose(test_td["action"], retrieved_td["action"]) + assert torch.allclose(test_td["reward"], retrieved_td["reward"], atol=1e-6) + assert torch.allclose(test_td["done"], retrieved_td["done"]) + + def test_compressed_storage_multiple_indices(self): + """Test storing and retrieving multiple items.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Store multiple tensors + tensors = [ + torch.randn(2, 2, dtype=torch.float32), + torch.randn(3, 3, dtype=torch.float32), + torch.randn(4, 4, dtype=torch.float32), + ] + + for i, tensor in enumerate(tensors): + storage.set(i, tensor) + + # Retrieve multiple tensors + retrieved = storage.get([0, 1, 2]) + + # Verify data integrity + for original, retrieved_tensor in zip(tensors, retrieved): + assert torch.allclose(original, retrieved_tensor, atol=1e-6) + + def test_compressed_storage_with_replay_buffer(self): + """Test CompressedListStorage with ReplayBuffer.""" + storage = CompressedListStorage(max_size=100, compression_level=3) + rb = ReplayBuffer(storage=storage, batch_size=5) + + # Create test data + data = TensorDict( + { + "obs": torch.randn(10, 3, 84, 84, dtype=torch.float32), + "action": torch.randint(0, 4, (10,)), + "reward": torch.randn(10), + }, + batch_size=[10], + ) + + # Add data to replay buffer + rb.extend(data) + + # Sample from replay buffer + sample = rb.sample(5) + + # Verify sample has correct shape + assert is_tensor_collection(sample), sample + assert sample["obs"].shape[0] == 5 + assert sample["obs"].shape[1:] == (3, 84, 84) + assert sample["action"].shape[0] == 5 + assert sample["reward"].shape[0] == 5 + + def test_compressed_storage_state_dict(self): + """Test saving and loading state dict.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Add some data + test_tensor = torch.randn(3, 3, dtype=torch.float32) + storage.set(0, test_tensor) + + # Save state dict + state_dict = storage.state_dict() + + # Create new storage and load state dict + new_storage = CompressedListStorage(max_size=10, compression_level=3) + new_storage.load_state_dict(state_dict) + + # Verify data integrity + retrieved_tensor = new_storage.get(0) + assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) + + def test_compressed_storage_checkpointing(self): + """Test checkpointing functionality.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Add some data + test_td = TensorDict( + { + "obs": torch.randn(3, 84, 84, dtype=torch.float32), + "action": torch.tensor([1, 2, 3]), + }, + batch_size=[3], + ) + storage.set(0, test_td) + + # second batch, different shape + test_td2 = TensorDict( + { + "obs": torch.randn(3, 85, 83, dtype=torch.float32), + "action": torch.tensor([1, 2, 3]), + "meta": torch.randn(3), + "astring": "a string!", + }, + batch_size=[3], + ) + storage.set(1, test_td) + + # Create temporary directory for checkpointing + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "checkpoint" + + # Save checkpoint + storage.dumps(checkpoint_path) + + # Create new storage and load checkpoint + new_storage = CompressedListStorage(max_size=10, compression_level=3) + new_storage.loads(checkpoint_path) + + # Verify data integrity + retrieved_td = new_storage.get(0) + assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6) + assert torch.allclose(test_td["action"], retrieved_td["action"]) + + def test_compressed_storage_length(self): + """Test that length is calculated correctly.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Initially empty + assert len(storage) == 0 + + # Add some data + storage.set(0, torch.randn(2, 2)) + assert len(storage) == 1 + + storage.set(1, torch.randn(2, 2)) + assert len(storage) == 2 + + storage.set(2, torch.randn(2, 2)) + assert len(storage) == 3 + + def test_compressed_storage_contains(self): + """Test the contains method.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Initially empty + assert not storage.contains(0) + + # Add data + storage.set(0, torch.randn(2, 2)) + assert storage.contains(0) + assert not storage.contains(1) + + def test_compressed_storage_empty(self): + """Test emptying the storage.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Add some data + storage.set(0, torch.randn(2, 2)) + storage.set(1, torch.randn(2, 2)) + assert len(storage) == 2 + + # Empty storage + storage._empty() + assert len(storage) == 0 + + def test_compressed_storage_custom_compression(self): + """Test custom compression functions.""" + + def custom_compress(tensor): + # Simple compression: just convert to uint8 + return tensor.to(torch.uint8) + + def custom_decompress(compressed_tensor, metadata): + # Simple decompression: convert back to original dtype + return compressed_tensor.to(metadata["dtype"]) + + storage = CompressedListStorage( + max_size=10, + compression_fn=custom_compress, + decompression_fn=custom_decompress, + ) + + # Test with tensor + test_tensor = torch.randn(2, 2, dtype=torch.float32) + storage.set(0, test_tensor) + retrieved_tensor = storage.get(0) + + # Note: This will lose precision due to uint8 conversion + # but should still work + assert retrieved_tensor.shape == test_tensor.shape + + def test_compressed_storage_error_handling(self): + """Test error handling for invalid operations.""" + storage = CompressedListStorage(max_size=5, compression_level=3) + + # Test setting data beyond max_size + with pytest.raises(RuntimeError): + storage.set(10, torch.randn(2, 2)) + + # Test getting non-existent data + with pytest.raises(IndexError): + storage.get(0) + + def test_compressed_storage_memory_efficiency(self): + """Test that compression actually reduces memory usage.""" + storage = CompressedListStorage(max_size=100, compression_level=3) + + # Create large tensor data + large_tensor = torch.zeros(100, 3, 84, 84, dtype=torch.int64) + large_tensor.copy_( + torch.arange(large_tensor.numel(), dtype=torch.int32).view_as(large_tensor) + // (3 * 84 * 84) + ) + original_size = large_tensor.numel() * large_tensor.element_size() + + # Store in compressed storage + storage.set(0, large_tensor) + + # Estimate compressed size + compressed_data = storage._storage[0] + compressed_size = compressed_data.numel() # uint8 bytes + + # Verify compression ratio is reasonable (at least 2x for random data) + compression_ratio = original_size / compressed_size + assert ( + compression_ratio > 1.5 + ), f"Compression ratio {compression_ratio} is too low" + + @staticmethod + def make_compressible_mock_data(num_experiences: int, device=None) -> dict: + """Easily compressible data for testing.""" + if device is None: + device = torch.device("cpu") + + return { + "observations": torch.zeros( + (num_experiences, 4, 84, 84), + dtype=torch.uint8, + device=device, + ), + "actions": torch.zeros((num_experiences,), device=device), + "rewards": torch.zeros((num_experiences,), device=device), + "next_observations": torch.zeros( + (num_experiences, 4, 84, 84), + dtype=torch.uint8, + device=device, + ), + "terminations": torch.zeros( + (num_experiences,), dtype=torch.bool, device=device + ), + "truncations": torch.zeros( + (num_experiences,), dtype=torch.bool, device=device + ), + "batch_size": [num_experiences], + } + + @staticmethod + def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict: + """Uncompressible data for testing.""" + if device is None: + device = torch.device("cpu") + return { + "observations": torch.randn( + (num_experiences, 4, 84, 84), + dtype=torch.float32, + device=device, + ), + "actions": torch.randint(0, 10, (num_experiences,), device=device), + "rewards": torch.randn( + (num_experiences,), dtype=torch.float32, device=device + ), + "next_observations": torch.randn( + (num_experiences, 4, 84, 84), + dtype=torch.float32, + device=device, + ), + "terminations": torch.rand((num_experiences,), device=device) + < 0.2, # ~20% True + "truncations": torch.rand((num_experiences,), device=device) + < 0.1, # ~10% True + "batch_size": [num_experiences], + } + + @pytest.mark.benchmark( + group="tensor_serialization_speed", + min_time=0.1, + max_time=0.5, + min_rounds=5, + disable_gc=True, + warmup=False, + ) + @pytest.mark.parametrize( + "serialization_method", + ["pickle", "torch.save", "untyped_storage", "numpy"], # "safetensors", "nvcomp" + ) + def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str): + """Benchmark the speed of different tensor serialization methods. + + TODO: we might need to also test which methods work on the gpu. + pytest test/test_rb.py::TestCompressedListStorage::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops' + + ------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests ------------------------- + Name (time in us) Mean (smaller is better) OPS (bigger is better) + -------------------------------------------------------------------------------------------------- + test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0) + test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16) + test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12) + test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07) + test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00) + -------------------------------------------------------------------------------------------------- + """ + import io + import pickle + + import torch + from safetensors.torch import save + + def serialize_with_pickle(data: torch.Tensor) -> bytes: + """Serialize tensor using pickle.""" + buffer = io.BytesIO() + pickle.dump(data, buffer) + return buffer.getvalue() + + def serialize_with_untyped_storage(data: torch.Tensor) -> bytes: + """Serialize tensor using torch's built-in method.""" + return bytes(data.untyped_storage()) + + def serialize_with_numpy(data: torch.Tensor) -> bytes: + """Serialize tensor using numpy.""" + return data.numpy().tobytes() + + def serialize_with_safetensors(data: torch.Tensor) -> bytes: + return save({"0": data}) + + def serialize_with_torch(data: torch.Tensor) -> bytes: + """Serialize tensor using torch's built-in method.""" + buffer = io.BytesIO() + torch.save(data, buffer) + return buffer.getvalue() + + def serialize_with_nvcomp(data: torch.Tensor) -> bytes: + """Nvcomp just wants the data in a NVcomp array format, no bytestream needed. + pip install nvidia-nvcomp-cu12 + + """ + import nvidia.nvcomp as nvcomp + + return nvcomp.from_dlpack(torch.utils.dlpack.to_dlpack(data.cuda())) + + # Benchmark each serialization method + if serialization_method == "pickle": + serialize_fn = serialize_with_pickle + elif serialization_method == "torch.save": + serialize_fn = serialize_with_torch + elif serialization_method == "untyped_storage": + serialize_fn = serialize_with_untyped_storage + elif serialization_method == "numpy": + serialize_fn = serialize_with_numpy + elif serialization_method == "safetensors": + serialize_fn = serialize_with_safetensors + elif serialization_method == "nvcomp": + serialize_fn = serialize_with_nvcomp + else: + raise ValueError(f"Unknown serialization method: {serialization_method}") + + data = self.make_compressible_mock_data(1).get("observations") + + # Run the actual benchmark + benchmark(serialize_fn, data) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index b3fafd16ee1..226ec4d5bb9 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -32,6 +32,8 @@ ) from .postprocs import DensifyReward, MultiStep from .replay_buffers import ( + CompressedListStorage, + CompressedListStorageCheckpointer, Flat2TED, FlatStorageCheckpointer, H5Combine, @@ -116,21 +118,22 @@ "BoundedTensorSpec", "Categorical", "Choice", - "ContentBase", - "TopKRewardSelector", "Composite", "CompositeSpec", + "CompressedListStorage", + "CompressedListStorageCheckpointer", "ConstantKLController", + "ContentBase", "DEVICE_TYPING", "DensifyReward", "DiscreteTensorSpec", "Flat2TED", "FlatStorageCheckpointer", - "History", "H5Combine", "H5Split", "H5StorageCheckpointer", "HashToInt", + "History", "ImmutableDatasetWriter", "LazyMemmapStorage", "LazyStackStorage", @@ -191,6 +194,7 @@ "TensorStorage", "TensorStorageCheckpointer", "TokenizedDatasetLoader", + "TopKRewardSelector", "Tree", "Unbounded", "UnboundedContinuous", diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 6e7bff8eac0..540d7c129be 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from .checkpointers import ( + CompressedListStorageCheckpointer, FlatStorageCheckpointer, H5StorageCheckpointer, ListStorageCheckpointer, @@ -32,6 +33,7 @@ SliceSamplerWithoutReplacement, ) from .storages import ( + CompressedListStorage, LazyMemmapStorage, LazyStackStorage, LazyTensorStorage, @@ -51,6 +53,8 @@ ) __all__ = [ + "CompressedListStorage", + "CompressedListStorageCheckpointer", "FlatStorageCheckpointer", "H5StorageCheckpointer", "ListStorageCheckpointer", diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 6328857292c..267ab362e79 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -6,6 +6,7 @@ import abc import json +import tempfile import warnings from pathlib import Path @@ -13,11 +14,14 @@ import torch from tensordict import ( is_tensor_collection, + lazy_stack, NonTensorData, PersistentTensorDict, TensorDict, ) from tensordict.memmap import MemoryMappedTensor +from tensordict.utils import _zip_strict +from torch.utils._pytree import tree_map from torchrl._utils import _STRDTYPE2DTYPE from torchrl.data.replay_buffers.utils import ( @@ -68,6 +72,211 @@ def loads(storage, path): ) +class CompressedListStorageCheckpointer(StorageCheckpointerBase): + """A storage checkpointer for CompressedListStorage. + + This checkpointer saves compressed data and metadata using memory-mapped storage + for efficient disk I/O and memory usage. + + """ + + def dumps(self, storage, path): + """Save compressed storage to disk using memory-mapped storage. + + Args: + storage: The CompressedListStorage instance to save + path: Directory path where to save the storage + """ + path = Path(path) + path.mkdir(exist_ok=True) + + if not hasattr(storage, "_storage") or len(storage._storage) == 0: + raise RuntimeError( + "Cannot save an empty or non-initialized CompressedListStorage." + ) + + # Get state dict from storage + state_dict = storage.state_dict() + compressed_data = state_dict["_storage"] + metadata = state_dict["_metadata"] + + # Create a temporary directory for processing + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Process compressed data for memmap storage + processed_data = [] + for item in compressed_data: + if item is None: + processed_data.append(None) + continue + + if isinstance(item, torch.Tensor): + # For tensor data, create a TensorDict with the tensor + processed_item = TensorDict({"data": item}, batch_size=[]) + elif isinstance(item, dict): + # For dict data (tensordict fields), convert to TensorDict + processed_item = TensorDict(item, batch_size=[]) + else: + # For other types, wrap in TensorDict + processed_item = TensorDict({"data": item}, batch_size=[]) + + processed_data.append(processed_item) + + # Stack all non-None items into a single TensorDict for memmap + non_none_data = [item for item in processed_data if item is not None] + if non_none_data: + # Use lazy_stack to handle heterogeneous structures + stacked_data = lazy_stack(non_none_data) + + # Save to memmap + stacked_data.memmap_(tmp_path / "compressed_data") + + # Create index mapping for None values + data_indices = [] + current_idx = 0 + for item in processed_data: + if item is None: + data_indices.append(None) + else: + data_indices.append(current_idx) + current_idx += 1 + else: + # No data to save + data_indices = [] + + # Process metadata for JSON serialization + def is_leaf(item): + return isinstance( + item, + ( + torch.Size, + torch.dtype, + torch.device, + str, + int, + float, + bool, + torch.Tensor, + NonTensorData, + ), + ) + + def map_to_json_serializable(item): + if isinstance(item, torch.Size): + return {"__type__": "torch.Size", "value": list(item)} + elif isinstance(item, torch.dtype): + return {"__type__": "torch.dtype", "value": str(item)} + elif isinstance(item, torch.device): + return {"__type__": "torch.device", "value": str(item)} + elif isinstance(item, torch.Tensor): + return {"__type__": "torch.Tensor", "value": item.tolist()} + elif isinstance(item, NonTensorData): + return {"__type__": "NonTensorData", "value": item.data} + return item + + serializable_metadata = tree_map( + map_to_json_serializable, metadata, is_leaf=is_leaf + ) + + # Save metadata and indices + metadata_file = tmp_path / "metadata.json" + with open(metadata_file, "w") as f: + json.dump(serializable_metadata, f, indent=2) + + indices_file = tmp_path / "data_indices.json" + with open(indices_file, "w") as f: + json.dump(data_indices, f, indent=2) + + # Copy all files from temp directory to final destination + import shutil + + for item in tmp_path.iterdir(): + if item.is_file(): + shutil.copy2(item, path / item.name) + elif item.is_dir(): + shutil.copytree(item, path / item.name, dirs_exist_ok=True) + + def loads(self, storage, path): + """Load compressed storage from disk. + + Args: + storage: The CompressedListStorage instance to load into + path: Directory path where the storage was saved + """ + path = Path(path) + + # Load metadata + metadata_file = path / "metadata.json" + if not metadata_file.exists(): + raise FileNotFoundError(f"Metadata file not found at {metadata_file}") + + with open(metadata_file) as f: + serializable_metadata = json.load(f) + + # Load data indices + indices_file = path / "data_indices.json" + if not indices_file.exists(): + raise FileNotFoundError(f"Data indices file not found at {indices_file}") + + with open(indices_file) as f: + data_indices = json.load(f) + + # Convert serializable metadata back to original format + def is_leaf(item): + return isinstance(item, dict) and "__type__" in item + + def map_from_json_serializable(item): + if isinstance(item, dict) and "__type__" in item: + if item["__type__"] == "torch.Size": + return torch.Size(item["value"]) + elif item["__type__"] == "torch.dtype": + # Handle torch.dtype conversion + dtype_str = item["value"] + if hasattr(torch, dtype_str.replace("torch.", "")): + return getattr(torch, dtype_str.replace("torch.", "")) + else: + # Handle cases like 'torch.float32' -> torch.float32 + return eval(dtype_str) + elif item["__type__"] == "torch.device": + return torch.device(item["value"]) + elif item["__type__"] == "torch.Tensor": + return torch.tensor(item["value"]) + elif item["__type__"] == "NonTensorData": + return NonTensorData(item["value"]) + return item + + metadata = tree_map( + map_from_json_serializable, serializable_metadata, is_leaf=is_leaf + ) + + # Load compressed data from memmap + compressed_data = [] + memmap_path = path / "compressed_data" + + if memmap_path.exists(): + # Load the memmapped data + stacked_data = TensorDict.load_memmap(memmap_path) + compressed_data = stacked_data.tolist() + if len(compressed_data) != len(data_indices): + raise ValueError( + f"Length of compressed data ({len(compressed_data)}) does not match length of data indices ({len(data_indices)})" + ) + for i, (data, mtdt) in enumerate(_zip_strict(compressed_data, metadata)): + if mtdt["type"] == "tensor": + compressed_data[i] = data["data"] + else: + compressed_data[i] = data + + else: + # No data to load + compressed_data = [None] * len(data_indices) + + # Load into storage + storage._storage = compressed_data + storage._metadata = metadata + + class TensorStorageCheckpointer(StorageCheckpointerBase): """A storage checkpointer for TensorStorages. diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index e5ca2367be5..d89aad0fc9c 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -7,12 +7,13 @@ import abc import logging import os +import sys import textwrap import warnings from collections import OrderedDict from copy import copy from multiprocessing.context import get_spawning_popen -from typing import Any, Sequence +from typing import Any, Callable, Mapping, Sequence import numpy as np import tensordict @@ -32,6 +33,7 @@ from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.checkpointers import ( + CompressedListStorageCheckpointer, ListStorageCheckpointer, StorageCheckpointerBase, StorageEnsembleCheckpointer, @@ -278,7 +280,7 @@ def set( return if isinstance(cursor, slice): data = self._to_device(data) - self._storage[cursor] = data + self._set_slice(cursor, data) return if isinstance( data, @@ -305,7 +307,7 @@ def set( if cursor > len(self._storage): raise RuntimeError( "Cannot append data located more than one item away from " - f"the storage size: the storage size is {len(self)} " + f"the storage size: the storage size is {len(self._storage)} " f"and the index of the item to be set is {cursor}." ) if cursor >= self.max_size: @@ -315,14 +317,24 @@ def set( f"and the index of the item to be set is {cursor}." ) data = self._to_device(data) - if cursor == len(self._storage): - self._storage.append(data) - else: - self._storage[cursor] = data + self._set_item(cursor, data) + + def _set_item(self, cursor: int, data: Any) -> None: + """Set a single item in the storage.""" + if cursor == len(self._storage): + self._storage.append(data) + else: + self._storage[cursor] = data + + def _set_slice(self, cursor: slice, data: Any) -> None: + """Set a slice in the storage.""" + self._storage[cursor] = data def get(self, index: int | Sequence[int] | slice) -> Any: - if isinstance(index, (INT_CLASSES, slice)): - return self._storage[index] + if isinstance(index, INT_CLASSES): + return self._get_item(index) + elif isinstance(index, slice): + return self._get_slice(index) elif isinstance(index, tuple): if len(index) > 1: raise RuntimeError( @@ -332,9 +344,22 @@ def get(self, index: int | Sequence[int] | slice) -> Any: else: if isinstance(index, torch.Tensor) and index.device.type != "cpu": index = index.cpu().tolist() - return [self._storage[i] for i in index] + return self._get_list(index) + + def _get_item(self, index: int) -> Any: + """Get a single item from the storage.""" + return self._storage[index] + + def _get_slice(self, index: slice) -> Any: + """Get a slice from the storage.""" + return self._storage[index] + + def _get_list(self, index: list) -> list: + """Get a list of items from the storage.""" + return [self._storage[i] for i in index] def __len__(self): + """Get the length of the storage.""" return len(self._storage) def state_dict(self) -> dict[str, Any]: @@ -379,8 +404,7 @@ def contains(self, item): if isinstance(item, int): if item < 0: item += len(self._storage) - - return 0 <= item < len(self._storage) + return self._contains_int(item) if isinstance(item, torch.Tensor): return torch.tensor( [self.contains(elt) for elt in item.tolist()], @@ -389,6 +413,10 @@ def contains(self, item): ).reshape_as(item) raise NotImplementedError(f"type {type(item)} is not supported yet.") + def _contains_int(self, item: int) -> bool: + """Check if an integer index is contained in the storage.""" + return 0 <= item < len(self._storage) + class LazyStackStorage(ListStorage): """A ListStorage that returns LazyStackTensorDict instances. @@ -872,7 +900,6 @@ def set( # noqa: F811 *, set_cursor: bool = True, ): - if set_cursor: self._last_cursor = cursor @@ -1360,6 +1387,314 @@ def get(self, index: int | Sequence[int] | slice) -> Any: return result +class CompressedListStorage(ListStorage): + """A storage that compresses and decompresses data. + + This storage compresses data when storing and decompresses when retrieving. + It's particularly useful for storing raw sensory observations like images + that can be compressed significantly to save memory. + + Args: + max_size (int): size of the storage, i.e. maximum number of elements stored + in the buffer. + compression_fn (callable, optional): function to compress data. Should take + a tensor and return a compressed byte tensor. Defaults to zstd compression. + decompression_fn (callable, optional): function to decompress data. Should take + a compressed byte tensor and return the original tensor. Defaults to zstd decompression. + compression_level (int, optional): compression level (1-22 for zstd) when using the default compression function. + Defaults to 3. + device (torch.device, optional): device where the sampled tensors will be + stored and sent. Default is :obj:`torch.device("cpu")`. + compilable (bool, optional): whether the storage is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. + + Examples: + >>> import torch + >>> from torchrl.data import CompressedListStorage, ReplayBuffer + >>> from tensordict import TensorDict + >>> + >>> # Create a compressed storage for image data + >>> storage = CompressedListStorage(max_size=1000, compression_level=3) + >>> rb = ReplayBuffer(storage=storage, batch_size=5) + >>> + >>> # Add some image data + >>> images = torch.randn(10, 3, 84, 84) # Atari-like frames + >>> data = TensorDict({"obs": images}, batch_size=[10]) + >>> rb.extend(data) + >>> + >>> # Sample and verify data is decompressed correctly + >>> sample = rb.sample(3) + >>> print(sample["obs"].shape) # torch.Size([3, 3, 84, 84]) + + """ + + _default_checkpointer = CompressedListStorageCheckpointer + + def __init__( + self, + max_size: int, + *, + compression_fn: Callable | None = None, + decompression_fn: Callable | None = None, + compression_level: int = 3, + device: torch.device = "cpu", + compilable: bool = False, + ): + super().__init__(max_size, compilable=compilable, device=device) + self.compression_level = compression_level + + # Set up compression functions + if compression_fn is None: + self.compression_fn = self._default_compression_fn + else: + self.compression_fn = compression_fn + + if decompression_fn is None: + self.decompression_fn = self._default_decompression_fn + else: + self.decompression_fn = decompression_fn + + # Store compressed data and metadata + self._storage = [] + self._metadata = [] # Store shape, dtype, device info for each item + + def _default_compression_fn(self, tensor: torch.Tensor) -> torch.Tensor: + """Default compression using zstd.""" + if sys.version_info >= (3, 14): + from compression import zstd + + compressor_fn = zstd.compress + + else: + import zlib + + compressor_fn = zlib.compress + + # Convert tensor to bytes + tensor_bytes = self.to_bytestream(tensor) + + # Compress with zstd + compressed_bytes = compressor_fn(tensor_bytes, level=self.compression_level) + + # Convert to tensor + return torch.frombuffer(bytearray(compressed_bytes), dtype=torch.uint8) + + def _default_decompression_fn( + self, compressed_tensor: torch.Tensor, metadata: dict + ) -> torch.Tensor: + """Default decompression using zstd.""" + if sys.version_info >= (3, 14): + from compression import zstd + + decompressor_fn = zstd.decompress + + else: + import zlib + + decompressor_fn = zlib.decompress + + # Convert tensor to bytes + compressed_bytes = self.to_bytestream(compressed_tensor.cpu()) + + # Decompress with zstd + decompressed_bytes = decompressor_fn(compressed_bytes) + + # Convert back to tensor + tensor = torch.frombuffer( + bytearray(decompressed_bytes), dtype=metadata["dtype"] + ) + tensor = tensor.reshape(metadata["shape"]) + tensor = tensor.to(metadata["device"]) + + return tensor + + def _compress_item(self, item: Any) -> tuple[torch.Tensor, dict]: + """Compress a single item and return compressed data with metadata.""" + if isinstance(item, torch.Tensor): + metadata = { + "type": "tensor", + "shape": item.shape, + "dtype": item.dtype, + "device": item.device, + } + compressed = self.compression_fn(item) + elif is_tensor_collection(item): + # For TensorDict, compress each tensor field + compressed_fields = {} + metadata = {"type": "tensordict", "fields": {}} + + for key, value in item.items(): + if isinstance(value, torch.Tensor): + compressed_fields[key] = self.compression_fn(value) + metadata["fields"][key] = { + "type": "tensor", + "shape": value.shape, + "dtype": value.dtype, + "device": value.device, + } + else: + # For non-tensor data, store as-is + compressed_fields[key] = value + metadata["fields"][key] = {"type": "non_tensor", "value": value} + + compressed = compressed_fields + else: + # For other types, store as-is + compressed = item + metadata = {"type": "other", "value": item} + + return compressed, metadata + + def _decompress_item(self, compressed_data: Any, metadata: dict) -> Any: + """Decompress a single item using its metadata.""" + if metadata["type"] == "tensor": + return self.decompression_fn(compressed_data, metadata) + elif metadata["type"] == "tensordict": + # Reconstruct TensorDict + result = TensorDict({}, batch_size=metadata.get("batch_size", [])) + + for key, field_metadata in metadata["fields"].items(): + if field_metadata["type"] == "non_tensor": + result[key] = field_metadata["value"] + else: + # Decompress tensor field + result[key] = self.decompression_fn( + compressed_data[key], field_metadata + ) + + return result + else: + # Return as-is for other types + return metadata["value"] + + def _set_item(self, cursor: int, data: Any) -> None: + """Set a single item in the compressed storage.""" + # Ensure we have enough space + while len(self._storage) <= cursor: + self._storage.append(None) + self._metadata.append(None) + + # Compress and store + compressed_data, metadata = self._compress_item(data) + self._storage[cursor] = compressed_data + self._metadata[cursor] = metadata + + def _set_slice(self, cursor: slice, data: Any) -> None: + """Set a slice in the compressed storage.""" + # Handle slice assignment + if not hasattr(data, "__iter__"): + data = [data] + start, stop, step = cursor.indices(len(self._storage)) + indices = list(range(start, stop, step)) + + for i, value in zip(indices, data): + self._set_item(i, value) + + def _get_item(self, index: int) -> Any: + """Get a single item from the compressed storage.""" + if index >= len(self._storage) or self._storage[index] is None: + raise IndexError(f"Index {index} out of bounds or not set") + + compressed_data = self._storage[index] + metadata = self._metadata[index] + return self._decompress_item(compressed_data, metadata) + + def _get_slice(self, index: slice) -> list: + """Get a slice from the compressed storage.""" + start, stop, step = index.indices(len(self._storage)) + results = [] + for i in range(start, stop, step): + if i < len(self._storage) and self._storage[i] is not None: + results.append(self._get_item(i)) + return results + + def _get_list(self, index: list) -> list: + """Get a list of items from the compressed storage.""" + if isinstance(index, torch.Tensor) and index.device.type != "cpu": + index = index.cpu().tolist() + + results = [] + for i in index: + if i >= len(self._storage) or self._storage[i] is None: + raise IndexError(f"Index {i} out of bounds or not set") + results.append(self._get_item(i)) + return results + + def __len__(self) -> int: + """Get the length of the compressed storage.""" + return len([item for item in self._storage if item is not None]) + + def _contains_int(self, item: int) -> bool: + """Check if an integer index is contained in the compressed storage.""" + return 0 <= item < len(self._storage) and self._storage[item] is not None + + def _empty(self): + """Empty the storage.""" + self._storage = [] + self._metadata = [] + + def state_dict(self) -> dict[str, Any]: + """Save the storage state.""" + return { + "_storage": self._storage, + "_metadata": self._metadata, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the storage state.""" + self._storage = state_dict["_storage"] + self._metadata = state_dict["_metadata"] + + def to_bytestream(self, data_to_bytestream: torch.Tensor | np.array | Any) -> bytes: + """Convert data to a byte stream.""" + if isinstance(data_to_bytestream, torch.Tensor): + byte_stream = data_to_bytestream.cpu().numpy().tobytes() + + elif isinstance(data_to_bytestream, np.array): + byte_stream = bytes(data_to_bytestream.tobytes()) + + else: + import io + import pickle + + buffer = io.BytesIO() + pickle.dump(data_to_bytestream, buffer) + buffer.seek(0) + byte_stream = bytes(buffer.read()) + + return byte_stream + + def bytes(self): + """Return the number of bytes in the storage.""" + + def compressed_size_from_list(data: Any) -> int: + if data is None: + return 0 + elif isinstance(data, (bytes,)): + return len(data) + elif isinstance(data, (np.ndarray,)): + return data.nbytes + elif isinstance(data, (torch.Tensor)): + return compressed_size_from_list(data.cpu().numpy()) + elif isinstance(data, (tuple, list, Sequence)): + return sum(compressed_size_from_list(item) for item in data) + elif isinstance(data, Mapping) or is_tensor_collection(data): + return sum(compressed_size_from_list(value) for value in data.values()) + else: + return 0 + + compressed_size_estimate = compressed_size_from_list(self._storage) + if compressed_size_estimate == 0: + if len(self._storage) > 0: + raise RuntimeError( + "Compressed storage is not empty but the compressed size is 0. This is a bug." + ) + warnings.warn("Compressed storage is empty, returning 0 bytes.") + + return compressed_size_estimate + + class StorageEnsemble(Storage): """An ensemble of storages. @@ -1395,7 +1730,7 @@ def __init__( self._transforms = transforms if transforms is not None and len(transforms) != len(storages): raise TypeError( - "transforms must have the same length as the storages " "provided." + "transforms must have the same length as the storages provided." ) @property @@ -1419,7 +1754,7 @@ def get(self, item): buffer_ids = item.get("buffer_ids") index = item.get("index") results = [] - for (buffer_id, sample) in zip(buffer_ids, index): + for buffer_id, sample in zip(buffer_ids, index): buffer_id = self._convert_id(buffer_id) results.append((buffer_id, self._get_storage(buffer_id).get(sample))) if self._transforms is not None: @@ -1567,9 +1902,11 @@ def _collate_id(x): def _get_default_collate(storage, _is_tensordict=False): - if isinstance(storage, LazyStackStorage) or isinstance(storage, TensorStorage): + if isinstance(storage, (LazyStackStorage, TensorStorage)): return _collate_id - elif isinstance(storage, ListStorage): + elif isinstance(storage, CompressedListStorage): + return lazy_stack + elif isinstance(storage, (ListStorage, StorageEnsemble)): return _stack_anything else: raise NotImplementedError(