Skip to content

Commit b69d099

Browse files
committed
update
1 parent fb99d94 commit b69d099

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

tests/quantization/torchao/test_torchao.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,16 +640,24 @@ def test_torch_compile(self):
640640
super()._test_torch_compile(quantization_config=self.quantization_config)
641641

642642
@unittest.skip(
643-
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work."
643+
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
644+
"when compiling."
644645
)
645646
def test_torch_compile_with_cpu_offload(self):
646647
# RuntimeError: _apply(): Couldn't swap Linear.weight
647648
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
648649

649650
@unittest.skip(
650-
"Changing the device of AQT tensor with .to() does not work. Needs to be discussed with TorchAO team."
651+
"Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation "
652+
"is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure."
651653
)
652654
def test_torch_compile_with_group_offload_leaf(self):
655+
# If we run group offloading without compilation, we will see:
656+
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
657+
# When running with compilation, the error ends up being different:
658+
# Dynamo failed to run FX node with fake tensors: call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
659+
# requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
660+
# Looks like something that will have to be looked into upstream.
653661
# for linear layers, weight.tensor_impl shows cuda... but:
654662
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
655663
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)

0 commit comments

Comments
 (0)