Skip to content

Commit fb99d94

Browse files
committed
update
1 parent 8173a29 commit fb99d94

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
219219
return module
220220

221221
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
222-
breakpoint()
223222
# If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
224223
# method is the onload_leader of the group.
225224
if self.group.onload_leader is None:
@@ -286,7 +285,6 @@ def callback():
286285
return module
287286

288287
def post_forward(self, module, output):
289-
breakpoint()
290288
# At this point, for the current modules' submodules, we know the execution order of the layers. We can now
291289
# remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each
292290
# group offloading hook.
@@ -626,9 +624,7 @@ def _apply_group_offloading_leaf_level(
626624
modules_with_group_offloading = set()
627625
for name, submodule in module.named_modules():
628626
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
629-
print("unsupported module", name, type(submodule))
630627
continue
631-
print("applying group offloading to", name, type(submodule))
632628
group = ModuleGroup(
633629
modules=[submodule],
634630
offload_device=offload_device,

tests/quantization/test_torch_compile_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _test_torch_compile_with_group_offload_leaf(self, quantization_config, torch
7676
"use_stream": False,
7777
}
7878
pipe.transformer.enable_group_offload(**group_offload_kwargs)
79-
# pipe.transformer.compile()
79+
pipe.transformer.compile()
8080
for name, component in pipe.components.items():
8181
if name != "transformer" and isinstance(component, torch.nn.Module):
8282
if torch.device(component.device).type == "cpu":

tests/quantization/torchao/test_torchao.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -639,13 +639,19 @@ class TorchAoCompileTest(QuantCompileTests):
639639
def test_torch_compile(self):
640640
super()._test_torch_compile(quantization_config=self.quantization_config)
641641

642+
@unittest.skip(
643+
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work."
644+
)
642645
def test_torch_compile_with_cpu_offload(self):
646+
# RuntimeError: _apply(): Couldn't swap Linear.weight
643647
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
644648

649+
@unittest.skip(
650+
"Changing the device of AQT tensor with .to() does not work. Needs to be discussed with TorchAO team."
651+
)
645652
def test_torch_compile_with_group_offload_leaf(self):
646-
from diffusers.utils.logging import set_verbosity_debug
647-
648-
set_verbosity_debug()
653+
# for linear layers, weight.tensor_impl shows cuda... but:
654+
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
649655
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
650656

651657
@unittest.skip(

0 commit comments

Comments
 (0)