@@ -40,34 +40,52 @@ void add_transfer_copy_node(
40
40
41
41
int64_t dim_whcn = nchw_dim_to_whcn_dim (dim, ndim);
42
42
43
+ struct TransferParams {
44
+ int32_t dim;
45
+ int32_t index_or_start_ref;
46
+ int32_t step_ref;
47
+ } transfer_params{static_cast <int32_t >(dim_whcn), 0 , 0 };
48
+
49
+ const bool param_is_scalar = graph.is_scalar_or_none (index_or_start_ref) &&
50
+ (transfer_type == TransferType::SELECT ||
51
+ graph.is_scalar_or_none (step_ref));
52
+
43
53
vkapi::ParamsBindList param_buffers;
44
- if (transfer_type == TransferType::SELECT) {
45
- param_buffers = {
46
- graph.get_or_create_int_param_buffer (index_or_start_ref, 0 )};
47
- } else { // TransferType::SLICE
48
- param_buffers = {
49
- graph.get_or_create_int_param_buffer (index_or_start_ref, 0 ),
50
- graph.get_or_create_int_param_buffer (step_ref, 1 )};
54
+ if (!param_is_scalar) {
55
+ if (transfer_type == TransferType::SELECT) {
56
+ param_buffers = {
57
+ graph.get_or_create_int_param_buffer (index_or_start_ref, 0 )};
58
+ } else { // TransferType::SLICE
59
+ param_buffers = {
60
+ graph.get_or_create_int_param_buffer (index_or_start_ref, 0 ),
61
+ graph.get_or_create_int_param_buffer (step_ref, 1 )};
62
+ }
63
+ } else {
64
+ transfer_params.index_or_start_ref =
65
+ graph.extract_scalar_or <int32_t >(index_or_start_ref, 0 );
66
+ if (transfer_type != TransferType::SELECT) {
67
+ transfer_params.step_ref = graph.extract_scalar_or <int32_t >(step_ref, 1 );
68
+ }
51
69
}
52
70
53
- const struct TransferParams {
54
- const int32_t dim;
55
- } transfer_params{static_cast <int32_t >(dim_whcn)};
56
-
57
71
std::vector<PushConstantDataInfo> push_constants;
72
+ push_constants.reserve (graph.is_buffer_storage (out) ? 5 : 3 );
58
73
59
74
if (graph.is_buffer_storage (out)) {
60
- push_constants = {
61
- graph.sizes_pc_of (in),
62
- graph.strides_pc_of (out),
63
- graph.strides_pc_of (in),
64
- graph.numel_pc_of (out),
65
- PushConstantDataInfo (&transfer_params, sizeof (transfer_params))};
75
+ push_constants.emplace_back (graph.sizes_pc_of (in));
76
+ push_constants.emplace_back (graph.strides_pc_of (out));
77
+ push_constants.emplace_back (graph.strides_pc_of (in));
78
+ push_constants.emplace_back (graph.numel_pc_of (out));
66
79
} else {
67
- push_constants = {
68
- graph.sizes_pc_of (out),
69
- graph.sizes_pc_of (in),
70
- PushConstantDataInfo (&transfer_params, sizeof (transfer_params))};
80
+ push_constants.emplace_back (graph.sizes_pc_of (out));
81
+ push_constants.emplace_back (graph.sizes_pc_of (in));
82
+ }
83
+
84
+ if (param_is_scalar) {
85
+ push_constants.emplace_back (&transfer_params, sizeof (transfer_params));
86
+ } else {
87
+ push_constants.emplace_back (
88
+ &transfer_params.dim , sizeof (transfer_params.dim ));
71
89
}
72
90
73
91
vkapi::SpecVarList spec_vars = {
@@ -82,6 +100,9 @@ void add_transfer_copy_node(
82
100
} else { // TransferType::SLICE
83
101
kernel_name = " slice" ;
84
102
}
103
+ if (!param_is_scalar) {
104
+ kernel_name += " _ubo" ;
105
+ }
85
106
add_storage_type_suffix (kernel_name, graph.storage_type_of (out));
86
107
add_dtype_suffix (kernel_name, graph.dtype_of (out));
87
108
0 commit comments