From 0acd4771059293e7699c15824ad61e499b6c55cb Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 11 Jul 2025 06:42:56 +0100 Subject: [PATCH 1/8] amend --- .../unittest/linux/scripts/environment.yml | 1 + docs/source/reference/data.rst | 62 ++++ test/test_rb.py | 266 +++++++++++++++ torchrl/data/__init__.py | 10 +- torchrl/data/replay_buffers/__init__.py | 4 + torchrl/data/replay_buffers/checkpointers.py | 112 ++++++ torchrl/data/replay_buffers/storages.py | 319 +++++++++++++++++- 7 files changed, 768 insertions(+), 6 deletions(-) diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index e3c79e4569e..b7ca29ff0d4 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -35,3 +35,4 @@ dependencies: - transformers - ninja - timm + - zstandard diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index e9d08822239..48fb5f1e06f 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -144,6 +144,8 @@ using the following components: :template: rl_template.rst + CompressedStorage + CompressedStorageCheckpointer FlatStorageCheckpointer H5StorageCheckpointer ImmutableDatasetWriter @@ -191,6 +193,66 @@ were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/be | :class:`LazyMemmapStorage` | 3.44x | +-------------------------------+-----------+ +Compressed Storage for Memory Efficiency +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For applications where memory usage is a primary concern, especially when storing +large sensory observations like images or audio, the :class:`~torchrl.data.replay_buffers.storages.CompressedStorage` +provides significant memory savings through compression. + +The `CompressedStorage`` compresses data when storing and decompresses when retrieving, +achieving compression ratios of 2-10x for image data while maintaining full data fidelity. +It uses zstd compression by default but supports custom compression algorithms. + +Key features: +- **Memory Efficiency**: Achieves significant memory savings through compression +- **Data Integrity**: Maintains full data fidelity through lossless compression +- **Flexible Compression**: Supports custom compression algorithms or uses zstd by default +- **TensorDict Support**: Seamlessly works with TensorDict structures +- **Checkpointing**: Full support for saving and loading compressed data + +Example usage: + + >>> import torch + >>> from torchrl.data import ReplayBuffer, CompressedStorage + >>> from tensordict import TensorDict + >>> + >>> # Create a compressed storage for image data + >>> storage = CompressedStorage(max_size=1000, compression_level=3) + >>> rb = ReplayBuffer(storage=storage, batch_size=32) + >>> + >>> # Add image data + >>> images = torch.randn(100, 3, 84, 84) # Atari-like frames + >>> data = TensorDict({"obs": images}, batch_size=[100]) + >>> rb.extend(data) + >>> + >>> # Sample data (automatically decompressed) + >>> sample = rb.sample(16) + >>> print(sample["obs"].shape) # torch.Size([16, 3, 84, 84]) + +The compression level can be adjusted from 1 (fast, less compression) to 22 (slow, more compression), +with level 3 being a good default for most use cases. + +For custom compression algorithms: + + >>> def my_compress(tensor): + ... return tensor.to(torch.uint8) # Simple example + >>> + >>> def my_decompress(compressed_tensor, metadata): + ... return compressed_tensor.to(metadata["dtype"]) + >>> + >>> storage = CompressedStorage( + ... max_size=1000, + ... compression_fn=my_compress, + ... decompression_fn=my_decompress + ... ) + +.. note:: The CompressedStorage requires the `zstandard` library for default compression. + Install with: ``pip install zstandard`` + +.. note:: An example of how to use the CompressedStorage is available in the + `examples/replay-buffers/compressed_replay_buffer_example.py `_ file. + Sharing replay buffers across processes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_rb.py b/test/test_rb.py index ca2bb121d65..a52295d9695 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -11,10 +11,13 @@ import os import pickle import sys +import tempfile from functools import partial +from pathlib import Path from unittest import mock import numpy as np + import pytest import torch from packaging import version @@ -35,6 +38,7 @@ from torchrl.collectors import RandomPolicy, SyncDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( + CompressedStorage, FlatStorageCheckpointer, MultiStep, NestedStorageCheckpointer, @@ -129,6 +133,7 @@ _os_is_windows = sys.platform == "win32" _has_transformers = importlib.util.find_spec("transformers") is not None _has_ray = importlib.util.find_spec("ray") is not None +_has_zstandard = importlib.util.find_spec("zstandard") is not None TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) @@ -4027,6 +4032,267 @@ def test_ray_rb_iter(self): rb.close() +@pytest.mark.skipif(not _has_zstandard, reason="zstandard required for this test.") +class TestCompressedStorage: + """Test cases for CompressedStorage.""" + + def test_compressed_storage_initialization(self): + """Test that CompressedStorage initializes correctly.""" + storage = CompressedStorage(max_size=100, compression_level=3) + assert storage.max_size == 100 + assert storage.compression_level == 3 + assert len(storage) == 0 + + def test_compressed_storage_tensor(self): + """Test compression and decompression of tensor data.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Create test tensor + test_tensor = torch.randn(3, 84, 84, dtype=torch.float32) + + # Store tensor + storage.set(0, test_tensor) + + # Retrieve tensor + retrieved_tensor = storage.get(0) + + # Verify data integrity + assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) + assert test_tensor.shape == retrieved_tensor.shape + assert test_tensor.dtype == retrieved_tensor.dtype + + def test_compressed_storage_tensordict(self): + """Test compression and decompression of TensorDict data.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Create test TensorDict + test_td = TensorDict( + { + "obs": torch.randn(3, 84, 84, dtype=torch.float32), + "action": torch.tensor([1, 2, 3]), + "reward": torch.randn(3), + "done": torch.tensor([False, True, False]), + }, + batch_size=[3], + ) + + # Store TensorDict + storage.set(0, test_td) + + # Retrieve TensorDict + retrieved_td = storage.get(0) + + # Verify data integrity + assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6) + assert torch.allclose(test_td["action"], retrieved_td["action"]) + assert torch.allclose(test_td["reward"], retrieved_td["reward"], atol=1e-6) + assert torch.allclose(test_td["done"], retrieved_td["done"]) + + def test_compressed_storage_multiple_indices(self): + """Test storing and retrieving multiple items.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Store multiple tensors + tensors = [ + torch.randn(2, 2, dtype=torch.float32), + torch.randn(3, 3, dtype=torch.float32), + torch.randn(4, 4, dtype=torch.float32), + ] + + for i, tensor in enumerate(tensors): + storage.set(i, tensor) + + # Retrieve multiple tensors + retrieved = storage.get([0, 1, 2]) + + # Verify data integrity + for original, retrieved_tensor in zip(tensors, retrieved): + assert torch.allclose(original, retrieved_tensor, atol=1e-6) + + def test_compressed_storage_with_replay_buffer(self): + """Test CompressedStorage with ReplayBuffer.""" + storage = CompressedStorage(max_size=100, compression_level=3) + rb = ReplayBuffer(storage=storage, batch_size=5) + + # Create test data + data = TensorDict( + { + "obs": torch.randn(10, 3, 84, 84, dtype=torch.float32), + "action": torch.randint(0, 4, (10,)), + "reward": torch.randn(10), + }, + batch_size=[10], + ) + + # Add data to replay buffer + print("extending") + rb.extend(data) + + # Sample from replay buffer + sample = rb.sample(5) + + # Verify sample has correct shape + assert is_tensor_collection(sample), sample + assert sample["obs"].shape[0] == 5 + assert sample["obs"].shape[1:] == (3, 84, 84) + assert sample["action"].shape[0] == 5 + assert sample["reward"].shape[0] == 5 + + def test_compressed_storage_state_dict(self): + """Test saving and loading state dict.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Add some data + test_tensor = torch.randn(3, 3, dtype=torch.float32) + storage.set(0, test_tensor) + + # Save state dict + state_dict = storage.state_dict() + + # Create new storage and load state dict + new_storage = CompressedStorage(max_size=10, compression_level=3) + new_storage.load_state_dict(state_dict) + + # Verify data integrity + retrieved_tensor = new_storage.get(0) + assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) + + def test_compressed_storage_checkpointing(self): + """Test checkpointing functionality.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Add some data + test_td = TensorDict( + { + "obs": torch.randn(3, 84, 84, dtype=torch.float32), + "action": torch.tensor([1, 2, 3]), + }, + batch_size=[3], + ) + storage.set(0, test_td) + + # Create temporary directory for checkpointing + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "checkpoint" + + # Save checkpoint + storage.dumps(checkpoint_path) + + # Create new storage and load checkpoint + new_storage = CompressedStorage(max_size=10, compression_level=3) + new_storage.loads(checkpoint_path) + + # Verify data integrity + retrieved_td = new_storage.get(0) + assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6) + assert torch.allclose(test_td["action"], retrieved_td["action"]) + + def test_compressed_storage_length(self): + """Test that length is calculated correctly.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Initially empty + assert len(storage) == 0 + + # Add some data + storage.set(0, torch.randn(2, 2)) + assert len(storage) == 1 + + storage.set(2, torch.randn(2, 2)) + assert len(storage) == 2 + + storage.set(1, torch.randn(2, 2)) + assert len(storage) == 3 + + def test_compressed_storage_contains(self): + """Test the contains method.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Initially empty + assert not storage.contains(0) + + # Add data + storage.set(0, torch.randn(2, 2)) + assert storage.contains(0) + assert not storage.contains(1) + + def test_compressed_storage_empty(self): + """Test emptying the storage.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Add some data + storage.set(0, torch.randn(2, 2)) + storage.set(1, torch.randn(2, 2)) + assert len(storage) == 2 + + # Empty storage + storage._empty() + assert len(storage) == 0 + + def test_compressed_storage_custom_compression(self): + """Test custom compression functions.""" + + def custom_compress(tensor): + # Simple compression: just convert to uint8 + return tensor.to(torch.uint8) + + def custom_decompress(compressed_tensor, metadata): + # Simple decompression: convert back to original dtype + return compressed_tensor.to(metadata["dtype"]) + + storage = CompressedStorage( + max_size=10, + compression_fn=custom_compress, + decompression_fn=custom_decompress, + ) + + # Test with tensor + test_tensor = torch.randn(2, 2, dtype=torch.float32) + storage.set(0, test_tensor) + retrieved_tensor = storage.get(0) + + # Note: This will lose precision due to uint8 conversion + # but should still work + assert retrieved_tensor.shape == test_tensor.shape + + def test_compressed_storage_error_handling(self): + """Test error handling for invalid operations.""" + storage = CompressedStorage(max_size=5, compression_level=3) + + # Test setting data beyond max_size + with pytest.raises(RuntimeError): + storage.set(10, torch.randn(2, 2)) + + # Test getting non-existent data + with pytest.raises(IndexError): + storage.get(0) + + def test_compressed_storage_memory_efficiency(self): + """Test that compression actually reduces memory usage.""" + storage = CompressedStorage(max_size=100, compression_level=3) + + # Create large tensor data + large_tensor = torch.zeros(100, 3, 84, 84, dtype=torch.int64) + large_tensor.copy_( + torch.arange(large_tensor.numel(), dtype=torch.int32).view_as(large_tensor) + // (3 * 84 * 84) + ) + original_size = large_tensor.numel() * large_tensor.element_size() + + # Store in compressed storage + storage.set(0, large_tensor) + + # Estimate compressed size + compressed_data = storage._compressed_data[0] + compressed_size = compressed_data.numel() # uint8 bytes + + # Verify compression ratio is reasonable (at least 2x for random data) + compression_ratio = original_size / compressed_size + assert ( + compression_ratio > 1.5 + ), f"Compression ratio {compression_ratio} is too low" + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index b3fafd16ee1..5c8bff8a3af 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -32,6 +32,8 @@ ) from .postprocs import DensifyReward, MultiStep from .replay_buffers import ( + CompressedStorage, + CompressedStorageCheckpointer, Flat2TED, FlatStorageCheckpointer, H5Combine, @@ -116,21 +118,22 @@ "BoundedTensorSpec", "Categorical", "Choice", - "ContentBase", - "TopKRewardSelector", "Composite", "CompositeSpec", + "CompressedStorage", + "CompressedStorageCheckpointer", "ConstantKLController", + "ContentBase", "DEVICE_TYPING", "DensifyReward", "DiscreteTensorSpec", "Flat2TED", "FlatStorageCheckpointer", - "History", "H5Combine", "H5Split", "H5StorageCheckpointer", "HashToInt", + "History", "ImmutableDatasetWriter", "LazyMemmapStorage", "LazyStackStorage", @@ -191,6 +194,7 @@ "TensorStorage", "TensorStorageCheckpointer", "TokenizedDatasetLoader", + "TopKRewardSelector", "Tree", "Unbounded", "UnboundedContinuous", diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 6e7bff8eac0..71c1dd20795 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from .checkpointers import ( + CompressedStorageCheckpointer, FlatStorageCheckpointer, H5StorageCheckpointer, ListStorageCheckpointer, @@ -32,6 +33,7 @@ SliceSamplerWithoutReplacement, ) from .storages import ( + CompressedStorage, LazyMemmapStorage, LazyStackStorage, LazyTensorStorage, @@ -51,6 +53,8 @@ ) __all__ = [ + "CompressedStorage", + "CompressedStorageCheckpointer", "FlatStorageCheckpointer", "H5StorageCheckpointer", "ListStorageCheckpointer", diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 6328857292c..346dddf28af 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -68,6 +68,118 @@ def loads(storage, path): ) +class CompressedStorageCheckpointer(StorageCheckpointerBase): + """A storage checkpointer for CompressedStorage. + + This checkpointer saves compressed data and metadata separately for efficient storage. + + """ + + def dumps(self, storage, path): + path = Path(path) + path.mkdir(exist_ok=True) + + if not hasattr(storage, "_compressed_data") or not storage._compressed_data: + raise RuntimeError( + "Cannot save an empty or non-initialized CompressedStorage." + ) + + # Save compressed data and metadata + state_dict = storage.state_dict() + + # Save compressed data as separate files for efficiency + compressed_data = state_dict["_compressed_data"] + metadata = state_dict["_metadata"] + + # Save metadata + with open(path / "compressed_metadata.json", "w") as f: + json.dump(metadata, f, default=str) + + # Save compressed data + for i, (compressed_item, item_metadata) in enumerate( + zip(compressed_data, metadata) + ): + if compressed_item is not None: + if item_metadata["type"] == "tensor": + # Save as numpy array + np.save( + path / f"compressed_data_{i}.npy", compressed_item.cpu().numpy() + ) + elif item_metadata["type"] == "tensordict": + # Save each field separately + item_dir = path / f"compressed_data_{i}" + item_dir.mkdir(exist_ok=True) + + for key, value in compressed_item.items(): + if isinstance(value, torch.Tensor): + np.save(item_dir / f"{key}.npy", value.cpu().numpy()) + else: + # Save non-tensor data as pickle + import pickle + + with open(item_dir / f"{key}.pkl", "wb") as f: + pickle.dump(value, f) + else: + # Save other types as pickle + import pickle + + with open(path / f"compressed_data_{i}.pkl", "wb") as f: + pickle.dump(compressed_item, f) + + def loads(self, storage, path): + path = Path(path) + + # Load metadata + with open(path / "compressed_metadata.json") as f: + metadata = json.load(f) + + # Load compressed data + compressed_data = [] + i = 0 + + while True: + if (path / f"compressed_data_{i}.npy").exists(): + # Load tensor data + data = np.load(path / f"compressed_data_{i}.npy") + compressed_data.append(torch.from_numpy(data)) + elif (path / f"compressed_data_{i}.pkl").exists(): + # Load other data + import pickle + + with open(path / f"compressed_data_{i}.pkl", "rb") as f: + data = pickle.load(f) + compressed_data.append(data) + elif (path / f"compressed_data_{i}").exists(): + # Load tensordict data + item_dir = path / f"compressed_data_{i}" + item_data = {} + + for key in metadata[i]["fields"].keys(): + if (item_dir / f"{key}.npy").exists(): + data = np.load(item_dir / f"{key}.npy") + item_data[key] = torch.from_numpy(data) + elif (item_dir / f"{key}.pkl").exists(): + import pickle + + with open(item_dir / f"{key}.pkl", "rb") as f: + data = pickle.load(f) + item_data[key] = data + + compressed_data.append(item_data) + else: + break + + i += 1 + + # Pad with None to match metadata length + while len(compressed_data) < len(metadata): + compressed_data.append(None) + + # Load into storage + storage._compressed_data = compressed_data + storage._metadata = metadata + + class TensorStorageCheckpointer(StorageCheckpointerBase): """A storage checkpointer for TensorStorages. diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index e5ca2367be5..b6c17769659 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -12,7 +12,7 @@ from collections import OrderedDict from copy import copy from multiprocessing.context import get_spawning_popen -from typing import Any, Sequence +from typing import Any, Callable, Mapping, Sequence import numpy as np import tensordict @@ -32,6 +32,7 @@ from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.checkpointers import ( + CompressedStorageCheckpointer, ListStorageCheckpointer, StorageCheckpointerBase, StorageEnsembleCheckpointer, @@ -1360,6 +1361,316 @@ def get(self, index: int | Sequence[int] | slice) -> Any: return result +class CompressedStorage(Storage): + """A storage that compresses and decompresses data. + + This storage compresses data when storing and decompresses when retrieving. + It's particularly useful for storing raw sensory observations like images + that can be compressed significantly to save memory. + + Args: + max_size (int): size of the storage, i.e. maximum number of elements stored + in the buffer. + compression_fn (callable, optional): function to compress data. Should take + a tensor and return a compressed byte tensor. Defaults to zstd compression. + decompression_fn (callable, optional): function to decompress data. Should take + a compressed byte tensor and return the original tensor. Defaults to zstd decompression. + compression_level (int, optional): compression level (1-22 for zstd) when using the default compression function. + Defaults to 3. + device (torch.device, optional): device where the sampled tensors will be + stored and sent. Default is :obj:`torch.device("cpu")`. + compilable (bool, optional): whether the storage is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. + + Examples: + >>> import torch + >>> from torchrl.data import CompressedStorage, ReplayBuffer + >>> from tensordict import TensorDict + >>> + >>> # Create a compressed storage for image data + >>> storage = CompressedStorage(max_size=1000, compression_level=3) + >>> rb = ReplayBuffer(storage=storage, batch_size=5) + >>> + >>> # Add some image data + >>> images = torch.randn(10, 3, 84, 84) # Atari-like frames + >>> data = TensorDict({"obs": images}, batch_size=[10]) + >>> rb.extend(data) + >>> + >>> # Sample and verify data is decompressed correctly + >>> sample = rb.sample(3) + >>> print(sample["obs"].shape) # torch.Size([3, 3, 84, 84]) + + """ + + _default_checkpointer = CompressedStorageCheckpointer + + def __init__( + self, + max_size: int, + *, + compression_fn: Callable | None = None, + decompression_fn: Callable | None = None, + compression_level: int = 3, + device: torch.device = "cpu", + compilable: bool = False, + ): + super().__init__(max_size, compilable=compilable) + self.device = device + self.compression_level = compression_level + + # Set up compression functions + if compression_fn is None: + self.compression_fn = self._default_compression_fn + else: + self.compression_fn = compression_fn + + if decompression_fn is None: + self.decompression_fn = self._default_decompression_fn + else: + self.decompression_fn = decompression_fn + + # Store compressed data and metadata + self._compressed_data = [] + self._metadata = [] # Store shape, dtype, device info for each item + + def _default_compression_fn(self, tensor: torch.Tensor) -> torch.Tensor: + """Default compression using zstd.""" + try: + import zstandard as zstd + except ImportError: + raise ImportError( + "zstandard is required for default compression. " + "Install with: pip install zstandard" + ) + + # Convert tensor to bytes + tensor_bytes = tensor.cpu().numpy().tobytes() + + # Compress with zstd + compressor = zstd.ZstdCompressor(level=self.compression_level) + compressed_bytes = compressor.compress(tensor_bytes) + + # Convert to tensor + return torch.tensor(list(compressed_bytes), dtype=torch.uint8) + + def _default_decompression_fn( + self, compressed_tensor: torch.Tensor, metadata: dict + ) -> torch.Tensor: + """Default decompression using zstd.""" + try: + import zstandard as zstd + except ImportError: + raise ImportError( + "zstandard is required for default decompression. " + "Install with: pip install zstandard" + ) + + # Convert tensor to bytes + compressed_bytes = bytes(compressed_tensor.cpu().numpy()) + + # Decompress with zstd + decompressor = zstd.ZstdDecompressor() + decompressed_bytes = decompressor.decompress(compressed_bytes) + + # Convert back to tensor + tensor = torch.frombuffer(decompressed_bytes, dtype=metadata["dtype"]) + tensor = tensor.reshape(metadata["shape"]) + tensor = tensor.to(metadata["device"]) + + return tensor + + def set( + self, + cursor: int | Sequence[int] | slice, + data: Any, + *, + set_cursor: bool = True, + ): + """Set data in the storage with compression.""" + if isinstance(cursor, (INT_CLASSES, slice)): + cursor = [cursor] + + if isinstance(data, (list, tuple)): + data_list = data + elif isinstance(data, (torch.Tensor, TensorDictBase)): + data_list = data.unbind(0) + # determine if data is iterable and has length equal to cursor + elif hasattr(data, "__len__") and len(data) == len(cursor): + data_list = data + else: + data_list = [data] + + for idx, item in zip(cursor, data_list): + if idx >= self.max_size: + raise RuntimeError( + f"Cannot set data at index {idx}: storage max_size is {self.max_size}" + ) + + # Compress the data + compressed_data, metadata = self._compress_item(item) + + # Ensure we have enough space + while len(self._compressed_data) <= idx: + self._compressed_data.append(None) + self._metadata.append(None) + + # Store compressed data and metadata + self._compressed_data[idx] = compressed_data + self._metadata[idx] = metadata + + def _compress_item(self, item: Any) -> tuple[torch.Tensor, dict]: + """Compress a single item and return compressed data with metadata.""" + if isinstance(item, torch.Tensor): + metadata = { + "type": "tensor", + "shape": item.shape, + "dtype": item.dtype, + "device": item.device, + } + compressed = self.compression_fn(item) + elif is_tensor_collection(item): + # For TensorDict, compress each tensor field + compressed_fields = {} + metadata = {"type": "tensordict", "fields": {}} + + for key, value in item.items(): + if isinstance(value, torch.Tensor): + compressed_fields[key] = self.compression_fn(value) + metadata["fields"][key] = { + "type": "tensor", + "shape": value.shape, + "dtype": value.dtype, + "device": value.device, + } + else: + # For non-tensor data, store as-is + compressed_fields[key] = value + metadata["fields"][key] = {"type": "non_tensor", "value": value} + + compressed = compressed_fields + else: + # For other types, store as-is + compressed = item + metadata = {"type": "other", "value": item} + + return compressed, metadata + + def get(self, index: int | Sequence[int] | slice) -> Any: + """Get data from the storage with decompression.""" + if isinstance(index, (INT_CLASSES, slice)): + indices = [index] + else: + indices = index + + results = [] + for idx in indices: + if idx >= len(self._compressed_data) or self._compressed_data[idx] is None: + raise IndexError(f"Index {idx} out of bounds or not set") + + compressed_data = self._compressed_data[idx] + metadata = self._metadata[idx] + + # Decompress the data + decompressed = self._decompress_item(compressed_data, metadata) + results.append(decompressed) + + if len(results) == 1: + return results[0] + return results + + def _decompress_item(self, compressed_data: Any, metadata: dict) -> Any: + """Decompress a single item using its metadata.""" + if metadata["type"] == "tensor": + return self.decompression_fn(compressed_data, metadata) + elif metadata["type"] == "tensordict": + # Reconstruct TensorDict + result = TensorDict({}, batch_size=metadata.get("batch_size", [])) + + for key, field_metadata in metadata["fields"].items(): + if field_metadata["type"] == "non_tensor": + result[key] = field_metadata["value"] + else: + # Decompress tensor field + result[key] = self.decompression_fn( + compressed_data[key], field_metadata + ) + + return result + else: + # Return as-is for other types + return metadata["value"] + + def __len__(self): + """Return the number of items in the storage.""" + return len([item for item in self._compressed_data if item is not None]) + + def state_dict(self) -> dict[str, Any]: + """Save the storage state.""" + return { + "_compressed_data": self._compressed_data, + "_metadata": self._metadata, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the storage state.""" + self._compressed_data = state_dict["_compressed_data"] + self._metadata = state_dict["_metadata"] + + def _empty(self): + """Empty the storage.""" + self._compressed_data = [] + self._metadata = [] + + def loads(self, path: str): + """Load the storage state from a file.""" + super().loads(path) + from tensordict.utils import _STR_DTYPE_TO_DTYPE + from torch.utils._pytree import tree_flatten, tree_unflatten + + leaves, spec = tree_flatten(self._metadata) + leaves = [ + _STR_DTYPE_TO_DTYPE.get(x, x) if isinstance(x, str) else x for x in leaves + ] + self._metadata = tree_unflatten(leaves, spec) + + def contains(self, item): + """Check if an item is in the storage.""" + if isinstance(item, int): + if item < 0: + item += len(self._compressed_data) + return ( + 0 <= item < len(self._compressed_data) + and self._compressed_data[item] is not None + ) + raise NotImplementedError(f"type {type(item)} is not supported yet.") + + def bytes(self): + """Return the number of bytes in the storage.""" + + def compressed_size_from_list(data: Any) -> int: + if data is None: + return 0 + elif isinstance(data, torch.Tensor): + return data.numel() + elif isinstance(data, (tuple, list, Sequence)): + return sum(compressed_size_from_list(item) for item in data) + elif isinstance(data, Mapping) or is_tensor_collection(data): + return sum(compressed_size_from_list(value) for value in data.values()) + else: + return 0 + + compressed_size_estimate = compressed_size_from_list(self._compressed_data) + if compressed_size_estimate == 0: + if len(self._compressed_data) > 0: + raise RuntimeError( + "Compressed storage is not empty but the compressed size is 0. This is a bug." + ) + warnings.warn("Compressed storage is empty, returning 0 bytes.") + + return compressed_size_estimate + + class StorageEnsemble(Storage): """An ensemble of storages. @@ -1567,9 +1878,11 @@ def _collate_id(x): def _get_default_collate(storage, _is_tensordict=False): - if isinstance(storage, LazyStackStorage) or isinstance(storage, TensorStorage): + if isinstance(storage, (LazyStackStorage, TensorStorage)): return _collate_id - elif isinstance(storage, ListStorage): + elif isinstance(storage, CompressedStorage): + return lazy_stack + elif isinstance(storage, (ListStorage, StorageEnsemble)): return _stack_anything else: raise NotImplementedError( From dce0fa9f11d6ff26ac038349e0797d2362c7d79e Mon Sep 17 00:00:00 2001 From: Adrian Orenstein Date: Sat, 12 Jul 2025 05:54:58 -0600 Subject: [PATCH 2/8] compression is helpful for data transfer --- docs/source/reference/data.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 48fb5f1e06f..812c19201f5 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -196,7 +196,7 @@ were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/be Compressed Storage for Memory Efficiency ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -For applications where memory usage is a primary concern, especially when storing +For applications where memory usage or memory bandwidth is a primary concern, especially when storing or transferring large sensory observations like images or audio, the :class:`~torchrl.data.replay_buffers.storages.CompressedStorage` provides significant memory savings through compression. From 6a7b9e0104dd98aff4ea6d6b973d8661a769c1f0 Mon Sep 17 00:00:00 2001 From: Adrian Orenstein Date: Sat, 12 Jul 2025 12:54:21 -0600 Subject: [PATCH 3/8] Move CompressedStorage to CompressedListStorage. Moved out all of the cursor logic to a view class. Passing all tests now. --- .pre-commit-config.yaml | 2 +- docs/source/reference/data.rst | 18 +- test/test_rb.py | 187 ++++++----- torchrl/data/__init__.py | 8 +- torchrl/data/replay_buffers/__init__.py | 8 +- torchrl/data/replay_buffers/checkpointers.py | 41 ++- torchrl/data/replay_buffers/storages.py | 315 ++++++++++++------- 7 files changed, 374 insertions(+), 205 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 37adaef7979..43d9ad1a525 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: - libcst == 0.4.7 - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 6.0.0 hooks: - id: flake8 args: [--config=setup.cfg] diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 812c19201f5..fdfa5405b56 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -144,8 +144,8 @@ using the following components: :template: rl_template.rst - CompressedStorage - CompressedStorageCheckpointer + CompressedListStorage + CompressedListStorageCheckpointer FlatStorageCheckpointer H5StorageCheckpointer ImmutableDatasetWriter @@ -197,10 +197,10 @@ Compressed Storage for Memory Efficiency ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ For applications where memory usage or memory bandwidth is a primary concern, especially when storing or transferring -large sensory observations like images or audio, the :class:`~torchrl.data.replay_buffers.storages.CompressedStorage` +large sensory observations like images, audio, or text. The :class:`~torchrl.data.replay_buffers.storages.CompressedListStorage` provides significant memory savings through compression. -The `CompressedStorage`` compresses data when storing and decompresses when retrieving, +The `CompressedListStorage`` compresses data when storing and decompresses when retrieving, achieving compression ratios of 2-10x for image data while maintaining full data fidelity. It uses zstd compression by default but supports custom compression algorithms. @@ -214,11 +214,11 @@ Key features: Example usage: >>> import torch - >>> from torchrl.data import ReplayBuffer, CompressedStorage + >>> from torchrl.data import ReplayBuffer, CompressedListStorage >>> from tensordict import TensorDict >>> >>> # Create a compressed storage for image data - >>> storage = CompressedStorage(max_size=1000, compression_level=3) + >>> storage = CompressedListStorage(max_size=1000, compression_level=3) >>> rb = ReplayBuffer(storage=storage, batch_size=32) >>> >>> # Add image data @@ -241,16 +241,16 @@ For custom compression algorithms: >>> def my_decompress(compressed_tensor, metadata): ... return compressed_tensor.to(metadata["dtype"]) >>> - >>> storage = CompressedStorage( + >>> storage = CompressedListStorage( ... max_size=1000, ... compression_fn=my_compress, ... decompression_fn=my_decompress ... ) -.. note:: The CompressedStorage requires the `zstandard` library for default compression. +.. note:: The CompressedListStorage requires the `zstandard` library for default compression. Install with: ``pip install zstandard`` -.. note:: An example of how to use the CompressedStorage is available in the +.. note:: An example of how to use the CompressedListStorage is available in the `examples/replay-buffers/compressed_replay_buffer_example.py `_ file. Sharing replay buffers across processes diff --git a/test/test_rb.py b/test/test_rb.py index a52295d9695..fbe3842b202 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -38,7 +38,7 @@ from torchrl.collectors import RandomPolicy, SyncDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( - CompressedStorage, + CompressedListStorage, FlatStorageCheckpointer, MultiStep, NestedStorageCheckpointer, @@ -189,7 +189,6 @@ @pytest.mark.parametrize("size", [3, 5, 100]) class TestComposableBuffers: def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False): - if storage is not None: storage = storage(size, compilable=compilable) @@ -332,10 +331,14 @@ def test_cursor_position(self, rb_type, sampler, writer, storage, size, datatype writer.extend(batch1) return - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): writer.extend(batch1) # Added less data than storage max size @@ -383,10 +386,14 @@ def test_extend(self, rb_type, sampler, writer, storage, size, datatype): length = min(rb._storage.max_size, len(rb) + data_shape) if writer is TensorDictMaxValueWriter: data["next", "reward"][-length:] = 1_000_000 - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) length = len(rb) if is_tensor_collection(data): @@ -424,10 +431,14 @@ def data_iter(): and size < len(data2) and isinstance(rb._storage, TensorStorage) ) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data2) @pytest.mark.skipif( @@ -530,10 +541,14 @@ def test_sample(self, rb_type, sampler, writer, storage, size, datatype): ): rb.extend(data) return - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) rb_sample = rb.sample() # if not isinstance(new_data, (torch.Tensor, TensorDictBase)): @@ -550,7 +565,6 @@ def data_iter_func(maxval, data=data): rb_sample_iter = data_iter_func(rb._batch_size, rb_sample) for single_sample in rb_sample_iter: - if is_tensor_collection(data) or isinstance(data, torch.Tensor): data_iter = data else: @@ -600,10 +614,14 @@ def test_index(self, rb_type, sampler, writer, storage, size, datatype): ): rb.extend(data) return - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) d1 = rb[2] d2 = rb._storage[2] @@ -620,7 +638,6 @@ def test_index(self, rb_type, sampler, writer, storage, size, datatype): assert b def test_pickable(self, rb_type, sampler, writer, storage, size, datatype): - rb = self._get_rb( rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) @@ -879,7 +896,6 @@ def extend_and_sample(data): @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) def test_extend_lazystack(self, storage_type): - rb = ReplayBuffer( storage=storage_type(6), batch_size=2, @@ -1506,7 +1522,6 @@ def test_rng_dumps(self, tmpdir): @pytest.mark.parametrize("size", [3, 5, 100]) @pytest.mark.parametrize("prefetch", [0]) class TestBuffers: - default_constr = { ReplayBuffer: ReplayBuffer, PrioritizedReplayBuffer: functools.partial( @@ -1573,10 +1588,14 @@ def test_cursor_position2(self, rbtype, storage, size, prefetch): cond = ( OLD_TORCH and size < len(batch1) and isinstance(rb._storage, TensorStorage) ) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(batch1) # Added fewer data than storage max size @@ -1631,10 +1650,14 @@ def test_extend(self, rbtype, storage, size, prefetch): rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) data = self._get_data(rbtype, size=5) cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) length = len(rb) for d in data[-length:]: @@ -1658,10 +1681,14 @@ def test_sample(self, rbtype, storage, size, prefetch): rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) data = self._get_data(rbtype, size=5) cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) new_data = rb.sample() if not isinstance(new_data, (torch.Tensor, TensorDictBase)): @@ -1688,10 +1715,14 @@ def test_index(self, rbtype, storage, size, prefetch): rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) data = self._get_data(rbtype, size=5) cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) d1 = rb[2] d2 = rb._storage[2] @@ -2568,7 +2599,7 @@ def test_slice_sampler( break else: raise AssertionError( - f"Not all items can be sampled: {set(range(100))-count_unique} are missing" + f"Not all items can be sampled: {set(range(100)) - count_unique} are missing" ) if strict_length: @@ -2775,7 +2806,6 @@ def test_slice_sampler_left_right_ndim(self): assert curr_eps.unique().numel() == 1 def test_slice_sampler_strictlength(self): - torch.manual_seed(0) data = TensorDict( @@ -3435,7 +3465,6 @@ def _robust_stack(tensor_list): def test_rb( self, storage_type, sampler_type, data_type, p, num_buffer_sampled, batch_size ): - storages = [self._make_storage(storage_type, data_type) for _ in range(3)] collate_fn = self._make_collate(storage_type) data = [self._make_data(data_type) for _ in range(3)] @@ -4033,22 +4062,29 @@ def test_ray_rb_iter(self): @pytest.mark.skipif(not _has_zstandard, reason="zstandard required for this test.") -class TestCompressedStorage: - """Test cases for CompressedStorage.""" +class TestCompressedListStorage: + """Test cases for CompressedListStorage.""" def test_compressed_storage_initialization(self): - """Test that CompressedStorage initializes correctly.""" - storage = CompressedStorage(max_size=100, compression_level=3) + """Test that CompressedListStorage initializes correctly.""" + storage = CompressedListStorage(max_size=100, compression_level=3) assert storage.max_size == 100 assert storage.compression_level == 3 assert len(storage) == 0 - def test_compressed_storage_tensor(self): - """Test compression and decompression of tensor data.""" - storage = CompressedStorage(max_size=10, compression_level=3) - - # Create test tensor - test_tensor = torch.randn(3, 84, 84, dtype=torch.float32) + @pytest.mark.parametrize( + "test_tensor", + [ + torch.rand(1), # 0D scalar + torch.randn(84, dtype=torch.float32), # 1D tensor + torch.randn(84, 84, dtype=torch.float32), # 2D tensor + torch.randn(1, 84, 84, dtype=torch.float32), # 3D tensor + torch.randn(32, 84, 84, dtype=torch.float32), # 3D tensor + ], + ) + def test_compressed_storage_tensor(self, test_tensor): + """Test compression and decompression of tensor data of various shapes.""" + storage = CompressedListStorage(max_size=10, compression_level=3) # Store tensor storage.set(0, test_tensor) @@ -4057,13 +4093,17 @@ def test_compressed_storage_tensor(self): retrieved_tensor = storage.get(0) # Verify data integrity + assert ( + test_tensor.shape == retrieved_tensor.shape + ), f"Expected shape {test_tensor.shape}, got {retrieved_tensor.shape}" + assert ( + test_tensor.dtype == retrieved_tensor.dtype + ), f"Expected dtype {test_tensor.dtype}, got {retrieved_tensor.dtype}" assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) - assert test_tensor.shape == retrieved_tensor.shape - assert test_tensor.dtype == retrieved_tensor.dtype def test_compressed_storage_tensordict(self): """Test compression and decompression of TensorDict data.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Create test TensorDict test_td = TensorDict( @@ -4090,7 +4130,7 @@ def test_compressed_storage_tensordict(self): def test_compressed_storage_multiple_indices(self): """Test storing and retrieving multiple items.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Store multiple tensors tensors = [ @@ -4110,8 +4150,8 @@ def test_compressed_storage_multiple_indices(self): assert torch.allclose(original, retrieved_tensor, atol=1e-6) def test_compressed_storage_with_replay_buffer(self): - """Test CompressedStorage with ReplayBuffer.""" - storage = CompressedStorage(max_size=100, compression_level=3) + """Test CompressedListStorage with ReplayBuffer.""" + storage = CompressedListStorage(max_size=100, compression_level=3) rb = ReplayBuffer(storage=storage, batch_size=5) # Create test data @@ -4125,7 +4165,6 @@ def test_compressed_storage_with_replay_buffer(self): ) # Add data to replay buffer - print("extending") rb.extend(data) # Sample from replay buffer @@ -4140,7 +4179,7 @@ def test_compressed_storage_with_replay_buffer(self): def test_compressed_storage_state_dict(self): """Test saving and loading state dict.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Add some data test_tensor = torch.randn(3, 3, dtype=torch.float32) @@ -4150,7 +4189,7 @@ def test_compressed_storage_state_dict(self): state_dict = storage.state_dict() # Create new storage and load state dict - new_storage = CompressedStorage(max_size=10, compression_level=3) + new_storage = CompressedListStorage(max_size=10, compression_level=3) new_storage.load_state_dict(state_dict) # Verify data integrity @@ -4159,7 +4198,7 @@ def test_compressed_storage_state_dict(self): def test_compressed_storage_checkpointing(self): """Test checkpointing functionality.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Add some data test_td = TensorDict( @@ -4179,7 +4218,7 @@ def test_compressed_storage_checkpointing(self): storage.dumps(checkpoint_path) # Create new storage and load checkpoint - new_storage = CompressedStorage(max_size=10, compression_level=3) + new_storage = CompressedListStorage(max_size=10, compression_level=3) new_storage.loads(checkpoint_path) # Verify data integrity @@ -4189,7 +4228,7 @@ def test_compressed_storage_checkpointing(self): def test_compressed_storage_length(self): """Test that length is calculated correctly.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Initially empty assert len(storage) == 0 @@ -4198,15 +4237,15 @@ def test_compressed_storage_length(self): storage.set(0, torch.randn(2, 2)) assert len(storage) == 1 - storage.set(2, torch.randn(2, 2)) + storage.set(1, torch.randn(2, 2)) assert len(storage) == 2 - storage.set(1, torch.randn(2, 2)) + storage.set(2, torch.randn(2, 2)) assert len(storage) == 3 def test_compressed_storage_contains(self): """Test the contains method.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Initially empty assert not storage.contains(0) @@ -4218,7 +4257,7 @@ def test_compressed_storage_contains(self): def test_compressed_storage_empty(self): """Test emptying the storage.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Add some data storage.set(0, torch.randn(2, 2)) @@ -4240,7 +4279,7 @@ def custom_decompress(compressed_tensor, metadata): # Simple decompression: convert back to original dtype return compressed_tensor.to(metadata["dtype"]) - storage = CompressedStorage( + storage = CompressedListStorage( max_size=10, compression_fn=custom_compress, decompression_fn=custom_decompress, @@ -4257,7 +4296,7 @@ def custom_decompress(compressed_tensor, metadata): def test_compressed_storage_error_handling(self): """Test error handling for invalid operations.""" - storage = CompressedStorage(max_size=5, compression_level=3) + storage = CompressedListStorage(max_size=5, compression_level=3) # Test setting data beyond max_size with pytest.raises(RuntimeError): @@ -4269,7 +4308,7 @@ def test_compressed_storage_error_handling(self): def test_compressed_storage_memory_efficiency(self): """Test that compression actually reduces memory usage.""" - storage = CompressedStorage(max_size=100, compression_level=3) + storage = CompressedListStorage(max_size=100, compression_level=3) # Create large tensor data large_tensor = torch.zeros(100, 3, 84, 84, dtype=torch.int64) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 5c8bff8a3af..226ec4d5bb9 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -32,8 +32,8 @@ ) from .postprocs import DensifyReward, MultiStep from .replay_buffers import ( - CompressedStorage, - CompressedStorageCheckpointer, + CompressedListStorage, + CompressedListStorageCheckpointer, Flat2TED, FlatStorageCheckpointer, H5Combine, @@ -120,8 +120,8 @@ "Choice", "Composite", "CompositeSpec", - "CompressedStorage", - "CompressedStorageCheckpointer", + "CompressedListStorage", + "CompressedListStorageCheckpointer", "ConstantKLController", "ContentBase", "DEVICE_TYPING", diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 71c1dd20795..540d7c129be 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from .checkpointers import ( - CompressedStorageCheckpointer, + CompressedListStorageCheckpointer, FlatStorageCheckpointer, H5StorageCheckpointer, ListStorageCheckpointer, @@ -33,7 +33,7 @@ SliceSamplerWithoutReplacement, ) from .storages import ( - CompressedStorage, + CompressedListStorage, LazyMemmapStorage, LazyStackStorage, LazyTensorStorage, @@ -53,8 +53,8 @@ ) __all__ = [ - "CompressedStorage", - "CompressedStorageCheckpointer", + "CompressedListStorage", + "CompressedListStorageCheckpointer", "FlatStorageCheckpointer", "H5StorageCheckpointer", "ListStorageCheckpointer", diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 346dddf28af..82869204ea6 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -68,8 +68,8 @@ def loads(storage, path): ) -class CompressedStorageCheckpointer(StorageCheckpointerBase): - """A storage checkpointer for CompressedStorage. +class CompressedListStorageCheckpointer(StorageCheckpointerBase): + """A storage checkpointer for CompressedListStorage. This checkpointer saves compressed data and metadata separately for efficient storage. @@ -79,9 +79,12 @@ def dumps(self, storage, path): path = Path(path) path.mkdir(exist_ok=True) - if not hasattr(storage, "_compressed_data") or not storage._compressed_data: + if ( + not hasattr(storage, "_compressed_data") + or len(storage._compressed_data) == 0 + ): raise RuntimeError( - "Cannot save an empty or non-initialized CompressedStorage." + "Cannot save an empty or non-initialized CompressedListStorage." ) # Save compressed data and metadata @@ -107,7 +110,7 @@ def dumps(self, storage, path): ) elif item_metadata["type"] == "tensordict": # Save each field separately - item_dir = path / f"compressed_data_{i}" + item_dir = path / f"compressed_data_{i}.td" item_dir.mkdir(exist_ok=True) for key, value in compressed_item.items(): @@ -133,10 +136,34 @@ def loads(self, storage, path): with open(path / "compressed_metadata.json") as f: metadata = json.load(f) + # Convert string dtypes back to torch.dtype objects + def convert_dtype(item): + if isinstance(item, dict): + if "dtype" in item and isinstance(item["dtype"], str): + # Convert string back to torch.dtype + dtype_str = item["dtype"] + if hasattr(torch, dtype_str.replace("torch.", "")): + item["dtype"] = getattr(torch, dtype_str.replace("torch.", "")) + else: + # Handle cases like 'torch.float32' -> torch.float32 + item["dtype"] = eval(dtype_str) + + # Recursively handle nested dictionaries + for _key, value in item.items(): + if isinstance(value, dict): + convert_dtype(value) + return item + + for item in metadata: + if item is not None: + convert_dtype(item) + # Load compressed data compressed_data = [] i = 0 + # TODO(adrian): Can we not know the serialised format beforehand? Then we can use glob to iterate through the files we know exist: + # `for path in glob.glob(path / f"compressed_data_*.{fmt}" for fmt in ["npy", "pkl", "td"]):`` while True: if (path / f"compressed_data_{i}.npy").exists(): # Load tensor data @@ -149,9 +176,9 @@ def loads(self, storage, path): with open(path / f"compressed_data_{i}.pkl", "rb") as f: data = pickle.load(f) compressed_data.append(data) - elif (path / f"compressed_data_{i}").exists(): + elif (path / f"compressed_data_{i}.td").exists(): # Load tensordict data - item_dir = path / f"compressed_data_{i}" + item_dir = path / f"compressed_data_{i}.td" item_data = {} for key in metadata[i]["fields"].keys(): diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index b6c17769659..e836df3a828 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -32,7 +32,7 @@ from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.checkpointers import ( - CompressedStorageCheckpointer, + CompressedListStorageCheckpointer, ListStorageCheckpointer, StorageCheckpointerBase, StorageEnsembleCheckpointer, @@ -873,7 +873,6 @@ def set( # noqa: F811 *, set_cursor: bool = True, ): - if set_cursor: self._last_cursor = cursor @@ -1361,13 +1360,107 @@ def get(self, index: int | Sequence[int] | slice) -> Any: return result -class CompressedStorage(Storage): +class CompressedStorageView: + """A view that makes compressed storage look like a regular list to ListStorage methods.""" + + def __init__(self, compressed_data, metadata, parent_storage): + self._compressed_data = compressed_data + self._metadata = metadata + self._parent_storage = parent_storage + + def __len__(self): + return len([item for item in self._compressed_data if item is not None]) + + def __getitem__(self, index): + if isinstance(index, INT_CLASSES): + if ( + index >= len(self._compressed_data) + or self._compressed_data[index] is None + ): + raise IndexError(f"Index {index} out of bounds or not set") + return self._parent_storage._decompress_item( + self._compressed_data[index], self._metadata[index] + ) + elif isinstance(index, slice): + start, stop, step = index.indices(len(self._compressed_data)) + results = [] + for i in range(start, stop, step): + if ( + i < len(self._compressed_data) + and self._compressed_data[i] is not None + ): + results.append( + self._parent_storage._decompress_item( + self._compressed_data[i], self._metadata[i] + ) + ) + return results + else: + # Handle lists, tensors, etc. + if isinstance(index, torch.Tensor): + if index.ndim == 0: + # Handle 0-dimensional tensor (scalar) + return self[index.item()] + if index.device.type != "cpu": + index = index.cpu().tolist() + else: + index = index.tolist() + + results = [] + for i in index: + if i >= len(self._compressed_data) or self._compressed_data[i] is None: + raise IndexError(f"Index {i} out of bounds or not set") + results.append( + self._parent_storage._decompress_item( + self._compressed_data[i], self._metadata[i] + ) + ) + return results + + def __setitem__(self, index, value): + if isinstance(index, INT_CLASSES): + # Ensure we have enough space + while len(self._compressed_data) <= index: + self._compressed_data.append(None) + self._metadata.append(None) + + # Compress and store + compressed_data, metadata = self._parent_storage._compress_item(value) + self._compressed_data[index] = compressed_data + self._metadata[index] = metadata + elif isinstance(index, slice): + # Handle slice assignment + if not hasattr(value, "__iter__"): + value = [value] + start, stop, step = index.indices(len(self._compressed_data)) + indices = list(range(start, stop, step)) + + for i, v in zip(indices, value): + self[i] = v + else: + # Handle multiple indices + if isinstance(index, torch.Tensor) and index.device.type != "cpu": + index = index.cpu().tolist() + if not hasattr(value, "__iter__"): + value = [value] * len(index) + for i, v in zip(index, value): + self[i] = v + + def append(self, value): + # Find the next available slot + index = len(self._compressed_data) + self[index] = value + + +class CompressedListStorage(ListStorage): """A storage that compresses and decompresses data. This storage compresses data when storing and decompresses when retrieving. It's particularly useful for storing raw sensory observations like images that can be compressed significantly to save memory. + pytest test/test_rb.py::TestCompressedListStorage::test_compressed_storage_tensor -v + Args: max_size (int): size of the storage, i.e. maximum number of elements stored in the buffer. @@ -1385,11 +1478,11 @@ class CompressedStorage(Storage): Examples: >>> import torch - >>> from torchrl.data import CompressedStorage, ReplayBuffer + >>> from torchrl.data import CompressedListStorage, ReplayBuffer >>> from tensordict import TensorDict >>> >>> # Create a compressed storage for image data - >>> storage = CompressedStorage(max_size=1000, compression_level=3) + >>> storage = CompressedListStorage(max_size=1000, compression_level=3) >>> rb = ReplayBuffer(storage=storage, batch_size=5) >>> >>> # Add some image data @@ -1403,7 +1496,7 @@ class CompressedStorage(Storage): """ - _default_checkpointer = CompressedStorageCheckpointer + _default_checkpointer = CompressedListStorageCheckpointer def __init__( self, @@ -1415,8 +1508,7 @@ def __init__( device: torch.device = "cpu", compilable: bool = False, ): - super().__init__(max_size, compilable=compilable) - self.device = device + super().__init__(max_size, compilable=compilable, device=device) self.compression_level = compression_level # Set up compression functions @@ -1474,50 +1566,61 @@ def _default_decompression_fn( decompressed_bytes = decompressor.decompress(compressed_bytes) # Convert back to tensor - tensor = torch.frombuffer(decompressed_bytes, dtype=metadata["dtype"]) + tensor = torch.frombuffer( + bytearray(decompressed_bytes), dtype=metadata["dtype"] + ) tensor = tensor.reshape(metadata["shape"]) tensor = tensor.to(metadata["device"]) return tensor - def set( - self, - cursor: int | Sequence[int] | slice, - data: Any, - *, - set_cursor: bool = True, - ): - """Set data in the storage with compression.""" - if isinstance(cursor, (INT_CLASSES, slice)): - cursor = [cursor] - - if isinstance(data, (list, tuple)): - data_list = data - elif isinstance(data, (torch.Tensor, TensorDictBase)): - data_list = data.unbind(0) - # determine if data is iterable and has length equal to cursor - elif hasattr(data, "__len__") and len(data) == len(cursor): - data_list = data + @property + def _storage(self): + """Virtual storage that handles compression/decompression transparently.""" + return CompressedStorageView(self._compressed_data, self._metadata, self) + + @_storage.setter + def _storage(self, value): + # This allows ListStorage.__init__ to set _storage initially + if hasattr(self, "_compressed_data"): + # If we already have compressed storage, ignore attempts to set _storage + pass else: - data_list = [data] + # During initialization, create our compressed storage + self._compressed_data = [] + self._metadata = [] - for idx, item in zip(cursor, data_list): - if idx >= self.max_size: - raise RuntimeError( - f"Cannot set data at index {idx}: storage max_size is {self.max_size}" - ) + def __len__(self): + """Return the number of items in the storage.""" + return len([item for item in self._compressed_data if item is not None]) - # Compress the data - compressed_data, metadata = self._compress_item(item) + def _empty(self): + """Empty the storage.""" + self._compressed_data = [] + self._metadata = [] - # Ensure we have enough space - while len(self._compressed_data) <= idx: - self._compressed_data.append(None) - self._metadata.append(None) + def state_dict(self) -> dict[str, Any]: + """Save the storage state.""" + return { + "_compressed_data": self._compressed_data, + "_metadata": self._metadata, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the storage state.""" + self._compressed_data = state_dict["_compressed_data"] + self._metadata = state_dict["_metadata"] - # Store compressed data and metadata - self._compressed_data[idx] = compressed_data - self._metadata[idx] = metadata + def contains(self, item): + """Check if an item is in the storage.""" + if isinstance(item, int): + if item < 0: + item += len(self._compressed_data) + return ( + 0 <= item < len(self._compressed_data) + and self._compressed_data[item] is not None + ) + raise NotImplementedError(f"type {type(item)} is not supported yet.") def _compress_item(self, item: Any) -> tuple[torch.Tensor, dict]: """Compress a single item and return compressed data with metadata.""" @@ -1556,28 +1659,28 @@ def _compress_item(self, item: Any) -> tuple[torch.Tensor, dict]: return compressed, metadata - def get(self, index: int | Sequence[int] | slice) -> Any: - """Get data from the storage with decompression.""" - if isinstance(index, (INT_CLASSES, slice)): - indices = [index] - else: - indices = index + # def get(self, index: int | Sequence[int] | slice) -> Any: + # """Get data from the storage with decompression.""" + # if isinstance(index, (INT_CLASSES, slice)): + # indices = [index] + # else: + # indices = index - results = [] - for idx in indices: - if idx >= len(self._compressed_data) or self._compressed_data[idx] is None: - raise IndexError(f"Index {idx} out of bounds or not set") + # results = [] + # for idx in indices: + # if idx >= len(self._compressed_data) or self._compressed_data[idx] is None: + # raise IndexError(f"Index {idx} out of bounds or not set") - compressed_data = self._compressed_data[idx] - metadata = self._metadata[idx] + # compressed_data = self._compressed_data[idx] + # metadata = self._metadata[idx] - # Decompress the data - decompressed = self._decompress_item(compressed_data, metadata) - results.append(decompressed) + # # Decompress the data + # decompressed = self._decompress_item(compressed_data, metadata) + # results.append(decompressed) - if len(results) == 1: - return results[0] - return results + # if len(results) == 1: + # return results[0] + # return results def _decompress_item(self, compressed_data: Any, metadata: dict) -> Any: """Decompress a single item using its metadata.""" @@ -1601,49 +1704,49 @@ def _decompress_item(self, compressed_data: Any, metadata: dict) -> Any: # Return as-is for other types return metadata["value"] - def __len__(self): - """Return the number of items in the storage.""" - return len([item for item in self._compressed_data if item is not None]) - - def state_dict(self) -> dict[str, Any]: - """Save the storage state.""" - return { - "_compressed_data": self._compressed_data, - "_metadata": self._metadata, - } - - def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Load the storage state.""" - self._compressed_data = state_dict["_compressed_data"] - self._metadata = state_dict["_metadata"] - - def _empty(self): - """Empty the storage.""" - self._compressed_data = [] - self._metadata = [] - - def loads(self, path: str): - """Load the storage state from a file.""" - super().loads(path) - from tensordict.utils import _STR_DTYPE_TO_DTYPE - from torch.utils._pytree import tree_flatten, tree_unflatten - - leaves, spec = tree_flatten(self._metadata) - leaves = [ - _STR_DTYPE_TO_DTYPE.get(x, x) if isinstance(x, str) else x for x in leaves - ] - self._metadata = tree_unflatten(leaves, spec) - - def contains(self, item): - """Check if an item is in the storage.""" - if isinstance(item, int): - if item < 0: - item += len(self._compressed_data) - return ( - 0 <= item < len(self._compressed_data) - and self._compressed_data[item] is not None - ) - raise NotImplementedError(f"type {type(item)} is not supported yet.") + # def __len__(self): + # """Return the number of items in the storage.""" + # return len([item for item in self._compressed_data if item is not None]) + + # def state_dict(self) -> dict[str, Any]: + # """Save the storage state.""" + # return { + # "_compressed_data": self._compressed_data, + # "_metadata": self._metadata, + # } + + # def load_state_dict(self, state_dict: dict[str, Any]) -> None: + # """Load the storage state.""" + # self._compressed_data = state_dict["_compressed_data"] + # self._metadata = state_dict["_metadata"] + + # def _empty(self): + # """Empty the storage.""" + # self._compressed_data = [] + # self._metadata = [] + + # def loads(self, path: str): + # """Load the storage state from a file.""" + # super().loads(path) + # from tensordict.utils import _STR_DTYPE_TO_DTYPE + # from torch.utils._pytree import tree_flatten, tree_unflatten + + # leaves, spec = tree_flatten(self._metadata) + # leaves = [ + # _STR_DTYPE_TO_DTYPE.get(x, x) if isinstance(x, str) else x for x in leaves + # ] + # self._metadata = tree_unflatten(leaves, spec) + + # def contains(self, item): + # """Check if an item is in the storage.""" + # if isinstance(item, int): + # if item < 0: + # item += len(self._compressed_data) + # return ( + # 0 <= item < len(self._compressed_data) + # and self._compressed_data[item] is not None + # ) + # raise NotImplementedError(f"type {type(item)} is not supported yet.") def bytes(self): """Return the number of bytes in the storage.""" @@ -1706,7 +1809,7 @@ def __init__( self._transforms = transforms if transforms is not None and len(transforms) != len(storages): raise TypeError( - "transforms must have the same length as the storages " "provided." + "transforms must have the same length as the storages provided." ) @property @@ -1730,7 +1833,7 @@ def get(self, item): buffer_ids = item.get("buffer_ids") index = item.get("index") results = [] - for (buffer_id, sample) in zip(buffer_ids, index): + for buffer_id, sample in zip(buffer_ids, index): buffer_id = self._convert_id(buffer_id) results.append((buffer_id, self._get_storage(buffer_id).get(sample))) if self._transforms is not None: @@ -1880,7 +1983,7 @@ def _collate_id(x): def _get_default_collate(storage, _is_tensordict=False): if isinstance(storage, (LazyStackStorage, TensorStorage)): return _collate_id - elif isinstance(storage, CompressedStorage): + elif isinstance(storage, CompressedListStorage): return lazy_stack elif isinstance(storage, (ListStorage, StorageEnsemble)): return _stack_anything From 600077a482d583e6af7a263c3e397d6eb864cb18 Mon Sep 17 00:00:00 2001 From: Adrian Orenstein Date: Sat, 12 Jul 2025 21:39:00 -0600 Subject: [PATCH 4/8] Refactor out the storage view. Expose functions in the ListStorage class. --- test/test_rb.py | 2 +- torchrl/data/replay_buffers/checkpointers.py | 11 +- torchrl/data/replay_buffers/storages.py | 339 +++++++------------ 3 files changed, 123 insertions(+), 229 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index fbe3842b202..bf6c2cc9fbd 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -4322,7 +4322,7 @@ def test_compressed_storage_memory_efficiency(self): storage.set(0, large_tensor) # Estimate compressed size - compressed_data = storage._compressed_data[0] + compressed_data = storage._storage[0] compressed_size = compressed_data.numel() # uint8 bytes # Verify compression ratio is reasonable (at least 2x for random data) diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 82869204ea6..71c525d7963 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -79,10 +79,7 @@ def dumps(self, storage, path): path = Path(path) path.mkdir(exist_ok=True) - if ( - not hasattr(storage, "_compressed_data") - or len(storage._compressed_data) == 0 - ): + if not hasattr(storage, "_storage") or len(storage._storage) == 0: raise RuntimeError( "Cannot save an empty or non-initialized CompressedListStorage." ) @@ -90,8 +87,8 @@ def dumps(self, storage, path): # Save compressed data and metadata state_dict = storage.state_dict() - # Save compressed data as separate files for efficiency - compressed_data = state_dict["_compressed_data"] + # Save compressed data and metadata + compressed_data = state_dict["_storage"] metadata = state_dict["_metadata"] # Save metadata @@ -203,7 +200,7 @@ def convert_dtype(item): compressed_data.append(None) # Load into storage - storage._compressed_data = compressed_data + storage._storage = compressed_data storage._metadata = metadata diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index e836df3a828..6e0a904460d 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -279,7 +279,7 @@ def set( return if isinstance(cursor, slice): data = self._to_device(data) - self._storage[cursor] = data + self._set_slice(cursor, data) return if isinstance( data, @@ -303,7 +303,7 @@ def set( ) return else: - if cursor > len(self._storage): + if cursor > len(self): raise RuntimeError( "Cannot append data located more than one item away from " f"the storage size: the storage size is {len(self)} " @@ -316,14 +316,24 @@ def set( f"and the index of the item to be set is {cursor}." ) data = self._to_device(data) - if cursor == len(self._storage): - self._storage.append(data) - else: - self._storage[cursor] = data + self._set_item(cursor, data) + + def _set_item(self, cursor: int, data: Any) -> None: + """Set a single item in the storage.""" + if cursor == len(self._storage): + self._storage.append(data) + else: + self._storage[cursor] = data + + def _set_slice(self, cursor: slice, data: Any) -> None: + """Set a slice in the storage.""" + self._storage[cursor] = data def get(self, index: int | Sequence[int] | slice) -> Any: - if isinstance(index, (INT_CLASSES, slice)): - return self._storage[index] + if isinstance(index, INT_CLASSES): + return self._get_item(index) + elif isinstance(index, slice): + return self._get_slice(index) elif isinstance(index, tuple): if len(index) > 1: raise RuntimeError( @@ -333,9 +343,22 @@ def get(self, index: int | Sequence[int] | slice) -> Any: else: if isinstance(index, torch.Tensor) and index.device.type != "cpu": index = index.cpu().tolist() - return [self._storage[i] for i in index] + return self._get_list(index) + + def _get_item(self, index: int) -> Any: + """Get a single item from the storage.""" + return self._storage[index] + + def _get_slice(self, index: slice) -> Any: + """Get a slice from the storage.""" + return self._storage[index] + + def _get_list(self, index: list) -> list: + """Get a list of items from the storage.""" + return [self._storage[i] for i in index] def __len__(self): + """Get the length of the storage.""" return len(self._storage) def state_dict(self) -> dict[str, Any]: @@ -379,9 +402,8 @@ def __repr__(self): def contains(self, item): if isinstance(item, int): if item < 0: - item += len(self._storage) - - return 0 <= item < len(self._storage) + item += len(self) + return self._contains_int(item) if isinstance(item, torch.Tensor): return torch.tensor( [self.contains(elt) for elt in item.tolist()], @@ -390,6 +412,10 @@ def contains(self, item): ).reshape_as(item) raise NotImplementedError(f"type {type(item)} is not supported yet.") + def _contains_int(self, item: int) -> bool: + """Check if an integer index is contained in the storage.""" + return 0 <= item < len(self) + class LazyStackStorage(ListStorage): """A ListStorage that returns LazyStackTensorDict instances. @@ -1360,98 +1386,6 @@ def get(self, index: int | Sequence[int] | slice) -> Any: return result -class CompressedStorageView: - """A view that makes compressed storage look like a regular list to ListStorage methods.""" - - def __init__(self, compressed_data, metadata, parent_storage): - self._compressed_data = compressed_data - self._metadata = metadata - self._parent_storage = parent_storage - - def __len__(self): - return len([item for item in self._compressed_data if item is not None]) - - def __getitem__(self, index): - if isinstance(index, INT_CLASSES): - if ( - index >= len(self._compressed_data) - or self._compressed_data[index] is None - ): - raise IndexError(f"Index {index} out of bounds or not set") - return self._parent_storage._decompress_item( - self._compressed_data[index], self._metadata[index] - ) - elif isinstance(index, slice): - start, stop, step = index.indices(len(self._compressed_data)) - results = [] - for i in range(start, stop, step): - if ( - i < len(self._compressed_data) - and self._compressed_data[i] is not None - ): - results.append( - self._parent_storage._decompress_item( - self._compressed_data[i], self._metadata[i] - ) - ) - return results - else: - # Handle lists, tensors, etc. - if isinstance(index, torch.Tensor): - if index.ndim == 0: - # Handle 0-dimensional tensor (scalar) - return self[index.item()] - if index.device.type != "cpu": - index = index.cpu().tolist() - else: - index = index.tolist() - - results = [] - for i in index: - if i >= len(self._compressed_data) or self._compressed_data[i] is None: - raise IndexError(f"Index {i} out of bounds or not set") - results.append( - self._parent_storage._decompress_item( - self._compressed_data[i], self._metadata[i] - ) - ) - return results - - def __setitem__(self, index, value): - if isinstance(index, INT_CLASSES): - # Ensure we have enough space - while len(self._compressed_data) <= index: - self._compressed_data.append(None) - self._metadata.append(None) - - # Compress and store - compressed_data, metadata = self._parent_storage._compress_item(value) - self._compressed_data[index] = compressed_data - self._metadata[index] = metadata - elif isinstance(index, slice): - # Handle slice assignment - if not hasattr(value, "__iter__"): - value = [value] - start, stop, step = index.indices(len(self._compressed_data)) - indices = list(range(start, stop, step)) - - for i, v in zip(indices, value): - self[i] = v - else: - # Handle multiple indices - if isinstance(index, torch.Tensor) and index.device.type != "cpu": - index = index.cpu().tolist() - if not hasattr(value, "__iter__"): - value = [value] * len(index) - for i, v in zip(index, value): - self[i] = v - - def append(self, value): - # Find the next available slot - index = len(self._compressed_data) - self[index] = value - - class CompressedListStorage(ListStorage): """A storage that compresses and decompresses data. @@ -1523,7 +1457,7 @@ def __init__( self.decompression_fn = decompression_fn # Store compressed data and metadata - self._compressed_data = [] + self._storage = [] self._metadata = [] # Store shape, dtype, device info for each item def _default_compression_fn(self, tensor: torch.Tensor) -> torch.Tensor: @@ -1574,54 +1508,6 @@ def _default_decompression_fn( return tensor - @property - def _storage(self): - """Virtual storage that handles compression/decompression transparently.""" - return CompressedStorageView(self._compressed_data, self._metadata, self) - - @_storage.setter - def _storage(self, value): - # This allows ListStorage.__init__ to set _storage initially - if hasattr(self, "_compressed_data"): - # If we already have compressed storage, ignore attempts to set _storage - pass - else: - # During initialization, create our compressed storage - self._compressed_data = [] - self._metadata = [] - - def __len__(self): - """Return the number of items in the storage.""" - return len([item for item in self._compressed_data if item is not None]) - - def _empty(self): - """Empty the storage.""" - self._compressed_data = [] - self._metadata = [] - - def state_dict(self) -> dict[str, Any]: - """Save the storage state.""" - return { - "_compressed_data": self._compressed_data, - "_metadata": self._metadata, - } - - def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Load the storage state.""" - self._compressed_data = state_dict["_compressed_data"] - self._metadata = state_dict["_metadata"] - - def contains(self, item): - """Check if an item is in the storage.""" - if isinstance(item, int): - if item < 0: - item += len(self._compressed_data) - return ( - 0 <= item < len(self._compressed_data) - and self._compressed_data[item] is not None - ) - raise NotImplementedError(f"type {type(item)} is not supported yet.") - def _compress_item(self, item: Any) -> tuple[torch.Tensor, dict]: """Compress a single item and return compressed data with metadata.""" if isinstance(item, torch.Tensor): @@ -1659,29 +1545,6 @@ def _compress_item(self, item: Any) -> tuple[torch.Tensor, dict]: return compressed, metadata - # def get(self, index: int | Sequence[int] | slice) -> Any: - # """Get data from the storage with decompression.""" - # if isinstance(index, (INT_CLASSES, slice)): - # indices = [index] - # else: - # indices = index - - # results = [] - # for idx in indices: - # if idx >= len(self._compressed_data) or self._compressed_data[idx] is None: - # raise IndexError(f"Index {idx} out of bounds or not set") - - # compressed_data = self._compressed_data[idx] - # metadata = self._metadata[idx] - - # # Decompress the data - # decompressed = self._decompress_item(compressed_data, metadata) - # results.append(decompressed) - - # if len(results) == 1: - # return results[0] - # return results - def _decompress_item(self, compressed_data: Any, metadata: dict) -> Any: """Decompress a single item using its metadata.""" if metadata["type"] == "tensor": @@ -1704,49 +1567,83 @@ def _decompress_item(self, compressed_data: Any, metadata: dict) -> Any: # Return as-is for other types return metadata["value"] - # def __len__(self): - # """Return the number of items in the storage.""" - # return len([item for item in self._compressed_data if item is not None]) - - # def state_dict(self) -> dict[str, Any]: - # """Save the storage state.""" - # return { - # "_compressed_data": self._compressed_data, - # "_metadata": self._metadata, - # } - - # def load_state_dict(self, state_dict: dict[str, Any]) -> None: - # """Load the storage state.""" - # self._compressed_data = state_dict["_compressed_data"] - # self._metadata = state_dict["_metadata"] - - # def _empty(self): - # """Empty the storage.""" - # self._compressed_data = [] - # self._metadata = [] - - # def loads(self, path: str): - # """Load the storage state from a file.""" - # super().loads(path) - # from tensordict.utils import _STR_DTYPE_TO_DTYPE - # from torch.utils._pytree import tree_flatten, tree_unflatten - - # leaves, spec = tree_flatten(self._metadata) - # leaves = [ - # _STR_DTYPE_TO_DTYPE.get(x, x) if isinstance(x, str) else x for x in leaves - # ] - # self._metadata = tree_unflatten(leaves, spec) - - # def contains(self, item): - # """Check if an item is in the storage.""" - # if isinstance(item, int): - # if item < 0: - # item += len(self._compressed_data) - # return ( - # 0 <= item < len(self._compressed_data) - # and self._compressed_data[item] is not None - # ) - # raise NotImplementedError(f"type {type(item)} is not supported yet.") + def _set_item(self, cursor: int, data: Any) -> None: + """Set a single item in the compressed storage.""" + # Ensure we have enough space + while len(self._storage) <= cursor: + self._storage.append(None) + self._metadata.append(None) + + # Compress and store + compressed_data, metadata = self._compress_item(data) + self._storage[cursor] = compressed_data + self._metadata[cursor] = metadata + + def _set_slice(self, cursor: slice, data: Any) -> None: + """Set a slice in the compressed storage.""" + # Handle slice assignment + if not hasattr(data, "__iter__"): + data = [data] + start, stop, step = cursor.indices(len(self._storage)) + indices = list(range(start, stop, step)) + + for i, value in zip(indices, data): + self._set_item(i, value) + + def _get_item(self, index: int) -> Any: + """Get a single item from the compressed storage.""" + if index >= len(self._storage) or self._storage[index] is None: + raise IndexError(f"Index {index} out of bounds or not set") + + compressed_data = self._storage[index] + metadata = self._metadata[index] + return self._decompress_item(compressed_data, metadata) + + def _get_slice(self, index: slice) -> list: + """Get a slice from the compressed storage.""" + start, stop, step = index.indices(len(self._storage)) + results = [] + for i in range(start, stop, step): + if i < len(self._storage) and self._storage[i] is not None: + results.append(self._get_item(i)) + return results + + def _get_list(self, index: list) -> list: + """Get a list of items from the compressed storage.""" + if isinstance(index, torch.Tensor) and index.device.type != "cpu": + index = index.cpu().tolist() + + results = [] + for i in index: + if i >= len(self._storage) or self._storage[i] is None: + raise IndexError(f"Index {i} out of bounds or not set") + results.append(self._get_item(i)) + return results + + def __len__(self) -> int: + """Get the length of the compressed storage.""" + return len([item for item in self._storage if item is not None]) + + def _contains_int(self, item: int) -> bool: + """Check if an integer index is contained in the compressed storage.""" + return 0 <= item < len(self._storage) and self._storage[item] is not None + + def _empty(self): + """Empty the storage.""" + self._storage = [] + self._metadata = [] + + def state_dict(self) -> dict[str, Any]: + """Save the storage state.""" + return { + "_storage": self._storage, + "_metadata": self._metadata, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the storage state.""" + self._storage = state_dict["_storage"] + self._metadata = state_dict["_metadata"] def bytes(self): """Return the number of bytes in the storage.""" @@ -1763,9 +1660,9 @@ def compressed_size_from_list(data: Any) -> int: else: return 0 - compressed_size_estimate = compressed_size_from_list(self._compressed_data) + compressed_size_estimate = compressed_size_from_list(self._storage) if compressed_size_estimate == 0: - if len(self._compressed_data) > 0: + if len(self._storage) > 0: raise RuntimeError( "Compressed storage is not empty but the compressed size is 0. This is a bug." ) From 5581cf6dc7453fde02e4f7420f6f66706e55f7bd Mon Sep 17 00:00:00 2001 From: Adrian Orenstein Date: Sun, 13 Jul 2025 15:51:47 -0600 Subject: [PATCH 5/8] Using python's default compressor. Created to_bytestream. Created a to_bytestream speed test. --- .../unittest/linux/scripts/environment.yml | 2 +- test/test_rb.py | 143 ++++++++++ torchrl/data/replay_buffers/checkpointers.py | 265 +++++++++++------- torchrl/data/replay_buffers/storages.py | 94 +++++-- 4 files changed, 381 insertions(+), 123 deletions(-) diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index b7ca29ff0d4..d1debbdd510 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -35,4 +35,4 @@ dependencies: - transformers - ninja - timm - - zstandard + - safetensors diff --git a/test/test_rb.py b/test/test_rb.py index bf6c2cc9fbd..ad6d0411923 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -4210,6 +4210,18 @@ def test_compressed_storage_checkpointing(self): ) storage.set(0, test_td) + # second batch, different shape + test_td2 = TensorDict( + { + "obs": torch.randn(3, 85, 83, dtype=torch.float32), + "action": torch.tensor([1, 2, 3]), + "meta": torch.randn(3), + "astring": "a string!", + }, + batch_size=[3], + ) + storage.set(1, test_td) + # Create temporary directory for checkpointing with tempfile.TemporaryDirectory() as tmpdir: checkpoint_path = Path(tmpdir) / "checkpoint" @@ -4331,6 +4343,137 @@ def test_compressed_storage_memory_efficiency(self): compression_ratio > 1.5 ), f"Compression ratio {compression_ratio} is too low" + @staticmethod + def make_compressible_mock_data(num_experiences: int, device=None) -> dict: + """Easily compressible data for testing.""" + if device is None: + device = torch.device("cpu") + + return { + "observations": torch.zeros( + (num_experiences, 4, 84, 84), + dtype=torch.uint8, + device=device, + ), + "actions": torch.zeros((num_experiences,), device=device), + "rewards": torch.zeros((num_experiences,), device=device), + "next_observations": torch.zeros( + (num_experiences, 4, 84, 84), + dtype=torch.uint8, + device=device, + ), + "terminations": torch.zeros( + (num_experiences,), dtype=torch.bool, device=device + ), + "truncations": torch.zeros( + (num_experiences,), dtype=torch.bool, device=device + ), + "batch_size": [num_experiences], + } + + @staticmethod + def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict: + """Uncompressible data for testing.""" + if device is None: + device = torch.device("cpu") + return { + "observations": torch.randn( + (num_experiences, 4, 84, 84), + dtype=torch.float32, + device=device, + ), + "actions": torch.randint(0, 10, (num_experiences,), device=device), + "rewards": torch.randn( + (num_experiences,), dtype=torch.float32, device=device + ), + "next_observations": torch.randn( + (num_experiences, 4, 84, 84), + dtype=torch.float32, + device=device, + ), + "terminations": torch.rand((num_experiences,), device=device) + < 0.2, # ~20% True + "truncations": torch.rand((num_experiences,), device=device) + < 0.1, # ~10% True + "batch_size": [num_experiences], + } + + @pytest.mark.benchmark( + group="tensor_serialization_speed", + min_time=0.1, + max_time=0.5, + min_rounds=5, + disable_gc=True, + warmup=False, + ) + @pytest.mark.parametrize( + "serialization_method", + ["pickle", "torch.save", "untyped_storage", "numpy", "safetensors"], + ) + def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str): + """Benchmark the speed of different tensor serialization methods. + + TODO: we might need to also test which methods work on the gpu. + pytest test/test_rb.py::TestCompressedListStorage::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops' + + ------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests ------------------------- + Name (time in us) Mean (smaller is better) OPS (bigger is better) + -------------------------------------------------------------------------------------------------- + test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0) + test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16) + test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12) + test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07) + test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00) + -------------------------------------------------------------------------------------------------- + """ + import io + import pickle + + import torch + from safetensors.torch import save + + def serialize_with_pickle(data: torch.Tensor) -> bytes: + """Serialize tensor using pickle.""" + buffer = io.BytesIO() + pickle.dump(data, buffer) + return buffer.getvalue() + + def serialize_with_untyped_storage(data: torch.Tensor) -> bytes: + """Serialize tensor using torch's built-in method.""" + return bytes(data.untyped_storage()) + + def serialize_with_numpy(data: torch.Tensor) -> bytes: + """Serialize tensor using numpy.""" + return data.numpy().tobytes() + + def serialize_with_safetensors(data: torch.Tensor) -> bytes: + return save({"0": data}) + + def serialize_with_torch(data: torch.Tensor) -> bytes: + """Serialize tensor using torch's built-in method.""" + buffer = io.BytesIO() + torch.save(data, buffer) + return buffer.getvalue() + + # Benchmark each serialization method + if serialization_method == "pickle": + serialize_fn = serialize_with_pickle + elif serialization_method == "torch.save": + serialize_fn = serialize_with_torch + elif serialization_method == "untyped_storage": + serialize_fn = serialize_with_untyped_storage + elif serialization_method == "numpy": + serialize_fn = serialize_with_numpy + elif serialization_method == "safetensors": + serialize_fn = serialize_with_safetensors + else: + raise ValueError(f"Unknown serialization method: {serialization_method}") + + data = self.make_compressible_mock_data(1).get("observations") + + # Run the actual benchmark + benchmark(serialize_fn, data) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 71c525d7963..267ab362e79 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -6,6 +6,7 @@ import abc import json +import tempfile import warnings from pathlib import Path @@ -13,11 +14,14 @@ import torch from tensordict import ( is_tensor_collection, + lazy_stack, NonTensorData, PersistentTensorDict, TensorDict, ) from tensordict.memmap import MemoryMappedTensor +from tensordict.utils import _zip_strict +from torch.utils._pytree import tree_map from torchrl._utils import _STRDTYPE2DTYPE from torchrl.data.replay_buffers.utils import ( @@ -71,11 +75,18 @@ def loads(storage, path): class CompressedListStorageCheckpointer(StorageCheckpointerBase): """A storage checkpointer for CompressedListStorage. - This checkpointer saves compressed data and metadata separately for efficient storage. + This checkpointer saves compressed data and metadata using memory-mapped storage + for efficient disk I/O and memory usage. """ def dumps(self, storage, path): + """Save compressed storage to disk using memory-mapped storage. + + Args: + storage: The CompressedListStorage instance to save + path: Directory path where to save the storage + """ path = Path(path) path.mkdir(exist_ok=True) @@ -84,120 +95,182 @@ def dumps(self, storage, path): "Cannot save an empty or non-initialized CompressedListStorage." ) - # Save compressed data and metadata + # Get state dict from storage state_dict = storage.state_dict() - - # Save compressed data and metadata compressed_data = state_dict["_storage"] metadata = state_dict["_metadata"] - # Save metadata - with open(path / "compressed_metadata.json", "w") as f: - json.dump(metadata, f, default=str) - - # Save compressed data - for i, (compressed_item, item_metadata) in enumerate( - zip(compressed_data, metadata) - ): - if compressed_item is not None: - if item_metadata["type"] == "tensor": - # Save as numpy array - np.save( - path / f"compressed_data_{i}.npy", compressed_item.cpu().numpy() - ) - elif item_metadata["type"] == "tensordict": - # Save each field separately - item_dir = path / f"compressed_data_{i}.td" - item_dir.mkdir(exist_ok=True) - - for key, value in compressed_item.items(): - if isinstance(value, torch.Tensor): - np.save(item_dir / f"{key}.npy", value.cpu().numpy()) - else: - # Save non-tensor data as pickle - import pickle - - with open(item_dir / f"{key}.pkl", "wb") as f: - pickle.dump(value, f) + # Create a temporary directory for processing + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Process compressed data for memmap storage + processed_data = [] + for item in compressed_data: + if item is None: + processed_data.append(None) + continue + + if isinstance(item, torch.Tensor): + # For tensor data, create a TensorDict with the tensor + processed_item = TensorDict({"data": item}, batch_size=[]) + elif isinstance(item, dict): + # For dict data (tensordict fields), convert to TensorDict + processed_item = TensorDict(item, batch_size=[]) else: - # Save other types as pickle - import pickle + # For other types, wrap in TensorDict + processed_item = TensorDict({"data": item}, batch_size=[]) + + processed_data.append(processed_item) + + # Stack all non-None items into a single TensorDict for memmap + non_none_data = [item for item in processed_data if item is not None] + if non_none_data: + # Use lazy_stack to handle heterogeneous structures + stacked_data = lazy_stack(non_none_data) + + # Save to memmap + stacked_data.memmap_(tmp_path / "compressed_data") + + # Create index mapping for None values + data_indices = [] + current_idx = 0 + for item in processed_data: + if item is None: + data_indices.append(None) + else: + data_indices.append(current_idx) + current_idx += 1 + else: + # No data to save + data_indices = [] + + # Process metadata for JSON serialization + def is_leaf(item): + return isinstance( + item, + ( + torch.Size, + torch.dtype, + torch.device, + str, + int, + float, + bool, + torch.Tensor, + NonTensorData, + ), + ) + + def map_to_json_serializable(item): + if isinstance(item, torch.Size): + return {"__type__": "torch.Size", "value": list(item)} + elif isinstance(item, torch.dtype): + return {"__type__": "torch.dtype", "value": str(item)} + elif isinstance(item, torch.device): + return {"__type__": "torch.device", "value": str(item)} + elif isinstance(item, torch.Tensor): + return {"__type__": "torch.Tensor", "value": item.tolist()} + elif isinstance(item, NonTensorData): + return {"__type__": "NonTensorData", "value": item.data} + return item + + serializable_metadata = tree_map( + map_to_json_serializable, metadata, is_leaf=is_leaf + ) + + # Save metadata and indices + metadata_file = tmp_path / "metadata.json" + with open(metadata_file, "w") as f: + json.dump(serializable_metadata, f, indent=2) + + indices_file = tmp_path / "data_indices.json" + with open(indices_file, "w") as f: + json.dump(data_indices, f, indent=2) - with open(path / f"compressed_data_{i}.pkl", "wb") as f: - pickle.dump(compressed_item, f) + # Copy all files from temp directory to final destination + import shutil + + for item in tmp_path.iterdir(): + if item.is_file(): + shutil.copy2(item, path / item.name) + elif item.is_dir(): + shutil.copytree(item, path / item.name, dirs_exist_ok=True) def loads(self, storage, path): + """Load compressed storage from disk. + + Args: + storage: The CompressedListStorage instance to load into + path: Directory path where the storage was saved + """ path = Path(path) # Load metadata - with open(path / "compressed_metadata.json") as f: - metadata = json.load(f) - - # Convert string dtypes back to torch.dtype objects - def convert_dtype(item): - if isinstance(item, dict): - if "dtype" in item and isinstance(item["dtype"], str): - # Convert string back to torch.dtype - dtype_str = item["dtype"] + metadata_file = path / "metadata.json" + if not metadata_file.exists(): + raise FileNotFoundError(f"Metadata file not found at {metadata_file}") + + with open(metadata_file) as f: + serializable_metadata = json.load(f) + + # Load data indices + indices_file = path / "data_indices.json" + if not indices_file.exists(): + raise FileNotFoundError(f"Data indices file not found at {indices_file}") + + with open(indices_file) as f: + data_indices = json.load(f) + + # Convert serializable metadata back to original format + def is_leaf(item): + return isinstance(item, dict) and "__type__" in item + + def map_from_json_serializable(item): + if isinstance(item, dict) and "__type__" in item: + if item["__type__"] == "torch.Size": + return torch.Size(item["value"]) + elif item["__type__"] == "torch.dtype": + # Handle torch.dtype conversion + dtype_str = item["value"] if hasattr(torch, dtype_str.replace("torch.", "")): - item["dtype"] = getattr(torch, dtype_str.replace("torch.", "")) + return getattr(torch, dtype_str.replace("torch.", "")) else: # Handle cases like 'torch.float32' -> torch.float32 - item["dtype"] = eval(dtype_str) - - # Recursively handle nested dictionaries - for _key, value in item.items(): - if isinstance(value, dict): - convert_dtype(value) + return eval(dtype_str) + elif item["__type__"] == "torch.device": + return torch.device(item["value"]) + elif item["__type__"] == "torch.Tensor": + return torch.tensor(item["value"]) + elif item["__type__"] == "NonTensorData": + return NonTensorData(item["value"]) return item - for item in metadata: - if item is not None: - convert_dtype(item) + metadata = tree_map( + map_from_json_serializable, serializable_metadata, is_leaf=is_leaf + ) - # Load compressed data + # Load compressed data from memmap compressed_data = [] - i = 0 - - # TODO(adrian): Can we not know the serialised format beforehand? Then we can use glob to iterate through the files we know exist: - # `for path in glob.glob(path / f"compressed_data_*.{fmt}" for fmt in ["npy", "pkl", "td"]):`` - while True: - if (path / f"compressed_data_{i}.npy").exists(): - # Load tensor data - data = np.load(path / f"compressed_data_{i}.npy") - compressed_data.append(torch.from_numpy(data)) - elif (path / f"compressed_data_{i}.pkl").exists(): - # Load other data - import pickle - - with open(path / f"compressed_data_{i}.pkl", "rb") as f: - data = pickle.load(f) - compressed_data.append(data) - elif (path / f"compressed_data_{i}.td").exists(): - # Load tensordict data - item_dir = path / f"compressed_data_{i}.td" - item_data = {} - - for key in metadata[i]["fields"].keys(): - if (item_dir / f"{key}.npy").exists(): - data = np.load(item_dir / f"{key}.npy") - item_data[key] = torch.from_numpy(data) - elif (item_dir / f"{key}.pkl").exists(): - import pickle - - with open(item_dir / f"{key}.pkl", "rb") as f: - data = pickle.load(f) - item_data[key] = data - - compressed_data.append(item_data) - else: - break - - i += 1 + memmap_path = path / "compressed_data" + + if memmap_path.exists(): + # Load the memmapped data + stacked_data = TensorDict.load_memmap(memmap_path) + compressed_data = stacked_data.tolist() + if len(compressed_data) != len(data_indices): + raise ValueError( + f"Length of compressed data ({len(compressed_data)}) does not match length of data indices ({len(data_indices)})" + ) + for i, (data, mtdt) in enumerate(_zip_strict(compressed_data, metadata)): + if mtdt["type"] == "tensor": + compressed_data[i] = data["data"] + else: + compressed_data[i] = data - # Pad with None to match metadata length - while len(compressed_data) < len(metadata): - compressed_data.append(None) + else: + # No data to load + compressed_data = [None] * len(data_indices) # Load into storage storage._storage = compressed_data diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 6e0a904460d..def5c64bd3e 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -7,6 +7,7 @@ import abc import logging import os +import sys import textwrap import warnings from collections import OrderedDict @@ -303,10 +304,10 @@ def set( ) return else: - if cursor > len(self): + if cursor > len(self._storage): raise RuntimeError( "Cannot append data located more than one item away from " - f"the storage size: the storage size is {len(self)} " + f"the storage size: the storage size is {len(self._storage)} " f"and the index of the item to be set is {cursor}." ) if cursor >= self.max_size: @@ -402,7 +403,7 @@ def __repr__(self): def contains(self, item): if isinstance(item, int): if item < 0: - item += len(self) + item += len(self._storage) return self._contains_int(item) if isinstance(item, torch.Tensor): return torch.tensor( @@ -414,7 +415,7 @@ def contains(self, item): def _contains_int(self, item: int) -> bool: """Check if an integer index is contained in the storage.""" - return 0 <= item < len(self) + return 0 <= item < len(self._storage) class LazyStackStorage(ListStorage): @@ -1462,42 +1463,44 @@ def __init__( def _default_compression_fn(self, tensor: torch.Tensor) -> torch.Tensor: """Default compression using zstd.""" - try: - import zstandard as zstd - except ImportError: - raise ImportError( - "zstandard is required for default compression. " - "Install with: pip install zstandard" - ) + if sys.version_info >= (3, 14): + from compression import zstd + + compressor_fn = zstd.compress + + else: + import zlib + + compressor_fn = zlib.compress # Convert tensor to bytes - tensor_bytes = tensor.cpu().numpy().tobytes() + tensor_bytes = self.to_bytestream(tensor) # Compress with zstd - compressor = zstd.ZstdCompressor(level=self.compression_level) - compressed_bytes = compressor.compress(tensor_bytes) + compressed_bytes = compressor_fn(tensor_bytes, level=self.compression_level) # Convert to tensor - return torch.tensor(list(compressed_bytes), dtype=torch.uint8) + return torch.frombuffer(bytearray(compressed_bytes), dtype=torch.uint8) def _default_decompression_fn( self, compressed_tensor: torch.Tensor, metadata: dict ) -> torch.Tensor: """Default decompression using zstd.""" - try: - import zstandard as zstd - except ImportError: - raise ImportError( - "zstandard is required for default decompression. " - "Install with: pip install zstandard" - ) + if sys.version_info >= (3, 14): + from compression import zstd + + decompressor_fn = zstd.decompress + + else: + import zlib + + decompressor_fn = zlib.decompress # Convert tensor to bytes - compressed_bytes = bytes(compressed_tensor.cpu().numpy()) + compressed_bytes = self.to_bytestream(compressed_tensor.cpu()) # Decompress with zstd - decompressor = zstd.ZstdDecompressor() - decompressed_bytes = decompressor.decompress(compressed_bytes) + decompressed_bytes = decompressor_fn(compressed_bytes) # Convert back to tensor tensor = torch.frombuffer( @@ -1645,7 +1648,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._storage = state_dict["_storage"] self._metadata = state_dict["_metadata"] - def bytes(self): + def num_bytes(self): """Return the number of bytes in the storage.""" def compressed_size_from_list(data: Any) -> int: @@ -1670,6 +1673,45 @@ def compressed_size_from_list(data: Any) -> int: return compressed_size_estimate + def to_bytestream(self, data_to_bytestream: torch.Tensor | np.array | Any) -> bytes: + """Convert data to a byte stream.""" + if isinstance(data_to_bytestream, torch.Tensor): + byte_stream = data_to_bytestream.cpu().numpy().tobytes() + + elif isinstance(data_to_bytestream, np.array): + byte_stream = bytes(data_to_bytestream.tobytes()) + + else: + import io + import pickle + + buffer = io.BytesIO() + pickle.dump(data_to_bytestream, buffer) + buffer.seek(0) + byte_stream = bytes(buffer.read()) + + return byte_stream + + # def to_bytestream( + # self, data_to_bytestream: Union[torch.Tensor, np.array, Any] + # ) -> bytes: + # """Convert data to a byte stream.""" + # if isinstance(data_to_bytestream, torch.Tensor): + # from safetensors.torch import save + # byte_stream = save({"0": data_to_bytestream}) + + # elif isinstance(data_to_bytestream, np.array): + # from safetensors.numpy import save + # byte_stream = bytes(data_to_bytestream.tobytes()) + + # else: + # buffer = io.BytesIO() + # pickle.dump(data_to_bytestream, buffer) + # buffer.seek(0) + # byte_stream = bytes(buffer.read()) + + # return byte_stream + class StorageEnsemble(Storage): """An ensemble of storages. From 77059a0bfd56b09eb109daadeb9d53d3c70a6bfb Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 16 Jul 2025 23:47:11 +0100 Subject: [PATCH 6/8] add-examples --- .../compressed_replay_buffer.py | 143 ++++++++++++++++ .../compressed_replay_buffer_checkpoint.py | 156 ++++++++++++++++++ 2 files changed, 299 insertions(+) create mode 100644 examples/replay-buffers/compressed_replay_buffer.py create mode 100644 examples/replay-buffers/compressed_replay_buffer_checkpoint.py diff --git a/examples/replay-buffers/compressed_replay_buffer.py b/examples/replay-buffers/compressed_replay_buffer.py new file mode 100644 index 00000000000..3e68e76468a --- /dev/null +++ b/examples/replay-buffers/compressed_replay_buffer.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +""" +Example demonstrating the use of CompressedStorage for memory-efficient replay buffers. + +This example shows how to use the new CompressedStorage to store image observations +with significant memory savings through compression. +""" + +import time + +import torch +from tensordict import TensorDict +from torchrl.data import CompressedStorage, ReplayBuffer + + +def main(): + print("=== Compressed Replay Buffer Example ===\n") + + # Create a compressed storage with zstd compression + print("Creating compressed storage...") + storage = CompressedStorage( + max_size=1000, + compression_level=3, # zstd compression level (1-22) + device="cpu", + ) + + # Create replay buffer with compressed storage + rb = ReplayBuffer(storage=storage, batch_size=32) + + # Simulate Atari-like image data (84x84 RGB frames) + print("Generating sample image data...") + num_frames = 100 + image_data = torch.zeros(num_frames, 3, 84, 84, dtype=torch.float32) + image_data.copy_( + torch.arange(num_frames * 3 * 84 * 84).reshape(num_frames, 3, 84, 84) + // (3 * 84 * 84) + ) + + # Create TensorDict with image observations + data = TensorDict( + { + "obs": image_data, + "action": torch.randint(0, 4, (num_frames,)), # 4 possible actions + "reward": torch.randn(num_frames), + "done": torch.randint(0, 2, (num_frames,), dtype=torch.bool), + }, + batch_size=[num_frames], + ) + + # Measure memory usage before adding data + print(f"Original data size: {data.bytes() / 1024 / 1024:.2f} MB") + + # Add data to replay buffer + print("Adding data to replay buffer...") + start_time = time.time() + rb.extend(data) + add_time = time.time() - start_time + print(f"Time to add data: {add_time:.3f} seconds") + + # Sample from replay buffer + print("Sampling from replay buffer...") + start_time = time.time() + sample = rb.sample(32) + sample_time = time.time() - start_time + print(f"Time to sample: {sample_time:.3f} seconds") + + # Verify data integrity + print("\nVerifying data integrity...") + original_shape = image_data.shape + sampled_shape = sample["obs"].shape + print(f"Original shape: {original_shape}") + print(f"Sampled shape: {sampled_shape}") + + # Check that shapes match (accounting for batch size) + assert sampled_shape[1:] == original_shape[1:], "Shape mismatch!" + print("āœ“ Data integrity verified!") + + # Demonstrate compression ratio + print("\n=== Compression Analysis ===") + + # Estimate compressed size (this is approximate) + compressed_size_estimate = storage.bytes() + + original_size = data.bytes() + compression_ratio = ( + original_size / compressed_size_estimate if compressed_size_estimate > 0 else 1 + ) + + print(f"Original size: {original_size / 1024 / 1024:.2f} MB") + print( + f"Compressed size (estimate): {compressed_size_estimate / 1024 / 1024:.2f} MB" + ) + print(f"Compression ratio: {compression_ratio:.1f}x") + + # Test with different compression levels + print("\n=== Testing Different Compression Levels ===") + + for level in [1, 3, 6, 9]: + print(f"\nTesting compression level {level}...") + + # Create new storage with different compression level + test_storage = CompressedStorage( + max_size=100, compression_level=level, device="cpu" + ) + + # Test with a smaller dataset + N = 100 + obs = torch.zeros(N, 3, 84, 84, dtype=torch.float32) + obs.copy_(torch.arange(N * 3 * 84 * 84).reshape(N, 3, 84, 84) // (3 * 84 * 84)) + test_data = TensorDict( + { + "obs": obs, + }, + batch_size=[N], + ) + + test_rb = ReplayBuffer(storage=test_storage, batch_size=5) + + # Measure compression time + start_time = time.time() + test_rb.extend(test_data) + compress_time = time.time() - start_time + + # Measure decompression time + start_time = time.time() + test_rb.sample(5) + decompress_time = time.time() - start_time + + print(f" Compression time: {compress_time:.3f}s") + print(f" Decompression time: {decompress_time:.3f}s") + + # Estimate compression ratio + test_ratio = test_data.bytes() / test_storage.bytes() + print(f" Compression ratio: {test_ratio:.1f}x") + + print("\n=== Example Complete ===") + print( + "The CompressedStorage successfully reduces memory usage while maintaining data integrity!" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/replay-buffers/compressed_replay_buffer_checkpoint.py b/examples/replay-buffers/compressed_replay_buffer_checkpoint.py new file mode 100644 index 00000000000..9242c2ad4cd --- /dev/null +++ b/examples/replay-buffers/compressed_replay_buffer_checkpoint.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +""" +Example demonstrating the improved CompressedListStorage with memmap functionality. + +This example shows how to use the new checkpointing capabilities that leverage +memory-mapped storage for efficient disk I/O and memory usage. +""" + +import tempfile +from pathlib import Path + +import torch +from tensordict import TensorDict + +from torchrl.data import CompressedListStorage, ReplayBuffer + + +def main(): + """Demonstrate compressed storage with memmap checkpointing.""" + + # Create a compressed storage with high compression level + storage = CompressedListStorage(max_size=1000, compression_level=6) + + # Create some sample data with different shapes and types + print("Creating sample data...") + + # Add tensor data + tensor_data = torch.randn(10, 3, 84, 84, dtype=torch.float32) # Image-like data + storage.set(0, tensor_data) + + # Add TensorDict data with mixed content + td_data = TensorDict( + { + "obs": torch.randn(5, 4, 84, 84, dtype=torch.float32), + "action": torch.randint(0, 18, (5,), dtype=torch.long), + "reward": torch.randn(5, dtype=torch.float32), + "done": torch.zeros(5, dtype=torch.bool), + "meta": "some metadata string", + }, + batch_size=[5], + ) + storage.set(1, td_data) + + # Add another tensor with different shape + tensor_data2 = torch.randn(8, 64, dtype=torch.float32) + storage.set(2, tensor_data2) + + print(f"Storage length: {len(storage)}") + print(f"Storage contains index 0: {storage.contains(0)}") + print(f"Storage contains index 3: {storage.contains(3)}") + + # Demonstrate data retrieval + print("\nRetrieving data...") + retrieved_tensor = storage.get(0) + retrieved_td = storage.get(1) + retrieved_tensor2 = storage.get(2) + + print(f"Retrieved tensor shape: {retrieved_tensor.shape}") + print(f"Retrieved TensorDict keys: {list(retrieved_td.keys())}") + print(f"Retrieved tensor2 shape: {retrieved_tensor2.shape}") + + # Verify data integrity + assert torch.allclose(tensor_data, retrieved_tensor, atol=1e-6) + assert torch.allclose(td_data["obs"], retrieved_td["obs"], atol=1e-6) + assert torch.allclose(tensor_data2, retrieved_tensor2, atol=1e-6) + print("āœ“ Data integrity verified!") + + # Demonstrate memmap checkpointing + print("\nDemonstrating memmap checkpointing...") + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "compressed_storage_checkpoint" + + # Save to disk using memmap + print(f"Saving to {checkpoint_path}...") + storage.dumps(checkpoint_path) + + # Check what files were created + print("Files created:") + for file_path in checkpoint_path.rglob("*"): + if file_path.is_file(): + size_mb = file_path.stat().st_size / (1024 * 1024) + print(f" {file_path.relative_to(checkpoint_path)}: {size_mb:.2f} MB") + + # Create new storage and load from checkpoint + print("\nLoading from checkpoint...") + new_storage = CompressedListStorage(max_size=1000, compression_level=6) + new_storage.loads(checkpoint_path) + + # Verify data integrity after checkpointing + print("Verifying data integrity after checkpointing...") + new_retrieved_tensor = new_storage.get(0) + new_retrieved_td = new_storage.get(1) + new_retrieved_tensor2 = new_storage.get(2) + + assert torch.allclose(tensor_data, new_retrieved_tensor, atol=1e-6) + assert torch.allclose(td_data["obs"], new_retrieved_td["obs"], atol=1e-6) + assert torch.allclose(tensor_data2, new_retrieved_tensor2, atol=1e-6) + print("āœ“ Data integrity after checkpointing verified!") + + print(f"New storage length: {len(new_storage)}") + + # Demonstrate with ReplayBuffer + print("\nDemonstrating with ReplayBuffer...") + + rb = ReplayBuffer(storage=CompressedListStorage(max_size=100, compression_level=4)) + + # Add some data to the replay buffer + for _ in range(5): + data = TensorDict( + { + "obs": torch.randn(3, 84, 84, dtype=torch.float32), + "action": torch.randint(0, 18, (3,), dtype=torch.long), + "reward": torch.randn(3, dtype=torch.float32), + }, + batch_size=[3], + ) + rb.extend(data) + + print(f"ReplayBuffer size: {len(rb)}") + + # Sample from the buffer + sample = rb.sample(2) + print(f"Sampled data keys: {list(sample.keys())}") + print(f"Sampled obs shape: {sample['obs'].shape}") + + # Checkpoint the replay buffer + with tempfile.TemporaryDirectory() as tmpdir: + rb_checkpoint_path = Path(tmpdir) / "rb_checkpoint" + print(f"\nCheckpointing ReplayBuffer to {rb_checkpoint_path}...") + rb.dumps(rb_checkpoint_path) + + # Create new replay buffer and load + new_rb = ReplayBuffer( + storage=CompressedListStorage(max_size=100, compression_level=4) + ) + new_rb.loads(rb_checkpoint_path) + + print(f"New ReplayBuffer size: {len(new_rb)}") + + # Verify sampling works + new_sample = new_rb.sample(2) + assert new_sample["obs"].shape == sample["obs"].shape + print("āœ“ ReplayBuffer checkpointing verified!") + + print("\nšŸŽ‰ All demonstrations completed successfully!") + print("\nKey benefits of the new memmap implementation:") + print("1. Efficient disk I/O using memory-mapped storage") + print("2. Reduced memory usage for large datasets") + print("3. Fast loading and saving of compressed data") + print("4. Support for heterogeneous data structures") + print("5. Seamless integration with ReplayBuffer") + + +if __name__ == "__main__": + main() From 6f972907390ffd209aacb6a8c711b5d9413c2c01 Mon Sep 17 00:00:00 2001 From: Adrian Orenstein Date: Thu, 17 Jul 2025 14:17:27 -0600 Subject: [PATCH 7/8] zstandard no longer needed for default compressor function --- test/test_rb.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index ad6d0411923..02676743bb4 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -133,7 +133,6 @@ _os_is_windows = sys.platform == "win32" _has_transformers = importlib.util.find_spec("transformers") is not None _has_ray = importlib.util.find_spec("ray") is not None -_has_zstandard = importlib.util.find_spec("zstandard") is not None TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) @@ -4061,7 +4060,6 @@ def test_ray_rb_iter(self): rb.close() -@pytest.mark.skipif(not _has_zstandard, reason="zstandard required for this test.") class TestCompressedListStorage: """Test cases for CompressedListStorage.""" From 0dbd2335eb35a0542725d6bd07bd04429f42dd13 Mon Sep 17 00:00:00 2001 From: Adrian Orenstein Date: Wed, 23 Jul 2025 15:54:39 -0600 Subject: [PATCH 8/8] added a rollout example for either cpu and gpu compression with batched decompression with nvcomp. --- ...ssed_cpu_decompressed_gpu_replay_buffer.py | 199 ++++++++++++++++++ ...ssed_gpu_decompressed_gpu_replay_buffer.py | 182 ++++++++++++++++ .../compressed_replay_buffer.py | 24 +-- setup.cfg | 2 +- test/test_rb.py | 13 +- torchrl/data/replay_buffers/storages.py | 47 +---- 6 files changed, 407 insertions(+), 60 deletions(-) create mode 100644 examples/replay-buffers/compressed_cpu_decompressed_gpu_replay_buffer.py create mode 100644 examples/replay-buffers/compressed_gpu_decompressed_gpu_replay_buffer.py diff --git a/examples/replay-buffers/compressed_cpu_decompressed_gpu_replay_buffer.py b/examples/replay-buffers/compressed_cpu_decompressed_gpu_replay_buffer.py new file mode 100644 index 00000000000..ae0822c3bb4 --- /dev/null +++ b/examples/replay-buffers/compressed_cpu_decompressed_gpu_replay_buffer.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +""" +Example demonstrating the use of CompressedStorage for memory-efficient replay buffers on the GPU. +""" + +import sys +from typing import Any, Dict, List, NamedTuple + +import ale_py +import gymnasium as gym + +import numpy as np + +from torchrl.data.replay_buffers.storages import ListStorage + +gym.register_envs(ale_py) +import torch +from torchrl.data import ReplayBuffer + + +class AtariTransition(NamedTuple): + """Transition tuple generated by the Atari gymnasium environment.""" + + observations: np.uint8 + actions: np.uint8 + next_observations: np.uint8 + rewards: np.float32 + terminated: np.bool + truncated: np.bool + info: Dict[str, Any] + + +def main(): + # pip install gymnasium ale-py opencv-python-headless lz4 zstd + import time + + import nvidia.nvcomp as nvcomp + + algo = "Zstd" + bitstream = nvcomp.BitstreamKind.RAW + + if algo == "Zstd": + if sys.version_info >= (3, 14): + from compression import zstd + + else: + import zstd + + compressor_fn = zstd.compress + + elif algo == "LZ4": + import lz4 + + compressor_fn = lz4.compress + + else: + raise ValueError(f"Unsupported algo {algo}") + + # Create Pong environment and get a frame + seed = 42 + env = gym.make("ALE/Pong-v5", frameskip=1) + env = gym.wrappers.AtariPreprocessing(env, frame_skip=5) + env = gym.wrappers.RecordEpisodeStatistics(env) + env = gym.wrappers.TransformReward(env, np.sign) + env = gym.wrappers.FrameStackObservation(env, 4) + env.action_space.seed(seed) + env.observation_space.seed(seed) + + codec = nvcomp.Codec(algorithm=algo, bitstream_kind=bitstream) + + obs, _ = env.reset(seed=0) + compressed_obs: bytes = compressor_fn(obs.tobytes()) + compressed_nv_obs = nvcomp.as_array(compressed_obs).cuda(synchronize=False) + + decompressed_obs = codec.decode(compressed_nv_obs, data_type="|u1") + pt_obs = torch.from_dlpack(decompressed_obs).clone().view(4, 84, 84) + assert np.allclose(obs, pt_obs.cpu().numpy()) + + print("passed correctness checks") + + # === CompressedListStorage + ReplayBuffer with GPU compression === + print( + "\n=== ListStorage + ReplayBuffer (CPU compress, GPU decompress) Example ===\n" + ) + + print("Creating compressed storage...") + storage = ListStorage( + max_size=1000, + device="cuda", + ) + + def collate_compressed_data_and_batch_decompress( + data: List[AtariTransition], + ) -> List[AtariTransition]: + transitions = data + + # gather compressed data + compressed_obs: List[nvcomp.nvcomp_impl.Array] = [ + transition.observations for transition in transitions + ] + compressed_next_obs: List[nvcomp.nvcomp_impl.Array] = [ + transition.next_observations for transition in transitions + ] + + # batched decompress is faster + decompressed_data = codec.decode( + compressed_obs + compressed_next_obs, data_type="|u1" + ) + + # gather decompressed data + decompressed_obses = decompressed_data[: len(compressed_obs)] + decompressed_next_obses = decompressed_data[len(compressed_obs) :] + + # repack data + for i, (transition, obs, next_obs) in enumerate( + zip(transitions, decompressed_obses, decompressed_next_obses) + ): + transitions[i] = transition._replace( + observations=torch.from_dlpack(obs).clone().view(4, 84, 84), + next_observations=torch.from_dlpack(next_obs).clone().view(4, 84, 84), + ) + + return transitions + + rb = ReplayBuffer( + storage=storage, + batch_size=32, + collate_fn=collate_compressed_data_and_batch_decompress, + ) + + print("Starting rollout benchmark") + + compression_ratios = [] + num_transitions_in_rollout = 2000 + + print(f"...adding {num_transitions_in_rollout} transitions to replay buffer") + torch.cuda.synchronize() + + obs, _ = env.reset(seed=0) + nv_obs = nvcomp.as_array(obs).cuda(synchronize=False) + compressed_obs: bytes = compressor_fn(obs.tobytes()) + compressed_nv_obs = nvcomp.as_array(compressed_obs).cuda(synchronize=False) + + start_time = time.time() + for _ in range(num_transitions_in_rollout): + # get the torch observation onto the gpu as we would normally do inference here... + pt_obs = torch.from_numpy(obs).cuda(non_blocking=True) + action = env.action_space.sample() + + next_obs, reward, terminated, truncated, info = env.step(action) + + # replay buffer + compressed_next_obs: bytes = compressor_fn(next_obs.tobytes()) + compressed_nv_next_obs = nvcomp.as_array(compressed_next_obs).cuda( + synchronize=False + ) + + transition = AtariTransition( + compressed_nv_obs, + action, + compressed_nv_next_obs, + reward, + terminated, + truncated, + info, + ) + rb.add(transition) + + # logging + nv_obs = nvcomp.as_array(obs).cuda(synchronize=False) + compression_ratios.append(nv_obs.buffer_size / compressed_nv_obs.buffer_size) + + # reset + if terminated or truncated: + obs, _ = env.reset() + nv_obs = nvcomp.as_array(obs).cuda() + else: + obs: np.ndarray = next_obs + + rollout_time = time.time() - start_time + print( + f"done rollout with {algo} and {bitstream}, " + + f"avg_compression_ratio={np.array(compression_ratios).mean():0.0f} " + + f"@ transitions/s={num_transitions_in_rollout / rollout_time:0.0f}\n" + ) + + print("Sampling from replay buffer...") + batch_size = 32 + torch.cuda.synchronize() + start_time = time.time() + _ = rb.sample(batch_size) + sample_time = time.time() - start_time + print( + f"done batch sampling and decompression with {algo} and {bitstream} @ transitions/s={batch_size / sample_time:0.0f}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/replay-buffers/compressed_gpu_decompressed_gpu_replay_buffer.py b/examples/replay-buffers/compressed_gpu_decompressed_gpu_replay_buffer.py new file mode 100644 index 00000000000..db8826c8080 --- /dev/null +++ b/examples/replay-buffers/compressed_gpu_decompressed_gpu_replay_buffer.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +""" +Example demonstrating the use of CompressedStorage for memory-efficient replay buffers on the GPU. +""" + +from typing import Any, Dict, List, NamedTuple + +import ale_py +import gymnasium as gym + +import numpy as np + +gym.register_envs(ale_py) +import torch +from torchrl.data import ListStorage, ReplayBuffer + + +class AtariTransition(NamedTuple): + """Transition tuple generated by the Atari gymnasium environment.""" + + observations: np.uint8 + actions: np.uint8 + next_observations: np.uint8 + rewards: np.float32 + terminated: np.bool + truncated: np.bool + info: Dict[str, Any] + + +def main(): + # pip install gymnasium ale-py opencv-python-headless + import time + + import nvidia.nvcomp as nvcomp + + algos = ["Zstd"] # "LZ4", + bitstreams = [ + nvcomp.BitstreamKind.RAW, + ] + + # Create Pong environment and get a frame + seed = 42 + env = gym.make("ALE/Pong-v5", frameskip=1) + env = gym.wrappers.AtariPreprocessing(env, frame_skip=5) + env = gym.wrappers.RecordEpisodeStatistics(env) + env = gym.wrappers.TransformReward(env, np.sign) + env = gym.wrappers.FrameStackObservation(env, 4) + env.action_space.seed(seed) + env.observation_space.seed(seed) + + for algorithm in algos: + for bitstream_kind in bitstreams: + codec = nvcomp.Codec(algorithm=algorithm, bitstream_kind=bitstream_kind) + + obs, _ = env.reset(seed=0) + nv_obs = nvcomp.as_array(obs).cuda(synchronize=False) + compressed_obs = codec.encode(nv_obs) + decompressed_obs = codec.decode(compressed_obs, data_type="|u1") + pt_obs = torch.from_dlpack(decompressed_obs).clone().view(4, 84, 84) + assert np.allclose(obs, pt_obs.cpu().numpy()) + + print("passed correctness checks") + + # === CompressedListStorage + ReplayBuffer with GPU compression === + print("\n=== ListStorage + ReplayBuffer (GPU) Example ===\n") + + codec = nvcomp.Codec(algorithm=algorithm, bitstream_kind=bitstream_kind) + + print("Creating compressed storage...") + storage = ListStorage( + max_size=1000, + device="cuda", + ) + + def collate_compressed_data_and_batch_decompress( + data: List[AtariTransition], + ) -> List[AtariTransition]: + transitions = data + + # gather compressed data + compressed_obs = [transition.observations for transition in transitions] + compressed_next_obs = [ + transition.next_observations for transition in transitions + ] + + # optional checks + assert all(isinstance(arr, nvcomp.nvcomp_impl.Array) for arr in compressed_obs) + assert all( + isinstance(arr, nvcomp.nvcomp_impl.Array) for arr in compressed_next_obs + ) + + # batched decompress is faster + decompressed_data = codec.decode( + compressed_obs + compressed_next_obs, data_type="|u1" + ) + + # gather decompressed data + decompressed_obses = decompressed_data[: len(compressed_obs)] + decompressed_next_obses = decompressed_data[len(compressed_obs) :] + + # repack data + for i, (transition, obs, next_obs) in enumerate( + zip(transitions, decompressed_obses, decompressed_next_obses) + ): + transitions[i] = transition._replace( + observations=torch.from_dlpack(obs).clone().view(4, 84, 84), + next_observations=torch.from_dlpack(next_obs).clone().view(4, 84, 84), + ) + + return transitions + + rb = ReplayBuffer( + storage=storage, + batch_size=32, + collate_fn=collate_compressed_data_and_batch_decompress, + ) + + print("Starting rollout benchmark") + obs, _ = env.reset(seed=0) + nv_obs = nvcomp.as_array(obs).cuda(synchronize=False) + compressed_obs = codec.encode(nv_obs) + + compression_ratios = [] + num_transitions_in_rollout = 2000 + + print(f"...adding {num_transitions_in_rollout} transitions to replay buffer") + torch.cuda.synchronize() + start_time = time.time() + for _ in range(num_transitions_in_rollout): + pt_obs = torch.from_dlpack(nv_obs).clone() + action = env.action_space.sample() + + next_obs, reward, terminated, truncated, info = env.step(action) + nv_next_obs = nvcomp.as_array(next_obs).cuda(synchronize=False) + compressed_next_obs = codec.encode(nv_next_obs) + + compression_ratios.append( + nv_next_obs.buffer_size / compressed_next_obs.buffer_size + ) + + transition = AtariTransition( + compressed_obs, + action, + compressed_next_obs, + reward, + terminated, + truncated, + info, + ) + rb.add(transition) + + if terminated or truncated: + obs, _ = env.reset() + nv_obs = nvcomp.as_array(obs).cuda() + else: + nv_obs = nv_next_obs + compressed_obs = compressed_next_obs + rollout_time = time.time() - start_time + print( + f"done rollout with {algorithm} and {bitstream_kind}, " + + f"avg_compression_ratio={np.array(compression_ratios).mean():0.0f} " + + f"@ transitions/s={num_transitions_in_rollout / rollout_time:0.0f}\n" + ) + + batched_sampling_and_decompression_duration = 1000 + assert (batched_sampling_and_decompression_duration * 2) <= ( + num_transitions_in_rollout + ) + + print("Sampling from replay buffer...") + batch_size = 32 + torch.cuda.synchronize() + start_time = time.time() + rb.sample(batch_size) + sample_time = time.time() - start_time + print( + f"done batch sampling and decompression with {algorithm} and {bitstream_kind} @ transitions/s={batch_size / sample_time:0.0f}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/replay-buffers/compressed_replay_buffer.py b/examples/replay-buffers/compressed_replay_buffer.py index 3e68e76468a..bc735d67f5f 100644 --- a/examples/replay-buffers/compressed_replay_buffer.py +++ b/examples/replay-buffers/compressed_replay_buffer.py @@ -10,7 +10,7 @@ import torch from tensordict import TensorDict -from torchrl.data import CompressedStorage, ReplayBuffer +from torchrl.data import CompressedListStorage, ReplayBuffer def main(): @@ -18,7 +18,7 @@ def main(): # Create a compressed storage with zstd compression print("Creating compressed storage...") - storage = CompressedStorage( + storage = CompressedListStorage( max_size=1000, compression_level=3, # zstd compression level (1-22) device="cpu", @@ -48,21 +48,21 @@ def main(): ) # Measure memory usage before adding data - print(f"Original data size: {data.bytes() / 1024 / 1024:.2f} MB") + print(f"Original data size: {data.bytes() / 1024 / 1024: .2f} MB") # Add data to replay buffer print("Adding data to replay buffer...") start_time = time.time() rb.extend(data) add_time = time.time() - start_time - print(f"Time to add data: {add_time:.3f} seconds") + print(f"Time to add data: {add_time: .3f} seconds") # Sample from replay buffer print("Sampling from replay buffer...") start_time = time.time() sample = rb.sample(32) sample_time = time.time() - start_time - print(f"Time to sample: {sample_time:.3f} seconds") + print(f"Time to sample: {sample_time: .3f} seconds") # Verify data integrity print("\nVerifying data integrity...") @@ -86,11 +86,11 @@ def main(): original_size / compressed_size_estimate if compressed_size_estimate > 0 else 1 ) - print(f"Original size: {original_size / 1024 / 1024:.2f} MB") + print(f"Original size: {original_size / 1024 / 1024: .2f} MB") print( - f"Compressed size (estimate): {compressed_size_estimate / 1024 / 1024:.2f} MB" + f"Compressed size (estimate): {compressed_size_estimate / 1024 / 1024: .2f} MB" ) - print(f"Compression ratio: {compression_ratio:.1f}x") + print(f"Compression ratio: {compression_ratio: .1f}x") # Test with different compression levels print("\n=== Testing Different Compression Levels ===") @@ -99,7 +99,7 @@ def main(): print(f"\nTesting compression level {level}...") # Create new storage with different compression level - test_storage = CompressedStorage( + test_storage = CompressedListStorage( max_size=100, compression_level=level, device="cpu" ) @@ -126,12 +126,12 @@ def main(): test_rb.sample(5) decompress_time = time.time() - start_time - print(f" Compression time: {compress_time:.3f}s") - print(f" Decompression time: {decompress_time:.3f}s") + print(f" Compression time: {compress_time: .3f}s") + print(f" Decompression time: {decompress_time: .3f}s") # Estimate compression ratio test_ratio = test_data.bytes() / test_storage.bytes() - print(f" Compression ratio: {test_ratio:.1f}x") + print(f" Compression ratio: {test_ratio: .1f}x") print("\n=== Example Complete ===") print( diff --git a/setup.cfg b/setup.cfg index 0649a97497f..98d45350291 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ max-line-length = 120 [flake8] # note: we ignore all 501s (line too long) anyway as they're taken care of by black max-line-length = 79 -ignore = E203, E402, W503, W504, E501 +ignore = E203, E402, W503, W504, E501, E231 per-file-ignores = __init__.py: F401, F403, F405 ./hubconf.py: F401 diff --git a/test/test_rb.py b/test/test_rb.py index 02676743bb4..3c58d3e4a0e 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -4406,7 +4406,7 @@ def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict: ) @pytest.mark.parametrize( "serialization_method", - ["pickle", "torch.save", "untyped_storage", "numpy", "safetensors"], + ["pickle", "torch.save", "untyped_storage", "numpy"], # "safetensors", "nvcomp" ) def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str): """Benchmark the speed of different tensor serialization methods. @@ -4453,6 +4453,15 @@ def serialize_with_torch(data: torch.Tensor) -> bytes: torch.save(data, buffer) return buffer.getvalue() + def serialize_with_nvcomp(data: torch.Tensor) -> bytes: + """Nvcomp just wants the data in a NVcomp array format, no bytestream needed. + pip install nvidia-nvcomp-cu12 + + """ + import nvidia.nvcomp as nvcomp + + return nvcomp.from_dlpack(torch.utils.dlpack.to_dlpack(data.cuda())) + # Benchmark each serialization method if serialization_method == "pickle": serialize_fn = serialize_with_pickle @@ -4464,6 +4473,8 @@ def serialize_with_torch(data: torch.Tensor) -> bytes: serialize_fn = serialize_with_numpy elif serialization_method == "safetensors": serialize_fn = serialize_with_safetensors + elif serialization_method == "nvcomp": + serialize_fn = serialize_with_nvcomp else: raise ValueError(f"Unknown serialization method: {serialization_method}") diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index def5c64bd3e..a7eb726a6dc 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -13,7 +13,7 @@ from collections import OrderedDict from copy import copy from multiprocessing.context import get_spawning_popen -from typing import Any, Callable, Mapping, Sequence +from typing import Any, Callable, Sequence import numpy as np import tensordict @@ -1648,31 +1648,6 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._storage = state_dict["_storage"] self._metadata = state_dict["_metadata"] - def num_bytes(self): - """Return the number of bytes in the storage.""" - - def compressed_size_from_list(data: Any) -> int: - if data is None: - return 0 - elif isinstance(data, torch.Tensor): - return data.numel() - elif isinstance(data, (tuple, list, Sequence)): - return sum(compressed_size_from_list(item) for item in data) - elif isinstance(data, Mapping) or is_tensor_collection(data): - return sum(compressed_size_from_list(value) for value in data.values()) - else: - return 0 - - compressed_size_estimate = compressed_size_from_list(self._storage) - if compressed_size_estimate == 0: - if len(self._storage) > 0: - raise RuntimeError( - "Compressed storage is not empty but the compressed size is 0. This is a bug." - ) - warnings.warn("Compressed storage is empty, returning 0 bytes.") - - return compressed_size_estimate - def to_bytestream(self, data_to_bytestream: torch.Tensor | np.array | Any) -> bytes: """Convert data to a byte stream.""" if isinstance(data_to_bytestream, torch.Tensor): @@ -1692,26 +1667,6 @@ def to_bytestream(self, data_to_bytestream: torch.Tensor | np.array | Any) -> by return byte_stream - # def to_bytestream( - # self, data_to_bytestream: Union[torch.Tensor, np.array, Any] - # ) -> bytes: - # """Convert data to a byte stream.""" - # if isinstance(data_to_bytestream, torch.Tensor): - # from safetensors.torch import save - # byte_stream = save({"0": data_to_bytestream}) - - # elif isinstance(data_to_bytestream, np.array): - # from safetensors.numpy import save - # byte_stream = bytes(data_to_bytestream.tobytes()) - - # else: - # buffer = io.BytesIO() - # pickle.dump(data_to_bytestream, buffer) - # buffer.seek(0) - # byte_stream = bytes(buffer.read()) - - # return byte_stream - class StorageEnsemble(Storage): """An ensemble of storages.