1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import os
1516from contextlib import contextmanager , nullcontext
1617from typing import Dict , List , Optional , Set , Tuple , Union
1718
19+ import safetensors .torch
1820import torch
1921
2022from ..utils import get_logger , is_accelerate_available
@@ -59,6 +61,7 @@ def __init__(
5961 record_stream : Optional [bool ] = False ,
6062 low_cpu_mem_usage : bool = False ,
6163 onload_self : bool = True ,
64+ offload_to_disk_path : Optional [str ] = None ,
6265 ) -> None :
6366 self .modules = modules
6467 self .offload_device = offload_device
@@ -72,7 +75,26 @@ def __init__(
7275 self .record_stream = record_stream
7376 self .onload_self = onload_self
7477 self .low_cpu_mem_usage = low_cpu_mem_usage
75- self .cpu_param_dict = self ._init_cpu_param_dict ()
78+
79+ self .offload_to_disk_path = offload_to_disk_path
80+ self ._is_offloaded_to_disk = False
81+
82+ if self .offload_to_disk_path :
83+ self .safetensors_file_path = os .path .join (self .offload_to_disk_path , f"group_{ id (self )} .safetensors" )
84+
85+ all_tensors = []
86+ for module in self .modules :
87+ all_tensors .extend (list (module .parameters ()))
88+ all_tensors .extend (list (module .buffers ()))
89+ all_tensors .extend (self .parameters )
90+ all_tensors .extend (self .buffers )
91+ all_tensors = list (dict .fromkeys (all_tensors )) # Remove duplicates
92+
93+ self .tensor_to_key = {tensor : f"tensor_{ i } " for i , tensor in enumerate (all_tensors )}
94+ self .key_to_tensor = {v : k for k , v in self .tensor_to_key .items ()}
95+ self .cpu_param_dict = {}
96+ else :
97+ self .cpu_param_dict = self ._init_cpu_param_dict ()
7698
7799 if self .stream is None and self .record_stream :
78100 raise ValueError ("`record_stream` cannot be True when `stream` is None." )
@@ -124,6 +146,30 @@ def onload_(self):
124146 context = nullcontext () if self .stream is None else torch_accelerator_module .stream (self .stream )
125147 current_stream = torch_accelerator_module .current_stream () if self .record_stream else None
126148
149+ if self .offload_to_disk_path :
150+ if self .stream is not None :
151+ # Wait for previous Host->Device transfer to complete
152+ self .stream .synchronize ()
153+
154+ with context :
155+ if self .stream is not None :
156+ # Load to CPU, pin, and async copy to device for overlapping transfer and compute
157+ loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
158+ for key , tensor_obj in self .key_to_tensor .items ():
159+ pinned_tensor = loaded_cpu_tensors [key ].pin_memory ()
160+ tensor_obj .data = pinned_tensor .to (self .onload_device , non_blocking = self .non_blocking )
161+ if self .record_stream :
162+ tensor_obj .data .record_stream (current_stream )
163+ else :
164+ # Load directly to the target device (synchronous)
165+ onload_device = (
166+ self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
167+ )
168+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
169+ for key , tensor_obj in self .key_to_tensor .items ():
170+ tensor_obj .data = loaded_tensors [key ]
171+ return
172+
127173 if self .stream is not None :
128174 # Wait for previous Host->Device transfer to complete
129175 self .stream .synchronize ()
@@ -169,6 +215,26 @@ def onload_(self):
169215 @torch .compiler .disable ()
170216 def offload_ (self ):
171217 r"""Offloads the group of modules to the offload_device."""
218+ if self .offload_to_disk_path :
219+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
220+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
221+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
222+ # we perform a write.
223+ # Check if the file has been saved in this session or if it already exists on disk.
224+ if not self ._is_offloaded_to_disk and not os .path .exists (self .safetensors_file_path ):
225+ os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
226+ tensors_to_save = {
227+ key : tensor .data .to (self .offload_device ) for tensor , key in self .tensor_to_key .items ()
228+ }
229+ safetensors .torch .save_file (tensors_to_save , self .safetensors_file_path )
230+
231+ # The group is now considered offloaded to disk for the rest of the session.
232+ self ._is_offloaded_to_disk = True
233+
234+ # We do this to free up the RAM which is still holding the up tensor data.
235+ for tensor_obj in self .tensor_to_key .keys ():
236+ tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
237+ return
172238
173239 torch_accelerator_module = (
174240 getattr (torch , torch .accelerator .current_accelerator ().type )
@@ -205,11 +271,7 @@ class GroupOffloadingHook(ModelHook):
205271
206272 _is_stateful = False
207273
208- def __init__ (
209- self ,
210- group : ModuleGroup ,
211- next_group : Optional [ModuleGroup ] = None ,
212- ) -> None :
274+ def __init__ (self , group : ModuleGroup , next_group : Optional [ModuleGroup ] = None ) -> None :
213275 self .group = group
214276 self .next_group = next_group
215277
@@ -363,6 +425,7 @@ def apply_group_offloading(
363425 use_stream : bool = False ,
364426 record_stream : bool = False ,
365427 low_cpu_mem_usage : bool = False ,
428+ offload_to_disk_path : Optional [str ] = None ,
366429) -> None :
367430 r"""
368431 Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -401,6 +464,9 @@ def apply_group_offloading(
401464 offload_type (`str`, defaults to "block_level"):
402465 The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
403466 "block_level".
467+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
468+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
469+ RAM environment settings where a reasonable speed-memory trade-off is desired.
404470 num_blocks_per_group (`int`, *optional*):
405471 The number of blocks per group when using offload_type="block_level". This is required when using
406472 offload_type="block_level".
@@ -458,6 +524,7 @@ def apply_group_offloading(
458524 num_blocks_per_group = num_blocks_per_group ,
459525 offload_device = offload_device ,
460526 onload_device = onload_device ,
527+ offload_to_disk_path = offload_to_disk_path ,
461528 non_blocking = non_blocking ,
462529 stream = stream ,
463530 record_stream = record_stream ,
@@ -468,6 +535,7 @@ def apply_group_offloading(
468535 module = module ,
469536 offload_device = offload_device ,
470537 onload_device = onload_device ,
538+ offload_to_disk_path = offload_to_disk_path ,
471539 non_blocking = non_blocking ,
472540 stream = stream ,
473541 record_stream = record_stream ,
@@ -486,6 +554,7 @@ def _apply_group_offloading_block_level(
486554 stream : Union [torch .cuda .Stream , torch .Stream , None ] = None ,
487555 record_stream : Optional [bool ] = False ,
488556 low_cpu_mem_usage : bool = False ,
557+ offload_to_disk_path : Optional [str ] = None ,
489558) -> None :
490559 r"""
491560 This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -496,6 +565,9 @@ def _apply_group_offloading_block_level(
496565 The module to which group offloading is applied.
497566 offload_device (`torch.device`):
498567 The device to which the group of modules are offloaded. This should typically be the CPU.
568+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
569+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
570+ RAM environment settings where a reasonable speed-memory trade-off is desired.
499571 onload_device (`torch.device`):
500572 The device to which the group of modules are onloaded.
501573 non_blocking (`bool`):
@@ -535,6 +607,7 @@ def _apply_group_offloading_block_level(
535607 modules = current_modules ,
536608 offload_device = offload_device ,
537609 onload_device = onload_device ,
610+ offload_to_disk_path = offload_to_disk_path ,
538611 offload_leader = current_modules [- 1 ],
539612 onload_leader = current_modules [0 ],
540613 non_blocking = non_blocking ,
@@ -567,6 +640,7 @@ def _apply_group_offloading_block_level(
567640 modules = unmatched_modules ,
568641 offload_device = offload_device ,
569642 onload_device = onload_device ,
643+ offload_to_disk_path = offload_to_disk_path ,
570644 offload_leader = module ,
571645 onload_leader = module ,
572646 parameters = parameters ,
@@ -590,6 +664,7 @@ def _apply_group_offloading_leaf_level(
590664 stream : Union [torch .cuda .Stream , torch .Stream , None ] = None ,
591665 record_stream : Optional [bool ] = False ,
592666 low_cpu_mem_usage : bool = False ,
667+ offload_to_disk_path : Optional [str ] = None ,
593668) -> None :
594669 r"""
595670 This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -604,6 +679,9 @@ def _apply_group_offloading_leaf_level(
604679 The device to which the group of modules are offloaded. This should typically be the CPU.
605680 onload_device (`torch.device`):
606681 The device to which the group of modules are onloaded.
682+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
683+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
684+ RAM environment settings where a reasonable speed-memory trade-off is desired.
607685 non_blocking (`bool`):
608686 If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
609687 and data transfer.
@@ -629,6 +707,7 @@ def _apply_group_offloading_leaf_level(
629707 modules = [submodule ],
630708 offload_device = offload_device ,
631709 onload_device = onload_device ,
710+ offload_to_disk_path = offload_to_disk_path ,
632711 offload_leader = submodule ,
633712 onload_leader = submodule ,
634713 non_blocking = non_blocking ,
@@ -675,6 +754,7 @@ def _apply_group_offloading_leaf_level(
675754 onload_device = onload_device ,
676755 offload_leader = parent_module ,
677756 onload_leader = parent_module ,
757+ offload_to_disk_path = offload_to_disk_path ,
678758 parameters = parameters ,
679759 buffers = buffers ,
680760 non_blocking = non_blocking ,
@@ -693,6 +773,7 @@ def _apply_group_offloading_leaf_level(
693773 modules = [],
694774 offload_device = offload_device ,
695775 onload_device = onload_device ,
776+ offload_to_disk_path = offload_to_disk_path ,
696777 offload_leader = module ,
697778 onload_leader = module ,
698779 parameters = None ,
0 commit comments