Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,27 +153,27 @@ def _pinned_memory_tensors(self):
finally:
pinned_dict = None

def _transfer_tensor_to_device(self, tensor, source_tensor):
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
tensor.data.record_stream(default_stream)

def _process_tensors_from_modules(self, pinned_memory=None):
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules:
for param in group_module.parameters():
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source)
self._transfer_tensor_to_device(param, source, default_stream)
for buffer in group_module.buffers():
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source)
self._transfer_tensor_to_device(buffer, source, default_stream)

for param in self.parameters:
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source)
self._transfer_tensor_to_device(param, source, default_stream)

for buffer in self.buffers:
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source)
self._transfer_tensor_to_device(buffer, source, default_stream)

def _onload_from_disk(self):
if self.stream is not None:
Expand Down Expand Up @@ -208,10 +208,12 @@ def _onload_from_memory(self):
self.stream.synchronize()

context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None

with context:
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory)
self._process_tensors_from_modules(pinned_memory, default_stream=default_stream)
else:
self._process_tensors_from_modules(None)

Expand Down
Loading