Skip to content

Commit d62df0e

Browse files
ScalarTensor operation (#12394)
Summary: Adding the scalar tensor operation. Includes unit test. Differential Revision: D75939657
1 parent af07feb commit d62df0e

File tree

6 files changed

+160
-0
lines changed

6 files changed

+160
-0
lines changed

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,7 @@ def register_transfer_ops(features: OpFeatures):
688688
exir_ops.edge.aten.full_like.default,
689689
exir_ops.edge.aten.ones.default,
690690
exir_ops.edge.aten.ones_like.default,
691+
exir_ops.edge.aten.scalar_tensor.default,
691692
exir_ops.edge.aten.upsample_nearest2d.vec,
692693
exir_ops.edge.aten.upsample_bilinear2d.vec,
693694
exir_ops.edge.aten.zeros.default,

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,14 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
273273
return val.toConstTensor().dtype();
274274
} else if (val.isTensorRef()) {
275275
return val.toConstTensorRef().dtype;
276+
} else if (val.isBool()) {
277+
return vkapi::ScalarType::Bool;
278+
} else if (val.isDouble()) {
279+
// We downcast anyway in the shader and we want to avoid having to
280+
// write special cases there.
281+
return vkapi::ScalarType::Float;
282+
} else if (val.isInt()) {
283+
return vkapi::ScalarType::Int;
276284
}
277285
VK_THROW("Could not get dtype of value with type ", val.type());
278286
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define BUF_T ${buffer_scalar_type(DTYPE)}
14+
#define VEC4_T ${texel_type(DTYPE)}
15+
16+
${define_active_storage_type(STORAGE)}
17+
${define_required_extensions(DTYPE)}
18+
${define_required_extensions(SCALAR_VALUE_TYPE)}
19+
20+
#include "indexing_utils.h"
21+
22+
layout(std430) buffer;
23+
24+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
25+
${layout_declare_ubo(B, buffer_scalar_type(SCALAR_VALUE_TYPE), "scalar_value")}
26+
27+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
28+
29+
#ifdef USING_BUFFER
30+
31+
void main() {
32+
const int i = int(gl_GlobalInvocationID.x);
33+
34+
if (i > 0) {
35+
return;
36+
}
37+
38+
t_out[i] = BUF_T(scalar_value);
39+
}
40+
41+
# else // !USING_BUFFER
42+
43+
void main() {
44+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
45+
46+
// Scalar tensor is a special case where the packed dim is always 1.
47+
if (any(greaterThanEqual(pos, ivec3(1)))) {
48+
return;
49+
}
50+
51+
VEC4_T outtex = VEC4_T(scalar_value);
52+
write_texel(t_out, pos, outtex);
53+
}
54+
55+
#endif // !USING_BUFFER
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
scalar_tensor:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
SCALAR_VALUE_TYPE: float
12+
PACKING: C_packed
13+
STORAGE: texture3d
14+
generate_variant_forall:
15+
DTYPE:
16+
- VALUE: half
17+
- VALUE: float
18+
- VALUE: int32
19+
STORAGE:
20+
- VALUE: texture3d
21+
- VALUE: buffer
22+
SCALAR_VALUE_TYPE:
23+
- VALUE: float
24+
- VALUE: int32
25+
- VALUE: bool
26+
shader_variants:
27+
- NAME: scalar_tensor
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
13+
14+
namespace vkcompute {
15+
16+
void scalar_tensor(ComputeGraph& graph, const std::vector<ValueRef>& args) {
17+
// Extract the scalar value from the first argument
18+
ValueRef scalar_in = args[0];
19+
float scalar_value = graph.extract_scalar<float>(scalar_in);
20+
21+
// Get the output tensor reference
22+
ValueRef out = args[args.size() - 1];
23+
24+
std::string kernel_name("scalar_tensor");
25+
kernel_name.reserve(kShaderNameReserve);
26+
27+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
28+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
29+
add_dtype_suffix(kernel_name, graph.dtype_of(scalar_in));
30+
31+
graph.execute_nodes().emplace_back(new DispatchNode(
32+
graph,
33+
VK_KERNEL_FROM_STR(kernel_name),
34+
graph.create_global_wg_size(out),
35+
graph.create_local_wg_size(out),
36+
// Inputs and Outputs
37+
{{out, vkapi::kWrite}},
38+
// Shader params buffers
39+
{graph.create_params_buffer(scalar_value)},
40+
// Push Constants
41+
{},
42+
// Specialization Constants
43+
{},
44+
// Resize Args
45+
{},
46+
// Resizing Logic
47+
nullptr));
48+
}
49+
50+
REGISTER_OPERATORS {
51+
VK_REGISTER_OP(aten.scalar_tensor.default, scalar_tensor);
52+
}
53+
54+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,21 @@ def get_full_inputs():
763763
return test_suite
764764

765765

766+
@register_test_suite("aten.scalar_tensor.default")
767+
def get_scalar_tensor_inputs():
768+
test_suite = VkTestSuite(
769+
[
770+
(42.0,),
771+
(3.14,),
772+
(2.72,),
773+
(0.0,),
774+
(-1.0,),
775+
(100.0,),
776+
]
777+
)
778+
return test_suite
779+
780+
766781
@register_test_suite(
767782
[
768783
"aten.zeros.default",

0 commit comments

Comments
 (0)