Skip to content

Commit 5581cf6

Browse files
Adrian Orensteinvmoens
authored andcommitted
Using python's default compressor. Created to_bytestream. Created a to_bytestream speed test.
1 parent 600077a commit 5581cf6

File tree

4 files changed

+381
-123
lines changed

4 files changed

+381
-123
lines changed

.github/unittest/linux/scripts/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@ dependencies:
3535
- transformers
3636
- ninja
3737
- timm
38-
- zstandard
38+
- safetensors

test/test_rb.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4210,6 +4210,18 @@ def test_compressed_storage_checkpointing(self):
42104210
)
42114211
storage.set(0, test_td)
42124212

4213+
# second batch, different shape
4214+
test_td2 = TensorDict(
4215+
{
4216+
"obs": torch.randn(3, 85, 83, dtype=torch.float32),
4217+
"action": torch.tensor([1, 2, 3]),
4218+
"meta": torch.randn(3),
4219+
"astring": "a string!",
4220+
},
4221+
batch_size=[3],
4222+
)
4223+
storage.set(1, test_td)
4224+
42134225
# Create temporary directory for checkpointing
42144226
with tempfile.TemporaryDirectory() as tmpdir:
42154227
checkpoint_path = Path(tmpdir) / "checkpoint"
@@ -4331,6 +4343,137 @@ def test_compressed_storage_memory_efficiency(self):
43314343
compression_ratio > 1.5
43324344
), f"Compression ratio {compression_ratio} is too low"
43334345

4346+
@staticmethod
4347+
def make_compressible_mock_data(num_experiences: int, device=None) -> dict:
4348+
"""Easily compressible data for testing."""
4349+
if device is None:
4350+
device = torch.device("cpu")
4351+
4352+
return {
4353+
"observations": torch.zeros(
4354+
(num_experiences, 4, 84, 84),
4355+
dtype=torch.uint8,
4356+
device=device,
4357+
),
4358+
"actions": torch.zeros((num_experiences,), device=device),
4359+
"rewards": torch.zeros((num_experiences,), device=device),
4360+
"next_observations": torch.zeros(
4361+
(num_experiences, 4, 84, 84),
4362+
dtype=torch.uint8,
4363+
device=device,
4364+
),
4365+
"terminations": torch.zeros(
4366+
(num_experiences,), dtype=torch.bool, device=device
4367+
),
4368+
"truncations": torch.zeros(
4369+
(num_experiences,), dtype=torch.bool, device=device
4370+
),
4371+
"batch_size": [num_experiences],
4372+
}
4373+
4374+
@staticmethod
4375+
def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict:
4376+
"""Uncompressible data for testing."""
4377+
if device is None:
4378+
device = torch.device("cpu")
4379+
return {
4380+
"observations": torch.randn(
4381+
(num_experiences, 4, 84, 84),
4382+
dtype=torch.float32,
4383+
device=device,
4384+
),
4385+
"actions": torch.randint(0, 10, (num_experiences,), device=device),
4386+
"rewards": torch.randn(
4387+
(num_experiences,), dtype=torch.float32, device=device
4388+
),
4389+
"next_observations": torch.randn(
4390+
(num_experiences, 4, 84, 84),
4391+
dtype=torch.float32,
4392+
device=device,
4393+
),
4394+
"terminations": torch.rand((num_experiences,), device=device)
4395+
< 0.2, # ~20% True
4396+
"truncations": torch.rand((num_experiences,), device=device)
4397+
< 0.1, # ~10% True
4398+
"batch_size": [num_experiences],
4399+
}
4400+
4401+
@pytest.mark.benchmark(
4402+
group="tensor_serialization_speed",
4403+
min_time=0.1,
4404+
max_time=0.5,
4405+
min_rounds=5,
4406+
disable_gc=True,
4407+
warmup=False,
4408+
)
4409+
@pytest.mark.parametrize(
4410+
"serialization_method",
4411+
["pickle", "torch.save", "untyped_storage", "numpy", "safetensors"],
4412+
)
4413+
def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str):
4414+
"""Benchmark the speed of different tensor serialization methods.
4415+
4416+
TODO: we might need to also test which methods work on the gpu.
4417+
pytest test/test_rb.py::TestCompressedListStorage::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops'
4418+
4419+
------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests -------------------------
4420+
Name (time in us) Mean (smaller is better) OPS (bigger is better)
4421+
--------------------------------------------------------------------------------------------------
4422+
test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0)
4423+
test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16)
4424+
test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12)
4425+
test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07)
4426+
test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00)
4427+
--------------------------------------------------------------------------------------------------
4428+
"""
4429+
import io
4430+
import pickle
4431+
4432+
import torch
4433+
from safetensors.torch import save
4434+
4435+
def serialize_with_pickle(data: torch.Tensor) -> bytes:
4436+
"""Serialize tensor using pickle."""
4437+
buffer = io.BytesIO()
4438+
pickle.dump(data, buffer)
4439+
return buffer.getvalue()
4440+
4441+
def serialize_with_untyped_storage(data: torch.Tensor) -> bytes:
4442+
"""Serialize tensor using torch's built-in method."""
4443+
return bytes(data.untyped_storage())
4444+
4445+
def serialize_with_numpy(data: torch.Tensor) -> bytes:
4446+
"""Serialize tensor using numpy."""
4447+
return data.numpy().tobytes()
4448+
4449+
def serialize_with_safetensors(data: torch.Tensor) -> bytes:
4450+
return save({"0": data})
4451+
4452+
def serialize_with_torch(data: torch.Tensor) -> bytes:
4453+
"""Serialize tensor using torch's built-in method."""
4454+
buffer = io.BytesIO()
4455+
torch.save(data, buffer)
4456+
return buffer.getvalue()
4457+
4458+
# Benchmark each serialization method
4459+
if serialization_method == "pickle":
4460+
serialize_fn = serialize_with_pickle
4461+
elif serialization_method == "torch.save":
4462+
serialize_fn = serialize_with_torch
4463+
elif serialization_method == "untyped_storage":
4464+
serialize_fn = serialize_with_untyped_storage
4465+
elif serialization_method == "numpy":
4466+
serialize_fn = serialize_with_numpy
4467+
elif serialization_method == "safetensors":
4468+
serialize_fn = serialize_with_safetensors
4469+
else:
4470+
raise ValueError(f"Unknown serialization method: {serialization_method}")
4471+
4472+
data = self.make_compressible_mock_data(1).get("observations")
4473+
4474+
# Run the actual benchmark
4475+
benchmark(serialize_fn, data)
4476+
43344477

43354478
if __name__ == "__main__":
43364479
args, unknown = argparse.ArgumentParser().parse_known_args()

0 commit comments

Comments
 (0)