Skip to content

Commit b4c2c8f

Browse files
authored
[ET-VK][qlinear] Faster weight only quantized linear gemv kernel
Differential Revision: D78275584 Pull Request resolved: #12444
1 parent 4b5b75f commit b4c2c8f

11 files changed

+523
-171
lines changed

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,18 @@
6868
*/
6969
#define mod4(x) ((x) & 3)
7070

71+
#define ALIGN_UP_4(x) (((x) + 3) & ~3)
72+
73+
#define DIV_UP_8(x) (((x) + 7) >> 3)
74+
#define DIV_UP_4(x) (((x) + 3) >> 2)
75+
76+
#define DIV_4(x) ((x) >> 2)
77+
#define DIV_2(x) ((x) >> 1)
78+
79+
#define MUL_8(x) ((x) << 3)
80+
#define MUL_4(x) ((x) << 2)
81+
#define MUL_2(x) ((x) << 1)
82+
7183
/*
7284
* Get the staging buffer indices that contain the data of the texel that
7385
* corresponds to the provided tensor index. Since the texel have 4 elements,

backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl

Lines changed: 105 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -13,187 +13,147 @@
1313
#define T ${buffer_scalar_type(DTYPE)}
1414
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
1515

16-
#define TILE_ROWS ${TILE_ROWS}
17-
18-
#define NGROUPS 8
19-
#define NWORKERS 8
16+
#define WGS ${WGS}
2017

2118
${define_required_extensions(DTYPE)}
22-
$if WEIGHT_STORAGE == "buffer":
23-
${define_required_extensions("uint8")}
19+
${define_required_extensions("uint8")}
2420

2521
#extension GL_EXT_control_flow_attributes : require
22+
#extension GL_EXT_debug_printf : require
2623

2724
layout(std430) buffer;
2825

29-
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
30-
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)}
31-
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
26+
#include "indexing_utils.h"
27+
28+
${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)}
29+
${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)}
30+
${layout_declare_tensor(B, "r", "t_qmat2", "uint", WEIGHT_STORAGE, is_scalar_array=False)}
3231
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)}
3332

3433
layout(push_constant) uniform restrict Block {
35-
ivec4 out_sizes;
36-
ivec4 mat1_sizes;
37-
ivec4 qmat2_sizes;
34+
ivec4 output_sizes;
35+
ivec4 input_sizes;
36+
ivec4 weight_sizes;
3837
};
3938

4039
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4140

4241
layout(constant_id = 3) const int group_size = 64;
4342

44-
shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][2];
43+
shared VEC4_T partial_sums[WGS][2];
44+
45+
$if IO_STORAGE == "buffer":
46+
#define BUFFER_IO
47+
$if WEIGHT_STORAGE == "buffer":
48+
#define BUFFER_WEIGHT
49+
50+
#include "qlinear_utils.glslh"
4551

46-
/*
47-
* This shader computes a linear operator between a floating point input matrix
48-
* x and a weights matrix that is quantized to 4 bits. Please refer to the
49-
* q_4w_linear shader for more details.
50-
*
51-
* This shader implements a co-operative algorithm to compute the output. The
52-
* work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads
53-
* cooperative to compute TILE_ROWS * 2 output texels. Therefore,
54-
* NGROUP * TILE_ROWS * 2 output texels are computed across one work group.
55-
*
56-
* The threads co-operate by each thread computing a partial reduction along the
57-
* K dimension. To illustrate the computation, consider a scalar variant of the
58-
* algorithm that computes the dot product of 2 vectors. Also assume that
59-
* NWORKERS is 8.
60-
*
61-
* Thread 1 in each group will compute:
62-
* (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ...
63-
*
64-
* Thread 2 in each group will compute:
65-
* (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ...
66-
*
67-
* Thread 3 in each group will compute:
68-
* (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ...
69-
*
70-
* The partial accumulations is structured such that memory accesses in each
71-
* loop iteration can be coalesced.
72-
*
73-
* Then, at the end first thread in each group will accumulate the partial
74-
* accumulations computed by each thread to obtain the final result.
75-
*
76-
* Note that this shader assumes that all tensors are width packed.
77-
*/
7852
void main() {
79-
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
80-
// Each thread writes out 2 texels along the width axis, equivalent to 8
81-
// scalar elements. Therefore multiply the thread_idx.x by 8.
82-
const uint out_col = gl_GlobalInvocationID.x << 3;
83-
// Similar reasoning to the above, each thread works on 2 texels along the
84-
// width axis so multiply thread_idx.x by 2.
85-
const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1;
86-
87-
const uint gid = gl_LocalInvocationID.x; // group id
88-
const uint wid = gl_LocalInvocationID.z; // worker id
89-
90-
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
53+
const uint lid = gl_LocalInvocationID.x;
54+
const uint n8 = gl_GlobalInvocationID.y;
55+
// The output tensor will have a shape of [n, 1, 1, 1]. Each thread computes
56+
// 8 output elements, so each thread will write to 8 elements starting at the
57+
// tensor index (gid.x * 8, 0, 0, 0).
58+
const uint n = MUL_8(n8);
59+
const uint K4 = DIV_UP_4(input_sizes.x);
60+
61+
if (n >= output_sizes.x) {
9162
return;
9263
}
9364

94-
const int num_blocks = mat1_sizes.x / group_size;
65+
VEC4_T out_texels[2];
66+
out_texels[0] = VEC4_T(0);
67+
out_texels[1] = VEC4_T(0);
9568

96-
VEC4_T mat1[TILE_ROWS];
97-
VEC4_T qmat2[4][2];
98-
VEC4_T local_sums[TILE_ROWS][2];
69+
// initialize the group index to a value larger than the largest possible
70+
uint cur_group_idx = input_sizes.x;
9971

100-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
101-
local_sums[r][0] = VEC4_T(0);
102-
local_sums[r][1] = VEC4_T(0);
103-
}
72+
// Each thread in the work group accumulates a partial result.
73+
for (uint k4 = lid; k4 < DIV_UP_4(input_sizes.x); k4 += WGS) {
74+
const uint k = MUL_4(k4);
75+
const uint group_idx = k / group_size;
10476

105-
VEC4_T scales[2];
106-
VEC4_T zeros[2];
107-
108-
$if WEIGHT_STORAGE == "buffer":
109-
const int qmat2_stride = qmat2_sizes.x >> 2;
110-
$if PARAMS_STORAGE == "buffer":
111-
const int qparams_y_stride = out_sizes.x >> 2;
112-
const int qparams_z_stride = qparams_y_stride * 2;
113-
114-
for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
115-
$if PARAMS_STORAGE == "buffer":
116-
scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx];
117-
zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride];
118-
119-
scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1];
120-
zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride];
121-
$else:
122-
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
123-
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);
124-
125-
scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
126-
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);
127-
128-
for (uint g_idx = 4 * wid; g_idx < group_size; g_idx += (4 * NWORKERS)) {
129-
const uint k = block_idx * group_size + g_idx;
130-
131-
// Preload B
132-
[[unroll]] for (int r = 0; r < 4; ++r) {
133-
$if WEIGHT_STORAGE == "buffer":
134-
const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x];
135-
$else:
136-
const uvec4 packed_weight_tex = texelFetch(
137-
t_qmat2,
138-
ivec2(gl_GlobalInvocationID.x, k + r),
139-
0);
140-
141-
qmat2[r][0] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0) * scales[0] + zeros[0];
142-
qmat2[r][1] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0) * scales[1] + zeros[1];
143-
}
77+
VEC4_T scales[2];
78+
VEC4_T zeros[2];
14479

145-
// Preload A
146-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
147-
$if IN_STORAGE == "buffer":
148-
mat1[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2];
149-
$else:
150-
mat1[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0);
151-
}
80+
// Only update the scales/zeros if the current iteration is now working on a
81+
// new quantization group.
82+
if (group_idx != cur_group_idx) {
83+
// The qparams tensor contains the quantization scales and zeros, with
84+
// shape [2, N, K / group_size, 1].
85+
// Loading a texel from the qparams tensor will return 2 scales and 2
86+
// zeros for 2 adjacent output channels.
87+
uint qparams_bufi = group_idx * DIV_2(output_sizes.x) + DIV_2(n);
88+
VEC4_T scales_zeros_texels[4];
89+
$for comp in range(4):
90+
scales_zeros_texels[${comp}] = t_qparams[qparams_bufi++];
15291

153-
// Accumulate local output tile
154-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
155-
local_sums[r][0] += mat1[r].x * qmat2[0][0]
156-
+ mat1[r].y * qmat2[1][0]
157-
+ mat1[r].z * qmat2[2][0]
158-
+ mat1[r].w * qmat2[3][0];
159-
160-
local_sums[r][1] += mat1[r].x * qmat2[0][1]
161-
+ mat1[r].y * qmat2[1][1]
162-
+ mat1[r].z * qmat2[2][1]
163-
+ mat1[r].w * qmat2[3][1];
164-
}
92+
scales[0] = VEC4_T(scales_zeros_texels[0].xz, scales_zeros_texels[1].xz);
93+
zeros[0] = VEC4_T(scales_zeros_texels[0].yw, scales_zeros_texels[1].yw);
94+
95+
scales[1] = VEC4_T(scales_zeros_texels[2].xz, scales_zeros_texels[3].xz);
96+
zeros[1] = VEC4_T(scales_zeros_texels[2].yw, scales_zeros_texels[3].yw);
97+
98+
cur_group_idx = group_idx;
16599
}
100+
// The input tensor will have a shape of [K, 1, 1, 1]; in each iteration,
101+
// load 4 elements starting from the tensor index (k, 0, 0, 0).
102+
VEC4_T in_texel = load_input_texel(k4);
103+
// Extract each element of the in_texel into a separate vectorized variable;
104+
// these are used to "broadcast" the input values in subsequent fma calls.
105+
VEC4_T in_texel_val[4];
106+
$for comp in range(4):
107+
in_texel_val[${comp}] = VEC4_T(in_texel[${comp}]);
108+
109+
uvec4 packed_weight_block = load_transposed_weight_block(k4, n8, K4);
110+
111+
VEC4_T weight_texels[2];
112+
$for comp in range(4):
113+
{
114+
weight_texels[0].x = extract_4bit_from_transposed_block(packed_weight_block, 0, ${comp});
115+
weight_texels[0].y = extract_4bit_from_transposed_block(packed_weight_block, 1, ${comp});
116+
weight_texels[0].z = extract_4bit_from_transposed_block(packed_weight_block, 2, ${comp});
117+
weight_texels[0].w = extract_4bit_from_transposed_block(packed_weight_block, 3, ${comp});
118+
119+
weight_texels[1].x = extract_4bit_from_transposed_block(packed_weight_block, 4, ${comp});
120+
weight_texels[1].y = extract_4bit_from_transposed_block(packed_weight_block, 5, ${comp});
121+
weight_texels[1].z = extract_4bit_from_transposed_block(packed_weight_block, 6, ${comp});
122+
weight_texels[1].w = extract_4bit_from_transposed_block(packed_weight_block, 7, ${comp});
123+
124+
weight_texels[0] = fma(weight_texels[0], scales[0], zeros[0]);
125+
weight_texels[1] = fma(weight_texels[1], scales[1], zeros[1]);
126+
127+
out_texels[0] = fma(in_texel_val[${comp}], weight_texels[0], out_texels[0]);
128+
out_texels[1] = fma(in_texel_val[${comp}], weight_texels[1], out_texels[1]);
129+
}
166130
}
167131

168-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
169-
partial_sums[gid][wid][r][0] = local_sums[r][0];
170-
partial_sums[gid][wid][r][1] = local_sums[r][1];
171-
}
132+
partial_sums[lid][0] = out_texels[0];
133+
partial_sums[lid][1] = out_texels[1];
172134

173135
memoryBarrierShared();
174136
barrier();
175137

176-
if (wid != 0) {
177-
return;
138+
// Tree reduction to compute the overall result.
139+
for (int i = WGS / 2; i > 0; i /= 2) {
140+
if (lid < i) {
141+
partial_sums[lid][0] = partial_sums[lid][0] + partial_sums[lid + i][0];
142+
partial_sums[lid][1] = partial_sums[lid][1] + partial_sums[lid + i][1];
143+
}
144+
memoryBarrierShared();
145+
barrier();
178146
}
179147

180-
VEC4_T sums[TILE_ROWS][2];
148+
// Only the first thread will write out result
149+
if (lid == 0) {
150+
out_texels[0] = partial_sums[0][0];
151+
out_texels[1] = partial_sums[0][1];
181152

182-
for (int r = 0; r < TILE_ROWS; ++r) {
183-
sums[r][0] = VEC4_T(0);
184-
sums[r][1] = VEC4_T(0);
185-
[[unroll]] for (int worker = 0; worker < NWORKERS; ++ worker) {
186-
sums[r][0] += partial_sums[gid][worker][r][0];
187-
sums[r][1] += partial_sums[gid][worker][r][1];
153+
uint n4 = DIV_4(n);
154+
write_output_texel(out_texels[0], n4);
155+
if (n + 4 < output_sizes.x) {
156+
write_output_texel(out_texels[1], n4 + 1);
188157
}
189158
}
190-
191-
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
192-
$if OUT_STORAGE == "buffer":
193-
t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = sums[r][0];
194-
t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = sums[r][1];
195-
$else:
196-
imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), sums[r][0]);
197-
imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), sums[r][1]);
198-
}
199159
}

backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,13 @@
77
linear_qga4w_coop:
88
parameter_names_with_default_values:
99
DTYPE: float
10-
OUT_STORAGE: texture3d
11-
IN_STORAGE: texture3d
10+
IO_STORAGE: texture3d
1211
WEIGHT_STORAGE: texture2d
13-
PARAMS_STORAGE: buffer
14-
TILE_ROWS: 1
12+
WGS: 64
1513
shader_variants:
1614
- NAME: linear_qga4w_coop_texture3d_texture3d_texture2d_float
1715
- NAME: linear_qga4w_coop_buffer_buffer_texture2d_float
18-
OUT_STORAGE: buffer
19-
IN_STORAGE: buffer
16+
IO_STORAGE: buffer
2017
- NAME: linear_qga4w_coop_buffer_buffer_buffer_float
21-
OUT_STORAGE: buffer
22-
IN_STORAGE: buffer
18+
IO_STORAGE: buffer
2319
WEIGHT_STORAGE: buffer

backends/vulkan/runtime/graph/ops/glsl/no_op.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ no_op:
1313
- VALUE: half
1414
- VALUE: float
1515
- VALUE: int32
16+
- VALUE: uint32
1617
- VALUE: int8
1718
- VALUE: uint8
1819
STORAGE:

0 commit comments

Comments
 (0)