19
19
from typing import List
20
20
21
21
import numpy as np
22
+ from parameterized import parameterized
22
23
from transformers import AutoTokenizer , CLIPTextModel , CLIPTokenizer , T5EncoderModel
23
24
24
25
from diffusers import (
@@ -648,10 +649,17 @@ def test_torch_compile_with_cpu_offload(self):
648
649
super ()._test_torch_compile_with_cpu_offload (quantization_config = self .quantization_config )
649
650
650
651
@unittest .skip (
651
- "Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation "
652
- "is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure."
652
+ """
653
+ For `use_stream=False`:
654
+ - Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
655
+ is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
656
+ For `use_stream=True`:
657
+ Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
658
+ """
653
659
)
660
+ @parameterized .expand ([False , True ])
654
661
def test_torch_compile_with_group_offload_leaf (self ):
662
+ # For use_stream=False:
655
663
# If we run group offloading without compilation, we will see:
656
664
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
657
665
# When running with compilation, the error ends up being different:
@@ -660,14 +668,10 @@ def test_torch_compile_with_group_offload_leaf(self):
660
668
# Looks like something that will have to be looked into upstream.
661
669
# for linear layers, weight.tensor_impl shows cuda... but:
662
670
# weight.tensor_impl.{data,scale,zero_point}.device will be cpu
663
- super ()._test_torch_compile_with_group_offload_leaf (quantization_config = self .quantization_config )
664
671
665
- @unittest .skip (
666
- "Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO."
667
- )
668
- def test_torch_compile_with_group_offload_leaf_stream (self ):
669
- # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
670
- super ()._test_torch_compile_with_group_offload_leaf_stream (quantization_config = self .quantization_config )
672
+ # For use_stream=True:
673
+ # # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
674
+ super ()._test_torch_compile_with_group_offload_leaf (quantization_config = self .quantization_config )
671
675
672
676
673
677
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
0 commit comments