Skip to content

Commit 600077a

Browse files
Adrian Orensteinvmoens
authored andcommitted
Refactor out the storage view. Expose functions in the ListStorage class.
1 parent 6a7b9e0 commit 600077a

File tree

3 files changed

+123
-229
lines changed

3 files changed

+123
-229
lines changed

test/test_rb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4322,7 +4322,7 @@ def test_compressed_storage_memory_efficiency(self):
43224322
storage.set(0, large_tensor)
43234323

43244324
# Estimate compressed size
4325-
compressed_data = storage._compressed_data[0]
4325+
compressed_data = storage._storage[0]
43264326
compressed_size = compressed_data.numel() # uint8 bytes
43274327

43284328
# Verify compression ratio is reasonable (at least 2x for random data)

torchrl/data/replay_buffers/checkpointers.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,16 @@ def dumps(self, storage, path):
7979
path = Path(path)
8080
path.mkdir(exist_ok=True)
8181

82-
if (
83-
not hasattr(storage, "_compressed_data")
84-
or len(storage._compressed_data) == 0
85-
):
82+
if not hasattr(storage, "_storage") or len(storage._storage) == 0:
8683
raise RuntimeError(
8784
"Cannot save an empty or non-initialized CompressedListStorage."
8885
)
8986

9087
# Save compressed data and metadata
9188
state_dict = storage.state_dict()
9289

93-
# Save compressed data as separate files for efficiency
94-
compressed_data = state_dict["_compressed_data"]
90+
# Save compressed data and metadata
91+
compressed_data = state_dict["_storage"]
9592
metadata = state_dict["_metadata"]
9693

9794
# Save metadata
@@ -203,7 +200,7 @@ def convert_dtype(item):
203200
compressed_data.append(None)
204201

205202
# Load into storage
206-
storage._compressed_data = compressed_data
203+
storage._storage = compressed_data
207204
storage._metadata = metadata
208205

209206

0 commit comments

Comments
 (0)