File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -155,8 +155,6 @@ def _pinned_memory_tensors(self):
155155
156156 def _transfer_tensor_to_device (self , tensor , source_tensor ):
157157 tensor .data = source_tensor .to (self .onload_device , non_blocking = self .non_blocking )
158- if self .record_stream :
159- tensor .data .record_stream (self ._torch_accelerator_module .current_stream ())
160158
161159 def _process_tensors_from_modules (self , pinned_memory = None ):
162160 for group_module in self .modules :
@@ -240,6 +238,8 @@ def _offload_to_memory(self):
240238
241239 for group_module in self .modules :
242240 for param in group_module .parameters ():
241+ if self .record_stream and param .device .type == 'cuda' :
242+ param .data .record_stream (self ._torch_accelerator_module .current_stream ())
243243 param .data = self .cpu_param_dict [param ]
244244 for param in self .parameters :
245245 param .data = self .cpu_param_dict [param ]
You can’t perform that action at this time.
0 commit comments