Skip to content

Commit b9c31e5

Browse files
Arm backend: Make per-channel quantization default for VgfPipeline (#12705)
Signed-off-by: Yufeng Shi <[email protected]>
1 parent 3ab7063 commit b9c31e5

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

backends/arm/test/ops/test_multihead_attention.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
EthosU85PipelineBI,
1212
TosaPipelineBI,
1313
TosaPipelineMI,
14+
VgfPipeline,
1415
)
1516

1617

@@ -105,3 +106,39 @@ def test_multihead_attention_u85_BI(test_data: input_t1):
105106
per_channel_quantization=False,
106107
)
107108
pipeline.run()
109+
110+
111+
@common.parametrize(
112+
"test_data",
113+
test_suite,
114+
)
115+
@common.SkipIfNoModelConverter
116+
def test_multihead_attention_vgf_FP(test_data: input_t1):
117+
test_data_vals, module = test_data()
118+
pipeline = VgfPipeline[input_t1](
119+
module,
120+
(*test_data_vals, *test_data_vals, *test_data_vals),
121+
[],
122+
[],
123+
tosa_version="TOSA-1.0+FP",
124+
)
125+
pipeline.run()
126+
127+
128+
@common.parametrize(
129+
"test_data",
130+
test_suite,
131+
)
132+
@common.SkipIfNoModelConverter
133+
def test_multihead_attention_vgf_INT(test_data: input_t1):
134+
test_data_vals, module = test_data()
135+
pipeline = VgfPipeline[input_t1](
136+
module,
137+
(*test_data_vals, *test_data_vals, *test_data_vals),
138+
[],
139+
[],
140+
tosa_version="TOSA-1.0+INT",
141+
# TODO: Per-channel quantization is broken (MLETORCH-1144)
142+
per_channel_quantization=False,
143+
)
144+
pipeline.run()

backends/arm/test/tester/test_pipeline.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def __init__(
854854
vgf_compiler_flags: Optional[str] = "",
855855
tosa_version: str = "TOSA-1.0+FP",
856856
symmetric_io_quantization: bool = False,
857-
per_channel_quantization: bool = False,
857+
per_channel_quantization: bool = True,
858858
use_to_edge_transform_and_lower: bool = True,
859859
custom_path: str = None,
860860
atol: float = 1e-03,
@@ -866,11 +866,6 @@ def __init__(
866866
] = None,
867867
):
868868

869-
if (
870-
symmetric_io_quantization or per_channel_quantization
871-
) and tosa_version == "TOSA-1.0+FP":
872-
raise ValueError("Dont configure quantization with FP TOSA profile.")
873-
874869
tosa_profile = TosaSpecification.create_from_string(tosa_version)
875870
compile_spec = common.get_vgf_compile_spec(
876871
tosa_profile, compiler_flags=vgf_compiler_flags, custom_path=custom_path
@@ -887,18 +882,15 @@ def __init__(
887882
transform_passes=transform_passes,
888883
)
889884

890-
if symmetric_io_quantization or per_channel_quantization:
885+
if "INT" in tosa_version:
891886
quantizer = VgfQuantizer(compile_spec)
892887
quantization_config = get_symmetric_quantization_config(
893888
is_per_channel=per_channel_quantization
894889
)
895890
if symmetric_io_quantization:
896891
quantizer.set_io(quantization_config)
897892
quant_stage = Quantize(quantizer, quantization_config)
898-
else:
899-
quant_stage = None
900893

901-
if "INT" in tosa_version:
902894
self.add_stage(self.tester.quantize, quant_stage, pos=0)
903895

904896
self.add_stage_after(

0 commit comments

Comments
 (0)