@@ -827,6 +827,145 @@ void test_reference_dequantize_per_tensor(
827
827
ASSERT_TRUE (output_correct);
828
828
}
829
829
830
+ void test_vulkan_dequantize_per_tensor_impl (
831
+ const std::vector<int >& input_sizes,
832
+ float scale,
833
+ int zero_point,
834
+ int64_t quant_min,
835
+ int64_t quant_max,
836
+ at::ScalarType dtype,
837
+ at::ScalarType out_dtype,
838
+ const vkcompute::utils::StorageType in_storage,
839
+ const vkcompute::utils::StorageType out_storage) {
840
+ check_dequantize_args (quant_min, quant_max, dtype, out_dtype);
841
+ std::vector<int64_t > input_sizes_int64 (
842
+ input_sizes.begin (), input_sizes.end ());
843
+
844
+ // Create a quantized input tensor with values from quant_min to quant_max
845
+ at::Tensor input;
846
+ if (dtype == at::kByte ) {
847
+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kByte ));
848
+ } else if (dtype == at::kChar ) {
849
+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kChar ));
850
+ } else if (dtype == at::kShort ) {
851
+ input =
852
+ at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kShort ));
853
+ } else if (dtype == at::kInt ) {
854
+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kInt ));
855
+ } else {
856
+ input = at::zeros (input_sizes_int64, at::device (at::kCPU ).dtype (at::kLong ));
857
+ }
858
+
859
+ // Fill with a simple pattern: values from quant_min to quant_max in steps
860
+ float step = 1 .0f ;
861
+ if (input.numel () > 1 ) {
862
+ step = static_cast <float >(quant_max - quant_min) / (input.numel () - 1 );
863
+ }
864
+
865
+ auto flat_input = input.flatten ();
866
+ for (int i = 0 ; i < flat_input.numel (); i++) {
867
+ int64_t qvalue = quant_min + i * step;
868
+ if (dtype == at::kByte ) {
869
+ flat_input[i] = static_cast <uint8_t >(qvalue);
870
+ } else if (dtype == at::kChar ) {
871
+ flat_input[i] = static_cast <int8_t >(qvalue);
872
+ } else if (dtype == at::kShort ) {
873
+ flat_input[i] = static_cast <int16_t >(qvalue);
874
+ } else if (dtype == at::kInt ) {
875
+ flat_input[i] = static_cast <int32_t >(qvalue);
876
+ } else if (dtype == at::kLong ) {
877
+ flat_input[i] = static_cast <int64_t >(qvalue);
878
+ }
879
+ }
880
+
881
+ // Reshape back to original dimensions
882
+ input = flat_input.reshape (input_sizes_int64);
883
+
884
+ // Get reference output
885
+ at::Tensor reference_out =
886
+ torch::executor::native::dequantize_per_tensor_aten (
887
+ input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
888
+
889
+ // Build Vulkan dequantize_per_tensor graph
890
+ using namespace vkcompute ;
891
+
892
+ GraphConfig config;
893
+ config.set_storage_type_override (in_storage);
894
+ ComputeGraph graph (config);
895
+
896
+ IOValueRef r_input = graph.add_input_tensor (
897
+ input.sizes ().vec (), from_at_scalartype (dtype), in_storage);
898
+
899
+ const ValueRef r_scale = graph.add_scalar <double >(scale);
900
+ const ValueRef r_zero_point = graph.add_scalar <int64_t >(zero_point);
901
+ const ValueRef r_quant_min = graph.add_scalar <int64_t >(quant_min);
902
+ const ValueRef r_quant_max = graph.add_scalar <int64_t >(quant_max);
903
+
904
+ const ValueRef r_out = graph.add_tensor (
905
+ input.sizes ().vec (), from_at_scalartype (out_dtype), out_storage);
906
+
907
+ VK_GET_OP_FN (" dequantize_per_tensor.default" )
908
+ (graph,
909
+ {
910
+ r_input.value ,
911
+ r_scale,
912
+ r_zero_point,
913
+ r_quant_min,
914
+ r_quant_max,
915
+ r_out,
916
+ });
917
+
918
+ ValueRef staging_out = graph.set_output_tensor (r_out);
919
+
920
+ graph.prepare ();
921
+
922
+ graph.prepack ();
923
+ graph.encode_execute ();
924
+
925
+ // Run Vulkan dequantize_per_tensor
926
+ graph.copy_into_staging (
927
+ r_input.staging , input.const_data_ptr (), input.numel ());
928
+
929
+ graph.execute ();
930
+
931
+ at::Tensor vk_out = at::empty_like (reference_out).contiguous ();
932
+ graph.copy_from_staging (
933
+ staging_out, vk_out.mutable_data_ptr (), vk_out.numel ());
934
+
935
+ // Compare outputs with appropriate tolerance for half precision
936
+ bool output_correct;
937
+ if (out_dtype == at::kHalf ) {
938
+ // Use higher tolerance for half precision due to limited precision
939
+ output_correct =
940
+ at::allclose (reference_out, vk_out, /* rtol=*/ 1e-2 , /* atol=*/ 1e-2 );
941
+ } else {
942
+ output_correct = at::allclose (reference_out, vk_out);
943
+ }
944
+ if (!output_correct) {
945
+ std::cout << " \n "
946
+ << " Failed with parameters: " << std::endl;
947
+ std::cout << " scale: " << scale << std::endl;
948
+ std::cout << " zero_point: " << zero_point << std::endl;
949
+ std::cout << " quant_min: " << quant_min << std::endl;
950
+ std::cout << " quant_max: " << quant_max << std::endl;
951
+ std::cout << " storage type: "
952
+ << (in_storage == vkcompute::utils::kBuffer ? " buffer"
953
+ : " texture" )
954
+ << std::endl;
955
+ std::cout << " input dtype: " << dtype << std::endl;
956
+ std::cout << " output dtype: " << out_dtype << std::endl;
957
+
958
+ std::cout << " input:" << std::endl;
959
+ std::cout << input << std::endl;
960
+ std::cout << " reference:" << std::endl;
961
+ std::cout << reference_out << std::endl;
962
+ std::cout << " vulkan:" << std::endl;
963
+ std::cout << vk_out << std::endl;
964
+ }
965
+
966
+ ASSERT_TRUE (output_correct);
967
+ }
968
+
830
969
TEST (
831
970
VulkanDequantizePerTensorTest,
832
971
test_reference_dequantize_per_tensor_uint8_to_float) {
@@ -1138,7 +1277,7 @@ void test_vulkan_dequantize_per_token_impl(
1138
1277
ValueRef staging_out = graph.set_output_tensor (r_out);
1139
1278
1140
1279
graph.prepare ();
1141
- graph. encode_prepack ();
1280
+
1142
1281
graph.prepack ();
1143
1282
graph.encode_execute ();
1144
1283
@@ -1670,7 +1809,6 @@ void test_vulkan_dequantize_per_channel_impl(
1670
1809
ValueRef staging_out = graph.set_output_tensor (r_out);
1671
1810
1672
1811
graph.prepare ();
1673
- graph.encode_prepack ();
1674
1812
graph.prepack ();
1675
1813
graph.encode_execute ();
1676
1814
@@ -2345,7 +2483,6 @@ void test_vulkan_dequantize_per_tensor_tensor_impl(
2345
2483
ValueRef staging_out = graph.set_output_tensor (r_out);
2346
2484
2347
2485
graph.prepare ();
2348
- graph.encode_prepack ();
2349
2486
graph.prepack ();
2350
2487
graph.encode_execute ();
2351
2488
0 commit comments