@@ -1247,17 +1247,17 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_sve2(simsimd_u32_t const* a, simsimd_u
12471247 *results = c;
12481248}
12491249
1250- SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2 ( //
1251- simsimd_u16_t const * a, simsimd_u16_t const * b, //
1252- simsimd_bf16_t const * a_weights, simsimd_bf16_t const * b_weights, //
1253- simsimd_size_t a_length, simsimd_size_t b_length, //
1250+ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2 ( //
1251+ simsimd_u16_t const * a, simsimd_u16_t const * b, //
1252+ simsimd_i16_t const * a_weights, simsimd_i16_t const * b_weights, //
1253+ simsimd_size_t a_length, simsimd_size_t b_length, //
12541254 simsimd_distance_t * results) {
12551255
12561256 // A single SVE lane is 128 bits wide, so one lane fits 8 values.
12571257 simsimd_size_t const register_size = svcnth ();
12581258 simsimd_size_t const lanes_count = register_size / 8 ;
12591259 simsimd_size_t a_idx = 0 , b_idx = 0 ;
1260- svfloat32_t product_vec = svdupq_n_f32 ( 0 . f , 0 . f , 0 . f , 0 . f );
1260+ svint64_t product_vec = svdupq_n_s64 ( 0 , 0 );
12611261 simsimd_size_t intersection_size = 0 ;
12621262
12631263 while (a_idx < a_length && b_idx < b_length) {
@@ -1303,12 +1303,12 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2( //
13031303 simsimd_u64_t b_step = svcntp_b16 (b_progress, b_mask);
13041304
13051305 // Compare `a_vec` with each lane of `b_vec`
1306- svbfloat16_t a_weights_vec = svld1_bf16 (a_progress, a_weights + a_idx);
1307- svbfloat16_t b_weights_vec = svld1_bf16 (b_progress, b_weights + b_idx);
1306+ svint16_t a_weights_vec = svld1_s16 (a_progress, a_weights + a_idx);
1307+ svint16_t b_weights_vec = svld1_s16 (b_progress, b_weights + b_idx);
13081308 for (simsimd_size_t i = 0 ; i < lanes_count; i++) {
13091309 svbool_t equal_mask = svmatch_u16 (a_progress, a_vec, b_vec);
1310- svbfloat16_t b_equal_weights_vec = svsel_bf16 (equal_mask, b_weights_vec, svdup_n_bf16 (0 .f ));
1311- product_vec = svbfdot_f32 (product_vec, a_weights_vec, b_equal_weights_vec);
1310+ svint16_t b_equal_weights_vec = svsel_s16 (equal_mask, b_weights_vec, svdup_n_s16 (0 .f ));
1311+ product_vec = svdot_s64 (product_vec, a_weights_vec, b_equal_weights_vec);
13121312 b_vec = svext_u16 (b_vec, b_vec, 8 );
13131313 intersection_size += svcntp_b16 (svptrue_b16 (), equal_mask);
13141314 }
@@ -1318,20 +1318,29 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2( //
13181318 b_idx += b_step;
13191319 }
13201320 results[0 ] = (simsimd_distance_t )intersection_size;
1321- results[1 ] = svaddv_f32 ( svptrue_b32 (), product_vec);
1321+ results[1 ] = svaddv_s64 ( svptrue_b64 (), product_vec);
13221322}
13231323
1324- SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2 ( //
1325- simsimd_u16_t const * a, simsimd_u16_t const * b, //
1326- simsimd_i16_t const * a_weights, simsimd_i16_t const * b_weights, //
1327- simsimd_size_t a_length, simsimd_size_t b_length, //
1324+ #pragma clang attribute pop
1325+ #pragma GCC pop_options
1326+ #endif // SIMSIMD_TARGET_SVE2
1327+
1328+ #if SIMSIMD_TARGET_SVE2 && SIMSIMD_TARGET_SVE_BF16
1329+ #pragma GCC push_options
1330+ #pragma GCC target("arch=armv8.6-a+sve+sve2+bf16")
1331+ #pragma clang attribute push(__attribute__((target("arch=armv8.6-a+sve+sve2+bf16"))), apply_to = function)
1332+
1333+ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2 ( //
1334+ simsimd_u16_t const * a, simsimd_u16_t const * b, //
1335+ simsimd_bf16_t const * a_weights, simsimd_bf16_t const * b_weights, //
1336+ simsimd_size_t a_length, simsimd_size_t b_length, //
13281337 simsimd_distance_t * results) {
13291338
13301339 // A single SVE lane is 128 bits wide, so one lane fits 8 values.
13311340 simsimd_size_t const register_size = svcnth ();
13321341 simsimd_size_t const lanes_count = register_size / 8 ;
13331342 simsimd_size_t a_idx = 0 , b_idx = 0 ;
1334- svint64_t product_vec = svdupq_n_s64 ( 0 , 0 );
1343+ svfloat32_t product_vec = svdupq_n_f32 ( 0 . f , 0 . f , 0 . f , 0 . f );
13351344 simsimd_size_t intersection_size = 0 ;
13361345
13371346 while (a_idx < a_length && b_idx < b_length) {
@@ -1377,12 +1386,15 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2( //
13771386 simsimd_u64_t b_step = svcntp_b16 (b_progress, b_mask);
13781387
13791388 // Compare `a_vec` with each lane of `b_vec`
1380- svbfloat16_t a_weights_vec = svld1_s16 (a_progress, a_weights + a_idx);
1381- svbfloat16_t b_weights_vec = svld1_s16 (b_progress, b_weights + b_idx);
1389+ svbfloat16_t a_weights_vec = svld1_bf16 (a_progress, a_weights + a_idx);
1390+ svbfloat16_t b_weights_vec = svld1_bf16 (b_progress, b_weights + b_idx);
13821391 for (simsimd_size_t i = 0 ; i < lanes_count; i++) {
13831392 svbool_t equal_mask = svmatch_u16 (a_progress, a_vec, b_vec);
1384- svbfloat16_t b_equal_weights_vec = svsel_s16 (equal_mask, b_weights_vec, svdup_n_bf16 (0 .f ));
1385- product_vec = svdot_s64 (product_vec, a_weights_vec, b_equal_weights_vec);
1393+ // ! The `svsel_bf16` intrinsic is broken in many compilers, not returning the correct type.
1394+ // ! So we reinterprete floats as integers and apply `svsel_s16`.
1395+ svint16_t b_equal_weights_vec =
1396+ svsel_s16 (equal_mask, svreinterpret_s16_bs16 (b_weights_vec), svdup_n_s16 (0 ));
1397+ product_vec = svbfdot_f32 (product_vec, a_weights_vec, svreinterpret_bf16_s16 (b_equal_weights_vec));
13861398 b_vec = svext_u16 (b_vec, b_vec, 8 );
13871399 intersection_size += svcntp_b16 (svptrue_b16 (), equal_mask);
13881400 }
@@ -1392,12 +1404,12 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2( //
13921404 b_idx += b_step;
13931405 }
13941406 results[0 ] = (simsimd_distance_t )intersection_size;
1395- results[1 ] = svaddv_s64 ( svptrue_b64 (), product_vec);
1407+ results[1 ] = svaddv_f32 ( svptrue_b32 (), product_vec);
13961408}
13971409
13981410#pragma clang attribute pop
13991411#pragma GCC pop_options
1400- #endif // SIMSIMD_TARGET_SVE2
1412+ #endif // SIMSIMD_TARGET_SVE2 && SIMSIMD_TARGET_SVE_BF16
14011413#endif // SIMSIMD_TARGET_ARM
14021414
14031415#ifdef __cplusplus
0 commit comments