From 720de57d2c959a52c5a4976f310cfb0bb33d302b Mon Sep 17 00:00:00 2001 From: Jianbo Liu Date: Tue, 1 Jul 2025 15:16:52 -0700 Subject: [PATCH] Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading (#3148) Summary: X-link: https://github.com/pytorch/FBGEMM/pull/4429 X-link: https://github.com/facebookresearch/FBGEMM/pull/1495 Pull Request resolved: https://github.com/pytorch/torchrec/pull/3148 # Context In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again. # This diff * added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used * added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op * added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors * updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend) * by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs Differential Revision: D77604158 --- torchrec/distributed/embedding.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 362240d88..553a1c9b4 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -954,7 +954,17 @@ def _initialize_torch_state(self) -> None: # noqa ( [ # assuming virtual table only supports rw sharding for now - 0 if dim == 0 else dim_size + # When backend return whole row, need to respect dim(1) + # otherwise will see shard dim exceeded tensor dim error + ( + 0 + if dim == 0 + else ( + local_shards[0].metadata.shard_sizes[1] + if dim == 1 + else dim_size + ) + ) for dim, dim_size in enumerate( self._name_to_table_size[table_name] )