Skip to content

Commit 0acd477

Browse files
committed
amend
1 parent db0e30d commit 0acd477

File tree

7 files changed

+768
-6
lines changed

7 files changed

+768
-6
lines changed

.github/unittest/linux/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ dependencies:
3535
- transformers
3636
- ninja
3737
- timm
38+
- zstandard

docs/source/reference/data.rst

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ using the following components:
144144
:template: rl_template.rst
145145

146146

147+
CompressedStorage
148+
CompressedStorageCheckpointer
147149
FlatStorageCheckpointer
148150
H5StorageCheckpointer
149151
ImmutableDatasetWriter
@@ -191,6 +193,66 @@ were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/be
191193
| :class:`LazyMemmapStorage` | 3.44x |
192194
+-------------------------------+-----------+
193195

196+
Compressed Storage for Memory Efficiency
197+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
198+
199+
For applications where memory usage is a primary concern, especially when storing
200+
large sensory observations like images or audio, the :class:`~torchrl.data.replay_buffers.storages.CompressedStorage`
201+
provides significant memory savings through compression.
202+
203+
The `CompressedStorage`` compresses data when storing and decompresses when retrieving,
204+
achieving compression ratios of 2-10x for image data while maintaining full data fidelity.
205+
It uses zstd compression by default but supports custom compression algorithms.
206+
207+
Key features:
208+
- **Memory Efficiency**: Achieves significant memory savings through compression
209+
- **Data Integrity**: Maintains full data fidelity through lossless compression
210+
- **Flexible Compression**: Supports custom compression algorithms or uses zstd by default
211+
- **TensorDict Support**: Seamlessly works with TensorDict structures
212+
- **Checkpointing**: Full support for saving and loading compressed data
213+
214+
Example usage:
215+
216+
>>> import torch
217+
>>> from torchrl.data import ReplayBuffer, CompressedStorage
218+
>>> from tensordict import TensorDict
219+
>>>
220+
>>> # Create a compressed storage for image data
221+
>>> storage = CompressedStorage(max_size=1000, compression_level=3)
222+
>>> rb = ReplayBuffer(storage=storage, batch_size=32)
223+
>>>
224+
>>> # Add image data
225+
>>> images = torch.randn(100, 3, 84, 84) # Atari-like frames
226+
>>> data = TensorDict({"obs": images}, batch_size=[100])
227+
>>> rb.extend(data)
228+
>>>
229+
>>> # Sample data (automatically decompressed)
230+
>>> sample = rb.sample(16)
231+
>>> print(sample["obs"].shape) # torch.Size([16, 3, 84, 84])
232+
233+
The compression level can be adjusted from 1 (fast, less compression) to 22 (slow, more compression),
234+
with level 3 being a good default for most use cases.
235+
236+
For custom compression algorithms:
237+
238+
>>> def my_compress(tensor):
239+
... return tensor.to(torch.uint8) # Simple example
240+
>>>
241+
>>> def my_decompress(compressed_tensor, metadata):
242+
... return compressed_tensor.to(metadata["dtype"])
243+
>>>
244+
>>> storage = CompressedStorage(
245+
... max_size=1000,
246+
... compression_fn=my_compress,
247+
... decompression_fn=my_decompress
248+
... )
249+
250+
.. note:: The CompressedStorage requires the `zstandard` library for default compression.
251+
Install with: ``pip install zstandard``
252+
253+
.. note:: An example of how to use the CompressedStorage is available in the
254+
`examples/replay-buffers/compressed_replay_buffer_example.py <https://github.com/pytorch/rl/blob/main/examples/replay-buffers/compressed_replay_buffer_example.py>`_ file.
255+
194256
Sharing replay buffers across processes
195257
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
196258

test/test_rb.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
import os
1212
import pickle
1313
import sys
14+
import tempfile
1415
from functools import partial
16+
from pathlib import Path
1517
from unittest import mock
1618

1719
import numpy as np
20+
1821
import pytest
1922
import torch
2023
from packaging import version
@@ -35,6 +38,7 @@
3538
from torchrl.collectors import RandomPolicy, SyncDataCollector
3639
from torchrl.collectors.utils import split_trajectories
3740
from torchrl.data import (
41+
CompressedStorage,
3842
FlatStorageCheckpointer,
3943
MultiStep,
4044
NestedStorageCheckpointer,
@@ -129,6 +133,7 @@
129133
_os_is_windows = sys.platform == "win32"
130134
_has_transformers = importlib.util.find_spec("transformers") is not None
131135
_has_ray = importlib.util.find_spec("ray") is not None
136+
_has_zstandard = importlib.util.find_spec("zstandard") is not None
132137

133138
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
134139

@@ -4027,6 +4032,267 @@ def test_ray_rb_iter(self):
40274032
rb.close()
40284033

40294034

4035+
@pytest.mark.skipif(not _has_zstandard, reason="zstandard required for this test.")
4036+
class TestCompressedStorage:
4037+
"""Test cases for CompressedStorage."""
4038+
4039+
def test_compressed_storage_initialization(self):
4040+
"""Test that CompressedStorage initializes correctly."""
4041+
storage = CompressedStorage(max_size=100, compression_level=3)
4042+
assert storage.max_size == 100
4043+
assert storage.compression_level == 3
4044+
assert len(storage) == 0
4045+
4046+
def test_compressed_storage_tensor(self):
4047+
"""Test compression and decompression of tensor data."""
4048+
storage = CompressedStorage(max_size=10, compression_level=3)
4049+
4050+
# Create test tensor
4051+
test_tensor = torch.randn(3, 84, 84, dtype=torch.float32)
4052+
4053+
# Store tensor
4054+
storage.set(0, test_tensor)
4055+
4056+
# Retrieve tensor
4057+
retrieved_tensor = storage.get(0)
4058+
4059+
# Verify data integrity
4060+
assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6)
4061+
assert test_tensor.shape == retrieved_tensor.shape
4062+
assert test_tensor.dtype == retrieved_tensor.dtype
4063+
4064+
def test_compressed_storage_tensordict(self):
4065+
"""Test compression and decompression of TensorDict data."""
4066+
storage = CompressedStorage(max_size=10, compression_level=3)
4067+
4068+
# Create test TensorDict
4069+
test_td = TensorDict(
4070+
{
4071+
"obs": torch.randn(3, 84, 84, dtype=torch.float32),
4072+
"action": torch.tensor([1, 2, 3]),
4073+
"reward": torch.randn(3),
4074+
"done": torch.tensor([False, True, False]),
4075+
},
4076+
batch_size=[3],
4077+
)
4078+
4079+
# Store TensorDict
4080+
storage.set(0, test_td)
4081+
4082+
# Retrieve TensorDict
4083+
retrieved_td = storage.get(0)
4084+
4085+
# Verify data integrity
4086+
assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6)
4087+
assert torch.allclose(test_td["action"], retrieved_td["action"])
4088+
assert torch.allclose(test_td["reward"], retrieved_td["reward"], atol=1e-6)
4089+
assert torch.allclose(test_td["done"], retrieved_td["done"])
4090+
4091+
def test_compressed_storage_multiple_indices(self):
4092+
"""Test storing and retrieving multiple items."""
4093+
storage = CompressedStorage(max_size=10, compression_level=3)
4094+
4095+
# Store multiple tensors
4096+
tensors = [
4097+
torch.randn(2, 2, dtype=torch.float32),
4098+
torch.randn(3, 3, dtype=torch.float32),
4099+
torch.randn(4, 4, dtype=torch.float32),
4100+
]
4101+
4102+
for i, tensor in enumerate(tensors):
4103+
storage.set(i, tensor)
4104+
4105+
# Retrieve multiple tensors
4106+
retrieved = storage.get([0, 1, 2])
4107+
4108+
# Verify data integrity
4109+
for original, retrieved_tensor in zip(tensors, retrieved):
4110+
assert torch.allclose(original, retrieved_tensor, atol=1e-6)
4111+
4112+
def test_compressed_storage_with_replay_buffer(self):
4113+
"""Test CompressedStorage with ReplayBuffer."""
4114+
storage = CompressedStorage(max_size=100, compression_level=3)
4115+
rb = ReplayBuffer(storage=storage, batch_size=5)
4116+
4117+
# Create test data
4118+
data = TensorDict(
4119+
{
4120+
"obs": torch.randn(10, 3, 84, 84, dtype=torch.float32),
4121+
"action": torch.randint(0, 4, (10,)),
4122+
"reward": torch.randn(10),
4123+
},
4124+
batch_size=[10],
4125+
)
4126+
4127+
# Add data to replay buffer
4128+
print("extending")
4129+
rb.extend(data)
4130+
4131+
# Sample from replay buffer
4132+
sample = rb.sample(5)
4133+
4134+
# Verify sample has correct shape
4135+
assert is_tensor_collection(sample), sample
4136+
assert sample["obs"].shape[0] == 5
4137+
assert sample["obs"].shape[1:] == (3, 84, 84)
4138+
assert sample["action"].shape[0] == 5
4139+
assert sample["reward"].shape[0] == 5
4140+
4141+
def test_compressed_storage_state_dict(self):
4142+
"""Test saving and loading state dict."""
4143+
storage = CompressedStorage(max_size=10, compression_level=3)
4144+
4145+
# Add some data
4146+
test_tensor = torch.randn(3, 3, dtype=torch.float32)
4147+
storage.set(0, test_tensor)
4148+
4149+
# Save state dict
4150+
state_dict = storage.state_dict()
4151+
4152+
# Create new storage and load state dict
4153+
new_storage = CompressedStorage(max_size=10, compression_level=3)
4154+
new_storage.load_state_dict(state_dict)
4155+
4156+
# Verify data integrity
4157+
retrieved_tensor = new_storage.get(0)
4158+
assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6)
4159+
4160+
def test_compressed_storage_checkpointing(self):
4161+
"""Test checkpointing functionality."""
4162+
storage = CompressedStorage(max_size=10, compression_level=3)
4163+
4164+
# Add some data
4165+
test_td = TensorDict(
4166+
{
4167+
"obs": torch.randn(3, 84, 84, dtype=torch.float32),
4168+
"action": torch.tensor([1, 2, 3]),
4169+
},
4170+
batch_size=[3],
4171+
)
4172+
storage.set(0, test_td)
4173+
4174+
# Create temporary directory for checkpointing
4175+
with tempfile.TemporaryDirectory() as tmpdir:
4176+
checkpoint_path = Path(tmpdir) / "checkpoint"
4177+
4178+
# Save checkpoint
4179+
storage.dumps(checkpoint_path)
4180+
4181+
# Create new storage and load checkpoint
4182+
new_storage = CompressedStorage(max_size=10, compression_level=3)
4183+
new_storage.loads(checkpoint_path)
4184+
4185+
# Verify data integrity
4186+
retrieved_td = new_storage.get(0)
4187+
assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6)
4188+
assert torch.allclose(test_td["action"], retrieved_td["action"])
4189+
4190+
def test_compressed_storage_length(self):
4191+
"""Test that length is calculated correctly."""
4192+
storage = CompressedStorage(max_size=10, compression_level=3)
4193+
4194+
# Initially empty
4195+
assert len(storage) == 0
4196+
4197+
# Add some data
4198+
storage.set(0, torch.randn(2, 2))
4199+
assert len(storage) == 1
4200+
4201+
storage.set(2, torch.randn(2, 2))
4202+
assert len(storage) == 2
4203+
4204+
storage.set(1, torch.randn(2, 2))
4205+
assert len(storage) == 3
4206+
4207+
def test_compressed_storage_contains(self):
4208+
"""Test the contains method."""
4209+
storage = CompressedStorage(max_size=10, compression_level=3)
4210+
4211+
# Initially empty
4212+
assert not storage.contains(0)
4213+
4214+
# Add data
4215+
storage.set(0, torch.randn(2, 2))
4216+
assert storage.contains(0)
4217+
assert not storage.contains(1)
4218+
4219+
def test_compressed_storage_empty(self):
4220+
"""Test emptying the storage."""
4221+
storage = CompressedStorage(max_size=10, compression_level=3)
4222+
4223+
# Add some data
4224+
storage.set(0, torch.randn(2, 2))
4225+
storage.set(1, torch.randn(2, 2))
4226+
assert len(storage) == 2
4227+
4228+
# Empty storage
4229+
storage._empty()
4230+
assert len(storage) == 0
4231+
4232+
def test_compressed_storage_custom_compression(self):
4233+
"""Test custom compression functions."""
4234+
4235+
def custom_compress(tensor):
4236+
# Simple compression: just convert to uint8
4237+
return tensor.to(torch.uint8)
4238+
4239+
def custom_decompress(compressed_tensor, metadata):
4240+
# Simple decompression: convert back to original dtype
4241+
return compressed_tensor.to(metadata["dtype"])
4242+
4243+
storage = CompressedStorage(
4244+
max_size=10,
4245+
compression_fn=custom_compress,
4246+
decompression_fn=custom_decompress,
4247+
)
4248+
4249+
# Test with tensor
4250+
test_tensor = torch.randn(2, 2, dtype=torch.float32)
4251+
storage.set(0, test_tensor)
4252+
retrieved_tensor = storage.get(0)
4253+
4254+
# Note: This will lose precision due to uint8 conversion
4255+
# but should still work
4256+
assert retrieved_tensor.shape == test_tensor.shape
4257+
4258+
def test_compressed_storage_error_handling(self):
4259+
"""Test error handling for invalid operations."""
4260+
storage = CompressedStorage(max_size=5, compression_level=3)
4261+
4262+
# Test setting data beyond max_size
4263+
with pytest.raises(RuntimeError):
4264+
storage.set(10, torch.randn(2, 2))
4265+
4266+
# Test getting non-existent data
4267+
with pytest.raises(IndexError):
4268+
storage.get(0)
4269+
4270+
def test_compressed_storage_memory_efficiency(self):
4271+
"""Test that compression actually reduces memory usage."""
4272+
storage = CompressedStorage(max_size=100, compression_level=3)
4273+
4274+
# Create large tensor data
4275+
large_tensor = torch.zeros(100, 3, 84, 84, dtype=torch.int64)
4276+
large_tensor.copy_(
4277+
torch.arange(large_tensor.numel(), dtype=torch.int32).view_as(large_tensor)
4278+
// (3 * 84 * 84)
4279+
)
4280+
original_size = large_tensor.numel() * large_tensor.element_size()
4281+
4282+
# Store in compressed storage
4283+
storage.set(0, large_tensor)
4284+
4285+
# Estimate compressed size
4286+
compressed_data = storage._compressed_data[0]
4287+
compressed_size = compressed_data.numel() # uint8 bytes
4288+
4289+
# Verify compression ratio is reasonable (at least 2x for random data)
4290+
compression_ratio = original_size / compressed_size
4291+
assert (
4292+
compression_ratio > 1.5
4293+
), f"Compression ratio {compression_ratio} is too low"
4294+
4295+
40304296
if __name__ == "__main__":
40314297
args, unknown = argparse.ArgumentParser().parse_known_args()
40324298
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)