@@ -4210,6 +4210,18 @@ def test_compressed_storage_checkpointing(self):
4210
4210
)
4211
4211
storage .set (0 , test_td )
4212
4212
4213
+ # second batch, different shape
4214
+ test_td2 = TensorDict (
4215
+ {
4216
+ "obs" : torch .randn (3 , 85 , 83 , dtype = torch .float32 ),
4217
+ "action" : torch .tensor ([1 , 2 , 3 ]),
4218
+ "meta" : torch .randn (3 ),
4219
+ "astring" : "a string!" ,
4220
+ },
4221
+ batch_size = [3 ],
4222
+ )
4223
+ storage .set (1 , test_td )
4224
+
4213
4225
# Create temporary directory for checkpointing
4214
4226
with tempfile .TemporaryDirectory () as tmpdir :
4215
4227
checkpoint_path = Path (tmpdir ) / "checkpoint"
@@ -4331,6 +4343,137 @@ def test_compressed_storage_memory_efficiency(self):
4331
4343
compression_ratio > 1.5
4332
4344
), f"Compression ratio { compression_ratio } is too low"
4333
4345
4346
+ @staticmethod
4347
+ def make_compressible_mock_data (num_experiences : int , device = None ) -> dict :
4348
+ """Easily compressible data for testing."""
4349
+ if device is None :
4350
+ device = torch .device ("cpu" )
4351
+
4352
+ return {
4353
+ "observations" : torch .zeros (
4354
+ (num_experiences , 4 , 84 , 84 ),
4355
+ dtype = torch .uint8 ,
4356
+ device = device ,
4357
+ ),
4358
+ "actions" : torch .zeros ((num_experiences ,), device = device ),
4359
+ "rewards" : torch .zeros ((num_experiences ,), device = device ),
4360
+ "next_observations" : torch .zeros (
4361
+ (num_experiences , 4 , 84 , 84 ),
4362
+ dtype = torch .uint8 ,
4363
+ device = device ,
4364
+ ),
4365
+ "terminations" : torch .zeros (
4366
+ (num_experiences ,), dtype = torch .bool , device = device
4367
+ ),
4368
+ "truncations" : torch .zeros (
4369
+ (num_experiences ,), dtype = torch .bool , device = device
4370
+ ),
4371
+ "batch_size" : [num_experiences ],
4372
+ }
4373
+
4374
+ @staticmethod
4375
+ def make_uncompressible_mock_data (num_experiences : int , device = None ) -> dict :
4376
+ """Uncompressible data for testing."""
4377
+ if device is None :
4378
+ device = torch .device ("cpu" )
4379
+ return {
4380
+ "observations" : torch .randn (
4381
+ (num_experiences , 4 , 84 , 84 ),
4382
+ dtype = torch .float32 ,
4383
+ device = device ,
4384
+ ),
4385
+ "actions" : torch .randint (0 , 10 , (num_experiences ,), device = device ),
4386
+ "rewards" : torch .randn (
4387
+ (num_experiences ,), dtype = torch .float32 , device = device
4388
+ ),
4389
+ "next_observations" : torch .randn (
4390
+ (num_experiences , 4 , 84 , 84 ),
4391
+ dtype = torch .float32 ,
4392
+ device = device ,
4393
+ ),
4394
+ "terminations" : torch .rand ((num_experiences ,), device = device )
4395
+ < 0.2 , # ~20% True
4396
+ "truncations" : torch .rand ((num_experiences ,), device = device )
4397
+ < 0.1 , # ~10% True
4398
+ "batch_size" : [num_experiences ],
4399
+ }
4400
+
4401
+ @pytest .mark .benchmark (
4402
+ group = "tensor_serialization_speed" ,
4403
+ min_time = 0.1 ,
4404
+ max_time = 0.5 ,
4405
+ min_rounds = 5 ,
4406
+ disable_gc = True ,
4407
+ warmup = False ,
4408
+ )
4409
+ @pytest .mark .parametrize (
4410
+ "serialization_method" ,
4411
+ ["pickle" , "torch.save" , "untyped_storage" , "numpy" , "safetensors" ],
4412
+ )
4413
+ def test_tensor_to_bytestream_speed (self , benchmark , serialization_method : str ):
4414
+ """Benchmark the speed of different tensor serialization methods.
4415
+
4416
+ TODO: we might need to also test which methods work on the gpu.
4417
+ pytest test/test_rb.py::TestCompressedListStorage::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops'
4418
+
4419
+ ------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests -------------------------
4420
+ Name (time in us) Mean (smaller is better) OPS (bigger is better)
4421
+ --------------------------------------------------------------------------------------------------
4422
+ test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0)
4423
+ test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16)
4424
+ test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12)
4425
+ test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07)
4426
+ test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00)
4427
+ --------------------------------------------------------------------------------------------------
4428
+ """
4429
+ import io
4430
+ import pickle
4431
+
4432
+ import torch
4433
+ from safetensors .torch import save
4434
+
4435
+ def serialize_with_pickle (data : torch .Tensor ) -> bytes :
4436
+ """Serialize tensor using pickle."""
4437
+ buffer = io .BytesIO ()
4438
+ pickle .dump (data , buffer )
4439
+ return buffer .getvalue ()
4440
+
4441
+ def serialize_with_untyped_storage (data : torch .Tensor ) -> bytes :
4442
+ """Serialize tensor using torch's built-in method."""
4443
+ return bytes (data .untyped_storage ())
4444
+
4445
+ def serialize_with_numpy (data : torch .Tensor ) -> bytes :
4446
+ """Serialize tensor using numpy."""
4447
+ return data .numpy ().tobytes ()
4448
+
4449
+ def serialize_with_safetensors (data : torch .Tensor ) -> bytes :
4450
+ return save ({"0" : data })
4451
+
4452
+ def serialize_with_torch (data : torch .Tensor ) -> bytes :
4453
+ """Serialize tensor using torch's built-in method."""
4454
+ buffer = io .BytesIO ()
4455
+ torch .save (data , buffer )
4456
+ return buffer .getvalue ()
4457
+
4458
+ # Benchmark each serialization method
4459
+ if serialization_method == "pickle" :
4460
+ serialize_fn = serialize_with_pickle
4461
+ elif serialization_method == "torch.save" :
4462
+ serialize_fn = serialize_with_torch
4463
+ elif serialization_method == "untyped_storage" :
4464
+ serialize_fn = serialize_with_untyped_storage
4465
+ elif serialization_method == "numpy" :
4466
+ serialize_fn = serialize_with_numpy
4467
+ elif serialization_method == "safetensors" :
4468
+ serialize_fn = serialize_with_safetensors
4469
+ else :
4470
+ raise ValueError (f"Unknown serialization method: { serialization_method } " )
4471
+
4472
+ data = self .make_compressible_mock_data (1 ).get ("observations" )
4473
+
4474
+ # Run the actual benchmark
4475
+ benchmark (serialize_fn , data )
4476
+
4334
4477
4335
4478
if __name__ == "__main__" :
4336
4479
args , unknown = argparse .ArgumentParser ().parse_known_args ()
0 commit comments