Skip to content

Commit 3883247

Browse files
committed
[ET-VK][ez] Rename run_prepack() to prepack() and replace encode_prepack() + prepack() with just prepack()
Pull Request resolved: #12443 Title says it all! See below diff for more context on why this new API exists. Differential Revision: [D78275583](https://our.internmc.facebook.com/intern/diff/D78275583/) ghstack-source-id: 296122480
1 parent 16632e5 commit 3883247

11 files changed

+407
-45
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
507507
compute_graph->prepare();
508508
compute_graph->prepare_pipelines();
509509

510-
compute_graph->run_prepack();
510+
compute_graph->prepack();
511511

512512
// If dynamic shapes are not expected, then the command buffer only needs to
513513
// be encoded once. Otherwise, wait until the first inference to encode the

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -768,23 +768,7 @@ void ComputeGraph::submit_current_cmd_and_wait(const bool final_use) {
768768
context_->flush();
769769
}
770770

771-
void ComputeGraph::encode_prepack() {
772-
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
773-
node->encode(this);
774-
}
775-
}
776-
777-
void ComputeGraph::prepack() const {
778-
// Submit and execute the command buffer
779-
vkapi::VulkanFence fence = context_->fences().get_fence();
780-
context_->submit_cmd_to_gpu(fence.get_submit_handle(), /*final_use = */ true);
781-
fence.wait();
782-
context_->fences().return_fence(fence);
783-
784-
context_->flush();
785-
}
786-
787-
void ComputeGraph::run_prepack() {
771+
void ComputeGraph::prepack() {
788772
int i = 0;
789773
bool submitted = false;
790774
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -835,14 +835,11 @@ class ComputeGraph final {
835835
staging_nbytes_in_cmd_ += staging_bytes;
836836
}
837837

838-
void encode_prepack();
839-
void prepack() const;
840-
841838
/*
842839
* Executes prepacking operations to transfer model weight data from the CPU
843840
* to GPU.
844841
*/
845-
void run_prepack();
842+
void prepack();
846843

847844
//
848845
// Graph Execution

backends/vulkan/test/op_tests/choose_qparams_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ void test_vulkan_choose_qparams_tensor_impl(
456456
ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point);
457457

458458
graph.prepare();
459-
graph.encode_prepack();
459+
460460
graph.prepack();
461461
graph.encode_execute();
462462

@@ -676,7 +676,7 @@ void test_vulkan_choose_qparams_per_token_asymmetric_impl(
676676
ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point);
677677

678678
graph.prepare();
679-
graph.encode_prepack();
679+
680680
graph.prepack();
681681
graph.encode_execute();
682682

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 140 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,145 @@ void test_reference_dequantize_per_tensor(
827827
ASSERT_TRUE(output_correct);
828828
}
829829

830+
void test_vulkan_dequantize_per_tensor_impl(
831+
const std::vector<int>& input_sizes,
832+
float scale,
833+
int zero_point,
834+
int64_t quant_min,
835+
int64_t quant_max,
836+
at::ScalarType dtype,
837+
at::ScalarType out_dtype,
838+
const vkcompute::utils::StorageType in_storage,
839+
const vkcompute::utils::StorageType out_storage) {
840+
check_dequantize_args(quant_min, quant_max, dtype, out_dtype);
841+
std::vector<int64_t> input_sizes_int64(
842+
input_sizes.begin(), input_sizes.end());
843+
844+
// Create a quantized input tensor with values from quant_min to quant_max
845+
at::Tensor input;
846+
if (dtype == at::kByte) {
847+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte));
848+
} else if (dtype == at::kChar) {
849+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar));
850+
} else if (dtype == at::kShort) {
851+
input =
852+
at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort));
853+
} else if (dtype == at::kInt) {
854+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt));
855+
} else {
856+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong));
857+
}
858+
859+
// Fill with a simple pattern: values from quant_min to quant_max in steps
860+
float step = 1.0f;
861+
if (input.numel() > 1) {
862+
step = static_cast<float>(quant_max - quant_min) / (input.numel() - 1);
863+
}
864+
865+
auto flat_input = input.flatten();
866+
for (int i = 0; i < flat_input.numel(); i++) {
867+
int64_t qvalue = quant_min + i * step;
868+
if (dtype == at::kByte) {
869+
flat_input[i] = static_cast<uint8_t>(qvalue);
870+
} else if (dtype == at::kChar) {
871+
flat_input[i] = static_cast<int8_t>(qvalue);
872+
} else if (dtype == at::kShort) {
873+
flat_input[i] = static_cast<int16_t>(qvalue);
874+
} else if (dtype == at::kInt) {
875+
flat_input[i] = static_cast<int32_t>(qvalue);
876+
} else if (dtype == at::kLong) {
877+
flat_input[i] = static_cast<int64_t>(qvalue);
878+
}
879+
}
880+
881+
// Reshape back to original dimensions
882+
input = flat_input.reshape(input_sizes_int64);
883+
884+
// Get reference output
885+
at::Tensor reference_out =
886+
torch::executor::native::dequantize_per_tensor_aten(
887+
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
888+
889+
// Build Vulkan dequantize_per_tensor graph
890+
using namespace vkcompute;
891+
892+
GraphConfig config;
893+
config.set_storage_type_override(in_storage);
894+
ComputeGraph graph(config);
895+
896+
IOValueRef r_input = graph.add_input_tensor(
897+
input.sizes().vec(), from_at_scalartype(dtype), in_storage);
898+
899+
const ValueRef r_scale = graph.add_scalar<double>(scale);
900+
const ValueRef r_zero_point = graph.add_scalar<int64_t>(zero_point);
901+
const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min);
902+
const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max);
903+
904+
const ValueRef r_out = graph.add_tensor(
905+
input.sizes().vec(), from_at_scalartype(out_dtype), out_storage);
906+
907+
VK_GET_OP_FN("dequantize_per_tensor.default")
908+
(graph,
909+
{
910+
r_input.value,
911+
r_scale,
912+
r_zero_point,
913+
r_quant_min,
914+
r_quant_max,
915+
r_out,
916+
});
917+
918+
ValueRef staging_out = graph.set_output_tensor(r_out);
919+
920+
graph.prepare();
921+
922+
graph.prepack();
923+
graph.encode_execute();
924+
925+
// Run Vulkan dequantize_per_tensor
926+
graph.copy_into_staging(
927+
r_input.staging, input.const_data_ptr(), input.numel());
928+
929+
graph.execute();
930+
931+
at::Tensor vk_out = at::empty_like(reference_out).contiguous();
932+
graph.copy_from_staging(
933+
staging_out, vk_out.mutable_data_ptr(), vk_out.numel());
934+
935+
// Compare outputs with appropriate tolerance for half precision
936+
bool output_correct;
937+
if (out_dtype == at::kHalf) {
938+
// Use higher tolerance for half precision due to limited precision
939+
output_correct =
940+
at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2);
941+
} else {
942+
output_correct = at::allclose(reference_out, vk_out);
943+
}
944+
if (!output_correct) {
945+
std::cout << "\n"
946+
<< "Failed with parameters: " << std::endl;
947+
std::cout << " scale: " << scale << std::endl;
948+
std::cout << " zero_point: " << zero_point << std::endl;
949+
std::cout << " quant_min: " << quant_min << std::endl;
950+
std::cout << " quant_max: " << quant_max << std::endl;
951+
std::cout << " storage type: "
952+
<< (in_storage == vkcompute::utils::kBuffer ? "buffer"
953+
: "texture")
954+
<< std::endl;
955+
std::cout << " input dtype: " << dtype << std::endl;
956+
std::cout << " output dtype: " << out_dtype << std::endl;
957+
958+
std::cout << "input:" << std::endl;
959+
std::cout << input << std::endl;
960+
std::cout << "reference:" << std::endl;
961+
std::cout << reference_out << std::endl;
962+
std::cout << "vulkan:" << std::endl;
963+
std::cout << vk_out << std::endl;
964+
}
965+
966+
ASSERT_TRUE(output_correct);
967+
}
968+
830969
TEST(
831970
VulkanDequantizePerTensorTest,
832971
test_reference_dequantize_per_tensor_uint8_to_float) {
@@ -1138,7 +1277,7 @@ void test_vulkan_dequantize_per_token_impl(
11381277
ValueRef staging_out = graph.set_output_tensor(r_out);
11391278

11401279
graph.prepare();
1141-
graph.encode_prepack();
1280+
11421281
graph.prepack();
11431282
graph.encode_execute();
11441283

@@ -1670,7 +1809,6 @@ void test_vulkan_dequantize_per_channel_impl(
16701809
ValueRef staging_out = graph.set_output_tensor(r_out);
16711810

16721811
graph.prepare();
1673-
graph.encode_prepack();
16741812
graph.prepack();
16751813
graph.encode_execute();
16761814

@@ -2345,7 +2483,6 @@ void test_vulkan_dequantize_per_tensor_tensor_impl(
23452483
ValueRef staging_out = graph.set_output_tensor(r_out);
23462484

23472485
graph.prepare();
2348-
graph.encode_prepack();
23492486
graph.prepack();
23502487
graph.encode_execute();
23512488

0 commit comments

Comments
 (0)