From ff88f736b8e6deb4289df8476a645c32b8216d00 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 11 Jul 2025 06:42:56 +0100 Subject: [PATCH 1/7] amend --- .../unittest/linux/scripts/environment.yml | 1 + docs/source/reference/data.rst | 62 ++++ test/test_rb.py | 266 +++++++++++++++ torchrl/data/__init__.py | 10 +- torchrl/data/replay_buffers/__init__.py | 4 + torchrl/data/replay_buffers/checkpointers.py | 112 ++++++ torchrl/data/replay_buffers/storages.py | 319 +++++++++++++++++- 7 files changed, 768 insertions(+), 6 deletions(-) diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index e3c79e4569e..b7ca29ff0d4 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -35,3 +35,4 @@ dependencies: - transformers - ninja - timm + - zstandard diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index e9d08822239..48fb5f1e06f 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -144,6 +144,8 @@ using the following components: :template: rl_template.rst + CompressedStorage + CompressedStorageCheckpointer FlatStorageCheckpointer H5StorageCheckpointer ImmutableDatasetWriter @@ -191,6 +193,66 @@ were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/be | :class:`LazyMemmapStorage` | 3.44x | +-------------------------------+-----------+ +Compressed Storage for Memory Efficiency +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For applications where memory usage is a primary concern, especially when storing +large sensory observations like images or audio, the :class:`~torchrl.data.replay_buffers.storages.CompressedStorage` +provides significant memory savings through compression. + +The `CompressedStorage`` compresses data when storing and decompresses when retrieving, +achieving compression ratios of 2-10x for image data while maintaining full data fidelity. +It uses zstd compression by default but supports custom compression algorithms. + +Key features: +- **Memory Efficiency**: Achieves significant memory savings through compression +- **Data Integrity**: Maintains full data fidelity through lossless compression +- **Flexible Compression**: Supports custom compression algorithms or uses zstd by default +- **TensorDict Support**: Seamlessly works with TensorDict structures +- **Checkpointing**: Full support for saving and loading compressed data + +Example usage: + + >>> import torch + >>> from torchrl.data import ReplayBuffer, CompressedStorage + >>> from tensordict import TensorDict + >>> + >>> # Create a compressed storage for image data + >>> storage = CompressedStorage(max_size=1000, compression_level=3) + >>> rb = ReplayBuffer(storage=storage, batch_size=32) + >>> + >>> # Add image data + >>> images = torch.randn(100, 3, 84, 84) # Atari-like frames + >>> data = TensorDict({"obs": images}, batch_size=[100]) + >>> rb.extend(data) + >>> + >>> # Sample data (automatically decompressed) + >>> sample = rb.sample(16) + >>> print(sample["obs"].shape) # torch.Size([16, 3, 84, 84]) + +The compression level can be adjusted from 1 (fast, less compression) to 22 (slow, more compression), +with level 3 being a good default for most use cases. + +For custom compression algorithms: + + >>> def my_compress(tensor): + ... return tensor.to(torch.uint8) # Simple example + >>> + >>> def my_decompress(compressed_tensor, metadata): + ... return compressed_tensor.to(metadata["dtype"]) + >>> + >>> storage = CompressedStorage( + ... max_size=1000, + ... compression_fn=my_compress, + ... decompression_fn=my_decompress + ... ) + +.. note:: The CompressedStorage requires the `zstandard` library for default compression. + Install with: ``pip install zstandard`` + +.. note:: An example of how to use the CompressedStorage is available in the + `examples/replay-buffers/compressed_replay_buffer_example.py `_ file. + Sharing replay buffers across processes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_rb.py b/test/test_rb.py index ca2bb121d65..a52295d9695 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -11,10 +11,13 @@ import os import pickle import sys +import tempfile from functools import partial +from pathlib import Path from unittest import mock import numpy as np + import pytest import torch from packaging import version @@ -35,6 +38,7 @@ from torchrl.collectors import RandomPolicy, SyncDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( + CompressedStorage, FlatStorageCheckpointer, MultiStep, NestedStorageCheckpointer, @@ -129,6 +133,7 @@ _os_is_windows = sys.platform == "win32" _has_transformers = importlib.util.find_spec("transformers") is not None _has_ray = importlib.util.find_spec("ray") is not None +_has_zstandard = importlib.util.find_spec("zstandard") is not None TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) @@ -4027,6 +4032,267 @@ def test_ray_rb_iter(self): rb.close() +@pytest.mark.skipif(not _has_zstandard, reason="zstandard required for this test.") +class TestCompressedStorage: + """Test cases for CompressedStorage.""" + + def test_compressed_storage_initialization(self): + """Test that CompressedStorage initializes correctly.""" + storage = CompressedStorage(max_size=100, compression_level=3) + assert storage.max_size == 100 + assert storage.compression_level == 3 + assert len(storage) == 0 + + def test_compressed_storage_tensor(self): + """Test compression and decompression of tensor data.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Create test tensor + test_tensor = torch.randn(3, 84, 84, dtype=torch.float32) + + # Store tensor + storage.set(0, test_tensor) + + # Retrieve tensor + retrieved_tensor = storage.get(0) + + # Verify data integrity + assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) + assert test_tensor.shape == retrieved_tensor.shape + assert test_tensor.dtype == retrieved_tensor.dtype + + def test_compressed_storage_tensordict(self): + """Test compression and decompression of TensorDict data.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Create test TensorDict + test_td = TensorDict( + { + "obs": torch.randn(3, 84, 84, dtype=torch.float32), + "action": torch.tensor([1, 2, 3]), + "reward": torch.randn(3), + "done": torch.tensor([False, True, False]), + }, + batch_size=[3], + ) + + # Store TensorDict + storage.set(0, test_td) + + # Retrieve TensorDict + retrieved_td = storage.get(0) + + # Verify data integrity + assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6) + assert torch.allclose(test_td["action"], retrieved_td["action"]) + assert torch.allclose(test_td["reward"], retrieved_td["reward"], atol=1e-6) + assert torch.allclose(test_td["done"], retrieved_td["done"]) + + def test_compressed_storage_multiple_indices(self): + """Test storing and retrieving multiple items.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Store multiple tensors + tensors = [ + torch.randn(2, 2, dtype=torch.float32), + torch.randn(3, 3, dtype=torch.float32), + torch.randn(4, 4, dtype=torch.float32), + ] + + for i, tensor in enumerate(tensors): + storage.set(i, tensor) + + # Retrieve multiple tensors + retrieved = storage.get([0, 1, 2]) + + # Verify data integrity + for original, retrieved_tensor in zip(tensors, retrieved): + assert torch.allclose(original, retrieved_tensor, atol=1e-6) + + def test_compressed_storage_with_replay_buffer(self): + """Test CompressedStorage with ReplayBuffer.""" + storage = CompressedStorage(max_size=100, compression_level=3) + rb = ReplayBuffer(storage=storage, batch_size=5) + + # Create test data + data = TensorDict( + { + "obs": torch.randn(10, 3, 84, 84, dtype=torch.float32), + "action": torch.randint(0, 4, (10,)), + "reward": torch.randn(10), + }, + batch_size=[10], + ) + + # Add data to replay buffer + print("extending") + rb.extend(data) + + # Sample from replay buffer + sample = rb.sample(5) + + # Verify sample has correct shape + assert is_tensor_collection(sample), sample + assert sample["obs"].shape[0] == 5 + assert sample["obs"].shape[1:] == (3, 84, 84) + assert sample["action"].shape[0] == 5 + assert sample["reward"].shape[0] == 5 + + def test_compressed_storage_state_dict(self): + """Test saving and loading state dict.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Add some data + test_tensor = torch.randn(3, 3, dtype=torch.float32) + storage.set(0, test_tensor) + + # Save state dict + state_dict = storage.state_dict() + + # Create new storage and load state dict + new_storage = CompressedStorage(max_size=10, compression_level=3) + new_storage.load_state_dict(state_dict) + + # Verify data integrity + retrieved_tensor = new_storage.get(0) + assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) + + def test_compressed_storage_checkpointing(self): + """Test checkpointing functionality.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Add some data + test_td = TensorDict( + { + "obs": torch.randn(3, 84, 84, dtype=torch.float32), + "action": torch.tensor([1, 2, 3]), + }, + batch_size=[3], + ) + storage.set(0, test_td) + + # Create temporary directory for checkpointing + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "checkpoint" + + # Save checkpoint + storage.dumps(checkpoint_path) + + # Create new storage and load checkpoint + new_storage = CompressedStorage(max_size=10, compression_level=3) + new_storage.loads(checkpoint_path) + + # Verify data integrity + retrieved_td = new_storage.get(0) + assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6) + assert torch.allclose(test_td["action"], retrieved_td["action"]) + + def test_compressed_storage_length(self): + """Test that length is calculated correctly.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Initially empty + assert len(storage) == 0 + + # Add some data + storage.set(0, torch.randn(2, 2)) + assert len(storage) == 1 + + storage.set(2, torch.randn(2, 2)) + assert len(storage) == 2 + + storage.set(1, torch.randn(2, 2)) + assert len(storage) == 3 + + def test_compressed_storage_contains(self): + """Test the contains method.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Initially empty + assert not storage.contains(0) + + # Add data + storage.set(0, torch.randn(2, 2)) + assert storage.contains(0) + assert not storage.contains(1) + + def test_compressed_storage_empty(self): + """Test emptying the storage.""" + storage = CompressedStorage(max_size=10, compression_level=3) + + # Add some data + storage.set(0, torch.randn(2, 2)) + storage.set(1, torch.randn(2, 2)) + assert len(storage) == 2 + + # Empty storage + storage._empty() + assert len(storage) == 0 + + def test_compressed_storage_custom_compression(self): + """Test custom compression functions.""" + + def custom_compress(tensor): + # Simple compression: just convert to uint8 + return tensor.to(torch.uint8) + + def custom_decompress(compressed_tensor, metadata): + # Simple decompression: convert back to original dtype + return compressed_tensor.to(metadata["dtype"]) + + storage = CompressedStorage( + max_size=10, + compression_fn=custom_compress, + decompression_fn=custom_decompress, + ) + + # Test with tensor + test_tensor = torch.randn(2, 2, dtype=torch.float32) + storage.set(0, test_tensor) + retrieved_tensor = storage.get(0) + + # Note: This will lose precision due to uint8 conversion + # but should still work + assert retrieved_tensor.shape == test_tensor.shape + + def test_compressed_storage_error_handling(self): + """Test error handling for invalid operations.""" + storage = CompressedStorage(max_size=5, compression_level=3) + + # Test setting data beyond max_size + with pytest.raises(RuntimeError): + storage.set(10, torch.randn(2, 2)) + + # Test getting non-existent data + with pytest.raises(IndexError): + storage.get(0) + + def test_compressed_storage_memory_efficiency(self): + """Test that compression actually reduces memory usage.""" + storage = CompressedStorage(max_size=100, compression_level=3) + + # Create large tensor data + large_tensor = torch.zeros(100, 3, 84, 84, dtype=torch.int64) + large_tensor.copy_( + torch.arange(large_tensor.numel(), dtype=torch.int32).view_as(large_tensor) + // (3 * 84 * 84) + ) + original_size = large_tensor.numel() * large_tensor.element_size() + + # Store in compressed storage + storage.set(0, large_tensor) + + # Estimate compressed size + compressed_data = storage._compressed_data[0] + compressed_size = compressed_data.numel() # uint8 bytes + + # Verify compression ratio is reasonable (at least 2x for random data) + compression_ratio = original_size / compressed_size + assert ( + compression_ratio > 1.5 + ), f"Compression ratio {compression_ratio} is too low" + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index b3fafd16ee1..5c8bff8a3af 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -32,6 +32,8 @@ ) from .postprocs import DensifyReward, MultiStep from .replay_buffers import ( + CompressedStorage, + CompressedStorageCheckpointer, Flat2TED, FlatStorageCheckpointer, H5Combine, @@ -116,21 +118,22 @@ "BoundedTensorSpec", "Categorical", "Choice", - "ContentBase", - "TopKRewardSelector", "Composite", "CompositeSpec", + "CompressedStorage", + "CompressedStorageCheckpointer", "ConstantKLController", + "ContentBase", "DEVICE_TYPING", "DensifyReward", "DiscreteTensorSpec", "Flat2TED", "FlatStorageCheckpointer", - "History", "H5Combine", "H5Split", "H5StorageCheckpointer", "HashToInt", + "History", "ImmutableDatasetWriter", "LazyMemmapStorage", "LazyStackStorage", @@ -191,6 +194,7 @@ "TensorStorage", "TensorStorageCheckpointer", "TokenizedDatasetLoader", + "TopKRewardSelector", "Tree", "Unbounded", "UnboundedContinuous", diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 6e7bff8eac0..71c1dd20795 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from .checkpointers import ( + CompressedStorageCheckpointer, FlatStorageCheckpointer, H5StorageCheckpointer, ListStorageCheckpointer, @@ -32,6 +33,7 @@ SliceSamplerWithoutReplacement, ) from .storages import ( + CompressedStorage, LazyMemmapStorage, LazyStackStorage, LazyTensorStorage, @@ -51,6 +53,8 @@ ) __all__ = [ + "CompressedStorage", + "CompressedStorageCheckpointer", "FlatStorageCheckpointer", "H5StorageCheckpointer", "ListStorageCheckpointer", diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 6328857292c..346dddf28af 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -68,6 +68,118 @@ def loads(storage, path): ) +class CompressedStorageCheckpointer(StorageCheckpointerBase): + """A storage checkpointer for CompressedStorage. + + This checkpointer saves compressed data and metadata separately for efficient storage. + + """ + + def dumps(self, storage, path): + path = Path(path) + path.mkdir(exist_ok=True) + + if not hasattr(storage, "_compressed_data") or not storage._compressed_data: + raise RuntimeError( + "Cannot save an empty or non-initialized CompressedStorage." + ) + + # Save compressed data and metadata + state_dict = storage.state_dict() + + # Save compressed data as separate files for efficiency + compressed_data = state_dict["_compressed_data"] + metadata = state_dict["_metadata"] + + # Save metadata + with open(path / "compressed_metadata.json", "w") as f: + json.dump(metadata, f, default=str) + + # Save compressed data + for i, (compressed_item, item_metadata) in enumerate( + zip(compressed_data, metadata) + ): + if compressed_item is not None: + if item_metadata["type"] == "tensor": + # Save as numpy array + np.save( + path / f"compressed_data_{i}.npy", compressed_item.cpu().numpy() + ) + elif item_metadata["type"] == "tensordict": + # Save each field separately + item_dir = path / f"compressed_data_{i}" + item_dir.mkdir(exist_ok=True) + + for key, value in compressed_item.items(): + if isinstance(value, torch.Tensor): + np.save(item_dir / f"{key}.npy", value.cpu().numpy()) + else: + # Save non-tensor data as pickle + import pickle + + with open(item_dir / f"{key}.pkl", "wb") as f: + pickle.dump(value, f) + else: + # Save other types as pickle + import pickle + + with open(path / f"compressed_data_{i}.pkl", "wb") as f: + pickle.dump(compressed_item, f) + + def loads(self, storage, path): + path = Path(path) + + # Load metadata + with open(path / "compressed_metadata.json") as f: + metadata = json.load(f) + + # Load compressed data + compressed_data = [] + i = 0 + + while True: + if (path / f"compressed_data_{i}.npy").exists(): + # Load tensor data + data = np.load(path / f"compressed_data_{i}.npy") + compressed_data.append(torch.from_numpy(data)) + elif (path / f"compressed_data_{i}.pkl").exists(): + # Load other data + import pickle + + with open(path / f"compressed_data_{i}.pkl", "rb") as f: + data = pickle.load(f) + compressed_data.append(data) + elif (path / f"compressed_data_{i}").exists(): + # Load tensordict data + item_dir = path / f"compressed_data_{i}" + item_data = {} + + for key in metadata[i]["fields"].keys(): + if (item_dir / f"{key}.npy").exists(): + data = np.load(item_dir / f"{key}.npy") + item_data[key] = torch.from_numpy(data) + elif (item_dir / f"{key}.pkl").exists(): + import pickle + + with open(item_dir / f"{key}.pkl", "rb") as f: + data = pickle.load(f) + item_data[key] = data + + compressed_data.append(item_data) + else: + break + + i += 1 + + # Pad with None to match metadata length + while len(compressed_data) < len(metadata): + compressed_data.append(None) + + # Load into storage + storage._compressed_data = compressed_data + storage._metadata = metadata + + class TensorStorageCheckpointer(StorageCheckpointerBase): """A storage checkpointer for TensorStorages. diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index e5ca2367be5..b6c17769659 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -12,7 +12,7 @@ from collections import OrderedDict from copy import copy from multiprocessing.context import get_spawning_popen -from typing import Any, Sequence +from typing import Any, Callable, Mapping, Sequence import numpy as np import tensordict @@ -32,6 +32,7 @@ from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.checkpointers import ( + CompressedStorageCheckpointer, ListStorageCheckpointer, StorageCheckpointerBase, StorageEnsembleCheckpointer, @@ -1360,6 +1361,316 @@ def get(self, index: int | Sequence[int] | slice) -> Any: return result +class CompressedStorage(Storage): + """A storage that compresses and decompresses data. + + This storage compresses data when storing and decompresses when retrieving. + It's particularly useful for storing raw sensory observations like images + that can be compressed significantly to save memory. + + Args: + max_size (int): size of the storage, i.e. maximum number of elements stored + in the buffer. + compression_fn (callable, optional): function to compress data. Should take + a tensor and return a compressed byte tensor. Defaults to zstd compression. + decompression_fn (callable, optional): function to decompress data. Should take + a compressed byte tensor and return the original tensor. Defaults to zstd decompression. + compression_level (int, optional): compression level (1-22 for zstd) when using the default compression function. + Defaults to 3. + device (torch.device, optional): device where the sampled tensors will be + stored and sent. Default is :obj:`torch.device("cpu")`. + compilable (bool, optional): whether the storage is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. + + Examples: + >>> import torch + >>> from torchrl.data import CompressedStorage, ReplayBuffer + >>> from tensordict import TensorDict + >>> + >>> # Create a compressed storage for image data + >>> storage = CompressedStorage(max_size=1000, compression_level=3) + >>> rb = ReplayBuffer(storage=storage, batch_size=5) + >>> + >>> # Add some image data + >>> images = torch.randn(10, 3, 84, 84) # Atari-like frames + >>> data = TensorDict({"obs": images}, batch_size=[10]) + >>> rb.extend(data) + >>> + >>> # Sample and verify data is decompressed correctly + >>> sample = rb.sample(3) + >>> print(sample["obs"].shape) # torch.Size([3, 3, 84, 84]) + + """ + + _default_checkpointer = CompressedStorageCheckpointer + + def __init__( + self, + max_size: int, + *, + compression_fn: Callable | None = None, + decompression_fn: Callable | None = None, + compression_level: int = 3, + device: torch.device = "cpu", + compilable: bool = False, + ): + super().__init__(max_size, compilable=compilable) + self.device = device + self.compression_level = compression_level + + # Set up compression functions + if compression_fn is None: + self.compression_fn = self._default_compression_fn + else: + self.compression_fn = compression_fn + + if decompression_fn is None: + self.decompression_fn = self._default_decompression_fn + else: + self.decompression_fn = decompression_fn + + # Store compressed data and metadata + self._compressed_data = [] + self._metadata = [] # Store shape, dtype, device info for each item + + def _default_compression_fn(self, tensor: torch.Tensor) -> torch.Tensor: + """Default compression using zstd.""" + try: + import zstandard as zstd + except ImportError: + raise ImportError( + "zstandard is required for default compression. " + "Install with: pip install zstandard" + ) + + # Convert tensor to bytes + tensor_bytes = tensor.cpu().numpy().tobytes() + + # Compress with zstd + compressor = zstd.ZstdCompressor(level=self.compression_level) + compressed_bytes = compressor.compress(tensor_bytes) + + # Convert to tensor + return torch.tensor(list(compressed_bytes), dtype=torch.uint8) + + def _default_decompression_fn( + self, compressed_tensor: torch.Tensor, metadata: dict + ) -> torch.Tensor: + """Default decompression using zstd.""" + try: + import zstandard as zstd + except ImportError: + raise ImportError( + "zstandard is required for default decompression. " + "Install with: pip install zstandard" + ) + + # Convert tensor to bytes + compressed_bytes = bytes(compressed_tensor.cpu().numpy()) + + # Decompress with zstd + decompressor = zstd.ZstdDecompressor() + decompressed_bytes = decompressor.decompress(compressed_bytes) + + # Convert back to tensor + tensor = torch.frombuffer(decompressed_bytes, dtype=metadata["dtype"]) + tensor = tensor.reshape(metadata["shape"]) + tensor = tensor.to(metadata["device"]) + + return tensor + + def set( + self, + cursor: int | Sequence[int] | slice, + data: Any, + *, + set_cursor: bool = True, + ): + """Set data in the storage with compression.""" + if isinstance(cursor, (INT_CLASSES, slice)): + cursor = [cursor] + + if isinstance(data, (list, tuple)): + data_list = data + elif isinstance(data, (torch.Tensor, TensorDictBase)): + data_list = data.unbind(0) + # determine if data is iterable and has length equal to cursor + elif hasattr(data, "__len__") and len(data) == len(cursor): + data_list = data + else: + data_list = [data] + + for idx, item in zip(cursor, data_list): + if idx >= self.max_size: + raise RuntimeError( + f"Cannot set data at index {idx}: storage max_size is {self.max_size}" + ) + + # Compress the data + compressed_data, metadata = self._compress_item(item) + + # Ensure we have enough space + while len(self._compressed_data) <= idx: + self._compressed_data.append(None) + self._metadata.append(None) + + # Store compressed data and metadata + self._compressed_data[idx] = compressed_data + self._metadata[idx] = metadata + + def _compress_item(self, item: Any) -> tuple[torch.Tensor, dict]: + """Compress a single item and return compressed data with metadata.""" + if isinstance(item, torch.Tensor): + metadata = { + "type": "tensor", + "shape": item.shape, + "dtype": item.dtype, + "device": item.device, + } + compressed = self.compression_fn(item) + elif is_tensor_collection(item): + # For TensorDict, compress each tensor field + compressed_fields = {} + metadata = {"type": "tensordict", "fields": {}} + + for key, value in item.items(): + if isinstance(value, torch.Tensor): + compressed_fields[key] = self.compression_fn(value) + metadata["fields"][key] = { + "type": "tensor", + "shape": value.shape, + "dtype": value.dtype, + "device": value.device, + } + else: + # For non-tensor data, store as-is + compressed_fields[key] = value + metadata["fields"][key] = {"type": "non_tensor", "value": value} + + compressed = compressed_fields + else: + # For other types, store as-is + compressed = item + metadata = {"type": "other", "value": item} + + return compressed, metadata + + def get(self, index: int | Sequence[int] | slice) -> Any: + """Get data from the storage with decompression.""" + if isinstance(index, (INT_CLASSES, slice)): + indices = [index] + else: + indices = index + + results = [] + for idx in indices: + if idx >= len(self._compressed_data) or self._compressed_data[idx] is None: + raise IndexError(f"Index {idx} out of bounds or not set") + + compressed_data = self._compressed_data[idx] + metadata = self._metadata[idx] + + # Decompress the data + decompressed = self._decompress_item(compressed_data, metadata) + results.append(decompressed) + + if len(results) == 1: + return results[0] + return results + + def _decompress_item(self, compressed_data: Any, metadata: dict) -> Any: + """Decompress a single item using its metadata.""" + if metadata["type"] == "tensor": + return self.decompression_fn(compressed_data, metadata) + elif metadata["type"] == "tensordict": + # Reconstruct TensorDict + result = TensorDict({}, batch_size=metadata.get("batch_size", [])) + + for key, field_metadata in metadata["fields"].items(): + if field_metadata["type"] == "non_tensor": + result[key] = field_metadata["value"] + else: + # Decompress tensor field + result[key] = self.decompression_fn( + compressed_data[key], field_metadata + ) + + return result + else: + # Return as-is for other types + return metadata["value"] + + def __len__(self): + """Return the number of items in the storage.""" + return len([item for item in self._compressed_data if item is not None]) + + def state_dict(self) -> dict[str, Any]: + """Save the storage state.""" + return { + "_compressed_data": self._compressed_data, + "_metadata": self._metadata, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the storage state.""" + self._compressed_data = state_dict["_compressed_data"] + self._metadata = state_dict["_metadata"] + + def _empty(self): + """Empty the storage.""" + self._compressed_data = [] + self._metadata = [] + + def loads(self, path: str): + """Load the storage state from a file.""" + super().loads(path) + from tensordict.utils import _STR_DTYPE_TO_DTYPE + from torch.utils._pytree import tree_flatten, tree_unflatten + + leaves, spec = tree_flatten(self._metadata) + leaves = [ + _STR_DTYPE_TO_DTYPE.get(x, x) if isinstance(x, str) else x for x in leaves + ] + self._metadata = tree_unflatten(leaves, spec) + + def contains(self, item): + """Check if an item is in the storage.""" + if isinstance(item, int): + if item < 0: + item += len(self._compressed_data) + return ( + 0 <= item < len(self._compressed_data) + and self._compressed_data[item] is not None + ) + raise NotImplementedError(f"type {type(item)} is not supported yet.") + + def bytes(self): + """Return the number of bytes in the storage.""" + + def compressed_size_from_list(data: Any) -> int: + if data is None: + return 0 + elif isinstance(data, torch.Tensor): + return data.numel() + elif isinstance(data, (tuple, list, Sequence)): + return sum(compressed_size_from_list(item) for item in data) + elif isinstance(data, Mapping) or is_tensor_collection(data): + return sum(compressed_size_from_list(value) for value in data.values()) + else: + return 0 + + compressed_size_estimate = compressed_size_from_list(self._compressed_data) + if compressed_size_estimate == 0: + if len(self._compressed_data) > 0: + raise RuntimeError( + "Compressed storage is not empty but the compressed size is 0. This is a bug." + ) + warnings.warn("Compressed storage is empty, returning 0 bytes.") + + return compressed_size_estimate + + class StorageEnsemble(Storage): """An ensemble of storages. @@ -1567,9 +1878,11 @@ def _collate_id(x): def _get_default_collate(storage, _is_tensordict=False): - if isinstance(storage, LazyStackStorage) or isinstance(storage, TensorStorage): + if isinstance(storage, (LazyStackStorage, TensorStorage)): return _collate_id - elif isinstance(storage, ListStorage): + elif isinstance(storage, CompressedStorage): + return lazy_stack + elif isinstance(storage, (ListStorage, StorageEnsemble)): return _stack_anything else: raise NotImplementedError( From d91f76478ea1e95922bbfbbbe3b1e0e0111d13ad Mon Sep 17 00:00:00 2001 From: Adrian Orenstein Date: Sat, 12 Jul 2025 05:54:58 -0600 Subject: [PATCH 2/7] compression is helpful for data transfer --- docs/source/reference/data.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 48fb5f1e06f..812c19201f5 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -196,7 +196,7 @@ were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/be Compressed Storage for Memory Efficiency ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -For applications where memory usage is a primary concern, especially when storing +For applications where memory usage or memory bandwidth is a primary concern, especially when storing or transferring large sensory observations like images or audio, the :class:`~torchrl.data.replay_buffers.storages.CompressedStorage` provides significant memory savings through compression. From 84d00a87838d09a8fcad4249ab2a655e0952c840 Mon Sep 17 00:00:00 2001 From: Adrian Orenstein Date: Sat, 12 Jul 2025 12:54:21 -0600 Subject: [PATCH 3/7] Move CompressedStorage to CompressedListStorage. Moved out all of the cursor logic to a view class. Passing all tests now. --- .pre-commit-config.yaml | 2 +- docs/source/reference/data.rst | 18 +- test/test_rb.py | 187 ++++++----- torchrl/data/__init__.py | 8 +- torchrl/data/replay_buffers/__init__.py | 8 +- torchrl/data/replay_buffers/checkpointers.py | 41 ++- torchrl/data/replay_buffers/storages.py | 315 ++++++++++++------- 7 files changed, 374 insertions(+), 205 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 37adaef7979..43d9ad1a525 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: - libcst == 0.4.7 - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 6.0.0 hooks: - id: flake8 args: [--config=setup.cfg] diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 812c19201f5..fdfa5405b56 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -144,8 +144,8 @@ using the following components: :template: rl_template.rst - CompressedStorage - CompressedStorageCheckpointer + CompressedListStorage + CompressedListStorageCheckpointer FlatStorageCheckpointer H5StorageCheckpointer ImmutableDatasetWriter @@ -197,10 +197,10 @@ Compressed Storage for Memory Efficiency ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ For applications where memory usage or memory bandwidth is a primary concern, especially when storing or transferring -large sensory observations like images or audio, the :class:`~torchrl.data.replay_buffers.storages.CompressedStorage` +large sensory observations like images, audio, or text. The :class:`~torchrl.data.replay_buffers.storages.CompressedListStorage` provides significant memory savings through compression. -The `CompressedStorage`` compresses data when storing and decompresses when retrieving, +The `CompressedListStorage`` compresses data when storing and decompresses when retrieving, achieving compression ratios of 2-10x for image data while maintaining full data fidelity. It uses zstd compression by default but supports custom compression algorithms. @@ -214,11 +214,11 @@ Key features: Example usage: >>> import torch - >>> from torchrl.data import ReplayBuffer, CompressedStorage + >>> from torchrl.data import ReplayBuffer, CompressedListStorage >>> from tensordict import TensorDict >>> >>> # Create a compressed storage for image data - >>> storage = CompressedStorage(max_size=1000, compression_level=3) + >>> storage = CompressedListStorage(max_size=1000, compression_level=3) >>> rb = ReplayBuffer(storage=storage, batch_size=32) >>> >>> # Add image data @@ -241,16 +241,16 @@ For custom compression algorithms: >>> def my_decompress(compressed_tensor, metadata): ... return compressed_tensor.to(metadata["dtype"]) >>> - >>> storage = CompressedStorage( + >>> storage = CompressedListStorage( ... max_size=1000, ... compression_fn=my_compress, ... decompression_fn=my_decompress ... ) -.. note:: The CompressedStorage requires the `zstandard` library for default compression. +.. note:: The CompressedListStorage requires the `zstandard` library for default compression. Install with: ``pip install zstandard`` -.. note:: An example of how to use the CompressedStorage is available in the +.. note:: An example of how to use the CompressedListStorage is available in the `examples/replay-buffers/compressed_replay_buffer_example.py `_ file. Sharing replay buffers across processes diff --git a/test/test_rb.py b/test/test_rb.py index a52295d9695..fbe3842b202 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -38,7 +38,7 @@ from torchrl.collectors import RandomPolicy, SyncDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( - CompressedStorage, + CompressedListStorage, FlatStorageCheckpointer, MultiStep, NestedStorageCheckpointer, @@ -189,7 +189,6 @@ @pytest.mark.parametrize("size", [3, 5, 100]) class TestComposableBuffers: def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False): - if storage is not None: storage = storage(size, compilable=compilable) @@ -332,10 +331,14 @@ def test_cursor_position(self, rb_type, sampler, writer, storage, size, datatype writer.extend(batch1) return - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): writer.extend(batch1) # Added less data than storage max size @@ -383,10 +386,14 @@ def test_extend(self, rb_type, sampler, writer, storage, size, datatype): length = min(rb._storage.max_size, len(rb) + data_shape) if writer is TensorDictMaxValueWriter: data["next", "reward"][-length:] = 1_000_000 - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) length = len(rb) if is_tensor_collection(data): @@ -424,10 +431,14 @@ def data_iter(): and size < len(data2) and isinstance(rb._storage, TensorStorage) ) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data2) @pytest.mark.skipif( @@ -530,10 +541,14 @@ def test_sample(self, rb_type, sampler, writer, storage, size, datatype): ): rb.extend(data) return - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) rb_sample = rb.sample() # if not isinstance(new_data, (torch.Tensor, TensorDictBase)): @@ -550,7 +565,6 @@ def data_iter_func(maxval, data=data): rb_sample_iter = data_iter_func(rb._batch_size, rb_sample) for single_sample in rb_sample_iter: - if is_tensor_collection(data) or isinstance(data, torch.Tensor): data_iter = data else: @@ -600,10 +614,14 @@ def test_index(self, rb_type, sampler, writer, storage, size, datatype): ): rb.extend(data) return - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) d1 = rb[2] d2 = rb._storage[2] @@ -620,7 +638,6 @@ def test_index(self, rb_type, sampler, writer, storage, size, datatype): assert b def test_pickable(self, rb_type, sampler, writer, storage, size, datatype): - rb = self._get_rb( rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size ) @@ -879,7 +896,6 @@ def extend_and_sample(data): @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) def test_extend_lazystack(self, storage_type): - rb = ReplayBuffer( storage=storage_type(6), batch_size=2, @@ -1506,7 +1522,6 @@ def test_rng_dumps(self, tmpdir): @pytest.mark.parametrize("size", [3, 5, 100]) @pytest.mark.parametrize("prefetch", [0]) class TestBuffers: - default_constr = { ReplayBuffer: ReplayBuffer, PrioritizedReplayBuffer: functools.partial( @@ -1573,10 +1588,14 @@ def test_cursor_position2(self, rbtype, storage, size, prefetch): cond = ( OLD_TORCH and size < len(batch1) and isinstance(rb._storage, TensorStorage) ) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(batch1) # Added fewer data than storage max size @@ -1631,10 +1650,14 @@ def test_extend(self, rbtype, storage, size, prefetch): rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) data = self._get_data(rbtype, size=5) cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) length = len(rb) for d in data[-length:]: @@ -1658,10 +1681,14 @@ def test_sample(self, rbtype, storage, size, prefetch): rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) data = self._get_data(rbtype, size=5) cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) new_data = rb.sample() if not isinstance(new_data, (torch.Tensor, TensorDictBase)): @@ -1688,10 +1715,14 @@ def test_index(self, rbtype, storage, size, prefetch): rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) data = self._get_data(rbtype, size=5) cond = OLD_TORCH and size < len(data) and isinstance(rb._storage, TensorStorage) - with pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) if cond else contextlib.nullcontext(): + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): rb.extend(data) d1 = rb[2] d2 = rb._storage[2] @@ -2568,7 +2599,7 @@ def test_slice_sampler( break else: raise AssertionError( - f"Not all items can be sampled: {set(range(100))-count_unique} are missing" + f"Not all items can be sampled: {set(range(100)) - count_unique} are missing" ) if strict_length: @@ -2775,7 +2806,6 @@ def test_slice_sampler_left_right_ndim(self): assert curr_eps.unique().numel() == 1 def test_slice_sampler_strictlength(self): - torch.manual_seed(0) data = TensorDict( @@ -3435,7 +3465,6 @@ def _robust_stack(tensor_list): def test_rb( self, storage_type, sampler_type, data_type, p, num_buffer_sampled, batch_size ): - storages = [self._make_storage(storage_type, data_type) for _ in range(3)] collate_fn = self._make_collate(storage_type) data = [self._make_data(data_type) for _ in range(3)] @@ -4033,22 +4062,29 @@ def test_ray_rb_iter(self): @pytest.mark.skipif(not _has_zstandard, reason="zstandard required for this test.") -class TestCompressedStorage: - """Test cases for CompressedStorage.""" +class TestCompressedListStorage: + """Test cases for CompressedListStorage.""" def test_compressed_storage_initialization(self): - """Test that CompressedStorage initializes correctly.""" - storage = CompressedStorage(max_size=100, compression_level=3) + """Test that CompressedListStorage initializes correctly.""" + storage = CompressedListStorage(max_size=100, compression_level=3) assert storage.max_size == 100 assert storage.compression_level == 3 assert len(storage) == 0 - def test_compressed_storage_tensor(self): - """Test compression and decompression of tensor data.""" - storage = CompressedStorage(max_size=10, compression_level=3) - - # Create test tensor - test_tensor = torch.randn(3, 84, 84, dtype=torch.float32) + @pytest.mark.parametrize( + "test_tensor", + [ + torch.rand(1), # 0D scalar + torch.randn(84, dtype=torch.float32), # 1D tensor + torch.randn(84, 84, dtype=torch.float32), # 2D tensor + torch.randn(1, 84, 84, dtype=torch.float32), # 3D tensor + torch.randn(32, 84, 84, dtype=torch.float32), # 3D tensor + ], + ) + def test_compressed_storage_tensor(self, test_tensor): + """Test compression and decompression of tensor data of various shapes.""" + storage = CompressedListStorage(max_size=10, compression_level=3) # Store tensor storage.set(0, test_tensor) @@ -4057,13 +4093,17 @@ def test_compressed_storage_tensor(self): retrieved_tensor = storage.get(0) # Verify data integrity + assert ( + test_tensor.shape == retrieved_tensor.shape + ), f"Expected shape {test_tensor.shape}, got {retrieved_tensor.shape}" + assert ( + test_tensor.dtype == retrieved_tensor.dtype + ), f"Expected dtype {test_tensor.dtype}, got {retrieved_tensor.dtype}" assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) - assert test_tensor.shape == retrieved_tensor.shape - assert test_tensor.dtype == retrieved_tensor.dtype def test_compressed_storage_tensordict(self): """Test compression and decompression of TensorDict data.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Create test TensorDict test_td = TensorDict( @@ -4090,7 +4130,7 @@ def test_compressed_storage_tensordict(self): def test_compressed_storage_multiple_indices(self): """Test storing and retrieving multiple items.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Store multiple tensors tensors = [ @@ -4110,8 +4150,8 @@ def test_compressed_storage_multiple_indices(self): assert torch.allclose(original, retrieved_tensor, atol=1e-6) def test_compressed_storage_with_replay_buffer(self): - """Test CompressedStorage with ReplayBuffer.""" - storage = CompressedStorage(max_size=100, compression_level=3) + """Test CompressedListStorage with ReplayBuffer.""" + storage = CompressedListStorage(max_size=100, compression_level=3) rb = ReplayBuffer(storage=storage, batch_size=5) # Create test data @@ -4125,7 +4165,6 @@ def test_compressed_storage_with_replay_buffer(self): ) # Add data to replay buffer - print("extending") rb.extend(data) # Sample from replay buffer @@ -4140,7 +4179,7 @@ def test_compressed_storage_with_replay_buffer(self): def test_compressed_storage_state_dict(self): """Test saving and loading state dict.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Add some data test_tensor = torch.randn(3, 3, dtype=torch.float32) @@ -4150,7 +4189,7 @@ def test_compressed_storage_state_dict(self): state_dict = storage.state_dict() # Create new storage and load state dict - new_storage = CompressedStorage(max_size=10, compression_level=3) + new_storage = CompressedListStorage(max_size=10, compression_level=3) new_storage.load_state_dict(state_dict) # Verify data integrity @@ -4159,7 +4198,7 @@ def test_compressed_storage_state_dict(self): def test_compressed_storage_checkpointing(self): """Test checkpointing functionality.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Add some data test_td = TensorDict( @@ -4179,7 +4218,7 @@ def test_compressed_storage_checkpointing(self): storage.dumps(checkpoint_path) # Create new storage and load checkpoint - new_storage = CompressedStorage(max_size=10, compression_level=3) + new_storage = CompressedListStorage(max_size=10, compression_level=3) new_storage.loads(checkpoint_path) # Verify data integrity @@ -4189,7 +4228,7 @@ def test_compressed_storage_checkpointing(self): def test_compressed_storage_length(self): """Test that length is calculated correctly.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Initially empty assert len(storage) == 0 @@ -4198,15 +4237,15 @@ def test_compressed_storage_length(self): storage.set(0, torch.randn(2, 2)) assert len(storage) == 1 - storage.set(2, torch.randn(2, 2)) + storage.set(1, torch.randn(2, 2)) assert len(storage) == 2 - storage.set(1, torch.randn(2, 2)) + storage.set(2, torch.randn(2, 2)) assert len(storage) == 3 def test_compressed_storage_contains(self): """Test the contains method.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Initially empty assert not storage.contains(0) @@ -4218,7 +4257,7 @@ def test_compressed_storage_contains(self): def test_compressed_storage_empty(self): """Test emptying the storage.""" - storage = CompressedStorage(max_size=10, compression_level=3) + storage = CompressedListStorage(max_size=10, compression_level=3) # Add some data storage.set(0, torch.randn(2, 2)) @@ -4240,7 +4279,7 @@ def custom_decompress(compressed_tensor, metadata): # Simple decompression: convert back to original dtype return compressed_tensor.to(metadata["dtype"]) - storage = CompressedStorage( + storage = CompressedListStorage( max_size=10, compression_fn=custom_compress, decompression_fn=custom_decompress, @@ -4257,7 +4296,7 @@ def custom_decompress(compressed_tensor, metadata): def test_compressed_storage_error_handling(self): """Test error handling for invalid operations.""" - storage = CompressedStorage(max_size=5, compression_level=3) + storage = CompressedListStorage(max_size=5, compression_level=3) # Test setting data beyond max_size with pytest.raises(RuntimeError): @@ -4269,7 +4308,7 @@ def test_compressed_storage_error_handling(self): def test_compressed_storage_memory_efficiency(self): """Test that compression actually reduces memory usage.""" - storage = CompressedStorage(max_size=100, compression_level=3) + storage = CompressedListStorage(max_size=100, compression_level=3) # Create large tensor data large_tensor = torch.zeros(100, 3, 84, 84, dtype=torch.int64) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 5c8bff8a3af..226ec4d5bb9 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -32,8 +32,8 @@ ) from .postprocs import DensifyReward, MultiStep from .replay_buffers import ( - CompressedStorage, - CompressedStorageCheckpointer, + CompressedListStorage, + CompressedListStorageCheckpointer, Flat2TED, FlatStorageCheckpointer, H5Combine, @@ -120,8 +120,8 @@ "Choice", "Composite", "CompositeSpec", - "CompressedStorage", - "CompressedStorageCheckpointer", + "CompressedListStorage", + "CompressedListStorageCheckpointer", "ConstantKLController", "ContentBase", "DEVICE_TYPING", diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 71c1dd20795..540d7c129be 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from .checkpointers import ( - CompressedStorageCheckpointer, + CompressedListStorageCheckpointer, FlatStorageCheckpointer, H5StorageCheckpointer, ListStorageCheckpointer, @@ -33,7 +33,7 @@ SliceSamplerWithoutReplacement, ) from .storages import ( - CompressedStorage, + CompressedListStorage, LazyMemmapStorage, LazyStackStorage, LazyTensorStorage, @@ -53,8 +53,8 @@ ) __all__ = [ - "CompressedStorage", - "CompressedStorageCheckpointer", + "CompressedListStorage", + "CompressedListStorageCheckpointer", "FlatStorageCheckpointer", "H5StorageCheckpointer", "ListStorageCheckpointer", diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 346dddf28af..82869204ea6 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -68,8 +68,8 @@ def loads(storage, path): ) -class CompressedStorageCheckpointer(StorageCheckpointerBase): - """A storage checkpointer for CompressedStorage. +class CompressedListStorageCheckpointer(StorageCheckpointerBase): + """A storage checkpointer for CompressedListStorage. This checkpointer saves compressed data and metadata separately for efficient storage. @@ -79,9 +79,12 @@ def dumps(self, storage, path): path = Path(path) path.mkdir(exist_ok=True) - if not hasattr(storage, "_compressed_data") or not storage._compressed_data: + if ( + not hasattr(storage, "_compressed_data") + or len(storage._compressed_data) == 0 + ): raise RuntimeError( - "Cannot save an empty or non-initialized CompressedStorage." + "Cannot save an empty or non-initialized CompressedListStorage." ) # Save compressed data and metadata @@ -107,7 +110,7 @@ def dumps(self, storage, path): ) elif item_metadata["type"] == "tensordict": # Save each field separately - item_dir = path / f"compressed_data_{i}" + item_dir = path / f"compressed_data_{i}.td" item_dir.mkdir(exist_ok=True) for key, value in compressed_item.items(): @@ -133,10 +136,34 @@ def loads(self, storage, path): with open(path / "compressed_metadata.json") as f: metadata = json.load(f) + # Convert string dtypes back to torch.dtype objects + def convert_dtype(item): + if isinstance(item, dict): + if "dtype" in item and isinstance(item["dtype"], str): + # Convert string back to torch.dtype + dtype_str = item["dtype"] + if hasattr(torch, dtype_str.replace("torch.", "")): + item["dtype"] = getattr(torch, dtype_str.replace("torch.", "")) + else: + # Handle cases like 'torch.float32' -> torch.float32 + item["dtype"] = eval(dtype_str) + + # Recursively handle nested dictionaries + for _key, value in item.items(): + if isinstance(value, dict): + convert_dtype(value) + return item + + for item in metadata: + if item is not None: + convert_dtype(item) + # Load compressed data compressed_data = [] i = 0 + # TODO(adrian): Can we not know the serialised format beforehand? Then we can use glob to iterate through the files we know exist: + # `for path in glob.glob(path / f"compressed_data_*.{fmt}" for fmt in ["npy", "pkl", "td"]):`` while True: if (path / f"compressed_data_{i}.npy").exists(): # Load tensor data @@ -149,9 +176,9 @@ def loads(self, storage, path): with open(path / f"compressed_data_{i}.pkl", "rb") as f: data = pickle.load(f) compressed_data.append(data) - elif (path / f"compressed_data_{i}").exists(): + elif (path / f"compressed_data_{i}.td").exists(): # Load tensordict data - item_dir = path / f"compressed_data_{i}" + item_dir = path / f"compressed_data_{i}.td" item_data = {} for key in metadata[i]["fields"].keys(): diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index b6c17769659..e836df3a828 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -32,7 +32,7 @@ from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger from torchrl.data.replay_buffers.checkpointers import ( - CompressedStorageCheckpointer, + CompressedListStorageCheckpointer, ListStorageCheckpointer, StorageCheckpointerBase, StorageEnsembleCheckpointer, @@ -873,7 +873,6 @@ def set( # noqa: F811 *, set_cursor: bool = True, ): - if set_cursor: self._last_cursor = cursor @@ -1361,13 +1360,107 @@ def get(self, index: int | Sequence[int] | slice) -> Any: return result -class CompressedStorage(Storage): +class CompressedStorageView: + """A view that makes compressed storage look like a regular list to ListStorage methods.""" + + def __init__(self, compressed_data, metadata, parent_storage): + self._compressed_data = compressed_data + self._metadata = metadata + self._parent_storage = parent_storage + + def __len__(self): + return len([item for item in self._compressed_data if item is not None]) + + def __getitem__(self, index): + if isinstance(index, INT_CLASSES): + if ( + index >= len(self._compressed_data) + or self._compressed_data[index] is None + ): + raise IndexError(f"Index {index} out of bounds or not set") + return self._parent_storage._decompress_item( + self._compressed_data[index], self._metadata[index] + ) + elif isinstance(index, slice): + start, stop, step = index.indices(len(self._compressed_data)) + results = [] + for i in range(start, stop, step): + if ( + i < len(self._compressed_data) + and self._compressed_data[i] is not None + ): + results.append( + self._parent_storage._decompress_item( + self._compressed_data[i], self._metadata[i] + ) + ) + return results + else: + # Handle lists, tensors, etc. + if isinstance(index, torch.Tensor): + if index.ndim == 0: + # Handle 0-dimensional tensor (scalar) + return self[index.item()] + if index.device.type != "cpu": + index = index.cpu().tolist() + else: + index = index.tolist() + + results = [] + for i in index: + if i >= len(self._compressed_data) or self._compressed_data[i] is None: + raise IndexError(f"Index {i} out of bounds or not set") + results.append( + self._parent_storage._decompress_item( + self._compressed_data[i], self._metadata[i] + ) + ) + return results + + def __setitem__(self, index, value): + if isinstance(index, INT_CLASSES): + # Ensure we have enough space + while len(self._compressed_data) <= index: + self._compressed_data.append(None) + self._metadata.append(None) + + # Compress and store + compressed_data, metadata = self._parent_storage._compress_item(value) + self._compressed_data[index] = compressed_data + self._metadata[index] = metadata + elif isinstance(index, slice): + # Handle slice assignment + if not hasattr(value, "__iter__"): + value = [value] + start, stop, step = index.indices(len(self._compressed_data)) + indices = list(range(start, stop, step)) + + for i, v in zip(indices, value): + self[i] = v + else: + # Handle multiple indices + if isinstance(index, torch.Tensor) and index.device.type != "cpu": + index = index.cpu().tolist() + if not hasattr(value, "__iter__"): + value = [value] * len(index) + for i, v in zip(index, value): + self[i] = v + + def append(self, value): + # Find the next available slot + index = len(self._compressed_data) + self[index] = value + + +class CompressedListStorage(ListStorage): """A storage that compresses and decompresses data. This storage compresses data when storing and decompresses when retrieving. It's particularly useful for storing raw sensory observations like images that can be compressed significantly to save memory. + pytest test/test_rb.py::TestCompressedListStorage::test_compressed_storage_tensor -v + Args: max_size (int): size of the storage, i.e. maximum number of elements stored in the buffer. @@ -1385,11 +1478,11 @@ class CompressedStorage(Storage): Examples: >>> import torch - >>> from torchrl.data import CompressedStorage, ReplayBuffer + >>> from torchrl.data import CompressedListStorage, ReplayBuffer >>> from tensordict import TensorDict >>> >>> # Create a compressed storage for image data - >>> storage = CompressedStorage(max_size=1000, compression_level=3) + >>> storage = CompressedListStorage(max_size=1000, compression_level=3) >>> rb = ReplayBuffer(storage=storage, batch_size=5) >>> >>> # Add some image data @@ -1403,7 +1496,7 @@ class CompressedStorage(Storage): """ - _default_checkpointer = CompressedStorageCheckpointer + _default_checkpointer = CompressedListStorageCheckpointer def __init__( self, @@ -1415,8 +1508,7 @@ def __init__( device: torch.device = "cpu", compilable: bool = False, ): - super().__init__(max_size, compilable=compilable) - self.device = device + super().__init__(max_size, compilable=compilable, device=device) self.compression_level = compression_level # Set up compression functions @@ -1474,50 +1566,61 @@ def _default_decompression_fn( decompressed_bytes = decompressor.decompress(compressed_bytes) # Convert back to tensor - tensor = torch.frombuffer(decompressed_bytes, dtype=metadata["dtype"]) + tensor = torch.frombuffer( + bytearray(decompressed_bytes), dtype=metadata["dtype"] + ) tensor = tensor.reshape(metadata["shape"]) tensor = tensor.to(metadata["device"]) return tensor - def set( - self, - cursor: int | Sequence[int] | slice, - data: Any, - *, - set_cursor: bool = True, - ): - """Set data in the storage with compression.""" - if isinstance(cursor, (INT_CLASSES, slice)): - cursor = [cursor] - - if isinstance(data, (list, tuple)): - data_list = data - elif isinstance(data, (torch.Tensor, TensorDictBase)): - data_list = data.unbind(0) - # determine if data is iterable and has length equal to cursor - elif hasattr(data, "__len__") and len(data) == len(cursor): - data_list = data + @property + def _storage(self): + """Virtual storage that handles compression/decompression transparently.""" + return CompressedStorageView(self._compressed_data, self._metadata, self) + + @_storage.setter + def _storage(self, value): + # This allows ListStorage.__init__ to set _storage initially + if hasattr(self, "_compressed_data"): + # If we already have compressed storage, ignore attempts to set _storage + pass else: - data_list = [data] + # During initialization, create our compressed storage + self._compressed_data = [] + self._metadata = [] - for idx, item in zip(cursor, data_list): - if idx >= self.max_size: - raise RuntimeError( - f"Cannot set data at index {idx}: storage max_size is {self.max_size}" - ) + def __len__(self): + """Return the number of items in the storage.""" + return len([item for item in self._compressed_data if item is not None]) - # Compress the data - compressed_data, metadata = self._compress_item(item) + def _empty(self): + """Empty the storage.""" + self._compressed_data = [] + self._metadata = [] - # Ensure we have enough space - while len(self._compressed_data) <= idx: - self._compressed_data.append(None) - self._metadata.append(None) + def state_dict(self) -> dict[str, Any]: + """Save the storage state.""" + return { + "_compressed_data": self._compressed_data, + "_metadata": self._metadata, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the storage state.""" + self._compressed_data = state_dict["_compressed_data"] + self._metadata = state_dict["_metadata"] - # Store compressed data and metadata - self._compressed_data[idx] = compressed_data - self._metadata[idx] = metadata + def contains(self, item): + """Check if an item is in the storage.""" + if isinstance(item, int): + if item < 0: + item += len(self._compressed_data) + return ( + 0 <= item < len(self._compressed_data) + and self._compressed_data[item] is not None + ) + raise NotImplementedError(f"type {type(item)} is not supported yet.") def _compress_item(self, item: Any) -> tuple[torch.Tensor, dict]: """Compress a single item and return compressed data with metadata.""" @@ -1556,28 +1659,28 @@ def _compress_item(self, item: Any) -> tuple[torch.Tensor, dict]: return compressed, metadata - def get(self, index: int | Sequence[int] | slice) -> Any: - """Get data from the storage with decompression.""" - if isinstance(index, (INT_CLASSES, slice)): - indices = [index] - else: - indices = index + # def get(self, index: int | Sequence[int] | slice) -> Any: + # """Get data from the storage with decompression.""" + # if isinstance(index, (INT_CLASSES, slice)): + # indices = [index] + # else: + # indices = index - results = [] - for idx in indices: - if idx >= len(self._compressed_data) or self._compressed_data[idx] is None: - raise IndexError(f"Index {idx} out of bounds or not set") + # results = [] + # for idx in indices: + # if idx >= len(self._compressed_data) or self._compressed_data[idx] is None: + # raise IndexError(f"Index {idx} out of bounds or not set") - compressed_data = self._compressed_data[idx] - metadata = self._metadata[idx] + # compressed_data = self._compressed_data[idx] + # metadata = self._metadata[idx] - # Decompress the data - decompressed = self._decompress_item(compressed_data, metadata) - results.append(decompressed) + # # Decompress the data + # decompressed = self._decompress_item(compressed_data, metadata) + # results.append(decompressed) - if len(results) == 1: - return results[0] - return results + # if len(results) == 1: + # return results[0] + # return results def _decompress_item(self, compressed_data: Any, metadata: dict) -> Any: """Decompress a single item using its metadata.""" @@ -1601,49 +1704,49 @@ def _decompress_item(self, compressed_data: Any, metadata: dict) -> Any: # Return as-is for other types return metadata["value"] - def __len__(self): - """Return the number of items in the storage.""" - return len([item for item in self._compressed_data if item is not None]) - - def state_dict(self) -> dict[str, Any]: - """Save the storage state.""" - return { - "_compressed_data": self._compressed_data, - "_metadata": self._metadata, - } - - def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Load the storage state.""" - self._compressed_data = state_dict["_compressed_data"] - self._metadata = state_dict["_metadata"] - - def _empty(self): - """Empty the storage.""" - self._compressed_data = [] - self._metadata = [] - - def loads(self, path: str): - """Load the storage state from a file.""" - super().loads(path) - from tensordict.utils import _STR_DTYPE_TO_DTYPE - from torch.utils._pytree import tree_flatten, tree_unflatten - - leaves, spec = tree_flatten(self._metadata) - leaves = [ - _STR_DTYPE_TO_DTYPE.get(x, x) if isinstance(x, str) else x for x in leaves - ] - self._metadata = tree_unflatten(leaves, spec) - - def contains(self, item): - """Check if an item is in the storage.""" - if isinstance(item, int): - if item < 0: - item += len(self._compressed_data) - return ( - 0 <= item < len(self._compressed_data) - and self._compressed_data[item] is not None - ) - raise NotImplementedError(f"type {type(item)} is not supported yet.") + # def __len__(self): + # """Return the number of items in the storage.""" + # return len([item for item in self._compressed_data if item is not None]) + + # def state_dict(self) -> dict[str, Any]: + # """Save the storage state.""" + # return { + # "_compressed_data": self._compressed_data, + # "_metadata": self._metadata, + # } + + # def load_state_dict(self, state_dict: dict[str, Any]) -> None: + # """Load the storage state.""" + # self._compressed_data = state_dict["_compressed_data"] + # self._metadata = state_dict["_metadata"] + + # def _empty(self): + # """Empty the storage.""" + # self._compressed_data = [] + # self._metadata = [] + + # def loads(self, path: str): + # """Load the storage state from a file.""" + # super().loads(path) + # from tensordict.utils import _STR_DTYPE_TO_DTYPE + # from torch.utils._pytree import tree_flatten, tree_unflatten + + # leaves, spec = tree_flatten(self._metadata) + # leaves = [ + # _STR_DTYPE_TO_DTYPE.get(x, x) if isinstance(x, str) else x for x in leaves + # ] + # self._metadata = tree_unflatten(leaves, spec) + + # def contains(self, item): + # """Check if an item is in the storage.""" + # if isinstance(item, int): + # if item < 0: + # item += len(self._compressed_data) + # return ( + # 0 <= item < len(self._compressed_data) + # and self._compressed_data[item] is not None + # ) + # raise NotImplementedError(f"type {type(item)} is not supported yet.") def bytes(self): """Return the number of bytes in the storage.""" @@ -1706,7 +1809,7 @@ def __init__( self._transforms = transforms if transforms is not None and len(transforms) != len(storages): raise TypeError( - "transforms must have the same length as the storages " "provided." + "transforms must have the same length as the storages provided." ) @property @@ -1730,7 +1833,7 @@ def get(self, item): buffer_ids = item.get("buffer_ids") index = item.get("index") results = [] - for (buffer_id, sample) in zip(buffer_ids, index): + for buffer_id, sample in zip(buffer_ids, index): buffer_id = self._convert_id(buffer_id) results.append((buffer_id, self._get_storage(buffer_id).get(sample))) if self._transforms is not None: @@ -1880,7 +1983,7 @@ def _collate_id(x): def _get_default_collate(storage, _is_tensordict=False): if isinstance(storage, (LazyStackStorage, TensorStorage)): return _collate_id - elif isinstance(storage, CompressedStorage): + elif isinstance(storage, CompressedListStorage): return lazy_stack elif isinstance(storage, (ListStorage, StorageEnsemble)): return _stack_anything From 53bea2f6abda2a77ec271c817f811d3662f3e5c2 Mon Sep 17 00:00:00 2001 From: Adrian Orenstein Date: Sat, 12 Jul 2025 21:39:00 -0600 Subject: [PATCH 4/7] Refactor out the storage view. Expose functions in the ListStorage class. --- test/test_rb.py | 2 +- torchrl/data/replay_buffers/checkpointers.py | 11 +- torchrl/data/replay_buffers/storages.py | 339 +++++++------------ 3 files changed, 123 insertions(+), 229 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index fbe3842b202..bf6c2cc9fbd 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -4322,7 +4322,7 @@ def test_compressed_storage_memory_efficiency(self): storage.set(0, large_tensor) # Estimate compressed size - compressed_data = storage._compressed_data[0] + compressed_data = storage._storage[0] compressed_size = compressed_data.numel() # uint8 bytes # Verify compression ratio is reasonable (at least 2x for random data) diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 82869204ea6..71c525d7963 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -79,10 +79,7 @@ def dumps(self, storage, path): path = Path(path) path.mkdir(exist_ok=True) - if ( - not hasattr(storage, "_compressed_data") - or len(storage._compressed_data) == 0 - ): + if not hasattr(storage, "_storage") or len(storage._storage) == 0: raise RuntimeError( "Cannot save an empty or non-initialized CompressedListStorage." ) @@ -90,8 +87,8 @@ def dumps(self, storage, path): # Save compressed data and metadata state_dict = storage.state_dict() - # Save compressed data as separate files for efficiency - compressed_data = state_dict["_compressed_data"] + # Save compressed data and metadata + compressed_data = state_dict["_storage"] metadata = state_dict["_metadata"] # Save metadata @@ -203,7 +200,7 @@ def convert_dtype(item): compressed_data.append(None) # Load into storage - storage._compressed_data = compressed_data + storage._storage = compressed_data storage._metadata = metadata diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index e836df3a828..6e0a904460d 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -279,7 +279,7 @@ def set( return if isinstance(cursor, slice): data = self._to_device(data) - self._storage[cursor] = data + self._set_slice(cursor, data) return if isinstance( data, @@ -303,7 +303,7 @@ def set( ) return else: - if cursor > len(self._storage): + if cursor > len(self): raise RuntimeError( "Cannot append data located more than one item away from " f"the storage size: the storage size is {len(self)} " @@ -316,14 +316,24 @@ def set( f"and the index of the item to be set is {cursor}." ) data = self._to_device(data) - if cursor == len(self._storage): - self._storage.append(data) - else: - self._storage[cursor] = data + self._set_item(cursor, data) + + def _set_item(self, cursor: int, data: Any) -> None: + """Set a single item in the storage.""" + if cursor == len(self._storage): + self._storage.append(data) + else: + self._storage[cursor] = data + + def _set_slice(self, cursor: slice, data: Any) -> None: + """Set a slice in the storage.""" + self._storage[cursor] = data def get(self, index: int | Sequence[int] | slice) -> Any: - if isinstance(index, (INT_CLASSES, slice)): - return self._storage[index] + if isinstance(index, INT_CLASSES): + return self._get_item(index) + elif isinstance(index, slice): + return self._get_slice(index) elif isinstance(index, tuple): if len(index) > 1: raise RuntimeError( @@ -333,9 +343,22 @@ def get(self, index: int | Sequence[int] | slice) -> Any: else: if isinstance(index, torch.Tensor) and index.device.type != "cpu": index = index.cpu().tolist() - return [self._storage[i] for i in index] + return self._get_list(index) + + def _get_item(self, index: int) -> Any: + """Get a single item from the storage.""" + return self._storage[index] + + def _get_slice(self, index: slice) -> Any: + """Get a slice from the storage.""" + return self._storage[index] + + def _get_list(self, index: list) -> list: + """Get a list of items from the storage.""" + return [self._storage[i] for i in index] def __len__(self): + """Get the length of the storage.""" return len(self._storage) def state_dict(self) -> dict[str, Any]: @@ -379,9 +402,8 @@ def __repr__(self): def contains(self, item): if isinstance(item, int): if item < 0: - item += len(self._storage) - - return 0 <= item < len(self._storage) + item += len(self) + return self._contains_int(item) if isinstance(item, torch.Tensor): return torch.tensor( [self.contains(elt) for elt in item.tolist()], @@ -390,6 +412,10 @@ def contains(self, item): ).reshape_as(item) raise NotImplementedError(f"type {type(item)} is not supported yet.") + def _contains_int(self, item: int) -> bool: + """Check if an integer index is contained in the storage.""" + return 0 <= item < len(self) + class LazyStackStorage(ListStorage): """A ListStorage that returns LazyStackTensorDict instances. @@ -1360,98 +1386,6 @@ def get(self, index: int | Sequence[int] | slice) -> Any: return result -class CompressedStorageView: - """A view that makes compressed storage look like a regular list to ListStorage methods.""" - - def __init__(self, compressed_data, metadata, parent_storage): - self._compressed_data = compressed_data - self._metadata = metadata - self._parent_storage = parent_storage - - def __len__(self): - return len([item for item in self._compressed_data if item is not None]) - - def __getitem__(self, index): - if isinstance(index, INT_CLASSES): - if ( - index >= len(self._compressed_data) - or self._compressed_data[index] is None - ): - raise IndexError(f"Index {index} out of bounds or not set") - return self._parent_storage._decompress_item( - self._compressed_data[index], self._metadata[index] - ) - elif isinstance(index, slice): - start, stop, step = index.indices(len(self._compressed_data)) - results = [] - for i in range(start, stop, step): - if ( - i < len(self._compressed_data) - and self._compressed_data[i] is not None - ): - results.append( - self._parent_storage._decompress_item( - self._compressed_data[i], self._metadata[i] - ) - ) - return results - else: - # Handle lists, tensors, etc. - if isinstance(index, torch.Tensor): - if index.ndim == 0: - # Handle 0-dimensional tensor (scalar) - return self[index.item()] - if index.device.type != "cpu": - index = index.cpu().tolist() - else: - index = index.tolist() - - results = [] - for i in index: - if i >= len(self._compressed_data) or self._compressed_data[i] is None: - raise IndexError(f"Index {i} out of bounds or not set") - results.append( - self._parent_storage._decompress_item( - self._compressed_data[i], self._metadata[i] - ) - ) - return results - - def __setitem__(self, index, value): - if isinstance(index, INT_CLASSES): - # Ensure we have enough space - while len(self._compressed_data) <= index: - self._compressed_data.append(None) - self._metadata.append(None) - - # Compress and store - compressed_data, metadata = self._parent_storage._compress_item(value) - self._compressed_data[index] = compressed_data - self._metadata[index] = metadata - elif isinstance(index, slice): - # Handle slice assignment - if not hasattr(value, "__iter__"): - value = [value] - start, stop, step = index.indices(len(self._compressed_data)) - indices = list(range(start, stop, step)) - - for i, v in zip(indices, value): - self[i] = v - else: - # Handle multiple indices - if isinstance(index, torch.Tensor) and index.device.type != "cpu": - index = index.cpu().tolist() - if not hasattr(value, "__iter__"): - value = [value] * len(index) - for i, v in zip(index, value): - self[i] = v - - def append(self, value): - # Find the next available slot - index = len(self._compressed_data) - self[index] = value - - class CompressedListStorage(ListStorage): """A storage that compresses and decompresses data. @@ -1523,7 +1457,7 @@ def __init__( self.decompression_fn = decompression_fn # Store compressed data and metadata - self._compressed_data = [] + self._storage = [] self._metadata = [] # Store shape, dtype, device info for each item def _default_compression_fn(self, tensor: torch.Tensor) -> torch.Tensor: @@ -1574,54 +1508,6 @@ def _default_decompression_fn( return tensor - @property - def _storage(self): - """Virtual storage that handles compression/decompression transparently.""" - return CompressedStorageView(self._compressed_data, self._metadata, self) - - @_storage.setter - def _storage(self, value): - # This allows ListStorage.__init__ to set _storage initially - if hasattr(self, "_compressed_data"): - # If we already have compressed storage, ignore attempts to set _storage - pass - else: - # During initialization, create our compressed storage - self._compressed_data = [] - self._metadata = [] - - def __len__(self): - """Return the number of items in the storage.""" - return len([item for item in self._compressed_data if item is not None]) - - def _empty(self): - """Empty the storage.""" - self._compressed_data = [] - self._metadata = [] - - def state_dict(self) -> dict[str, Any]: - """Save the storage state.""" - return { - "_compressed_data": self._compressed_data, - "_metadata": self._metadata, - } - - def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Load the storage state.""" - self._compressed_data = state_dict["_compressed_data"] - self._metadata = state_dict["_metadata"] - - def contains(self, item): - """Check if an item is in the storage.""" - if isinstance(item, int): - if item < 0: - item += len(self._compressed_data) - return ( - 0 <= item < len(self._compressed_data) - and self._compressed_data[item] is not None - ) - raise NotImplementedError(f"type {type(item)} is not supported yet.") - def _compress_item(self, item: Any) -> tuple[torch.Tensor, dict]: """Compress a single item and return compressed data with metadata.""" if isinstance(item, torch.Tensor): @@ -1659,29 +1545,6 @@ def _compress_item(self, item: Any) -> tuple[torch.Tensor, dict]: return compressed, metadata - # def get(self, index: int | Sequence[int] | slice) -> Any: - # """Get data from the storage with decompression.""" - # if isinstance(index, (INT_CLASSES, slice)): - # indices = [index] - # else: - # indices = index - - # results = [] - # for idx in indices: - # if idx >= len(self._compressed_data) or self._compressed_data[idx] is None: - # raise IndexError(f"Index {idx} out of bounds or not set") - - # compressed_data = self._compressed_data[idx] - # metadata = self._metadata[idx] - - # # Decompress the data - # decompressed = self._decompress_item(compressed_data, metadata) - # results.append(decompressed) - - # if len(results) == 1: - # return results[0] - # return results - def _decompress_item(self, compressed_data: Any, metadata: dict) -> Any: """Decompress a single item using its metadata.""" if metadata["type"] == "tensor": @@ -1704,49 +1567,83 @@ def _decompress_item(self, compressed_data: Any, metadata: dict) -> Any: # Return as-is for other types return metadata["value"] - # def __len__(self): - # """Return the number of items in the storage.""" - # return len([item for item in self._compressed_data if item is not None]) - - # def state_dict(self) -> dict[str, Any]: - # """Save the storage state.""" - # return { - # "_compressed_data": self._compressed_data, - # "_metadata": self._metadata, - # } - - # def load_state_dict(self, state_dict: dict[str, Any]) -> None: - # """Load the storage state.""" - # self._compressed_data = state_dict["_compressed_data"] - # self._metadata = state_dict["_metadata"] - - # def _empty(self): - # """Empty the storage.""" - # self._compressed_data = [] - # self._metadata = [] - - # def loads(self, path: str): - # """Load the storage state from a file.""" - # super().loads(path) - # from tensordict.utils import _STR_DTYPE_TO_DTYPE - # from torch.utils._pytree import tree_flatten, tree_unflatten - - # leaves, spec = tree_flatten(self._metadata) - # leaves = [ - # _STR_DTYPE_TO_DTYPE.get(x, x) if isinstance(x, str) else x for x in leaves - # ] - # self._metadata = tree_unflatten(leaves, spec) - - # def contains(self, item): - # """Check if an item is in the storage.""" - # if isinstance(item, int): - # if item < 0: - # item += len(self._compressed_data) - # return ( - # 0 <= item < len(self._compressed_data) - # and self._compressed_data[item] is not None - # ) - # raise NotImplementedError(f"type {type(item)} is not supported yet.") + def _set_item(self, cursor: int, data: Any) -> None: + """Set a single item in the compressed storage.""" + # Ensure we have enough space + while len(self._storage) <= cursor: + self._storage.append(None) + self._metadata.append(None) + + # Compress and store + compressed_data, metadata = self._compress_item(data) + self._storage[cursor] = compressed_data + self._metadata[cursor] = metadata + + def _set_slice(self, cursor: slice, data: Any) -> None: + """Set a slice in the compressed storage.""" + # Handle slice assignment + if not hasattr(data, "__iter__"): + data = [data] + start, stop, step = cursor.indices(len(self._storage)) + indices = list(range(start, stop, step)) + + for i, value in zip(indices, data): + self._set_item(i, value) + + def _get_item(self, index: int) -> Any: + """Get a single item from the compressed storage.""" + if index >= len(self._storage) or self._storage[index] is None: + raise IndexError(f"Index {index} out of bounds or not set") + + compressed_data = self._storage[index] + metadata = self._metadata[index] + return self._decompress_item(compressed_data, metadata) + + def _get_slice(self, index: slice) -> list: + """Get a slice from the compressed storage.""" + start, stop, step = index.indices(len(self._storage)) + results = [] + for i in range(start, stop, step): + if i < len(self._storage) and self._storage[i] is not None: + results.append(self._get_item(i)) + return results + + def _get_list(self, index: list) -> list: + """Get a list of items from the compressed storage.""" + if isinstance(index, torch.Tensor) and index.device.type != "cpu": + index = index.cpu().tolist() + + results = [] + for i in index: + if i >= len(self._storage) or self._storage[i] is None: + raise IndexError(f"Index {i} out of bounds or not set") + results.append(self._get_item(i)) + return results + + def __len__(self) -> int: + """Get the length of the compressed storage.""" + return len([item for item in self._storage if item is not None]) + + def _contains_int(self, item: int) -> bool: + """Check if an integer index is contained in the compressed storage.""" + return 0 <= item < len(self._storage) and self._storage[item] is not None + + def _empty(self): + """Empty the storage.""" + self._storage = [] + self._metadata = [] + + def state_dict(self) -> dict[str, Any]: + """Save the storage state.""" + return { + "_storage": self._storage, + "_metadata": self._metadata, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load the storage state.""" + self._storage = state_dict["_storage"] + self._metadata = state_dict["_metadata"] def bytes(self): """Return the number of bytes in the storage.""" @@ -1763,9 +1660,9 @@ def compressed_size_from_list(data: Any) -> int: else: return 0 - compressed_size_estimate = compressed_size_from_list(self._compressed_data) + compressed_size_estimate = compressed_size_from_list(self._storage) if compressed_size_estimate == 0: - if len(self._compressed_data) > 0: + if len(self._storage) > 0: raise RuntimeError( "Compressed storage is not empty but the compressed size is 0. This is a bug." ) From 16aba7f6244adfeafb3283355f4f022b5a49d4a4 Mon Sep 17 00:00:00 2001 From: Adrian Orenstein Date: Sun, 13 Jul 2025 15:51:47 -0600 Subject: [PATCH 5/7] Using python's default compressor. Created to_bytestream. Created a to_bytestream speed test. --- .../unittest/linux/scripts/environment.yml | 2 +- test/test_rb.py | 143 ++++++++++ torchrl/data/replay_buffers/checkpointers.py | 265 +++++++++++------- torchrl/data/replay_buffers/storages.py | 94 +++++-- 4 files changed, 381 insertions(+), 123 deletions(-) diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index b7ca29ff0d4..d1debbdd510 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -35,4 +35,4 @@ dependencies: - transformers - ninja - timm - - zstandard + - safetensors diff --git a/test/test_rb.py b/test/test_rb.py index bf6c2cc9fbd..ad6d0411923 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -4210,6 +4210,18 @@ def test_compressed_storage_checkpointing(self): ) storage.set(0, test_td) + # second batch, different shape + test_td2 = TensorDict( + { + "obs": torch.randn(3, 85, 83, dtype=torch.float32), + "action": torch.tensor([1, 2, 3]), + "meta": torch.randn(3), + "astring": "a string!", + }, + batch_size=[3], + ) + storage.set(1, test_td) + # Create temporary directory for checkpointing with tempfile.TemporaryDirectory() as tmpdir: checkpoint_path = Path(tmpdir) / "checkpoint" @@ -4331,6 +4343,137 @@ def test_compressed_storage_memory_efficiency(self): compression_ratio > 1.5 ), f"Compression ratio {compression_ratio} is too low" + @staticmethod + def make_compressible_mock_data(num_experiences: int, device=None) -> dict: + """Easily compressible data for testing.""" + if device is None: + device = torch.device("cpu") + + return { + "observations": torch.zeros( + (num_experiences, 4, 84, 84), + dtype=torch.uint8, + device=device, + ), + "actions": torch.zeros((num_experiences,), device=device), + "rewards": torch.zeros((num_experiences,), device=device), + "next_observations": torch.zeros( + (num_experiences, 4, 84, 84), + dtype=torch.uint8, + device=device, + ), + "terminations": torch.zeros( + (num_experiences,), dtype=torch.bool, device=device + ), + "truncations": torch.zeros( + (num_experiences,), dtype=torch.bool, device=device + ), + "batch_size": [num_experiences], + } + + @staticmethod + def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict: + """Uncompressible data for testing.""" + if device is None: + device = torch.device("cpu") + return { + "observations": torch.randn( + (num_experiences, 4, 84, 84), + dtype=torch.float32, + device=device, + ), + "actions": torch.randint(0, 10, (num_experiences,), device=device), + "rewards": torch.randn( + (num_experiences,), dtype=torch.float32, device=device + ), + "next_observations": torch.randn( + (num_experiences, 4, 84, 84), + dtype=torch.float32, + device=device, + ), + "terminations": torch.rand((num_experiences,), device=device) + < 0.2, # ~20% True + "truncations": torch.rand((num_experiences,), device=device) + < 0.1, # ~10% True + "batch_size": [num_experiences], + } + + @pytest.mark.benchmark( + group="tensor_serialization_speed", + min_time=0.1, + max_time=0.5, + min_rounds=5, + disable_gc=True, + warmup=False, + ) + @pytest.mark.parametrize( + "serialization_method", + ["pickle", "torch.save", "untyped_storage", "numpy", "safetensors"], + ) + def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str): + """Benchmark the speed of different tensor serialization methods. + + TODO: we might need to also test which methods work on the gpu. + pytest test/test_rb.py::TestCompressedListStorage::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops' + + ------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests ------------------------- + Name (time in us) Mean (smaller is better) OPS (bigger is better) + -------------------------------------------------------------------------------------------------- + test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0) + test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16) + test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12) + test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07) + test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00) + -------------------------------------------------------------------------------------------------- + """ + import io + import pickle + + import torch + from safetensors.torch import save + + def serialize_with_pickle(data: torch.Tensor) -> bytes: + """Serialize tensor using pickle.""" + buffer = io.BytesIO() + pickle.dump(data, buffer) + return buffer.getvalue() + + def serialize_with_untyped_storage(data: torch.Tensor) -> bytes: + """Serialize tensor using torch's built-in method.""" + return bytes(data.untyped_storage()) + + def serialize_with_numpy(data: torch.Tensor) -> bytes: + """Serialize tensor using numpy.""" + return data.numpy().tobytes() + + def serialize_with_safetensors(data: torch.Tensor) -> bytes: + return save({"0": data}) + + def serialize_with_torch(data: torch.Tensor) -> bytes: + """Serialize tensor using torch's built-in method.""" + buffer = io.BytesIO() + torch.save(data, buffer) + return buffer.getvalue() + + # Benchmark each serialization method + if serialization_method == "pickle": + serialize_fn = serialize_with_pickle + elif serialization_method == "torch.save": + serialize_fn = serialize_with_torch + elif serialization_method == "untyped_storage": + serialize_fn = serialize_with_untyped_storage + elif serialization_method == "numpy": + serialize_fn = serialize_with_numpy + elif serialization_method == "safetensors": + serialize_fn = serialize_with_safetensors + else: + raise ValueError(f"Unknown serialization method: {serialization_method}") + + data = self.make_compressible_mock_data(1).get("observations") + + # Run the actual benchmark + benchmark(serialize_fn, data) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/replay_buffers/checkpointers.py b/torchrl/data/replay_buffers/checkpointers.py index 71c525d7963..267ab362e79 100644 --- a/torchrl/data/replay_buffers/checkpointers.py +++ b/torchrl/data/replay_buffers/checkpointers.py @@ -6,6 +6,7 @@ import abc import json +import tempfile import warnings from pathlib import Path @@ -13,11 +14,14 @@ import torch from tensordict import ( is_tensor_collection, + lazy_stack, NonTensorData, PersistentTensorDict, TensorDict, ) from tensordict.memmap import MemoryMappedTensor +from tensordict.utils import _zip_strict +from torch.utils._pytree import tree_map from torchrl._utils import _STRDTYPE2DTYPE from torchrl.data.replay_buffers.utils import ( @@ -71,11 +75,18 @@ def loads(storage, path): class CompressedListStorageCheckpointer(StorageCheckpointerBase): """A storage checkpointer for CompressedListStorage. - This checkpointer saves compressed data and metadata separately for efficient storage. + This checkpointer saves compressed data and metadata using memory-mapped storage + for efficient disk I/O and memory usage. """ def dumps(self, storage, path): + """Save compressed storage to disk using memory-mapped storage. + + Args: + storage: The CompressedListStorage instance to save + path: Directory path where to save the storage + """ path = Path(path) path.mkdir(exist_ok=True) @@ -84,120 +95,182 @@ def dumps(self, storage, path): "Cannot save an empty or non-initialized CompressedListStorage." ) - # Save compressed data and metadata + # Get state dict from storage state_dict = storage.state_dict() - - # Save compressed data and metadata compressed_data = state_dict["_storage"] metadata = state_dict["_metadata"] - # Save metadata - with open(path / "compressed_metadata.json", "w") as f: - json.dump(metadata, f, default=str) - - # Save compressed data - for i, (compressed_item, item_metadata) in enumerate( - zip(compressed_data, metadata) - ): - if compressed_item is not None: - if item_metadata["type"] == "tensor": - # Save as numpy array - np.save( - path / f"compressed_data_{i}.npy", compressed_item.cpu().numpy() - ) - elif item_metadata["type"] == "tensordict": - # Save each field separately - item_dir = path / f"compressed_data_{i}.td" - item_dir.mkdir(exist_ok=True) - - for key, value in compressed_item.items(): - if isinstance(value, torch.Tensor): - np.save(item_dir / f"{key}.npy", value.cpu().numpy()) - else: - # Save non-tensor data as pickle - import pickle - - with open(item_dir / f"{key}.pkl", "wb") as f: - pickle.dump(value, f) + # Create a temporary directory for processing + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Process compressed data for memmap storage + processed_data = [] + for item in compressed_data: + if item is None: + processed_data.append(None) + continue + + if isinstance(item, torch.Tensor): + # For tensor data, create a TensorDict with the tensor + processed_item = TensorDict({"data": item}, batch_size=[]) + elif isinstance(item, dict): + # For dict data (tensordict fields), convert to TensorDict + processed_item = TensorDict(item, batch_size=[]) else: - # Save other types as pickle - import pickle + # For other types, wrap in TensorDict + processed_item = TensorDict({"data": item}, batch_size=[]) + + processed_data.append(processed_item) + + # Stack all non-None items into a single TensorDict for memmap + non_none_data = [item for item in processed_data if item is not None] + if non_none_data: + # Use lazy_stack to handle heterogeneous structures + stacked_data = lazy_stack(non_none_data) + + # Save to memmap + stacked_data.memmap_(tmp_path / "compressed_data") + + # Create index mapping for None values + data_indices = [] + current_idx = 0 + for item in processed_data: + if item is None: + data_indices.append(None) + else: + data_indices.append(current_idx) + current_idx += 1 + else: + # No data to save + data_indices = [] + + # Process metadata for JSON serialization + def is_leaf(item): + return isinstance( + item, + ( + torch.Size, + torch.dtype, + torch.device, + str, + int, + float, + bool, + torch.Tensor, + NonTensorData, + ), + ) + + def map_to_json_serializable(item): + if isinstance(item, torch.Size): + return {"__type__": "torch.Size", "value": list(item)} + elif isinstance(item, torch.dtype): + return {"__type__": "torch.dtype", "value": str(item)} + elif isinstance(item, torch.device): + return {"__type__": "torch.device", "value": str(item)} + elif isinstance(item, torch.Tensor): + return {"__type__": "torch.Tensor", "value": item.tolist()} + elif isinstance(item, NonTensorData): + return {"__type__": "NonTensorData", "value": item.data} + return item + + serializable_metadata = tree_map( + map_to_json_serializable, metadata, is_leaf=is_leaf + ) + + # Save metadata and indices + metadata_file = tmp_path / "metadata.json" + with open(metadata_file, "w") as f: + json.dump(serializable_metadata, f, indent=2) + + indices_file = tmp_path / "data_indices.json" + with open(indices_file, "w") as f: + json.dump(data_indices, f, indent=2) - with open(path / f"compressed_data_{i}.pkl", "wb") as f: - pickle.dump(compressed_item, f) + # Copy all files from temp directory to final destination + import shutil + + for item in tmp_path.iterdir(): + if item.is_file(): + shutil.copy2(item, path / item.name) + elif item.is_dir(): + shutil.copytree(item, path / item.name, dirs_exist_ok=True) def loads(self, storage, path): + """Load compressed storage from disk. + + Args: + storage: The CompressedListStorage instance to load into + path: Directory path where the storage was saved + """ path = Path(path) # Load metadata - with open(path / "compressed_metadata.json") as f: - metadata = json.load(f) - - # Convert string dtypes back to torch.dtype objects - def convert_dtype(item): - if isinstance(item, dict): - if "dtype" in item and isinstance(item["dtype"], str): - # Convert string back to torch.dtype - dtype_str = item["dtype"] + metadata_file = path / "metadata.json" + if not metadata_file.exists(): + raise FileNotFoundError(f"Metadata file not found at {metadata_file}") + + with open(metadata_file) as f: + serializable_metadata = json.load(f) + + # Load data indices + indices_file = path / "data_indices.json" + if not indices_file.exists(): + raise FileNotFoundError(f"Data indices file not found at {indices_file}") + + with open(indices_file) as f: + data_indices = json.load(f) + + # Convert serializable metadata back to original format + def is_leaf(item): + return isinstance(item, dict) and "__type__" in item + + def map_from_json_serializable(item): + if isinstance(item, dict) and "__type__" in item: + if item["__type__"] == "torch.Size": + return torch.Size(item["value"]) + elif item["__type__"] == "torch.dtype": + # Handle torch.dtype conversion + dtype_str = item["value"] if hasattr(torch, dtype_str.replace("torch.", "")): - item["dtype"] = getattr(torch, dtype_str.replace("torch.", "")) + return getattr(torch, dtype_str.replace("torch.", "")) else: # Handle cases like 'torch.float32' -> torch.float32 - item["dtype"] = eval(dtype_str) - - # Recursively handle nested dictionaries - for _key, value in item.items(): - if isinstance(value, dict): - convert_dtype(value) + return eval(dtype_str) + elif item["__type__"] == "torch.device": + return torch.device(item["value"]) + elif item["__type__"] == "torch.Tensor": + return torch.tensor(item["value"]) + elif item["__type__"] == "NonTensorData": + return NonTensorData(item["value"]) return item - for item in metadata: - if item is not None: - convert_dtype(item) + metadata = tree_map( + map_from_json_serializable, serializable_metadata, is_leaf=is_leaf + ) - # Load compressed data + # Load compressed data from memmap compressed_data = [] - i = 0 - - # TODO(adrian): Can we not know the serialised format beforehand? Then we can use glob to iterate through the files we know exist: - # `for path in glob.glob(path / f"compressed_data_*.{fmt}" for fmt in ["npy", "pkl", "td"]):`` - while True: - if (path / f"compressed_data_{i}.npy").exists(): - # Load tensor data - data = np.load(path / f"compressed_data_{i}.npy") - compressed_data.append(torch.from_numpy(data)) - elif (path / f"compressed_data_{i}.pkl").exists(): - # Load other data - import pickle - - with open(path / f"compressed_data_{i}.pkl", "rb") as f: - data = pickle.load(f) - compressed_data.append(data) - elif (path / f"compressed_data_{i}.td").exists(): - # Load tensordict data - item_dir = path / f"compressed_data_{i}.td" - item_data = {} - - for key in metadata[i]["fields"].keys(): - if (item_dir / f"{key}.npy").exists(): - data = np.load(item_dir / f"{key}.npy") - item_data[key] = torch.from_numpy(data) - elif (item_dir / f"{key}.pkl").exists(): - import pickle - - with open(item_dir / f"{key}.pkl", "rb") as f: - data = pickle.load(f) - item_data[key] = data - - compressed_data.append(item_data) - else: - break - - i += 1 + memmap_path = path / "compressed_data" + + if memmap_path.exists(): + # Load the memmapped data + stacked_data = TensorDict.load_memmap(memmap_path) + compressed_data = stacked_data.tolist() + if len(compressed_data) != len(data_indices): + raise ValueError( + f"Length of compressed data ({len(compressed_data)}) does not match length of data indices ({len(data_indices)})" + ) + for i, (data, mtdt) in enumerate(_zip_strict(compressed_data, metadata)): + if mtdt["type"] == "tensor": + compressed_data[i] = data["data"] + else: + compressed_data[i] = data - # Pad with None to match metadata length - while len(compressed_data) < len(metadata): - compressed_data.append(None) + else: + # No data to load + compressed_data = [None] * len(data_indices) # Load into storage storage._storage = compressed_data diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 6e0a904460d..def5c64bd3e 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -7,6 +7,7 @@ import abc import logging import os +import sys import textwrap import warnings from collections import OrderedDict @@ -303,10 +304,10 @@ def set( ) return else: - if cursor > len(self): + if cursor > len(self._storage): raise RuntimeError( "Cannot append data located more than one item away from " - f"the storage size: the storage size is {len(self)} " + f"the storage size: the storage size is {len(self._storage)} " f"and the index of the item to be set is {cursor}." ) if cursor >= self.max_size: @@ -402,7 +403,7 @@ def __repr__(self): def contains(self, item): if isinstance(item, int): if item < 0: - item += len(self) + item += len(self._storage) return self._contains_int(item) if isinstance(item, torch.Tensor): return torch.tensor( @@ -414,7 +415,7 @@ def contains(self, item): def _contains_int(self, item: int) -> bool: """Check if an integer index is contained in the storage.""" - return 0 <= item < len(self) + return 0 <= item < len(self._storage) class LazyStackStorage(ListStorage): @@ -1462,42 +1463,44 @@ def __init__( def _default_compression_fn(self, tensor: torch.Tensor) -> torch.Tensor: """Default compression using zstd.""" - try: - import zstandard as zstd - except ImportError: - raise ImportError( - "zstandard is required for default compression. " - "Install with: pip install zstandard" - ) + if sys.version_info >= (3, 14): + from compression import zstd + + compressor_fn = zstd.compress + + else: + import zlib + + compressor_fn = zlib.compress # Convert tensor to bytes - tensor_bytes = tensor.cpu().numpy().tobytes() + tensor_bytes = self.to_bytestream(tensor) # Compress with zstd - compressor = zstd.ZstdCompressor(level=self.compression_level) - compressed_bytes = compressor.compress(tensor_bytes) + compressed_bytes = compressor_fn(tensor_bytes, level=self.compression_level) # Convert to tensor - return torch.tensor(list(compressed_bytes), dtype=torch.uint8) + return torch.frombuffer(bytearray(compressed_bytes), dtype=torch.uint8) def _default_decompression_fn( self, compressed_tensor: torch.Tensor, metadata: dict ) -> torch.Tensor: """Default decompression using zstd.""" - try: - import zstandard as zstd - except ImportError: - raise ImportError( - "zstandard is required for default decompression. " - "Install with: pip install zstandard" - ) + if sys.version_info >= (3, 14): + from compression import zstd + + decompressor_fn = zstd.decompress + + else: + import zlib + + decompressor_fn = zlib.decompress # Convert tensor to bytes - compressed_bytes = bytes(compressed_tensor.cpu().numpy()) + compressed_bytes = self.to_bytestream(compressed_tensor.cpu()) # Decompress with zstd - decompressor = zstd.ZstdDecompressor() - decompressed_bytes = decompressor.decompress(compressed_bytes) + decompressed_bytes = decompressor_fn(compressed_bytes) # Convert back to tensor tensor = torch.frombuffer( @@ -1645,7 +1648,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._storage = state_dict["_storage"] self._metadata = state_dict["_metadata"] - def bytes(self): + def num_bytes(self): """Return the number of bytes in the storage.""" def compressed_size_from_list(data: Any) -> int: @@ -1670,6 +1673,45 @@ def compressed_size_from_list(data: Any) -> int: return compressed_size_estimate + def to_bytestream(self, data_to_bytestream: torch.Tensor | np.array | Any) -> bytes: + """Convert data to a byte stream.""" + if isinstance(data_to_bytestream, torch.Tensor): + byte_stream = data_to_bytestream.cpu().numpy().tobytes() + + elif isinstance(data_to_bytestream, np.array): + byte_stream = bytes(data_to_bytestream.tobytes()) + + else: + import io + import pickle + + buffer = io.BytesIO() + pickle.dump(data_to_bytestream, buffer) + buffer.seek(0) + byte_stream = bytes(buffer.read()) + + return byte_stream + + # def to_bytestream( + # self, data_to_bytestream: Union[torch.Tensor, np.array, Any] + # ) -> bytes: + # """Convert data to a byte stream.""" + # if isinstance(data_to_bytestream, torch.Tensor): + # from safetensors.torch import save + # byte_stream = save({"0": data_to_bytestream}) + + # elif isinstance(data_to_bytestream, np.array): + # from safetensors.numpy import save + # byte_stream = bytes(data_to_bytestream.tobytes()) + + # else: + # buffer = io.BytesIO() + # pickle.dump(data_to_bytestream, buffer) + # buffer.seek(0) + # byte_stream = bytes(buffer.read()) + + # return byte_stream + class StorageEnsemble(Storage): """An ensemble of storages. From eade378b9f2c8da40fc61d3a5af0bc84fad0ebb7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 16 Jul 2025 23:47:11 +0100 Subject: [PATCH 6/7] add-examples --- .../compressed_replay_buffer.py | 143 ++++++++++++++++ .../compressed_replay_buffer_checkpoint.py | 156 ++++++++++++++++++ 2 files changed, 299 insertions(+) create mode 100644 examples/replay-buffers/compressed_replay_buffer.py create mode 100644 examples/replay-buffers/compressed_replay_buffer_checkpoint.py diff --git a/examples/replay-buffers/compressed_replay_buffer.py b/examples/replay-buffers/compressed_replay_buffer.py new file mode 100644 index 00000000000..3e68e76468a --- /dev/null +++ b/examples/replay-buffers/compressed_replay_buffer.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +""" +Example demonstrating the use of CompressedStorage for memory-efficient replay buffers. + +This example shows how to use the new CompressedStorage to store image observations +with significant memory savings through compression. +""" + +import time + +import torch +from tensordict import TensorDict +from torchrl.data import CompressedStorage, ReplayBuffer + + +def main(): + print("=== Compressed Replay Buffer Example ===\n") + + # Create a compressed storage with zstd compression + print("Creating compressed storage...") + storage = CompressedStorage( + max_size=1000, + compression_level=3, # zstd compression level (1-22) + device="cpu", + ) + + # Create replay buffer with compressed storage + rb = ReplayBuffer(storage=storage, batch_size=32) + + # Simulate Atari-like image data (84x84 RGB frames) + print("Generating sample image data...") + num_frames = 100 + image_data = torch.zeros(num_frames, 3, 84, 84, dtype=torch.float32) + image_data.copy_( + torch.arange(num_frames * 3 * 84 * 84).reshape(num_frames, 3, 84, 84) + // (3 * 84 * 84) + ) + + # Create TensorDict with image observations + data = TensorDict( + { + "obs": image_data, + "action": torch.randint(0, 4, (num_frames,)), # 4 possible actions + "reward": torch.randn(num_frames), + "done": torch.randint(0, 2, (num_frames,), dtype=torch.bool), + }, + batch_size=[num_frames], + ) + + # Measure memory usage before adding data + print(f"Original data size: {data.bytes() / 1024 / 1024:.2f} MB") + + # Add data to replay buffer + print("Adding data to replay buffer...") + start_time = time.time() + rb.extend(data) + add_time = time.time() - start_time + print(f"Time to add data: {add_time:.3f} seconds") + + # Sample from replay buffer + print("Sampling from replay buffer...") + start_time = time.time() + sample = rb.sample(32) + sample_time = time.time() - start_time + print(f"Time to sample: {sample_time:.3f} seconds") + + # Verify data integrity + print("\nVerifying data integrity...") + original_shape = image_data.shape + sampled_shape = sample["obs"].shape + print(f"Original shape: {original_shape}") + print(f"Sampled shape: {sampled_shape}") + + # Check that shapes match (accounting for batch size) + assert sampled_shape[1:] == original_shape[1:], "Shape mismatch!" + print("āœ“ Data integrity verified!") + + # Demonstrate compression ratio + print("\n=== Compression Analysis ===") + + # Estimate compressed size (this is approximate) + compressed_size_estimate = storage.bytes() + + original_size = data.bytes() + compression_ratio = ( + original_size / compressed_size_estimate if compressed_size_estimate > 0 else 1 + ) + + print(f"Original size: {original_size / 1024 / 1024:.2f} MB") + print( + f"Compressed size (estimate): {compressed_size_estimate / 1024 / 1024:.2f} MB" + ) + print(f"Compression ratio: {compression_ratio:.1f}x") + + # Test with different compression levels + print("\n=== Testing Different Compression Levels ===") + + for level in [1, 3, 6, 9]: + print(f"\nTesting compression level {level}...") + + # Create new storage with different compression level + test_storage = CompressedStorage( + max_size=100, compression_level=level, device="cpu" + ) + + # Test with a smaller dataset + N = 100 + obs = torch.zeros(N, 3, 84, 84, dtype=torch.float32) + obs.copy_(torch.arange(N * 3 * 84 * 84).reshape(N, 3, 84, 84) // (3 * 84 * 84)) + test_data = TensorDict( + { + "obs": obs, + }, + batch_size=[N], + ) + + test_rb = ReplayBuffer(storage=test_storage, batch_size=5) + + # Measure compression time + start_time = time.time() + test_rb.extend(test_data) + compress_time = time.time() - start_time + + # Measure decompression time + start_time = time.time() + test_rb.sample(5) + decompress_time = time.time() - start_time + + print(f" Compression time: {compress_time:.3f}s") + print(f" Decompression time: {decompress_time:.3f}s") + + # Estimate compression ratio + test_ratio = test_data.bytes() / test_storage.bytes() + print(f" Compression ratio: {test_ratio:.1f}x") + + print("\n=== Example Complete ===") + print( + "The CompressedStorage successfully reduces memory usage while maintaining data integrity!" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/replay-buffers/compressed_replay_buffer_checkpoint.py b/examples/replay-buffers/compressed_replay_buffer_checkpoint.py new file mode 100644 index 00000000000..9242c2ad4cd --- /dev/null +++ b/examples/replay-buffers/compressed_replay_buffer_checkpoint.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +""" +Example demonstrating the improved CompressedListStorage with memmap functionality. + +This example shows how to use the new checkpointing capabilities that leverage +memory-mapped storage for efficient disk I/O and memory usage. +""" + +import tempfile +from pathlib import Path + +import torch +from tensordict import TensorDict + +from torchrl.data import CompressedListStorage, ReplayBuffer + + +def main(): + """Demonstrate compressed storage with memmap checkpointing.""" + + # Create a compressed storage with high compression level + storage = CompressedListStorage(max_size=1000, compression_level=6) + + # Create some sample data with different shapes and types + print("Creating sample data...") + + # Add tensor data + tensor_data = torch.randn(10, 3, 84, 84, dtype=torch.float32) # Image-like data + storage.set(0, tensor_data) + + # Add TensorDict data with mixed content + td_data = TensorDict( + { + "obs": torch.randn(5, 4, 84, 84, dtype=torch.float32), + "action": torch.randint(0, 18, (5,), dtype=torch.long), + "reward": torch.randn(5, dtype=torch.float32), + "done": torch.zeros(5, dtype=torch.bool), + "meta": "some metadata string", + }, + batch_size=[5], + ) + storage.set(1, td_data) + + # Add another tensor with different shape + tensor_data2 = torch.randn(8, 64, dtype=torch.float32) + storage.set(2, tensor_data2) + + print(f"Storage length: {len(storage)}") + print(f"Storage contains index 0: {storage.contains(0)}") + print(f"Storage contains index 3: {storage.contains(3)}") + + # Demonstrate data retrieval + print("\nRetrieving data...") + retrieved_tensor = storage.get(0) + retrieved_td = storage.get(1) + retrieved_tensor2 = storage.get(2) + + print(f"Retrieved tensor shape: {retrieved_tensor.shape}") + print(f"Retrieved TensorDict keys: {list(retrieved_td.keys())}") + print(f"Retrieved tensor2 shape: {retrieved_tensor2.shape}") + + # Verify data integrity + assert torch.allclose(tensor_data, retrieved_tensor, atol=1e-6) + assert torch.allclose(td_data["obs"], retrieved_td["obs"], atol=1e-6) + assert torch.allclose(tensor_data2, retrieved_tensor2, atol=1e-6) + print("āœ“ Data integrity verified!") + + # Demonstrate memmap checkpointing + print("\nDemonstrating memmap checkpointing...") + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "compressed_storage_checkpoint" + + # Save to disk using memmap + print(f"Saving to {checkpoint_path}...") + storage.dumps(checkpoint_path) + + # Check what files were created + print("Files created:") + for file_path in checkpoint_path.rglob("*"): + if file_path.is_file(): + size_mb = file_path.stat().st_size / (1024 * 1024) + print(f" {file_path.relative_to(checkpoint_path)}: {size_mb:.2f} MB") + + # Create new storage and load from checkpoint + print("\nLoading from checkpoint...") + new_storage = CompressedListStorage(max_size=1000, compression_level=6) + new_storage.loads(checkpoint_path) + + # Verify data integrity after checkpointing + print("Verifying data integrity after checkpointing...") + new_retrieved_tensor = new_storage.get(0) + new_retrieved_td = new_storage.get(1) + new_retrieved_tensor2 = new_storage.get(2) + + assert torch.allclose(tensor_data, new_retrieved_tensor, atol=1e-6) + assert torch.allclose(td_data["obs"], new_retrieved_td["obs"], atol=1e-6) + assert torch.allclose(tensor_data2, new_retrieved_tensor2, atol=1e-6) + print("āœ“ Data integrity after checkpointing verified!") + + print(f"New storage length: {len(new_storage)}") + + # Demonstrate with ReplayBuffer + print("\nDemonstrating with ReplayBuffer...") + + rb = ReplayBuffer(storage=CompressedListStorage(max_size=100, compression_level=4)) + + # Add some data to the replay buffer + for _ in range(5): + data = TensorDict( + { + "obs": torch.randn(3, 84, 84, dtype=torch.float32), + "action": torch.randint(0, 18, (3,), dtype=torch.long), + "reward": torch.randn(3, dtype=torch.float32), + }, + batch_size=[3], + ) + rb.extend(data) + + print(f"ReplayBuffer size: {len(rb)}") + + # Sample from the buffer + sample = rb.sample(2) + print(f"Sampled data keys: {list(sample.keys())}") + print(f"Sampled obs shape: {sample['obs'].shape}") + + # Checkpoint the replay buffer + with tempfile.TemporaryDirectory() as tmpdir: + rb_checkpoint_path = Path(tmpdir) / "rb_checkpoint" + print(f"\nCheckpointing ReplayBuffer to {rb_checkpoint_path}...") + rb.dumps(rb_checkpoint_path) + + # Create new replay buffer and load + new_rb = ReplayBuffer( + storage=CompressedListStorage(max_size=100, compression_level=4) + ) + new_rb.loads(rb_checkpoint_path) + + print(f"New ReplayBuffer size: {len(new_rb)}") + + # Verify sampling works + new_sample = new_rb.sample(2) + assert new_sample["obs"].shape == sample["obs"].shape + print("āœ“ ReplayBuffer checkpointing verified!") + + print("\nšŸŽ‰ All demonstrations completed successfully!") + print("\nKey benefits of the new memmap implementation:") + print("1. Efficient disk I/O using memory-mapped storage") + print("2. Reduced memory usage for large datasets") + print("3. Fast loading and saving of compressed data") + print("4. Support for heterogeneous data structures") + print("5. Seamless integration with ReplayBuffer") + + +if __name__ == "__main__": + main() From 783e3ee088abc7e9c8a38650832ef4465f7c8ce9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 25 Jul 2025 10:59:38 +0100 Subject: [PATCH 7/7] move benchmarks to dedicated workflow --- .github/workflows/benchmarks.yml | 1 + .github/workflows/benchmarks_pr.yml | 1 + benchmarks/requirements.txt | 5 + .../test_compressed_storage_benchmark.py | 145 ++++++++++++++++++ test/test_rb.py | 131 ---------------- 5 files changed, 152 insertions(+), 131 deletions(-) create mode 100644 benchmarks/test_compressed_storage_benchmark.py diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 489fbc137f2..30659f703f4 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -80,6 +80,7 @@ jobs: python3.10 -m pip install ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]" python3 -m pip install "pybind11[global]" python3.10 -m pip install git+https://github.com/pytorch/tensordict + python3.10 -m pip install safetensors tqdm pandas numpy matplotlib python3.10 setup.py develop # test import diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index 6e4b3a57073..6e3a401a739 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -82,6 +82,7 @@ jobs: python3.10 -m pip install ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]" python3.10 -m pip install "pybind11[global]" python3.10 -m pip install git+https://github.com/pytorch/tensordict + python3.10 -m pip install safetensors tqdm pandas numpy matplotlib python3.10 setup.py develop # python3.10 -m pip install git+https://github.com/pytorch/rl@$GITHUB_BRANCH diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt index 30399e30270..8f9fe1cff3b 100644 --- a/benchmarks/requirements.txt +++ b/benchmarks/requirements.txt @@ -1,2 +1,7 @@ pytest-benchmark tenacity +safetensors +tqdm +pandas +numpy +matplotlib diff --git a/benchmarks/test_compressed_storage_benchmark.py b/benchmarks/test_compressed_storage_benchmark.py new file mode 100644 index 00000000000..722bac86b2a --- /dev/null +++ b/benchmarks/test_compressed_storage_benchmark.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import io +import pickle + +import pytest +import torch +try: + from safetensors.torch import save +except ImportError: + save = None + +from torchrl.data import CompressedListStorage + + +class TestCompressedStorageBenchmark: + """Benchmark tests for CompressedListStorage.""" + + @staticmethod + def make_compressible_mock_data(num_experiences: int, device=None) -> dict: + """Easily compressible data for testing.""" + if device is None: + device = torch.device("cpu") + + return { + "observations": torch.zeros( + (num_experiences, 4, 84, 84), + dtype=torch.uint8, + device=device, + ), + "actions": torch.zeros((num_experiences,), device=device), + "rewards": torch.zeros((num_experiences,), device=device), + "next_observations": torch.zeros( + (num_experiences, 4, 84, 84), + dtype=torch.uint8, + device=device, + ), + "terminations": torch.zeros( + (num_experiences,), dtype=torch.bool, device=device + ), + "truncations": torch.zeros( + (num_experiences,), dtype=torch.bool, device=device + ), + "batch_size": [num_experiences], + } + + @staticmethod + def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict: + """Uncompressible data for testing.""" + if device is None: + device = torch.device("cpu") + return { + "observations": torch.randn( + (num_experiences, 4, 84, 84), + dtype=torch.float32, + device=device, + ), + "actions": torch.randint(0, 10, (num_experiences,), device=device), + "rewards": torch.randn( + (num_experiences,), dtype=torch.float32, device=device + ), + "next_observations": torch.randn( + (num_experiences, 4, 84, 84), + dtype=torch.float32, + device=device, + ), + "terminations": torch.rand((num_experiences,), device=device) + < 0.2, # ~20% True + "truncations": torch.rand((num_experiences,), device=device) + < 0.1, # ~10% True + "batch_size": [num_experiences], + } + + @pytest.mark.benchmark( + group="tensor_serialization_speed", + min_time=0.1, + max_time=0.5, + min_rounds=5, + disable_gc=True, + warmup=False, + ) + @pytest.mark.parametrize( + "serialization_method", + ["pickle", "torch.save", "untyped_storage", "numpy", "safetensors"], + ) + def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str): + """Benchmark the speed of different tensor serialization methods. + + TODO: we might need to also test which methods work on the gpu. + pytest benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops' + + ------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests ------------------------- + Name (time in us) Mean (smaller is better) OPS (bigger is better) + -------------------------------------------------------------------------------------------------- + test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0) + test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16) + test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12) + test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07) + test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00) + -------------------------------------------------------------------------------------------------- + """ + + def serialize_with_pickle(data: torch.Tensor) -> bytes: + """Serialize tensor using pickle.""" + buffer = io.BytesIO() + pickle.dump(data, buffer) + return buffer.getvalue() + + def serialize_with_untyped_storage(data: torch.Tensor) -> bytes: + """Serialize tensor using torch's built-in method.""" + return bytes(data.untyped_storage()) + + def serialize_with_numpy(data: torch.Tensor) -> bytes: + """Serialize tensor using numpy.""" + return data.numpy().tobytes() + + def serialize_with_safetensors(data: torch.Tensor) -> bytes: + return save({"0": data}) + + def serialize_with_torch(data: torch.Tensor) -> bytes: + """Serialize tensor using torch's built-in method.""" + buffer = io.BytesIO() + torch.save(data, buffer) + return buffer.getvalue() + + # Benchmark each serialization method + if serialization_method == "pickle": + serialize_fn = serialize_with_pickle + elif serialization_method == "torch.save": + serialize_fn = serialize_with_torch + elif serialization_method == "untyped_storage": + serialize_fn = serialize_with_untyped_storage + elif serialization_method == "numpy": + serialize_fn = serialize_with_numpy + elif serialization_method == "safetensors": + serialize_fn = serialize_with_safetensors + else: + raise ValueError(f"Unknown serialization method: {serialization_method}") + + data = self.make_compressible_mock_data(1).get("observations") + + # Run the actual benchmark + benchmark(serialize_fn, data) \ No newline at end of file diff --git a/test/test_rb.py b/test/test_rb.py index ad6d0411923..07ceb82507e 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -4343,137 +4343,6 @@ def test_compressed_storage_memory_efficiency(self): compression_ratio > 1.5 ), f"Compression ratio {compression_ratio} is too low" - @staticmethod - def make_compressible_mock_data(num_experiences: int, device=None) -> dict: - """Easily compressible data for testing.""" - if device is None: - device = torch.device("cpu") - - return { - "observations": torch.zeros( - (num_experiences, 4, 84, 84), - dtype=torch.uint8, - device=device, - ), - "actions": torch.zeros((num_experiences,), device=device), - "rewards": torch.zeros((num_experiences,), device=device), - "next_observations": torch.zeros( - (num_experiences, 4, 84, 84), - dtype=torch.uint8, - device=device, - ), - "terminations": torch.zeros( - (num_experiences,), dtype=torch.bool, device=device - ), - "truncations": torch.zeros( - (num_experiences,), dtype=torch.bool, device=device - ), - "batch_size": [num_experiences], - } - - @staticmethod - def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict: - """Uncompressible data for testing.""" - if device is None: - device = torch.device("cpu") - return { - "observations": torch.randn( - (num_experiences, 4, 84, 84), - dtype=torch.float32, - device=device, - ), - "actions": torch.randint(0, 10, (num_experiences,), device=device), - "rewards": torch.randn( - (num_experiences,), dtype=torch.float32, device=device - ), - "next_observations": torch.randn( - (num_experiences, 4, 84, 84), - dtype=torch.float32, - device=device, - ), - "terminations": torch.rand((num_experiences,), device=device) - < 0.2, # ~20% True - "truncations": torch.rand((num_experiences,), device=device) - < 0.1, # ~10% True - "batch_size": [num_experiences], - } - - @pytest.mark.benchmark( - group="tensor_serialization_speed", - min_time=0.1, - max_time=0.5, - min_rounds=5, - disable_gc=True, - warmup=False, - ) - @pytest.mark.parametrize( - "serialization_method", - ["pickle", "torch.save", "untyped_storage", "numpy", "safetensors"], - ) - def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str): - """Benchmark the speed of different tensor serialization methods. - - TODO: we might need to also test which methods work on the gpu. - pytest test/test_rb.py::TestCompressedListStorage::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops' - - ------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests ------------------------- - Name (time in us) Mean (smaller is better) OPS (bigger is better) - -------------------------------------------------------------------------------------------------- - test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0) - test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16) - test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12) - test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07) - test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00) - -------------------------------------------------------------------------------------------------- - """ - import io - import pickle - - import torch - from safetensors.torch import save - - def serialize_with_pickle(data: torch.Tensor) -> bytes: - """Serialize tensor using pickle.""" - buffer = io.BytesIO() - pickle.dump(data, buffer) - return buffer.getvalue() - - def serialize_with_untyped_storage(data: torch.Tensor) -> bytes: - """Serialize tensor using torch's built-in method.""" - return bytes(data.untyped_storage()) - - def serialize_with_numpy(data: torch.Tensor) -> bytes: - """Serialize tensor using numpy.""" - return data.numpy().tobytes() - - def serialize_with_safetensors(data: torch.Tensor) -> bytes: - return save({"0": data}) - - def serialize_with_torch(data: torch.Tensor) -> bytes: - """Serialize tensor using torch's built-in method.""" - buffer = io.BytesIO() - torch.save(data, buffer) - return buffer.getvalue() - - # Benchmark each serialization method - if serialization_method == "pickle": - serialize_fn = serialize_with_pickle - elif serialization_method == "torch.save": - serialize_fn = serialize_with_torch - elif serialization_method == "untyped_storage": - serialize_fn = serialize_with_untyped_storage - elif serialization_method == "numpy": - serialize_fn = serialize_with_numpy - elif serialization_method == "safetensors": - serialize_fn = serialize_with_safetensors - else: - raise ValueError(f"Unknown serialization method: {serialization_method}") - - data = self.make_compressible_mock_data(1).get("observations") - - # Run the actual benchmark - benchmark(serialize_fn, data) - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()