diff --git a/libc-top-half/musl/src/string/strcspn.c b/libc-top-half/musl/src/string/strcspn.c index a0c617bd..ca4b90dc 100644 --- a/libc-top-half/musl/src/string/strcspn.c +++ b/libc-top-half/musl/src/string/strcspn.c @@ -1,3 +1,7 @@ +#if !defined(__wasm_simd128__) || !defined(__wasilibc_simd_string) || \ + __clang_major__ == 19 || __clang_major__ == 20 +// The SIMD implementation is in strspn_simd.c + #include #define BITOP(a,b,op) \ @@ -15,3 +19,5 @@ size_t strcspn(const char *s, const char *c) for (; *s && !BITOP(byteset, *(unsigned char *)s, &); s++); return s-a; } + +#endif diff --git a/libc-top-half/musl/src/string/strspn.c b/libc-top-half/musl/src/string/strspn.c index 9543dad0..64147a2a 100644 --- a/libc-top-half/musl/src/string/strspn.c +++ b/libc-top-half/musl/src/string/strspn.c @@ -1,3 +1,7 @@ +#if !defined(__wasm_simd128__) || !defined(__wasilibc_simd_string) || \ + __clang_major__ == 19 || __clang_major__ == 20 +// The SIMD implementation is in strspn_simd.c + #include #define BITOP(a,b,op) \ @@ -18,3 +22,5 @@ size_t strspn(const char *s, const char *c) for (; *s && BITOP(byteset, *(unsigned char *)s, &); s++); return s-a; } + +#endif diff --git a/libc-top-half/musl/src/string/strspn_simd.c b/libc-top-half/musl/src/string/strspn_simd.c new file mode 100644 index 00000000..3d370e59 --- /dev/null +++ b/libc-top-half/musl/src/string/strspn_simd.c @@ -0,0 +1,180 @@ +#if defined(__wasm_simd128__) && defined(__wasilibc_simd_string) +// Skip Clang 19 and Clang 20 which have a bug (llvm/llvm-project#146574) +// which results in an ICE when inline assembly is used with a vector result. +#if __clang_major__ != 19 && __clang_major__ != 20 + +#include +#include +#include + +#if !defined(__wasm_relaxed_simd__) || !defined(__RELAXED_FN_ATTRS) +#define wasm_i8x16_relaxed_swizzle wasm_i8x16_swizzle +#endif + +// SIMDized check which bytes are in a set (Geoff Langdale) +// http://0x80.pl/notesen/2018-10-18-simd-byte-lookup.html + +// This is the same algorithm as truffle from Hyperscan: +// https://github.com/intel/hyperscan/blob/v5.4.2/src/nfa/truffle.c#L64-L81 +// https://github.com/intel/hyperscan/blob/v5.4.2/src/nfa/trufflecompile.cpp + +typedef struct { + __u8x16 lo; + __u8x16 hi; +} __wasm_v128_bitmap256_t; + +__attribute__((always_inline)) +static void __wasm_v128_setbit(__wasm_v128_bitmap256_t *bitmap, uint8_t i) { + uint8_t hi_nibble = i >> 4; + uint8_t lo_nibble = i & 0xf; + bitmap->lo[lo_nibble] |= (uint8_t)(1u << (hi_nibble - 0)); + bitmap->hi[lo_nibble] |= (uint8_t)(1u << (hi_nibble - 8)); +} + +__attribute__((always_inline)) +static v128_t __wasm_v128_chkbits(__wasm_v128_bitmap256_t bitmap, v128_t v) { + v128_t hi_nibbles = wasm_u8x16_shr(v, 4); + v128_t bitmask_lookup = wasm_u64x2_const_splat(0x8040201008040201); + v128_t bitmask = wasm_i8x16_relaxed_swizzle(bitmask_lookup, hi_nibbles); + + v128_t indices_0_7 = v & wasm_u8x16_const_splat(0x8f); + v128_t indices_8_15 = indices_0_7 ^ wasm_u8x16_const_splat(0x80); + + v128_t row_0_7 = wasm_i8x16_swizzle((v128_t)bitmap.lo, indices_0_7); + v128_t row_8_15 = wasm_i8x16_swizzle((v128_t)bitmap.hi, indices_8_15); + + v128_t bitsets = row_0_7 | row_8_15; + return bitsets & bitmask; +} + +size_t strspn(const char *s, const char *c) +{ + // Note that reading before/after the allocation of a pointer is UB in + // C, so inline assembly is used to generate the exact machine + // instruction we want with opaque semantics to the compiler to avoid + // the UB. + uintptr_t align = (uintptr_t)s % sizeof(v128_t); + uintptr_t addr = (uintptr_t)s - align; + + if (!c[0]) return 0; + if (!c[1]) { + v128_t vc = wasm_i8x16_splat(*c); + for (;;) { + v128_t v; + __asm__( + "local.get %1\n" + "v128.load 0\n" + "local.set %0\n" + : "=r"(v) + : "r"(addr) + : "memory"); + v128_t cmp = wasm_i8x16_eq(v, vc); + // Bitmask is slow on AArch64, all_true is much faster. + if (!wasm_i8x16_all_true(cmp)) { + // Clear the bits corresponding to align (little-endian) + // so we can count trailing zeros. + int mask = (uint16_t)~wasm_i8x16_bitmask(cmp) >> align << align; + // At least one bit will be set, unless align cleared them. + // Knowing this helps the compiler if it unrolls the loop. + __builtin_assume(mask || align); + // If the mask became zero because of align, + // it's as if we didn't find anything. + if (mask) { + // Find the offset of the first one bit (little-endian). + return addr - (uintptr_t)s + __builtin_ctz(mask); + } + } + align = 0; + addr += sizeof(v128_t); + } + } + + __wasm_v128_bitmap256_t bitmap = {}; + + for (; *c; c++) { + // Terminator IS NOT on the bitmap. + __wasm_v128_setbit(&bitmap, (uint8_t)*c); + } + + for (;;) { + v128_t v; + __asm__( + "local.get %1\n" + "v128.load 0\n" + "local.set %0\n" + : "=r"(v) + : "r"(addr) + : "memory"); + v128_t found = __wasm_v128_chkbits(bitmap, v); + // Bitmask is slow on AArch64, all_true is much faster. + if (!wasm_i8x16_all_true(found)) { + v128_t cmp = wasm_i8x16_eq(found, (v128_t){}); + // Clear the bits corresponding to align (little-endian) + // so we can count trailing zeros. + int mask = wasm_i8x16_bitmask(cmp) >> align << align; + // At least one bit will be set, unless align cleared them. + // Knowing this helps the compiler if it unrolls the loop. + __builtin_assume(mask || align); + // If the mask became zero because of align, + // it's as if we didn't find anything. + if (mask) { + // Find the offset of the first one bit (little-endian). + return addr - (uintptr_t)s + __builtin_ctz(mask); + } + } + align = 0; + addr += sizeof(v128_t); + } +} + +size_t strcspn(const char *s, const char *c) +{ + if (!c[0] || !c[1]) return __strchrnul(s, *c) - s; + + // Note that reading before/after the allocation of a pointer is UB in + // C, so inline assembly is used to generate the exact machine + // instruction we want with opaque semantics to the compiler to avoid + // the UB. + uintptr_t align = (uintptr_t)s % sizeof(v128_t); + uintptr_t addr = (uintptr_t)s - align; + + __wasm_v128_bitmap256_t bitmap = {}; + + do { + // Terminator IS on the bitmap. + __wasm_v128_setbit(&bitmap, (uint8_t)*c); + } while (*c++); + + for (;;) { + v128_t v; + __asm__( + "local.get %1\n" + "v128.load 0\n" + "local.set %0\n" + : "=r"(v) + : "r"(addr) + : "memory"); + v128_t found = __wasm_v128_chkbits(bitmap, v); + // Bitmask is slow on AArch64, any_true is much faster. + if (wasm_v128_any_true(found)) { + v128_t cmp = wasm_i8x16_eq(found, (v128_t){}); + // Clear the bits corresponding to align (little-endian) + // so we can count trailing zeros. + int mask = (uint16_t)~wasm_i8x16_bitmask(cmp) >> align << align; + // At least one bit will be set, unless align cleared them. + // Knowing this helps the compiler if it unrolls the loop. + __builtin_assume(mask || align); + // If the mask became zero because of align, + // it's as if we didn't find anything. + if (mask) { + // Find the offset of the first one bit (little-endian). + return addr - (uintptr_t)s + __builtin_ctz(mask); + } + } + align = 0; + addr += sizeof(v128_t); + } +} + +#endif +#endif diff --git a/test/src/misc/strcspn.c b/test/src/misc/strcspn.c new file mode 100644 index 00000000..7897a86c --- /dev/null +++ b/test/src/misc/strcspn.c @@ -0,0 +1,62 @@ +//! add-flags.py(LDFLAGS): -Wl,--stack-first -Wl,--initial-memory=327680 + +#include <__macro_PAGESIZE.h> +#include +#include +#include + +void test(char *ptr, char *set, size_t want) { + size_t got = strcspn(ptr, set); + if (got != want) { + printf("strcspn(%p, \"%s\") = %lu, want %lu\n", ptr, set, got, want); + } +} + +int main(void) { + char *const LIMIT = (char *)(__builtin_wasm_memory_size(0) * PAGESIZE); + + for (ptrdiff_t length = 0; length < 64; length++) { + for (ptrdiff_t alignment = 0; alignment < 24; alignment++) { + for (ptrdiff_t pos = -2; pos < length + 2; pos++) { + // Create a buffer with the given length, at a pointer with the given + // alignment. Using the offset LIMIT - PAGESIZE - 8 means many buffers + // will straddle a (Wasm, and likely OS) page boundary. Place the + // character to find at every position in the buffer, including just + // prior to it and after its end. + char *ptr = LIMIT - PAGESIZE - 8 + alignment; + memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE); + memset(ptr, 5, length); + + // The first instance of the character is found. + if (pos >= 0) ptr[pos + 2] = 7; + ptr[pos] = 7; + ptr[length] = 0; + + // The character is found if it's within range. + ptrdiff_t want = 0 <= pos && pos < length ? pos : length; + test(ptr, "\x07", want); + test(ptr, "\x07\x03", want); + test(ptr, "\x07\x85", want); + test(ptr, "\x87\x85", length); + } + } + + // We need space for the terminator. + if (length == 0) continue; + + // Ensure we never read past the end of memory. + char *ptr = LIMIT - length; + memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE); + memset(ptr, 5, length); + + ptr[length - 1] = 7; + test(ptr, "\x07", length - 1); + test(ptr, "\x07\x03", length - 1); + + ptr[length - 1] = 0; + test(ptr, "\x07", length - 1); + test(ptr, "\x07\x03", length - 1); + } + + return 0; +} diff --git a/test/src/misc/strspn.c b/test/src/misc/strspn.c new file mode 100644 index 00000000..9c3495ae --- /dev/null +++ b/test/src/misc/strspn.c @@ -0,0 +1,62 @@ +//! add-flags.py(LDFLAGS): -Wl,--stack-first -Wl,--initial-memory=327680 + +#include <__macro_PAGESIZE.h> +#include +#include +#include + +void test(char *ptr, char *set, size_t want) { + size_t got = strspn(ptr, set); + if (got != want) { + printf("strspn(%p, \"%s\") = %lu, want %lu\n", ptr, set, got, want); + } +} + +int main(void) { + char *const LIMIT = (char *)(__builtin_wasm_memory_size(0) * PAGESIZE); + + for (ptrdiff_t length = 0; length < 64; length++) { + for (ptrdiff_t alignment = 0; alignment < 24; alignment++) { + for (ptrdiff_t pos = -2; pos < length + 2; pos++) { + // Create a buffer with the given length, at a pointer with the given + // alignment. Using the offset LIMIT - PAGESIZE - 8 means many buffers + // will straddle a (Wasm, and likely OS) page boundary. Place the + // character to find at every position in the buffer, including just + // prior to it and after its end. + char *ptr = LIMIT - PAGESIZE - 8 + alignment; + memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE); + memset(ptr, 5, length); + + // The first instance of the character is found. + if (pos >= 0) ptr[pos + 2] = 7; + ptr[pos] = 7; + ptr[length] = 0; + + // The character is found if it's within range. + ptrdiff_t want = 0 <= pos && pos < length ? pos : length; + test(ptr, "\x05", want); + test(ptr, "\x05\x03", want); + test(ptr, "\x05\x87", want); + test(ptr, "\x05\x07", length); + } + } + + // We need space for the terminator. + if (length == 0) continue; + + // Ensure we never read past the end of memory. + char *ptr = LIMIT - length; + memset(LIMIT - 2 * PAGESIZE, 0, 2 * PAGESIZE); + memset(ptr, 5, length); + + ptr[length - 1] = 7; + test(ptr, "\x05", length - 1); + test(ptr, "\x05\x03", length - 1); + + ptr[length - 1] = 0; + test(ptr, "\x05", length - 1); + test(ptr, "\x05\x03", length - 1); + } + + return 0; +}