|
11 | 11 | import os
|
12 | 12 | import pickle
|
13 | 13 | import sys
|
| 14 | +import tempfile |
14 | 15 | from functools import partial
|
| 16 | +from pathlib import Path |
15 | 17 | from unittest import mock
|
16 | 18 |
|
17 | 19 | import numpy as np
|
| 20 | + |
18 | 21 | import pytest
|
19 | 22 | import torch
|
20 | 23 | from packaging import version
|
|
35 | 38 | from torchrl.collectors import RandomPolicy, SyncDataCollector
|
36 | 39 | from torchrl.collectors.utils import split_trajectories
|
37 | 40 | from torchrl.data import (
|
| 41 | + CompressedStorage, |
38 | 42 | FlatStorageCheckpointer,
|
39 | 43 | MultiStep,
|
40 | 44 | NestedStorageCheckpointer,
|
|
129 | 133 | _os_is_windows = sys.platform == "win32"
|
130 | 134 | _has_transformers = importlib.util.find_spec("transformers") is not None
|
131 | 135 | _has_ray = importlib.util.find_spec("ray") is not None
|
| 136 | +_has_zstandard = importlib.util.find_spec("zstandard") is not None |
132 | 137 |
|
133 | 138 | TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
|
134 | 139 |
|
@@ -4027,6 +4032,267 @@ def test_ray_rb_iter(self):
|
4027 | 4032 | rb.close()
|
4028 | 4033 |
|
4029 | 4034 |
|
| 4035 | +@pytest.mark.skipif(not _has_zstandard, reason="zstandard required for this test.") |
| 4036 | +class TestCompressedStorage: |
| 4037 | + """Test cases for CompressedStorage.""" |
| 4038 | + |
| 4039 | + def test_compressed_storage_initialization(self): |
| 4040 | + """Test that CompressedStorage initializes correctly.""" |
| 4041 | + storage = CompressedStorage(max_size=100, compression_level=3) |
| 4042 | + assert storage.max_size == 100 |
| 4043 | + assert storage.compression_level == 3 |
| 4044 | + assert len(storage) == 0 |
| 4045 | + |
| 4046 | + def test_compressed_storage_tensor(self): |
| 4047 | + """Test compression and decompression of tensor data.""" |
| 4048 | + storage = CompressedStorage(max_size=10, compression_level=3) |
| 4049 | + |
| 4050 | + # Create test tensor |
| 4051 | + test_tensor = torch.randn(3, 84, 84, dtype=torch.float32) |
| 4052 | + |
| 4053 | + # Store tensor |
| 4054 | + storage.set(0, test_tensor) |
| 4055 | + |
| 4056 | + # Retrieve tensor |
| 4057 | + retrieved_tensor = storage.get(0) |
| 4058 | + |
| 4059 | + # Verify data integrity |
| 4060 | + assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) |
| 4061 | + assert test_tensor.shape == retrieved_tensor.shape |
| 4062 | + assert test_tensor.dtype == retrieved_tensor.dtype |
| 4063 | + |
| 4064 | + def test_compressed_storage_tensordict(self): |
| 4065 | + """Test compression and decompression of TensorDict data.""" |
| 4066 | + storage = CompressedStorage(max_size=10, compression_level=3) |
| 4067 | + |
| 4068 | + # Create test TensorDict |
| 4069 | + test_td = TensorDict( |
| 4070 | + { |
| 4071 | + "obs": torch.randn(3, 84, 84, dtype=torch.float32), |
| 4072 | + "action": torch.tensor([1, 2, 3]), |
| 4073 | + "reward": torch.randn(3), |
| 4074 | + "done": torch.tensor([False, True, False]), |
| 4075 | + }, |
| 4076 | + batch_size=[3], |
| 4077 | + ) |
| 4078 | + |
| 4079 | + # Store TensorDict |
| 4080 | + storage.set(0, test_td) |
| 4081 | + |
| 4082 | + # Retrieve TensorDict |
| 4083 | + retrieved_td = storage.get(0) |
| 4084 | + |
| 4085 | + # Verify data integrity |
| 4086 | + assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6) |
| 4087 | + assert torch.allclose(test_td["action"], retrieved_td["action"]) |
| 4088 | + assert torch.allclose(test_td["reward"], retrieved_td["reward"], atol=1e-6) |
| 4089 | + assert torch.allclose(test_td["done"], retrieved_td["done"]) |
| 4090 | + |
| 4091 | + def test_compressed_storage_multiple_indices(self): |
| 4092 | + """Test storing and retrieving multiple items.""" |
| 4093 | + storage = CompressedStorage(max_size=10, compression_level=3) |
| 4094 | + |
| 4095 | + # Store multiple tensors |
| 4096 | + tensors = [ |
| 4097 | + torch.randn(2, 2, dtype=torch.float32), |
| 4098 | + torch.randn(3, 3, dtype=torch.float32), |
| 4099 | + torch.randn(4, 4, dtype=torch.float32), |
| 4100 | + ] |
| 4101 | + |
| 4102 | + for i, tensor in enumerate(tensors): |
| 4103 | + storage.set(i, tensor) |
| 4104 | + |
| 4105 | + # Retrieve multiple tensors |
| 4106 | + retrieved = storage.get([0, 1, 2]) |
| 4107 | + |
| 4108 | + # Verify data integrity |
| 4109 | + for original, retrieved_tensor in zip(tensors, retrieved): |
| 4110 | + assert torch.allclose(original, retrieved_tensor, atol=1e-6) |
| 4111 | + |
| 4112 | + def test_compressed_storage_with_replay_buffer(self): |
| 4113 | + """Test CompressedStorage with ReplayBuffer.""" |
| 4114 | + storage = CompressedStorage(max_size=100, compression_level=3) |
| 4115 | + rb = ReplayBuffer(storage=storage, batch_size=5) |
| 4116 | + |
| 4117 | + # Create test data |
| 4118 | + data = TensorDict( |
| 4119 | + { |
| 4120 | + "obs": torch.randn(10, 3, 84, 84, dtype=torch.float32), |
| 4121 | + "action": torch.randint(0, 4, (10,)), |
| 4122 | + "reward": torch.randn(10), |
| 4123 | + }, |
| 4124 | + batch_size=[10], |
| 4125 | + ) |
| 4126 | + |
| 4127 | + # Add data to replay buffer |
| 4128 | + print("extending") |
| 4129 | + rb.extend(data) |
| 4130 | + |
| 4131 | + # Sample from replay buffer |
| 4132 | + sample = rb.sample(5) |
| 4133 | + |
| 4134 | + # Verify sample has correct shape |
| 4135 | + assert is_tensor_collection(sample), sample |
| 4136 | + assert sample["obs"].shape[0] == 5 |
| 4137 | + assert sample["obs"].shape[1:] == (3, 84, 84) |
| 4138 | + assert sample["action"].shape[0] == 5 |
| 4139 | + assert sample["reward"].shape[0] == 5 |
| 4140 | + |
| 4141 | + def test_compressed_storage_state_dict(self): |
| 4142 | + """Test saving and loading state dict.""" |
| 4143 | + storage = CompressedStorage(max_size=10, compression_level=3) |
| 4144 | + |
| 4145 | + # Add some data |
| 4146 | + test_tensor = torch.randn(3, 3, dtype=torch.float32) |
| 4147 | + storage.set(0, test_tensor) |
| 4148 | + |
| 4149 | + # Save state dict |
| 4150 | + state_dict = storage.state_dict() |
| 4151 | + |
| 4152 | + # Create new storage and load state dict |
| 4153 | + new_storage = CompressedStorage(max_size=10, compression_level=3) |
| 4154 | + new_storage.load_state_dict(state_dict) |
| 4155 | + |
| 4156 | + # Verify data integrity |
| 4157 | + retrieved_tensor = new_storage.get(0) |
| 4158 | + assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) |
| 4159 | + |
| 4160 | + def test_compressed_storage_checkpointing(self): |
| 4161 | + """Test checkpointing functionality.""" |
| 4162 | + storage = CompressedStorage(max_size=10, compression_level=3) |
| 4163 | + |
| 4164 | + # Add some data |
| 4165 | + test_td = TensorDict( |
| 4166 | + { |
| 4167 | + "obs": torch.randn(3, 84, 84, dtype=torch.float32), |
| 4168 | + "action": torch.tensor([1, 2, 3]), |
| 4169 | + }, |
| 4170 | + batch_size=[3], |
| 4171 | + ) |
| 4172 | + storage.set(0, test_td) |
| 4173 | + |
| 4174 | + # Create temporary directory for checkpointing |
| 4175 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 4176 | + checkpoint_path = Path(tmpdir) / "checkpoint" |
| 4177 | + |
| 4178 | + # Save checkpoint |
| 4179 | + storage.dumps(checkpoint_path) |
| 4180 | + |
| 4181 | + # Create new storage and load checkpoint |
| 4182 | + new_storage = CompressedStorage(max_size=10, compression_level=3) |
| 4183 | + new_storage.loads(checkpoint_path) |
| 4184 | + |
| 4185 | + # Verify data integrity |
| 4186 | + retrieved_td = new_storage.get(0) |
| 4187 | + assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6) |
| 4188 | + assert torch.allclose(test_td["action"], retrieved_td["action"]) |
| 4189 | + |
| 4190 | + def test_compressed_storage_length(self): |
| 4191 | + """Test that length is calculated correctly.""" |
| 4192 | + storage = CompressedStorage(max_size=10, compression_level=3) |
| 4193 | + |
| 4194 | + # Initially empty |
| 4195 | + assert len(storage) == 0 |
| 4196 | + |
| 4197 | + # Add some data |
| 4198 | + storage.set(0, torch.randn(2, 2)) |
| 4199 | + assert len(storage) == 1 |
| 4200 | + |
| 4201 | + storage.set(2, torch.randn(2, 2)) |
| 4202 | + assert len(storage) == 2 |
| 4203 | + |
| 4204 | + storage.set(1, torch.randn(2, 2)) |
| 4205 | + assert len(storage) == 3 |
| 4206 | + |
| 4207 | + def test_compressed_storage_contains(self): |
| 4208 | + """Test the contains method.""" |
| 4209 | + storage = CompressedStorage(max_size=10, compression_level=3) |
| 4210 | + |
| 4211 | + # Initially empty |
| 4212 | + assert not storage.contains(0) |
| 4213 | + |
| 4214 | + # Add data |
| 4215 | + storage.set(0, torch.randn(2, 2)) |
| 4216 | + assert storage.contains(0) |
| 4217 | + assert not storage.contains(1) |
| 4218 | + |
| 4219 | + def test_compressed_storage_empty(self): |
| 4220 | + """Test emptying the storage.""" |
| 4221 | + storage = CompressedStorage(max_size=10, compression_level=3) |
| 4222 | + |
| 4223 | + # Add some data |
| 4224 | + storage.set(0, torch.randn(2, 2)) |
| 4225 | + storage.set(1, torch.randn(2, 2)) |
| 4226 | + assert len(storage) == 2 |
| 4227 | + |
| 4228 | + # Empty storage |
| 4229 | + storage._empty() |
| 4230 | + assert len(storage) == 0 |
| 4231 | + |
| 4232 | + def test_compressed_storage_custom_compression(self): |
| 4233 | + """Test custom compression functions.""" |
| 4234 | + |
| 4235 | + def custom_compress(tensor): |
| 4236 | + # Simple compression: just convert to uint8 |
| 4237 | + return tensor.to(torch.uint8) |
| 4238 | + |
| 4239 | + def custom_decompress(compressed_tensor, metadata): |
| 4240 | + # Simple decompression: convert back to original dtype |
| 4241 | + return compressed_tensor.to(metadata["dtype"]) |
| 4242 | + |
| 4243 | + storage = CompressedStorage( |
| 4244 | + max_size=10, |
| 4245 | + compression_fn=custom_compress, |
| 4246 | + decompression_fn=custom_decompress, |
| 4247 | + ) |
| 4248 | + |
| 4249 | + # Test with tensor |
| 4250 | + test_tensor = torch.randn(2, 2, dtype=torch.float32) |
| 4251 | + storage.set(0, test_tensor) |
| 4252 | + retrieved_tensor = storage.get(0) |
| 4253 | + |
| 4254 | + # Note: This will lose precision due to uint8 conversion |
| 4255 | + # but should still work |
| 4256 | + assert retrieved_tensor.shape == test_tensor.shape |
| 4257 | + |
| 4258 | + def test_compressed_storage_error_handling(self): |
| 4259 | + """Test error handling for invalid operations.""" |
| 4260 | + storage = CompressedStorage(max_size=5, compression_level=3) |
| 4261 | + |
| 4262 | + # Test setting data beyond max_size |
| 4263 | + with pytest.raises(RuntimeError): |
| 4264 | + storage.set(10, torch.randn(2, 2)) |
| 4265 | + |
| 4266 | + # Test getting non-existent data |
| 4267 | + with pytest.raises(IndexError): |
| 4268 | + storage.get(0) |
| 4269 | + |
| 4270 | + def test_compressed_storage_memory_efficiency(self): |
| 4271 | + """Test that compression actually reduces memory usage.""" |
| 4272 | + storage = CompressedStorage(max_size=100, compression_level=3) |
| 4273 | + |
| 4274 | + # Create large tensor data |
| 4275 | + large_tensor = torch.zeros(100, 3, 84, 84, dtype=torch.int64) |
| 4276 | + large_tensor.copy_( |
| 4277 | + torch.arange(large_tensor.numel(), dtype=torch.int32).view_as(large_tensor) |
| 4278 | + // (3 * 84 * 84) |
| 4279 | + ) |
| 4280 | + original_size = large_tensor.numel() * large_tensor.element_size() |
| 4281 | + |
| 4282 | + # Store in compressed storage |
| 4283 | + storage.set(0, large_tensor) |
| 4284 | + |
| 4285 | + # Estimate compressed size |
| 4286 | + compressed_data = storage._compressed_data[0] |
| 4287 | + compressed_size = compressed_data.numel() # uint8 bytes |
| 4288 | + |
| 4289 | + # Verify compression ratio is reasonable (at least 2x for random data) |
| 4290 | + compression_ratio = original_size / compressed_size |
| 4291 | + assert ( |
| 4292 | + compression_ratio > 1.5 |
| 4293 | + ), f"Compression ratio {compression_ratio} is too low" |
| 4294 | + |
| 4295 | + |
4030 | 4296 | if __name__ == "__main__":
|
4031 | 4297 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
4032 | 4298 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments