Skip to content

Commit cef1ed6

Browse files
committed
Fixes #12673. record_stream is not working properly
1 parent c8656ed commit cef1ed6

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,6 @@ def _pinned_memory_tensors(self):
155155

156156
def _transfer_tensor_to_device(self, tensor, source_tensor):
157157
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
158-
if self.record_stream:
159-
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
160158

161159
def _process_tensors_from_modules(self, pinned_memory=None):
162160
for group_module in self.modules:
@@ -238,12 +236,20 @@ def _offload_to_memory(self):
238236
if not self.record_stream:
239237
self._torch_accelerator_module.current_stream().synchronize()
240238

239+
current_stream = self._torch_accelerator_module.current_stream()
240+
241241
for group_module in self.modules:
242242
for param in group_module.parameters():
243+
if self.record_stream and param.device.type == 'cuda':
244+
param.data.record_stream(current_stream)
243245
param.data = self.cpu_param_dict[param]
244246
for param in self.parameters:
247+
if self.record_stream and param.device.type == 'cuda':
248+
param.data.record_stream(current_stream)
245249
param.data = self.cpu_param_dict[param]
246250
for buffer in self.buffers:
251+
if self.record_stream and buffer.device.type == 'cuda':
252+
buffer.data.record_stream(current_stream)
247253
buffer.data = self.cpu_param_dict[buffer]
248254
else:
249255
for group_module in self.modules:

0 commit comments

Comments
 (0)