Skip to content

[Feature] Compressed storage gpu #3062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ dependencies:
- transformers
- ninja
- timm
- safetensors
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
- libcst == 0.4.7

- repo: https://github.com/pycqa/flake8
rev: 4.0.1
rev: 6.0.0
hooks:
- id: flake8
args: [--config=setup.cfg]
Expand Down
62 changes: 62 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ using the following components:
:template: rl_template.rst


CompressedListStorage
CompressedListStorageCheckpointer
FlatStorageCheckpointer
H5StorageCheckpointer
ImmutableDatasetWriter
Expand Down Expand Up @@ -191,6 +193,66 @@ were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/be
| :class:`LazyMemmapStorage` | 3.44x |
+-------------------------------+-----------+

Compressed Storage for Memory Efficiency
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

For applications where memory usage or memory bandwidth is a primary concern, especially when storing or transferring
large sensory observations like images, audio, or text. The :class:`~torchrl.data.replay_buffers.storages.CompressedListStorage`
provides significant memory savings through compression.

The `CompressedListStorage`` compresses data when storing and decompresses when retrieving,
achieving compression ratios of 2-10x for image data while maintaining full data fidelity.
It uses zstd compression by default but supports custom compression algorithms.

Key features:
- **Memory Efficiency**: Achieves significant memory savings through compression
- **Data Integrity**: Maintains full data fidelity through lossless compression
- **Flexible Compression**: Supports custom compression algorithms or uses zstd by default
- **TensorDict Support**: Seamlessly works with TensorDict structures
- **Checkpointing**: Full support for saving and loading compressed data

Example usage:

>>> import torch
>>> from torchrl.data import ReplayBuffer, CompressedListStorage
>>> from tensordict import TensorDict
>>>
>>> # Create a compressed storage for image data
>>> storage = CompressedListStorage(max_size=1000, compression_level=3)
>>> rb = ReplayBuffer(storage=storage, batch_size=32)
>>>
>>> # Add image data
>>> images = torch.randn(100, 3, 84, 84) # Atari-like frames
>>> data = TensorDict({"obs": images}, batch_size=[100])
>>> rb.extend(data)
>>>
>>> # Sample data (automatically decompressed)
>>> sample = rb.sample(16)
>>> print(sample["obs"].shape) # torch.Size([16, 3, 84, 84])

The compression level can be adjusted from 1 (fast, less compression) to 22 (slow, more compression),
with level 3 being a good default for most use cases.

For custom compression algorithms:

>>> def my_compress(tensor):
... return tensor.to(torch.uint8) # Simple example
>>>
>>> def my_decompress(compressed_tensor, metadata):
... return compressed_tensor.to(metadata["dtype"])
>>>
>>> storage = CompressedListStorage(
... max_size=1000,
... compression_fn=my_compress,
... decompression_fn=my_decompress
... )

.. note:: The CompressedListStorage requires the `zstandard` library for default compression.
Install with: ``pip install zstandard``

.. note:: An example of how to use the CompressedListStorage is available in the
`examples/replay-buffers/compressed_replay_buffer_example.py <https://github.com/pytorch/rl/blob/main/examples/replay-buffers/compressed_replay_buffer_example.py>`_ file.

Sharing replay buffers across processes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
143 changes: 143 additions & 0 deletions examples/replay-buffers/compressed_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env python3
"""
Example demonstrating the use of CompressedStorage for memory-efficient replay buffers.

This example shows how to use the new CompressedStorage to store image observations
with significant memory savings through compression.
"""

import time

import torch
from tensordict import TensorDict
from torchrl.data import CompressedStorage, ReplayBuffer


def main():
print("=== Compressed Replay Buffer Example ===\n")

# Create a compressed storage with zstd compression
print("Creating compressed storage...")
storage = CompressedStorage(
max_size=1000,
compression_level=3, # zstd compression level (1-22)
device="cpu",
)

# Create replay buffer with compressed storage
rb = ReplayBuffer(storage=storage, batch_size=32)

# Simulate Atari-like image data (84x84 RGB frames)
print("Generating sample image data...")
num_frames = 100
image_data = torch.zeros(num_frames, 3, 84, 84, dtype=torch.float32)
image_data.copy_(
torch.arange(num_frames * 3 * 84 * 84).reshape(num_frames, 3, 84, 84)
// (3 * 84 * 84)
)

# Create TensorDict with image observations
data = TensorDict(
{
"obs": image_data,
"action": torch.randint(0, 4, (num_frames,)), # 4 possible actions
"reward": torch.randn(num_frames),
"done": torch.randint(0, 2, (num_frames,), dtype=torch.bool),
},
batch_size=[num_frames],
)

# Measure memory usage before adding data
print(f"Original data size: {data.bytes() / 1024 / 1024:.2f} MB")

# Add data to replay buffer
print("Adding data to replay buffer...")
start_time = time.time()
rb.extend(data)
add_time = time.time() - start_time
print(f"Time to add data: {add_time:.3f} seconds")

# Sample from replay buffer
print("Sampling from replay buffer...")
start_time = time.time()
sample = rb.sample(32)
sample_time = time.time() - start_time
print(f"Time to sample: {sample_time:.3f} seconds")

# Verify data integrity
print("\nVerifying data integrity...")
original_shape = image_data.shape
sampled_shape = sample["obs"].shape
print(f"Original shape: {original_shape}")
print(f"Sampled shape: {sampled_shape}")

# Check that shapes match (accounting for batch size)
assert sampled_shape[1:] == original_shape[1:], "Shape mismatch!"
print("✓ Data integrity verified!")

# Demonstrate compression ratio
print("\n=== Compression Analysis ===")

# Estimate compressed size (this is approximate)
compressed_size_estimate = storage.bytes()

original_size = data.bytes()
compression_ratio = (
original_size / compressed_size_estimate if compressed_size_estimate > 0 else 1
)

print(f"Original size: {original_size / 1024 / 1024:.2f} MB")
print(
f"Compressed size (estimate): {compressed_size_estimate / 1024 / 1024:.2f} MB"
)
print(f"Compression ratio: {compression_ratio:.1f}x")

# Test with different compression levels
print("\n=== Testing Different Compression Levels ===")

for level in [1, 3, 6, 9]:
print(f"\nTesting compression level {level}...")

# Create new storage with different compression level
test_storage = CompressedStorage(
max_size=100, compression_level=level, device="cpu"
)

# Test with a smaller dataset
N = 100
obs = torch.zeros(N, 3, 84, 84, dtype=torch.float32)
obs.copy_(torch.arange(N * 3 * 84 * 84).reshape(N, 3, 84, 84) // (3 * 84 * 84))
test_data = TensorDict(
{
"obs": obs,
},
batch_size=[N],
)

test_rb = ReplayBuffer(storage=test_storage, batch_size=5)

# Measure compression time
start_time = time.time()
test_rb.extend(test_data)
compress_time = time.time() - start_time

# Measure decompression time
start_time = time.time()
test_rb.sample(5)
decompress_time = time.time() - start_time

print(f" Compression time: {compress_time:.3f}s")
print(f" Decompression time: {decompress_time:.3f}s")

# Estimate compression ratio
test_ratio = test_data.bytes() / test_storage.bytes()
print(f" Compression ratio: {test_ratio:.1f}x")

print("\n=== Example Complete ===")
print(
"The CompressedStorage successfully reduces memory usage while maintaining data integrity!"
)


if __name__ == "__main__":
main()
156 changes: 156 additions & 0 deletions examples/replay-buffers/compressed_replay_buffer_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#!/usr/bin/env python3
"""
Example demonstrating the improved CompressedListStorage with memmap functionality.

This example shows how to use the new checkpointing capabilities that leverage
memory-mapped storage for efficient disk I/O and memory usage.
"""

import tempfile
from pathlib import Path

import torch
from tensordict import TensorDict

from torchrl.data import CompressedListStorage, ReplayBuffer


def main():
"""Demonstrate compressed storage with memmap checkpointing."""

# Create a compressed storage with high compression level
storage = CompressedListStorage(max_size=1000, compression_level=6)

# Create some sample data with different shapes and types
print("Creating sample data...")

# Add tensor data
tensor_data = torch.randn(10, 3, 84, 84, dtype=torch.float32) # Image-like data
storage.set(0, tensor_data)

# Add TensorDict data with mixed content
td_data = TensorDict(
{
"obs": torch.randn(5, 4, 84, 84, dtype=torch.float32),
"action": torch.randint(0, 18, (5,), dtype=torch.long),
"reward": torch.randn(5, dtype=torch.float32),
"done": torch.zeros(5, dtype=torch.bool),
"meta": "some metadata string",
},
batch_size=[5],
)
storage.set(1, td_data)

# Add another tensor with different shape
tensor_data2 = torch.randn(8, 64, dtype=torch.float32)
storage.set(2, tensor_data2)

print(f"Storage length: {len(storage)}")
print(f"Storage contains index 0: {storage.contains(0)}")
print(f"Storage contains index 3: {storage.contains(3)}")

# Demonstrate data retrieval
print("\nRetrieving data...")
retrieved_tensor = storage.get(0)
retrieved_td = storage.get(1)
retrieved_tensor2 = storage.get(2)

print(f"Retrieved tensor shape: {retrieved_tensor.shape}")
print(f"Retrieved TensorDict keys: {list(retrieved_td.keys())}")
print(f"Retrieved tensor2 shape: {retrieved_tensor2.shape}")

# Verify data integrity
assert torch.allclose(tensor_data, retrieved_tensor, atol=1e-6)
assert torch.allclose(td_data["obs"], retrieved_td["obs"], atol=1e-6)
assert torch.allclose(tensor_data2, retrieved_tensor2, atol=1e-6)
print("✓ Data integrity verified!")

# Demonstrate memmap checkpointing
print("\nDemonstrating memmap checkpointing...")

with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_path = Path(tmpdir) / "compressed_storage_checkpoint"

# Save to disk using memmap
print(f"Saving to {checkpoint_path}...")
storage.dumps(checkpoint_path)

# Check what files were created
print("Files created:")
for file_path in checkpoint_path.rglob("*"):
if file_path.is_file():
size_mb = file_path.stat().st_size / (1024 * 1024)
print(f" {file_path.relative_to(checkpoint_path)}: {size_mb:.2f} MB")

# Create new storage and load from checkpoint
print("\nLoading from checkpoint...")
new_storage = CompressedListStorage(max_size=1000, compression_level=6)
new_storage.loads(checkpoint_path)

# Verify data integrity after checkpointing
print("Verifying data integrity after checkpointing...")
new_retrieved_tensor = new_storage.get(0)
new_retrieved_td = new_storage.get(1)
new_retrieved_tensor2 = new_storage.get(2)

assert torch.allclose(tensor_data, new_retrieved_tensor, atol=1e-6)
assert torch.allclose(td_data["obs"], new_retrieved_td["obs"], atol=1e-6)
assert torch.allclose(tensor_data2, new_retrieved_tensor2, atol=1e-6)
print("✓ Data integrity after checkpointing verified!")

print(f"New storage length: {len(new_storage)}")

# Demonstrate with ReplayBuffer
print("\nDemonstrating with ReplayBuffer...")

rb = ReplayBuffer(storage=CompressedListStorage(max_size=100, compression_level=4))

# Add some data to the replay buffer
for _ in range(5):
data = TensorDict(
{
"obs": torch.randn(3, 84, 84, dtype=torch.float32),
"action": torch.randint(0, 18, (3,), dtype=torch.long),
"reward": torch.randn(3, dtype=torch.float32),
},
batch_size=[3],
)
rb.extend(data)

print(f"ReplayBuffer size: {len(rb)}")

# Sample from the buffer
sample = rb.sample(2)
print(f"Sampled data keys: {list(sample.keys())}")
print(f"Sampled obs shape: {sample['obs'].shape}")

# Checkpoint the replay buffer
with tempfile.TemporaryDirectory() as tmpdir:
rb_checkpoint_path = Path(tmpdir) / "rb_checkpoint"
print(f"\nCheckpointing ReplayBuffer to {rb_checkpoint_path}...")
rb.dumps(rb_checkpoint_path)

# Create new replay buffer and load
new_rb = ReplayBuffer(
storage=CompressedListStorage(max_size=100, compression_level=4)
)
new_rb.loads(rb_checkpoint_path)

print(f"New ReplayBuffer size: {len(new_rb)}")

# Verify sampling works
new_sample = new_rb.sample(2)
assert new_sample["obs"].shape == sample["obs"].shape
print("✓ ReplayBuffer checkpointing verified!")

print("\n🎉 All demonstrations completed successfully!")
print("\nKey benefits of the new memmap implementation:")
print("1. Efficient disk I/O using memory-mapped storage")
print("2. Reduced memory usage for large datasets")
print("3. Fast loading and saving of compressed data")
print("4. Support for heterogeneous data structures")
print("5. Seamless integration with ReplayBuffer")


if __name__ == "__main__":
main()
Loading
Loading