/** * @brief StringZilla is a collection of simple string algorithms, designed to be used in Big Data applications. * It may be slower than LibC, but has a broader & cleaner interface, and a very short implementation * targeting modern x86 CPUs with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization. * * Consider overriding the following macros to customize the library: * * - `SZ_DEBUG=0` - whether to enable debug assertions and logging. * - `SZ_DYNAMIC_DISPATCH=0` - whether to use runtime dispatching of the most advanced SIMD backend. * - `SZ_USE_MISALIGNED_LOADS=0` - whether to use misaligned loads on platforms that support them. * - `SZ_SWAR_THRESHOLD=24` - threshold for switching to SWAR backend over serial byte-level for-loops. * - `SZ_USE_X86_AVX512=?` - whether to use AVX-512 instructions on x86_64. * - `SZ_USE_X86_AVX2=?` - whether to use AVX2 instructions on x86_64. * - `SZ_USE_ARM_NEON=?` - whether to use NEON instructions on ARM. * - `SZ_USE_ARM_SVE=?` - whether to use SVE instructions on ARM. * * @see StringZilla: https://github.com/ashvardanian/StringZilla/blob/main/README.md * @see LibC String: https://pubs.opengroup.org/onlinepubs/009695399/basedefs/string.h.html * * @file stringzilla.h * @author Ash Vardanian */ #ifndef STRINGZILLA_H_ #define STRINGZILLA_H_ #define STRINGZILLA_VERSION_MAJOR 3 #define STRINGZILLA_VERSION_MINOR 12 #define STRINGZILLA_VERSION_PATCH 6 /** * @brief When set to 1, the library will include the following LibC headers: and . * In debug builds (SZ_DEBUG=1), the library will also include and . * * You may want to disable this compiling for use in the kernel, or in embedded systems. * You may also avoid them, if you are very sensitive to compilation time and avoid pre-compiled headers. * https://artificial-mind.net/projects/compile-health/ */ #ifndef SZ_AVOID_LIBC #define SZ_AVOID_LIBC (0) // true or false #endif /** * @brief A misaligned load can be - trying to fetch eight consecutive bytes from an address * that is not divisible by eight. On x86 enabled by default. On ARM it's not. * * Most platforms support it, but there is no industry standard way to check for those. * This value will mostly affect the performance of the serial (SWAR) backend. */ #ifndef SZ_USE_MISALIGNED_LOADS #if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86) #define SZ_USE_MISALIGNED_LOADS (1) // true or false #else #define SZ_USE_MISALIGNED_LOADS (0) // true or false #endif #endif /** * @brief Removes compile-time dispatching, and replaces it with runtime dispatching. * So the `sz_find` function will invoke the most advanced backend supported by the CPU, * that runs the program, rather than the most advanced backend supported by the CPU * used to compile the library or the downstream application. */ #ifndef SZ_DYNAMIC_DISPATCH #define SZ_DYNAMIC_DISPATCH (0) // true or false #endif /** * @brief Analogous to `size_t` and `std::size_t`, unsigned integer, identical to pointer size. * 64-bit on most platforms where pointers are 64-bit. * 32-bit on platforms where pointers are 32-bit. */ #if defined(__LP64__) || defined(_LP64) || defined(__x86_64__) || defined(_WIN64) #define SZ_DETECT_64_BIT (1) #define SZ_SIZE_MAX (0xFFFFFFFFFFFFFFFFull) // Largest unsigned integer that fits into 64 bits. #define SZ_SSIZE_MAX (0x7FFFFFFFFFFFFFFFull) // Largest signed integer that fits into 64 bits. #else #define SZ_DETECT_64_BIT (0) #define SZ_SIZE_MAX (0xFFFFFFFFu) // Largest unsigned integer that fits into 32 bits. #define SZ_SSIZE_MAX (0x7FFFFFFFu) // Largest signed integer that fits into 32 bits. #endif /** * @brief On Big-Endian machines StringZilla will work in compatibility mode. * This disables SWAR hacks to minimize code duplication, assuming practically * all modern popular platforms are Little-Endian. * * This variable is hard to infer from macros reliably. It's best to set it manually. * For that CMake provides the `TestBigEndian` and `CMAKE__BYTE_ORDER` (from 3.20 onwards). * In Python one can check `sys.byteorder == 'big'` in the `setup.py` script and pass the appropriate macro. * https://stackoverflow.com/a/27054190 */ #ifndef SZ_DETECT_BIG_ENDIAN #if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || defined(__BIG_ENDIAN__) || defined(__ARMEB__) || \ defined(__THUMBEB__) || defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__) #define SZ_DETECT_BIG_ENDIAN (1) //< It's a big-endian target architecture #else #define SZ_DETECT_BIG_ENDIAN (0) //< It's a little-endian target architecture #endif #endif /* * Debugging and testing. */ #ifndef SZ_DEBUG #if defined(DEBUG) || defined(_DEBUG) // This means "Not using DEBUG information". #define SZ_DEBUG (1) #else #define SZ_DEBUG (0) #endif #endif /** * @brief Threshold for switching to SWAR (8-bytes at a time) backend over serial byte-level for-loops. * On very short strings, under 16 bytes long, at most a single word will be processed with SWAR. * Assuming potentially misaligned loads, SWAR makes sense only after ~24 bytes. */ #ifndef SZ_SWAR_THRESHOLD #if SZ_DEBUG #define SZ_SWAR_THRESHOLD (8u) // 8 bytes in debug builds #else #define SZ_SWAR_THRESHOLD (24u) // 24 bytes in release builds #endif #endif /* Annotation for the public API symbols: * * - `SZ_PUBLIC` is used for functions that are part of the public API. * - `SZ_INTERNAL` is used for internal helper functions with unstable APIs. * - `SZ_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime. */ #ifndef SZ_DYNAMIC #if SZ_DYNAMIC_DISPATCH #if defined(_WIN32) || defined(__CYGWIN__) #define SZ_DYNAMIC __declspec(dllexport) #define SZ_EXTERNAL __declspec(dllimport) #define SZ_PUBLIC inline static #define SZ_INTERNAL inline static #else #define SZ_DYNAMIC __attribute__((visibility("default"))) #define SZ_EXTERNAL extern #define SZ_PUBLIC __attribute__((unused)) inline static #define SZ_INTERNAL __attribute__((always_inline)) inline static #endif // _WIN32 || __CYGWIN__ #else #define SZ_DYNAMIC inline static #define SZ_EXTERNAL extern #define SZ_PUBLIC inline static #define SZ_INTERNAL inline static #endif // SZ_DYNAMIC_DISPATCH #endif // SZ_DYNAMIC /** * @brief Alignment macro for 64-byte alignment. */ #if defined(_MSC_VER) #define SZ_ALIGN64 __declspec(align(64)) #elif defined(__GNUC__) || defined(__clang__) #define SZ_ALIGN64 __attribute__((aligned(64))) #else #define SZ_ALIGN64 #endif #ifdef __cplusplus extern "C" { #endif /* * Let's infer the integer types or pull them from LibC, * if that is allowed by the user. */ #if !SZ_AVOID_LIBC #include // `size_t` #include // `uint8_t` typedef int8_t sz_i8_t; // Always 8 bits typedef uint8_t sz_u8_t; // Always 8 bits typedef uint16_t sz_u16_t; // Always 16 bits typedef int32_t sz_i32_t; // Always 32 bits typedef uint32_t sz_u32_t; // Always 32 bits typedef uint64_t sz_u64_t; // Always 64 bits typedef int64_t sz_i64_t; // Always 64 bits typedef size_t sz_size_t; // Pointer-sized unsigned integer, 32 or 64 bits typedef ptrdiff_t sz_ssize_t; // Signed version of `sz_size_t`, 32 or 64 bits #else // if SZ_AVOID_LIBC: // ! The C standard doesn't specify the signedness of char. // ! On x86 char is signed by default while on Arm it is unsigned by default. // ! That's why we don't define `sz_char_t` and generally use explicit `sz_i8_t` and `sz_u8_t`. typedef signed char sz_i8_t; // Always 8 bits typedef unsigned char sz_u8_t; // Always 8 bits typedef unsigned short sz_u16_t; // Always 16 bits typedef int sz_i32_t; // Always 32 bits typedef unsigned int sz_u32_t; // Always 32 bits typedef long long sz_i64_t; // Always 64 bits typedef unsigned long long sz_u64_t; // Always 64 bits // Now we need to redefine the `size_t`. // Microsoft Visual C++ (MSVC) typically follows LLP64 data model on 64-bit platforms, // where integers, pointers, and long types have different sizes: // // > `int` is 32 bits // > `long` is 32 bits // > `long long` is 64 bits // > pointer (thus, `size_t`) is 64 bits // // In contrast, GCC and Clang on 64-bit Unix-like systems typically follow the LP64 model, where: // // > `int` is 32 bits // > `long` and pointer (thus, `size_t`) are 64 bits // > `long long` is also 64 bits // // Source: https://learn.microsoft.com/en-us/windows/win32/winprog64/abstract-data-models #if SZ_DETECT_64_BIT typedef unsigned long long sz_size_t; // 64-bit. typedef long long sz_ssize_t; // 64-bit. #else typedef unsigned sz_size_t; // 32-bit. typedef unsigned sz_ssize_t; // 32-bit. #endif // SZ_DETECT_64_BIT #endif // SZ_AVOID_LIBC /** * @brief Compile-time assert macro similar to `static_assert` in C++. */ #define sz_static_assert(condition, name) \ typedef struct { \ int static_assert_##name : (condition) ? 1 : -1; \ } sz_static_assert_##name##_t sz_static_assert(sizeof(sz_size_t) == sizeof(void *), sz_size_t_must_be_pointer_size); sz_static_assert(sizeof(sz_ssize_t) == sizeof(void *), sz_ssize_t_must_be_pointer_size); #pragma region Public API typedef char *sz_ptr_t; // A type alias for `char *` typedef char const *sz_cptr_t; // A type alias for `char const *` typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions typedef sz_u64_t sz_sorted_idx_t; // Index of a sorted string in a list of strings typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> /** * @brief Tiny string-view structure. It's POD type, unlike the `std::string_view`. */ typedef struct sz_string_view_t { sz_cptr_t start; sz_size_t length; } sz_string_view_t; /** * @brief Enumeration of SIMD capabilities of the target architecture. * Used to introspect the supported functionality of the dynamic library. */ typedef enum sz_capability_t { sz_cap_serial_k = 1, /// Serial (non-SIMD) capability sz_cap_any_k = 0x7FFFFFFF, /// Mask representing any capability sz_cap_arm_neon_k = 1 << 10, /// ARM NEON capability sz_cap_arm_sve_k = 1 << 11, /// ARM SVE capability TODO: Not yet supported or used sz_cap_x86_avx2_k = 1 << 20, /// x86 AVX2 capability sz_cap_x86_avx512f_k = 1 << 21, /// x86 AVX512 F capability sz_cap_x86_avx512bw_k = 1 << 22, /// x86 AVX512 BW instruction capability sz_cap_x86_avx512vl_k = 1 << 23, /// x86 AVX512 VL instruction capability sz_cap_x86_avx512vbmi_k = 1 << 24, /// x86 AVX512 VBMI instruction capability sz_cap_x86_gfni_k = 1 << 25, /// x86 AVX512 GFNI instruction capability sz_cap_x86_avx512vbmi2_k = 1 << 26, /// x86 AVX512 VBMI 2 instruction capability } sz_capability_t; /** * @brief Function to determine the SIMD capabilities of the current machine @b only at @b runtime. * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. */ SZ_DYNAMIC sz_capability_t sz_capabilities(void); /** * @brief Bit-set structure for 256 possible byte values. Useful for filtering and search. * @see sz_charset_init, sz_charset_add, sz_charset_contains, sz_charset_invert */ typedef union sz_charset_t { sz_u64_t _u64s[4]; sz_u32_t _u32s[8]; sz_u16_t _u16s[16]; sz_u8_t _u8s[32]; } sz_charset_t; /** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ SZ_PUBLIC void sz_charset_init(sz_charset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } /** @brief Adds a character to the set and accepts @b unsigned integers. */ SZ_PUBLIC void sz_charset_add_u8(sz_charset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } /** @brief Adds a character to the set. Consider @b sz_charset_add_u8. */ SZ_PUBLIC void sz_charset_add(sz_charset_t *s, char c) { sz_charset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast /** @brief Checks if the set contains a given character and accepts @b unsigned integers. */ SZ_PUBLIC sz_bool_t sz_charset_contains_u8(sz_charset_t const *s, sz_u8_t c) { // Checking the bit can be done in different ways: // - (s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0 // - (s->_u32s[c >> 5] & (1u << (c & 31u))) != 0 // - (s->_u16s[c >> 4] & (1u << (c & 15u))) != 0 // - (s->_u8s[c >> 3] & (1u << (c & 7u))) != 0 return (sz_bool_t)((s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0); } /** @brief Checks if the set contains a given character. Consider @b sz_charset_contains_u8. */ SZ_PUBLIC sz_bool_t sz_charset_contains(sz_charset_t const *s, char c) { return sz_charset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast } /** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */ SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) { s->_u64s[0] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[1] ^= 0xFFFFFFFFFFFFFFFFull, // s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull; } typedef void *(*sz_memory_allocate_t)(sz_size_t, void *); typedef void (*sz_memory_free_t)(void *, sz_size_t, void *); typedef sz_u64_t (*sz_random_generator_t)(void *); /** * @brief Some complex pattern matching algorithms may require memory allocations. * This structure is used to pass the memory allocator to those functions. * @see sz_memory_allocator_init_fixed */ typedef struct sz_memory_allocator_t { sz_memory_allocate_t allocate; sz_memory_free_t free; void *handle; } sz_memory_allocator_t; /** * @brief Initializes a memory allocator to use the system default `malloc` and `free`. * ! The function is not available if the library was compiled with `SZ_AVOID_LIBC`. * * @param alloc Memory allocator to initialize. */ SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc); /** * @brief Initializes a memory allocator to use a static-capacity buffer. * No dynamic allocations will be performed. * * @param alloc Memory allocator to initialize. * @param buffer Buffer to use for allocations. * @param length Length of the buffer. @b Must be greater than 8 bytes. Different values would be optimal for * different algorithms and input lengths, but 4096 bytes (one RAM page) is a good default. */ SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length); /** * @brief The number of bytes a stack-allocated string can hold, including the SZ_NULL termination character. * ! This can't be changed from outside. Don't use the `#error` as it may already be included and set. */ #ifdef SZ_STRING_INTERNAL_SPACE #undef SZ_STRING_INTERNAL_SPACE #endif #define SZ_STRING_INTERNAL_SPACE (sizeof(sz_size_t) * 3 - 1) // 3 pointers minus one byte for an 8-bit length /** * @brief Tiny memory-owning string structure with a Small String Optimization (SSO). * Differs in layout from Folly, Clang, GCC, and probably most other implementations. * It's designed to avoid any branches on read-only operations, and can store up * to 22 characters on stack on 64-bit machines, followed by the SZ_NULL-termination character. * * @section Changing Length * * One nice thing about this design, is that you can, in many cases, change the length of the string * without any branches, invoking a `+=` or `-=` on the 64-bit `length` field. If the string is on heap, * the solution is obvious. If it's on stack, inplace decrement wouldn't affect the top bytes of the string, * only changing the last byte containing the length. */ typedef union sz_string_t { #if !SZ_DETECT_BIG_ENDIAN struct external { sz_ptr_t start; sz_size_t length; sz_size_t space; sz_size_t padding; } external; struct internal { sz_ptr_t start; sz_u8_t length; char chars[SZ_STRING_INTERNAL_SPACE]; } internal; #else struct external { sz_ptr_t start; sz_size_t space; sz_size_t padding; sz_size_t length; } external; struct internal { sz_ptr_t start; char chars[SZ_STRING_INTERNAL_SPACE]; sz_u8_t length; } internal; #endif sz_size_t words[4]; } sz_string_t; typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t); typedef sz_u64_t (*sz_checksum_t)(sz_cptr_t, sz_size_t); typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t); /** * @brief Computes the 64-bit check-sum of bytes in a string. * Similar to `std::ranges::accumulate`. * * @param text String to aggregate. * @param length Number of bytes in the text. * @return 64-bit unsigned value. */ SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length); /** @copydoc sz_checksum */ SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length); /** * @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings, * simple implementation, and supports rolling computation, reused in other APIs. * Similar to `std::hash` in C++. * * @param text String to hash. * @param length Number of bytes in the text. * @return 64-bit hash value. * * @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection */ SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length); /** @copydoc sz_hash */ SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); /** * @brief Checks if two string are equal. * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. * * The implementation of this function is very similar to `sz_order`, but the usage patterns are different. * This function is more often used in parsing, while `sz_order` is often used in sorting. * It works best on platforms with cheap * * @param a First string to compare. * @param b Second string to compare. * @param length Number of bytes in both strings. * @return 1 if strings match, 0 otherwise. */ SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); /** @copydoc sz_equal */ SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length); /** * @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC. * Can be used on different length strings. * * @param a First string to compare. * @param a_length Number of bytes in the first string. * @param b Second string to compare. * @param b_length Number of bytes in the second string. * @return Negative if (a < b), positive if (a > b), zero if they are equal. */ SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); /** @copydoc sz_order */ SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); /** * @brief Look Up Table @b (LUT) transformation of a string. Equivalent to `for (char & c : text) c = lut[c]`. * * Can be used to implement some form of string normalization, partially masking punctuation marks, * or converting between different character sets, like uppercase or lowercase. Surprisingly, also has * broad implications in image processing, where image channel transformations are often done using LUTs. * * @param text String to be normalized. * @param length Number of bytes in the string. * @param lut Look Up Table to apply. Must be exactly @b 256 bytes long. * @param result Output string, can point to the same address as ::text. */ SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t); /** @copydoc sz_look_up_transform */ SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result); /** * @brief Equivalent to `for (char & c : text) c = tolower(c)`. * * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. * So there are 26 english letters, shifted by 32 values, meaning that a conversion * can be done by flipping the 5th bit each inappropriate character byte. This, however, * breaks for extended ASCII, so a different solution is needed. * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html * * @param text String to be normalized. * @param length Number of bytes in the string. * @param result Output string, can point to the same address as ::text. */ SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); /** * @brief Equivalent to `for (char & c : text) c = toupper(c)`. * * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. * So there are 26 english letters, shifted by 32 values, meaning that a conversion * can be done by flipping the 5th bit each inappropriate character byte. This, however, * breaks for extended ASCII, so a different solution is needed. * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html * * @param text String to be normalized. * @param length Number of bytes in the string. * @param result Output string, can point to the same address as ::text. */ SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); /** * @brief Equivalent to `for (char & c : text) c = toascii(c)`. * * @param text String to be normalized. * @param length Number of bytes in the string. * @param result Output string, can point to the same address as ::text. */ SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); /** * @brief Checks if all characters in the range are valid ASCII characters. * * @param text String to be analyzed. * @param length Number of bytes in the string. * @return Whether all characters are valid ASCII characters. */ SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length); /** * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. * Similar to `text[i] = alphabet[rand() % cardinality]`. * * The modulo operation is expensive, and should be avoided in performance-critical code. * We avoid it using small lookup tables and replacing it with a multiplication and shifts, similar to `libdivide`. * Alternative algorithms would include: * - Montgomery form: https://en.algorithmica.org/hpc/number-theory/montgomery/ * - Barret reduction: https://www.nayuki.io/page/barrett-reduction-algorithm * - Lemire's trick: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ * * @param alphabet Set of characters to sample from. * @param cardinality Number of characters to sample from. * @param text Output string, can point to the same address as ::text. * @param generate Callback producing random numbers given the generator state. * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. */ SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, sz_random_generator_t generate, void *generator); /** @copydoc sz_generate */ SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, sz_random_generator_t generate, void *generator); /** * @brief Similar to `memcpy`, copies contents of one string into another. * The behavior is undefined if the strings overlap. * * @param target String to copy into. * @param length Number of bytes to copy. * @param source String to copy from. */ SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_copy */ SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** * @brief Similar to `memmove`, copies (moves) contents of one string into another. * Unlike `sz_copy`, allows overlapping strings as arguments. * * @param target String to copy into. * @param length Number of bytes to copy. * @param source String to copy from. */ SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_move */ SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); /** * @brief Similar to `memset`, fills a string with a given value. * * @param target String to fill. * @param length Number of bytes to fill. * @param value Value to fill with. */ SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value); /** @copydoc sz_fill */ SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); /** * @brief Initializes a string class instance to an empty value. */ SZ_PUBLIC void sz_string_init(sz_string_t *string); /** * @brief Convenience function checking if the provided string is stored inside of the ::string instance itself, * alternative being - allocated in a remote region of the heap. */ SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string); /** * @brief Unpacks the opaque instance of a string class into its components. * Recommended to use only in read-only operations. * * @param string String to unpack. * @param start Pointer to the start of the string. * @param length Number of bytes in the string, before the SZ_NULL character. * @param space Number of bytes allocated for the string (heap or stack), including the SZ_NULL character. * @param is_external Whether the string is allocated on the heap externally, or fits withing ::string instance. */ SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, sz_bool_t *is_external); /** * @brief Unpacks only the start and length of the string. * Recommended to use only in read-only operations. * * @param string String to unpack. * @param start Pointer to the start of the string. * @param length Number of bytes in the string, before the SZ_NULL character. */ SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length); /** * @brief Constructs a string of a given ::length with noisy contents. * Use the returned character pointer to populate the string. * * @param string String to initialize. * @param length Number of bytes in the string, before the SZ_NULL character. * @param allocator Memory allocator to use for the allocation. * @return SZ_NULL if the operation failed, pointer to the start of the string otherwise. */ SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator); /** * @brief Doesn't change the contents or the length of the string, but grows the available memory capacity. * This is beneficial, if several insertions are expected, and we want to minimize allocations. * * @param string String to grow. * @param new_capacity The number of characters to reserve space for, including existing ones. * @param allocator Memory allocator to use for the allocation. * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. */ SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator); /** * @brief Grows the string by adding an uninitialized region of ::added_length at the given ::offset. * Would often be used in conjunction with one or more `sz_copy` calls to populate the allocated region. * Similar to `sz_string_reserve`, but changes the length of the ::string. * * @param string String to grow. * @param offset Offset of the first byte to reserve space for. * If provided offset is larger than the length, it will be capped. * @param added_length The number of new characters to reserve space for. * @param allocator Memory allocator to use for the allocation. * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. */ SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, sz_memory_allocator_t *allocator); /** * @brief Removes a range from a string. Changes the length, but not the capacity. * Performs no allocations or deallocations and can't fail. * * @param string String to clean. * @param offset Offset of the first byte to remove. * @param length Number of bytes to remove. Out-of-bound ranges will be capped. * @return Number of bytes removed. */ SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length); /** * @brief Shrinks the string to fit the current length, if it's allocated on the heap. * It's the reverse operation of ::sz_string_reserve. * * @param string String to shrink. * @param allocator Memory allocator to use for the allocation. * @return Whether the operation was successful. The only failures can come from the allocator. * On failure, the string will remain unchanged. */ SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator); /** * @brief Frees the string, if it's allocated on the heap. * If the string is on the stack, the function clears/resets the state. */ SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator); #pragma endregion #pragma region Fast Substring Search API typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); /** * @brief Locates first matching byte in a string. Equivalent to `memchr(haystack, *needle, h_length)` in LibC. * * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memchr.S * Aarch64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/aarch64/memchr.S * * @param haystack Haystack - the string to search in. * @param h_length Number of bytes in the haystack. * @param needle Needle - single-byte substring to find. * @return Address of the first match. */ SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_find_byte */ SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** * @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC. * * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memrchr.S * Aarch64 implementation: missing * * @param haystack Haystack - the string to search in. * @param h_length Number of bytes in the haystack. * @param needle Needle - single-byte substring to find. * @return Address of the last match. */ SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_rfind_byte */ SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** * @brief Locates first matching substring. * Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. * Similar to `strstr(haystack, needle)` in LibC, but requires known length. * * @param haystack Haystack - the string to search in. * @param h_length Number of bytes in the haystack. * @param needle Needle - substring to find. * @param n_length Number of bytes in the needle. * @return Address of the first match. */ SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_find */ SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** * @brief Locates the last matching substring. * * @param haystack Haystack - the string to search in. * @param h_length Number of bytes in the haystack. * @param needle Needle - substring to find. * @param n_length Number of bytes in the needle. * @return Address of the last match. */ SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_rfind */ SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** * @brief Finds the first character present from the ::set, present in ::text. * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. * May have identical implementation and performance to ::sz_rfind_charset. * * Useful for parsing, when we want to skip a set of characters. Examples: * * 6 whitespaces: " \t\n\r\v\f". * * 16 digits forming a float number: "0123456789,.eE+-". * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. * * 2 JSON string special characters useful to locate the end of the string: "\"\\". * * @param text String to be scanned. * @param set Set of relevant characters. * @return Pointer to the first matching character from ::set. */ SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); /** @copydoc sz_find_charset */ SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); /** * @brief Finds the last character present from the ::set, present in ::text. * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. * May have identical implementation and performance to ::sz_find_charset. * * Useful for parsing, when we want to skip a set of characters. Examples: * * 6 whitespaces: " \t\n\r\v\f". * * 16 digits forming a float number: "0123456789,.eE+-". * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. * * 2 JSON string special characters useful to locate the end of the string: "\"\\". * * @param text String to be scanned. * @param set Set of relevant characters. * @return Pointer to the last matching character from ::set. */ SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); /** @copydoc sz_rfind_charset */ SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); #pragma endregion #pragma region String Similarity Measures API /** * @brief Computes the Hamming distance between two strings - number of not matching characters. * Difference in length is is counted as a mismatch. * * @param a First string to compare. * @param a_length Number of bytes in the first string. * @param b Second string to compare. * @param b_length Number of bytes in the second string. * * @param bound Upper bound on the distance, that allows us to exit early. * If zero is passed, the maximum possible distance will be equal to the length of the longer input. * @return Unsigned integer for the distance, the `bound` if was exceeded. * * @see sz_hamming_distance_utf8 * @see https://en.wikipedia.org/wiki/Hamming_distance */ SZ_DYNAMIC sz_size_t sz_hamming_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); /** @copydoc sz_hamming_distance */ SZ_PUBLIC sz_size_t sz_hamming_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); /** * @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters. * Difference in length is is counted as a mismatch. * * @param a First string to compare. * @param a_length Number of bytes in the first string. * @param b Second string to compare. * @param b_length Number of bytes in the second string. * * @param bound Upper bound on the distance, that allows us to exit early. * If zero is passed, the maximum possible distance will be equal to the length of the longer input. * @return Unsigned integer for the distance, the `bound` if was exceeded. * * @see sz_hamming_distance * @see https://en.wikipedia.org/wiki/Hamming_distance */ SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); /** @copydoc sz_hamming_distance_utf8 */ SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_size_t bound); typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t); /** * @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. * * @param a First string to compare. * @param a_length Number of bytes in the first string. * @param b Second string to compare. * @param b_length Number of bytes in the second string. * * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, * so the memory usage is linear in relation to ::a_length and ::b_length. * If SZ_NULL is passed, will initialize to the systems default `malloc`. * @param bound Upper bound on the distance, that allows us to exit early. * If zero is passed, the maximum possible distance will be equal to the length of the longer input. * @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX` * if the memory allocation failed. * * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default * @see https://en.wikipedia.org/wiki/Levenshtein_distance */ SZ_DYNAMIC sz_size_t sz_edit_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc); /** @copydoc sz_edit_distance */ SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc); /** * @brief Computes the Levenshtein edit-distance between two @b UTF8 strings. * Unlike `sz_edit_distance`, reports the distance in Unicode codepoints, and not in bytes. * * @param a First string to compare. * @param a_length Number of bytes in the first string. * @param b Second string to compare. * @param b_length Number of bytes in the second string. * * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, * so the memory usage is linear in relation to ::a_length and ::b_length. * If SZ_NULL is passed, will initialize to the systems default `malloc`. * @param bound Upper bound on the distance, that allows us to exit early. * If zero is passed, the maximum possible distance will be equal to the length of the longer input. * @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX` * if the memory allocation failed. * * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance * @see https://en.wikipedia.org/wiki/Levenshtein_distance */ SZ_DYNAMIC sz_size_t sz_edit_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc); typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); /** @copydoc sz_edit_distance_utf8 */ SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc); /** * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. * * Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may * not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric. * Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`. * * @param a First string to compare. * @param a_length Number of bytes in the first string. * @param b Second string to compare. * @param b_length Number of bytes in the second string. * @param gap Penalty cost for gaps - insertions and removals. * @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters. * * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, * so the memory usage is linear in relation to ::a_length and ::b_length. * If SZ_NULL is passed, will initialize to the systems default `malloc`. * @return Signed similarity score. Can be negative, depending on the substitution costs. * If the memory allocation fails, the function returns `SZ_SSIZE_MAX`. * * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default * @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm */ SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // sz_error_cost_t const *subs, sz_error_cost_t gap, // sz_memory_allocator_t *alloc); /** @copydoc sz_alignment_score */ SZ_PUBLIC sz_ssize_t sz_alignment_score_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // sz_error_cost_t const *subs, sz_error_cost_t gap, // sz_memory_allocator_t *alloc); typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, sz_error_cost_t, sz_memory_allocator_t *); typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user); /** * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. * Can be used for similarity scores, search, ranking, etc. * * Rabin-Karp-like rolling hashes can have very high-level of collisions and depend * on the choice of bases and the prime number. That's why, often two hashes from the same * family are used with different bases. * * 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet. * 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function. * * Choosing the right ::window_length is task- and domain-dependant. For example, most English words are * between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences, * the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long. * With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed. * For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs. * * @param text String to hash. * @param length Number of bytes in the string. * @param window_length Length of the rolling window in bytes. * @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`. * @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`. * @param callback_handle Optional user-provided pointer to be passed to the `callback`. * @see sz_hashes_fingerprint, sz_hashes_intersection */ SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // sz_hash_callback_t callback, void *callback_handle); /** @copydoc sz_hashes */ SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // sz_hash_callback_t callback, void *callback_handle); typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *); /** * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. * Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity. * * The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times * to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint. * It can also be reused to produce multi-resolution fingerprints by changing the ::window_length * and calling the same function multiple times for the same input ::text. * * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, * avoiding cache-coherency penalties of remote on-heap buffers. * * @param text String to hash. * @param length Number of bytes in the string. * @param fingerprint Output fingerprint buffer. * @param fingerprint_bytes Number of bytes in the fingerprint buffer. * @param window_length Length of the rolling window in bytes. * @see sz_hashes, sz_hashes_intersection */ SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t text, sz_size_t length, sz_size_t window_length, // sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); /** * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes * of the incoming document. Can be used for document scoring and search. * * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, * avoiding cache-coherency penalties of remote on-heap buffers. * * @param text Input document. * @param length Number of bytes in the input document. * @param fingerprint Reference document fingerprint. * @param fingerprint_bytes Number of bytes in the reference documents fingerprint. * @param window_length Length of the rolling window in bytes. * @see sz_hashes, sz_hashes_fingerprint */ SZ_PUBLIC sz_size_t sz_hashes_intersection(sz_cptr_t text, sz_size_t length, sz_size_t window_length, // sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); #pragma endregion #pragma region Convenience API /** * @brief Finds the first character in the haystack, that is present in the needle. * Convenience function, reused across different language bindings. * @see sz_find_charset */ SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); /** * @brief Finds the first character in the haystack, that is @b not present in the needle. * Convenience function, reused across different language bindings. * @see sz_find_charset */ SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); /** * @brief Finds the last character in the haystack, that is present in the needle. * Convenience function, reused across different language bindings. * @see sz_find_charset */ SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); /** * @brief Finds the last character in the haystack, that is @b not present in the needle. * Convenience function, reused across different language bindings. * @see sz_find_charset */ SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); #pragma endregion #pragma region String Sequences API struct sz_sequence_t; typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t); typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t); typedef sz_bool_t (*sz_sequence_predicate_t)(struct sz_sequence_t const *, sz_size_t); typedef sz_bool_t (*sz_sequence_comparator_t)(struct sz_sequence_t const *, sz_size_t, sz_size_t); typedef sz_bool_t (*sz_string_is_less_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); typedef struct sz_sequence_t { sz_sorted_idx_t *order; sz_size_t count; sz_sequence_member_start_t get_start; sz_sequence_member_length_t get_length; void const *handle; } sz_sequence_t; /** * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. * Expects ::offsets to contains `count + 1` entries, the last pointing at the end * of the last string, indicating the total length of the ::tape. */ SZ_PUBLIC void sz_sequence_from_u32tape(sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count, sz_sequence_t *sequence); /** * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. * Expects ::offsets to contains `count + 1` entries, the last pointing at the end * of the last string, indicating the total length of the ::tape. */ SZ_PUBLIC void sz_sequence_from_u64tape(sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count, sz_sequence_t *sequence); /** * @brief Similar to `std::partition`, given a predicate splits the sequence into two parts. * The algorithm is unstable, meaning that elements may change relative order, as long * as they are in the right partition. This is the simpler algorithm for partitioning. */ SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate); /** * @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`. * * @param partition The number of elements in the first sub-sequence in `sequence`. * @param less Comparison function, to determine the lexicographic ordering. */ SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less); /** * @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word * and a follow-up by a more conventional sorting procedure on equally prefixed parts. */ SZ_PUBLIC void sz_sort(sz_sequence_t *sequence); /** * @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word * and a follow-up by a more conventional sorting procedure on equally prefixed parts. */ SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n); /** * @brief Intro-Sort algorithm that supports custom comparators. */ SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less); #pragma endregion /* * Hardware feature detection. * All of those can be controlled by the user. */ #ifndef SZ_USE_X86_AVX512 #ifdef __AVX512BW__ #define SZ_USE_X86_AVX512 1 #else #define SZ_USE_X86_AVX512 0 #endif #endif #ifndef SZ_USE_X86_AVX2 #ifdef __AVX2__ #define SZ_USE_X86_AVX2 1 #else #define SZ_USE_X86_AVX2 0 #endif #endif #ifndef SZ_USE_ARM_NEON #ifdef __ARM_NEON #define SZ_USE_ARM_NEON 1 #else #define SZ_USE_ARM_NEON 0 #endif #endif #ifndef SZ_USE_ARM_SVE #ifdef __ARM_FEATURE_SVE #define SZ_USE_ARM_SVE 1 #else #define SZ_USE_ARM_SVE 0 #endif #endif /* * Include hardware-specific headers. */ #if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 #include #endif // SZ_USE_X86... #if SZ_USE_ARM_NEON #if !defined(_MSC_VER) #include #endif #include #endif // SZ_USE_ARM_NEON #if SZ_USE_ARM_SVE #if !defined(_MSC_VER) #include #endif #endif // SZ_USE_ARM_SVE #pragma region Hardware - Specific API #if SZ_USE_X86_AVX512 /** @copydoc sz_equal */ SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length); /** @copydoc sz_order */ SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); /** @copydoc sz_copy */ SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_move */ SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_fill */ SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value); /** @copydoc sz_look_up_transform */ SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); /** @copydoc sz_find_byte */ SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_rfind_byte */ SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_find */ SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_rfind */ SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_find_charset */ SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); /** @copydoc sz_rfind_charset */ SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); /** @copydoc sz_edit_distance */ SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc); /** @copydoc sz_alignment_score */ SZ_PUBLIC sz_ssize_t sz_alignment_score_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // sz_error_cost_t const *subs, sz_error_cost_t gap, // sz_memory_allocator_t *alloc); /** @copydoc sz_hashes */ SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // sz_hash_callback_t callback, void *callback_handle); #endif #if SZ_USE_X86_AVX2 /** @copydoc sz_equal */ SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length); /** @copydoc sz_order */ SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); /** @copydoc sz_copy */ SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_move */ SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_fill */ SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value); /** @copydoc sz_look_up_transform */ SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); /** @copydoc sz_find_byte */ SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_rfind_byte */ SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_find */ SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_rfind */ SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_hashes */ SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // sz_hash_callback_t callback, void *callback_handle); #endif #if SZ_USE_ARM_NEON /** @copydoc sz_equal */ SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length); /** @copydoc sz_order */ SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); /** @copydoc sz_copy */ SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_move */ SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_fill */ SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value); /** @copydoc sz_look_up_transform */ SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target); /** @copydoc sz_find_byte */ SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_rfind_byte */ SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_find */ SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_rfind */ SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_find_charset */ SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); /** @copydoc sz_rfind_charset */ SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); #endif #if SZ_USE_ARM_SVE /** @copydoc sz_equal */ SZ_PUBLIC sz_bool_t sz_equal_sve(sz_cptr_t a, sz_cptr_t b, sz_size_t length); /** @copydoc sz_order */ SZ_PUBLIC sz_ordering_t sz_order_sve(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); /** @copydoc sz_copy */ SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_move */ SZ_PUBLIC void sz_move_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length); /** @copydoc sz_fill */ SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value); /** @copydoc sz_find_byte */ SZ_PUBLIC sz_cptr_t sz_find_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_rfind_byte */ SZ_PUBLIC sz_cptr_t sz_rfind_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); /** @copydoc sz_find */ SZ_PUBLIC sz_cptr_t sz_find_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_rfind */ SZ_PUBLIC sz_cptr_t sz_rfind_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); /** @copydoc sz_find_charset */ SZ_PUBLIC sz_cptr_t sz_find_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); /** @copydoc sz_rfind_charset */ SZ_PUBLIC sz_cptr_t sz_rfind_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); #endif #pragma endregion #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wconversion" /* ********************************************************************************************************************** ********************************************************************************************************************** ********************************************************************************************************************** * * This is where we the actual implementation begins. * The rest of the file is hidden from the public API. * ********************************************************************************************************************** ********************************************************************************************************************** ********************************************************************************************************************** */ #pragma region Compiler Extensions and Helper Functions #pragma GCC visibility push(hidden) /** * @brief Helper-macro to mark potentially unused variables. */ #define sz_unused(x) ((void)(x)) /** * @brief Helper-macro casting a variable to another type of the same size. */ #define sz_bitcast(type, value) (*((type *)&(value))) /** * @brief Defines `SZ_NULL`, analogous to `NULL`. * The default often comes from locale.h, stddef.h, * stdio.h, stdlib.h, string.h, time.h, or wchar.h. */ #ifdef __GNUG__ #define SZ_NULL __null #define SZ_NULL_CHAR __null #else #define SZ_NULL ((void *)0) #define SZ_NULL_CHAR ((char *)0) #endif /** * @brief Cache-line width, that will affect the execution of some algorithms, * like equality checks and relative order computing. */ #define SZ_CACHE_LINE_WIDTH (64) // bytes /** * @brief Similar to `assert`, the `sz_assert` is used in the SZ_DEBUG mode * to check the invariants of the library. It's a no-op in the SZ_RELEASE mode. * @note If you want to catch it, put a breakpoint at @b `__GI_exit` */ #if SZ_DEBUG && defined(SZ_AVOID_LIBC) && !SZ_AVOID_LIBC && !defined(SZ_PIC) #include // `fprintf` #include // `EXIT_FAILURE` SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int line) { fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", condition, file, line); exit(EXIT_FAILURE); } #define sz_assert(condition) \ do { \ if (!(condition)) { _sz_assert_failure(#condition, __FILE__, __LINE__); } \ } while (0) #else #define sz_assert(condition) ((void)(condition)) #endif /* Intrinsics aliases for MSVC, GCC, Clang, and Clang-Cl. * The following section of compiler intrinsics comes in 2 flavors. */ #if defined(_MSC_VER) && !defined(__clang__) // On Clang-CL #include // Sadly, when building Win32 images, we can't use the `_tzcnt_u64`, `_lzcnt_u64`, // `_BitScanForward64`, or `_BitScanReverse64` intrinsics. For now it's a simple `for`-loop. // TODO: In the future we can switch to a more efficient De Bruijn's algorithm. // https://www.chessprogramming.org/BitScan // https://www.chessprogramming.org/De_Bruijn_Sequence // https://gist.github.com/resilar/e722d4600dbec9752771ab4c9d47044f // // Use the serial version on 32-bit x86 and on Arm. #if (defined(_WIN32) && !defined(_WIN64)) || defined(_M_ARM) || defined(_M_ARM64) SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { sz_assert(x != 0); int n = 0; while ((x & 1) == 0) { n++, x >>= 1; } return n; } SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { sz_assert(x != 0); int n = 0; while ((x & 0x8000000000000000ull) == 0) { n++, x <<= 1; } return n; } SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { x = x - ((x >> 1) & 0x5555555555555555ull); x = (x & 0x3333333333333333ull) + ((x >> 2) & 0x3333333333333333ull); return (((x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Full) * 0x0101010101010101ull) >> 56; } SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { sz_assert(x != 0); int n = 0; while ((x & 1) == 0) { n++, x >>= 1; } return n; } SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { sz_assert(x != 0); int n = 0; while ((x & 0x80000000u) == 0) { n++, x <<= 1; } return n; } SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { x = x - ((x >> 1) & 0x55555555); x = (x & 0x33333333) + ((x >> 2) & 0x33333333); return (((x + (x >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24; } #else SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return (int)_tzcnt_u64(x); } SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return (int)_lzcnt_u64(x); } SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return (int)__popcnt64(x); } SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return (int)_tzcnt_u32(x); } SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return (int)_lzcnt_u32(x); } SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return (int)__popcnt(x); } #endif // Force the byteswap functions to be intrinsics, because when /Oi- is given, these will turn into CRT function calls, // which breaks when `SZ_AVOID_LIBC` is given #pragma intrinsic(_byteswap_uint64) SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return _byteswap_uint64(val); } #pragma intrinsic(_byteswap_ulong) SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return _byteswap_ulong(val); } #else SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return __builtin_popcountll(x); } SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return __builtin_popcount(x); } SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return __builtin_ctzll(x); } SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return __builtin_clzll(x); } SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return __builtin_ctz(x); } // ! Undefined if `x == 0` SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return __builtin_clz(x); } // ! Undefined if `x == 0` SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return __builtin_bswap64(val); } SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap32(val); } #endif SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); } /** * @brief Select bits from either ::a or ::b depending on the value of ::mask bits. * * Similar to `_mm_blend_epi16` intrinsic on x86. * Described in the "Bit Twiddling Hacks" by Sean Eron Anderson. * https://graphics.stanford.edu/~seander/bithacks.html#ConditionalSetOrClearBitsWithoutBranching */ SZ_INTERNAL sz_u64_t sz_u64_blend(sz_u64_t a, sz_u64_t b, sz_u64_t mask) { return a ^ ((a ^ b) & mask); } /* * Efficiently computing the minimum and maximum of two or three values can be tricky. * The simple branching baseline would be: * * x < y ? x : y // can replace with 1 conditional move * * Branchless approach is well known for signed integers, but it doesn't apply to unsigned ones. * https://stackoverflow.com/questions/514435/templatized-branchless-int-max-min-function * https://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax * Using only bit-shifts for singed integers it would be: * * y + ((x - y) & (x - y) >> 31) // 4 unique operations * * Alternatively, for any integers using multiplication: * * (x > y) * y + (x <= y) * x // 5 operations * * Alternatively, to avoid multiplication: * * x & ~((x < y) - 1) + y & ((x < y) - 1) // 6 unique operations */ #define sz_min_of_two(x, y) (x < y ? x : y) #define sz_max_of_two(x, y) (x < y ? y : x) #define sz_min_of_three(x, y, z) sz_min_of_two(x, sz_min_of_two(y, z)) #define sz_max_of_three(x, y, z) sz_max_of_two(x, sz_max_of_two(y, z)) /** @brief Branchless minimum function for two signed 32-bit integers. */ SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); } /** @brief Branchless minimum function for two signed 32-bit integers. */ SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); } /** * @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing. */ SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end, sz_size_t *normalized_offset, sz_size_t *normalized_length) { // TODO: Remove branches. // Normalize negative indices if (start < 0) start += length; if (end < 0) end += length; // Clamp indices to a valid range if (start < 0) start = 0; if (end < 0) end = 0; if (start > (sz_ssize_t)length) start = length; if (end > (sz_ssize_t)length) end = length; // Ensure start <= end if (start > end) start = end; *normalized_offset = start; *normalized_length = end - start; } /** * @brief Compute the logarithm base 2 of a positive integer, rounding down. */ SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); sz_size_t leading_zeros = sz_u64_clz(x); return 63 - leading_zeros; } /** * @brief Compute the smallest power of two greater than or equal to ::x. */ SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. // https://stackoverflow.com/a/10143264 x--; x |= x >> 1; x |= x >> 2; x |= x >> 4; x |= x >> 8; x |= x >> 16; #if SZ_DETECT_64_BIT x |= x >> 32; #endif x++; return x; } /** * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. * * There is a well known SWAR sequence for that known to chess programmers, * willing to flip a bit-matrix of pieces along the main A1-H8 diagonal. * https://www.chessprogramming.org/Flipping_Mirroring_and_Rotating * https://lukas-prokop.at/articles/2021-07-23-transpose */ SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { sz_u64_t t; t = x ^ (x << 36); x ^= 0xf0f0f0f00f0f0f0full & (t ^ (x >> 36)); t = 0xcccc0000cccc0000ull & (x ^ (x << 18)); x ^= t ^ (t >> 18); t = 0xaa00aa00aa00aa00ull & (x ^ (x << 9)); x ^= t ^ (t >> 9); return x; } /** * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. */ SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) { sz_u64_t t = *a; *a = *b; *b = t; } /** * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. */ SZ_INTERNAL void sz_pointer_swap(void **a, void **b) { void *t = *a; *a = *b; *b = t; } /** * @brief Helper structure to simplify work with 16-bit words. * @see sz_u16_load */ typedef union sz_u16_vec_t { sz_u16_t u16; sz_u8_t u8s[2]; } sz_u16_vec_t; /** * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. */ SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { #if !SZ_USE_MISALIGNED_LOADS sz_u16_vec_t result; result.u8s[0] = ptr[0]; result.u8s[1] = ptr[1]; return result; #elif defined(_MSC_VER) && !defined(__clang__) #if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. return *((sz_u16_vec_t *)ptr); #else return *((__unaligned sz_u16_vec_t *)ptr); #endif #else __attribute__((aligned(1))) sz_u16_vec_t const *result = (sz_u16_vec_t const *)ptr; return *result; #endif } /** * @brief Helper structure to simplify work with 32-bit words. * @see sz_u32_load */ typedef union sz_u32_vec_t { sz_u32_t u32; sz_u16_t u16s[2]; sz_u8_t u8s[4]; } sz_u32_vec_t; /** * @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. */ SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { #if !SZ_USE_MISALIGNED_LOADS sz_u32_vec_t result; result.u8s[0] = ptr[0]; result.u8s[1] = ptr[1]; result.u8s[2] = ptr[2]; result.u8s[3] = ptr[3]; return result; #elif defined(_MSC_VER) && !defined(__clang__) #if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. return *((sz_u32_vec_t *)ptr); #else return *((__unaligned sz_u32_vec_t *)ptr); #endif #else __attribute__((aligned(1))) sz_u32_vec_t const *result = (sz_u32_vec_t const *)ptr; return *result; #endif } /** * @brief Helper structure to simplify work with 64-bit words. * @see sz_u64_load */ typedef union sz_u64_vec_t { sz_u64_t u64; sz_u32_t u32s[2]; sz_u16_t u16s[4]; sz_u8_t u8s[8]; } sz_u64_vec_t; /** * @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. */ SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { #if !SZ_USE_MISALIGNED_LOADS sz_u64_vec_t result; result.u8s[0] = ptr[0]; result.u8s[1] = ptr[1]; result.u8s[2] = ptr[2]; result.u8s[3] = ptr[3]; result.u8s[4] = ptr[4]; result.u8s[5] = ptr[5]; result.u8s[6] = ptr[6]; result.u8s[7] = ptr[7]; return result; #elif defined(_MSC_VER) && !defined(__clang__) #if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform. return *((sz_u64_vec_t *)ptr); #else return *((__unaligned sz_u64_vec_t *)ptr); #endif #else __attribute__((aligned(1))) sz_u64_vec_t const *result = (sz_u64_vec_t const *)ptr; return *result; #endif } /** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { sz_size_t capacity; sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t)); sz_size_t consumed_capacity = sizeof(sz_size_t); if (consumed_capacity + length > capacity) return SZ_NULL_CHAR; return (sz_ptr_t)handle + consumed_capacity; } /** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) { sz_unused(start && length && handle); } /** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; sz_size_t fingerprint_bytes = fingerprint_buffer->length; fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); sz_unused(start && length); } /** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; sz_size_t fingerprint_bytes = fingerprint_buffer->length; fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); sz_unused(start && length); } /** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *scalar_handle) { sz_unused(start && length && hash && scalar_handle); sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; *scalar_ptr ^= hash; } /** * @brief Chooses the offsets of the most interesting characters in a search needle. * * Search throughput can significantly deteriorate if we are matching the wrong characters. * Say the needle is "aXaYa", and we are comparing the first, second, and last character. * If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste. * * Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information. * Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and * into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the * bytes will carry absolutely no value and will be equal to 0x04. */ SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, // sz_size_t *first, sz_size_t *second, sz_size_t *third) { *first = 0; *second = length / 2; *third = length - 1; // int has_duplicates = // start[*first] == start[*second] || // start[*first] == start[*third] || // start[*second] == start[*third]; // Loop through letters to find non-colliding variants. if (length > 3 && has_duplicates) { // Pivot the middle point right, until we find a character different from the first one. for (; start[*second] == start[*first] && *second + 1 < *third; ++(*second)) {} // Pivot the third (last) point left, until we find a different character. for (; (start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1); --(*third)) {} } // TODO: Investigate alternative strategies for long needles. // On very long needles we have the luxury to choose! // Often dealing with UTF8, we will likely benefit from shifting the first and second characters // further to the right, to achieve not only uniqueness within the needle, but also avoid common // rune prefixes of 2-, 3-, and 4-byte codes. if (length > 8) { // Pivot the first and second points right, until we find a character, that: // > is different from others. // > doesn't start with 0b'110x'xxxx - only 5 bits of relevant info. // > doesn't start with 0b'1110'xxxx - only 4 bits of relevant info. // > doesn't start with 0b'1111'0xxx - only 3 bits of relevant info. // // So we are practically searching for byte values that start with 0b0xxx'xxxx or 0b'10xx'xxxx. // Meaning they fall in the range [0, 127] and [128, 191], in other words any unsigned int up to 191. sz_u8_t const *start_u8 = (sz_u8_t const *)start; sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third; // Let's begin with the seccond character, as the termination criterea there is more obvious // and we may end up with more variants to check for the first candidate. for (; (start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) && (vibrant_second + 1 < vibrant_third); ++vibrant_second) {} // Now check if we've indeed found a good candidate or should revert the `vibrant_second` to `second`. if (start_u8[vibrant_second] < 191) { *second = vibrant_second; } else { vibrant_second = *second; } // Now check the first character. for (; (start_u8[vibrant_first] > 191 || start_u8[vibrant_first] == start_u8[vibrant_second] || start_u8[vibrant_first] == start_u8[vibrant_third]) && (vibrant_first + 1 < vibrant_second); ++vibrant_first) {} // Now check if we've indeed found a good candidate or should revert the `vibrant_first` to `first`. // We don't need to shift the third one when dealing with texts as the last byte of the text is // also the last byte of a rune and contains the most information. if (start_u8[vibrant_first] < 191) { *first = vibrant_first; } } } #pragma GCC visibility pop #pragma endregion #pragma region Serial Implementation #if !SZ_AVOID_LIBC #include // `fprintf` #include // `malloc`, `EXIT_FAILURE` SZ_PUBLIC void *_sz_memory_allocate_default(sz_size_t length, void *handle) { sz_unused(handle); return malloc(length); } SZ_PUBLIC void _sz_memory_free_default(sz_ptr_t start, sz_size_t length, void *handle) { sz_unused(handle && length); free(start); } #endif SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc) { #if !SZ_AVOID_LIBC alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_default; alloc->free = (sz_memory_free_t)_sz_memory_free_default; #else alloc->allocate = (sz_memory_allocate_t)SZ_NULL; alloc->free = (sz_memory_free_t)SZ_NULL; #endif alloc->handle = SZ_NULL; } SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length) { // The logic here is simple - put the buffer length in the first slots of the buffer. // Later use it for bounds checking. alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed; alloc->free = (sz_memory_free_t)_sz_memory_free_fixed; alloc->handle = &buffer; sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); } /** * @brief Byte-level equality comparison between two strings. * If unaligned loads are allowed, uses a switch-table to avoid loops on short strings. */ SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { sz_cptr_t const a_end = a + length; #if SZ_USE_MISALIGNED_LOADS if (length >= SZ_SWAR_THRESHOLD) { sz_u64_vec_t a_vec, b_vec; for (; a + 8 <= a_end; a += 8, b += 8) { a_vec = sz_u64_load(a); b_vec = sz_u64_load(b); if (a_vec.u64 != b_vec.u64) return sz_false_k; } } #endif while (a != a_end && *a == *b) a++, b++; return (sz_bool_t)(a_end == a); } SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { for (sz_cptr_t const end = text + length; text != end; ++text) if (sz_charset_contains(set, *text)) return text; return SZ_NULL_CHAR; } SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Warray-bounds" sz_cptr_t const end = text; for (text += length; text != end;) if (sz_charset_contains(set, *(text -= 1))) return text; return SZ_NULL_CHAR; #pragma GCC diagnostic pop } /** * One option to avoid branching is to use conditional moves and lookup the comparison result in a table: * sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; * for (; a != min_end; ++a, ++b) * if (*a != *b) return ordering_lookup[*a < *b]; * That, however, introduces a data-dependency. * A cleaner option is to perform two comparisons and a subtraction. * One instruction more, but no data-dependency. */ #define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b))) SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); sz_size_t min_length = a_shorter ? a_length : b_length; sz_cptr_t min_end = a + min_length; #if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { a_vec = sz_u64_load(a); b_vec = sz_u64_load(b); if (a_vec.u64 != b_vec.u64) return _sz_order_scalars(sz_u64_bytes_reverse(a_vec.u64), sz_u64_bytes_reverse(b_vec.u64)); } #endif for (; a != min_end; ++a, ++b) if (*a != *b) return _sz_order_scalars(*a, *b); // If the strings are equal up to `min_end`, then the shorter string is smaller return _sz_order_scalars(a_length, b_length); } /** * @brief Byte-level equality comparison between two 64-bit integers. * @return 64-bit integer, where every top bit in each byte signifies a match. */ SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { sz_u64_vec_t vec; vec.u64 = ~(a.u64 ^ b.u64); // The match is valid, if every bit within each byte is set. // For that take the bottom 7 bits of each byte, add one to them, // and if this sets the top bit to one, then all the 7 bits are ones as well. vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); return vec; } /** * @brief Find the first occurrence of a @b single-character needle in an arbitrary length haystack. * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. * Identical to `memchr(haystack, needle[0], haystack_length)`. */ SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { if (!h_length) return SZ_NULL_CHAR; sz_cptr_t const h_end = h + h_length; #if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevity. #if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) if (*h == *n) return h; #endif // Broadcast the n into every byte of a 64-bit integer to use SWAR // techniques and process eight characters at a time. sz_u64_vec_t h_vec, n_vec, match_vec; match_vec.u64 = 0; n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; for (; h + 8 <= h_end; h += 8) { h_vec.u64 = *(sz_u64_t const *)h; match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8; } #endif // Handle the misaligned tail. for (; h < h_end; ++h) if (*h == *n) return h; return SZ_NULL_CHAR; } /** * @brief Find the last occurrence of a @b single-character needle in an arbitrary length haystack. * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. * Identical to `memrchr(haystack, needle[0], haystack_length)`. */ sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { if (!h_length) return SZ_NULL_CHAR; sz_cptr_t const h_start = h; // Reposition the `h` pointer to the end, as we will be walking backwards. h = h + h_length - 1; #if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevity. #if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) if (*h == *n) return h; #endif // Broadcast the n into every byte of a 64-bit integer to use SWAR // techniques and process eight characters at a time. sz_u64_vec_t h_vec, n_vec, match_vec; n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; for (; h >= h_start + 7; h -= 8) { h_vec.u64 = *(sz_u64_t const *)(h - 7); match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8; } #endif for (; h >= h_start; --h) if (*h == *n) return h; return SZ_NULL_CHAR; } /** * @brief 2Byte-level equality comparison between two 64-bit integers. * @return 64-bit integer, where every top bit in each 2byte signifies a match. */ SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { sz_u64_vec_t vec; vec.u64 = ~(a.u64 ^ b.u64); // The match is valid, if every bit within each 2byte is set. // For that take the bottom 15 bits of each 2byte, add one to them, // and if this sets the top bit to one, then all the 15 bits are ones as well. vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull)); return vec; } /** * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. */ SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. sz_assert(h_length >= 2 && "The haystack is too short."); sz_unused(n_length); sz_cptr_t const h_end = h + h_length; #if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h) if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; #endif sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec; n_vec.u64 = 0; n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1]; n_vec.u64 *= 0x0001000100010001ull; // broadcast // This code simulates hyper-scalar execution, analyzing 8 offsets at a time. for (; h + 9 <= h_end; h += 8) { h_even_vec.u64 = *(sz_u64_t *)h; h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56); matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec); matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec); matches_even_vec.u64 >>= 8; if (matches_even_vec.u64 + matches_odd_vec.u64) { sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64; return h + sz_u64_ctz(match_indicators) / 8; } } for (; h + 2 <= h_end; ++h) if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; return SZ_NULL_CHAR; } /** * @brief 4Byte-level equality comparison between two 64-bit integers. * @return 64-bit integer, where every top bit in each 4byte signifies a match. */ SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { sz_u64_vec_t vec; vec.u64 = ~(a.u64 ^ b.u64); // The match is valid, if every bit within each 4byte is set. // For that take the bottom 31 bits of each 4byte, add one to them, // and if this sets the top bit to one, then all the 31 bits are ones as well. vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull)); return vec; } /** * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. */ SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. sz_assert(h_length >= 4 && "The haystack is too short."); sz_unused(n_length); sz_cptr_t const h_end = h + h_length; #if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h) if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; #endif sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec; n_vec.u64 = 0; n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; n_vec.u64 *= 0x0000000100000001ull; // broadcast // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words. // We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :) sz_u64_t h_page_current, h_page_next; for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) { h_page_current = *(sz_u64_t *)h; h_page_next = *(sz_u32_t *)(h + 8); h0_vec.u64 = (h_page_current); h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec); matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec); matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec); matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec); if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) { matches0_vec.u64 >>= 24; matches1_vec.u64 >>= 16; matches2_vec.u64 >>= 8; sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64; return h + sz_u64_ctz(match_indicators) / 8; } } for (; h + 4 <= h_end; ++h) if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; return SZ_NULL_CHAR; } /** * @brief 3Byte-level equality comparison between two 64-bit integers. * @return 64-bit integer, where every top bit in each 3byte signifies a match. */ SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { sz_u64_vec_t vec; vec.u64 = ~(a.u64 ^ b.u64); // The match is valid, if every bit within each 4byte is set. // For that take the bottom 31 bits of each 4byte, add one to them, // and if this sets the top bit to one, then all the 31 bits are ones as well. vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull)); return vec; } /** * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. */ SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. sz_assert(h_length >= 3 && "The haystack is too short."); sz_unused(n_length); sz_cptr_t const h_end = h + h_length; #if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h) if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; #endif // We fetch 12 sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec; sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec; sz_u64_vec_t n_vec; n_vec.u64 = 0; n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2]; n_vec.u64 *= 0x0000000001000001ull; // broadcast // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words. // We load the subsequent two-byte word as well. sz_u64_t h_page_current, h_page_next; for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) { h_page_current = *(sz_u64_t *)h; h_page_next = *(sz_u16_t *)(h + 8); h0_vec.u64 = (h_page_current); h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32); matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec); matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec); matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec); matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec); matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec); if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) { matches0_vec.u64 >>= 16; matches1_vec.u64 >>= 8; matches3_vec.u64 <<= 8; matches4_vec.u64 <<= 16; sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64; return h + sz_u64_ctz(match_indicators) / 8; } } for (; h + 3 <= h_end; ++h) if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; return SZ_NULL_CHAR; } /** * @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long. * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. */ SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // sz_cptr_t n_chars, sz_size_t n_length) { sz_assert(n_length <= 256 && "The pattern is too long."); // Several popular string matching algorithms are using a bad-character shift table. // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html // Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html union { sz_u8_t jumps[256]; sz_u64_vec_t vecs[64]; } bad_shift_table; // Let's initialize the table using SWAR to the total length of the string. sz_u8_t const *h = (sz_u8_t const *)h_chars; sz_u8_t const *n = (sz_u8_t const *)n_chars; { sz_u64_vec_t n_length_vec; n_length_vec.u64 = ((sz_u8_t)(n_length - 1)) * 0x0101010101010101ull; // broadcast for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1); } // Another common heuristic is to match a few characters from different parts of a string. // Raita suggests to use the first two, the last, and the middle character of the pattern. sz_u32_vec_t h_vec, n_vec; // Pick the parts of the needle that are worth comparing. sz_size_t offset_first, offset_mid, offset_last; _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); // Broadcast those characters into an unsigned integer. n_vec.u8s[0] = n[offset_first]; n_vec.u8s[1] = n[offset_first + 1]; n_vec.u8s[2] = n[offset_mid]; n_vec.u8s[3] = n[offset_last]; // Scan through the whole haystack, skipping the last `n_length - 1` bytes. for (sz_size_t i = 0; i <= h_length - n_length;) { h_vec.u8s[0] = h[i + offset_first]; h_vec.u8s[1] = h[i + offset_first + 1]; h_vec.u8s[2] = h[i + offset_mid]; h_vec.u8s[3] = h[i + offset_last]; if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; i += bad_shift_table.jumps[h[i + n_length - 1]]; } return SZ_NULL_CHAR; } /** * @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long. * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. */ SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // sz_cptr_t n_chars, sz_size_t n_length) { sz_assert(n_length <= 256 && "The pattern is too long."); union { sz_u8_t jumps[256]; sz_u64_vec_t vecs[64]; } bad_shift_table; // Let's initialize the table using SWAR to the total length of the string. sz_u8_t const *h = (sz_u8_t const *)h_chars; sz_u8_t const *n = (sz_u8_t const *)n_chars; { sz_u64_vec_t n_length_vec; n_length_vec.u64 = ((sz_u8_t)(n_length - 1)) * 0x0101010101010101ull; // broadcast for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1); } // Another common heuristic is to match a few characters from different parts of a string. // Raita suggests to use the first two, the last, and the middle character of the pattern. sz_u32_vec_t h_vec, n_vec; // Pick the parts of the needle that are worth comparing. sz_size_t offset_first, offset_mid, offset_last; _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); // Broadcast those characters into an unsigned integer. n_vec.u8s[0] = n[offset_first]; n_vec.u8s[1] = n[offset_first + 1]; n_vec.u8s[2] = n[offset_mid]; n_vec.u8s[3] = n[offset_last]; // Scan through the whole haystack, skipping the first `n_length - 1` bytes. for (sz_size_t j = 0; j <= h_length - n_length;) { sz_size_t i = h_length - n_length - j; h_vec.u8s[0] = h[i + offset_first]; h_vec.u8s[1] = h[i + offset_first + 1]; h_vec.u8s[2] = h[i + offset_mid]; h_vec.u8s[3] = h[i + offset_last]; if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; j += bad_shift_table.jumps[h[i]]; } return SZ_NULL_CHAR; } /** * @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle * using a given search function, and then verifies the remaining part of the needle. */ SZ_INTERNAL sz_cptr_t _sz_find_with_prefix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, sz_find_t find_prefix, sz_size_t prefix_length) { sz_size_t suffix_length = n_length - prefix_length; while (1) { sz_cptr_t found = find_prefix(h, h_length, n, prefix_length); if (!found) return SZ_NULL_CHAR; // Verify the remaining part of the needle sz_size_t remaining = h_length - (found - h); if (remaining < n_length) return SZ_NULL_CHAR; if (sz_equal(found + prefix_length, n + prefix_length, suffix_length)) return found; // Adjust the position. h = found + 1; h_length = remaining - 1; } // Unreachable, but helps silence compiler warnings: return SZ_NULL_CHAR; } /** * @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the * needle using a given search function, and then verifies the remaining part of the needle. */ SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, sz_find_t find_suffix, sz_size_t suffix_length) { sz_size_t prefix_length = n_length - suffix_length; while (1) { sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length); if (!found) return SZ_NULL_CHAR; // Verify the remaining part of the needle sz_size_t remaining = found - h; if (remaining < prefix_length) return SZ_NULL_CHAR; if (sz_equal(found - prefix_length, n, prefix_length)) return found - prefix_length; // Adjust the position. h_length = remaining - 1; } // Unreachable, but helps silence compiler warnings: return SZ_NULL_CHAR; } SZ_INTERNAL sz_cptr_t _sz_find_byte_prefix_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { sz_unused(n_length); return sz_find_byte_serial(h, h_length, n); } SZ_INTERNAL sz_cptr_t _sz_rfind_byte_prefix_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { sz_unused(n_length); return sz_rfind_byte_serial(h, h_length, n); } SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_4byte_serial, 4); } SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256); } SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256); } SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; #if SZ_DETECT_BIG_ENDIAN sz_find_t backends[] = { _sz_find_byte_prefix_serial, _sz_find_horspool_upto_256bytes_serial, _sz_find_horspool_over_256bytes_serial, }; return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length); #else sz_find_t backends[] = { // For very short strings brute-force SWAR makes sense. _sz_find_byte_prefix_serial, _sz_find_2byte_serial, _sz_find_3byte_serial, _sz_find_4byte_serial, // To avoid constructing the skip-table, let's use the prefixed approach. _sz_find_over_4bytes_serial, // For longer needles - use skip tables. _sz_find_horspool_upto_256bytes_serial, _sz_find_horspool_over_256bytes_serial, }; return backends[ // For very short strings brute-force SWAR makes sense. (n_length > 1) + (n_length > 2) + (n_length > 3) + // To avoid constructing the skip-table, let's use the prefixed approach. (n_length > 4) + // For longer needles - use skip tables. (n_length > 8) + (n_length > 256)](h, h_length, n, n_length); #endif } SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; sz_find_t backends[] = { // For very short strings brute-force SWAR makes sense. _sz_rfind_byte_prefix_serial, // TODO: implement reverse-order SWAR for 2/3/4 byte variants. // TODO: _sz_rfind_2byte_serial, // TODO: _sz_rfind_3byte_serial, // TODO: _sz_rfind_4byte_serial, // To avoid constructing the skip-table, let's use the prefixed approach. // _sz_rfind_over_4bytes_serial, // For longer needles - use skip tables. _sz_rfind_horspool_upto_256bytes_serial, _sz_rfind_horspool_over_256bytes_serial, }; return backends[ // For very short strings brute-force SWAR makes sense. 0 + // To avoid constructing the skip-table, let's use the prefixed approach. (n_length > 1) + // For longer needles - use skip tables. (n_length > 256)](h, h_length, n, n_length); } SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // sz_cptr_t shorter, sz_size_t shorter_length, // sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. sz_memory_allocator_t global_alloc; if (!alloc) { sz_memory_allocator_init_default(&global_alloc); alloc = &global_alloc; } // TODO: Generalize to remove the following asserts! sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); sz_unused(longer_length && bound); // We are going to store 3 diagonals of the matrix. // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. sz_size_t n = shorter_length + 1; sz_size_t buffer_length = sizeof(sz_size_t) * n * 3; sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); if (!distances) return SZ_SIZE_MAX; sz_size_t *previous_distances = distances; sz_size_t *current_distances = previous_distances + n; sz_size_t *next_distances = previous_distances + n * 2; // Initialize the first two diagonals: previous_distances[0] = 0; current_distances[0] = current_distances[1] = 1; // Progress through the upper triangle of the Levenshtein matrix. sz_size_t next_skew_diagonal_index = 2; for (; next_skew_diagonal_index != n; ++next_skew_diagonal_index) { sz_size_t const next_skew_diagonal_length = next_skew_diagonal_index + 1; for (sz_size_t i = 0; i + 2 < next_skew_diagonal_length; ++i) { sz_size_t cost_of_substitution = shorter[next_skew_diagonal_index - i - 2] != longer[i]; sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); } // Don't forget to populate the first row and the first column of the Levenshtein matrix. next_distances[0] = next_distances[next_skew_diagonal_length - 1] = next_skew_diagonal_index; // Perform a circular rotation of those buffers, to reuse the memory. sz_size_t *temporary = previous_distances; previous_distances = current_distances; current_distances = next_distances; next_distances = temporary; } // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal // index on either side, we will be cropping those values out. sz_size_t total_diagonals = n + n - 1; for (; next_skew_diagonal_index != total_diagonals; ++next_skew_diagonal_index) { sz_size_t const next_skew_diagonal_length = total_diagonals - next_skew_diagonal_index; for (sz_size_t i = 0; i != next_skew_diagonal_length; ++i) { sz_size_t cost_of_substitution = shorter[shorter_length - 1 - i] != longer[next_skew_diagonal_index - n + i]; sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); } // Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift, // dropping the first element in the current array. sz_size_t *temporary = previous_distances; previous_distances = current_distances + 1; current_distances = next_distances; next_distances = temporary; } // Cache scalar before `free` call. sz_size_t result = current_distances[0]; alloc->free(distances, buffer_length, alloc->handle); return result; } /** * @brief Describes the length of a UTF8 character / codepoint / rune in bytes. */ typedef enum { sz_utf8_invalid_k = 0, //!< Invalid UTF8 character. sz_utf8_rune_1byte_k = 1, //!< 1-byte UTF8 character. sz_utf8_rune_2bytes_k = 2, //!< 2-byte UTF8 character. sz_utf8_rune_3bytes_k = 3, //!< 3-byte UTF8 character. sz_utf8_rune_4bytes_k = 4, //!< 4-byte UTF8 character. } sz_rune_length_t; typedef sz_u32_t sz_rune_t; /** * @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer. */ SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_length_t *code_length) { sz_u8_t const *current = (sz_u8_t const *)utf8; sz_u8_t leading_byte = *current++; sz_rune_t ch; sz_rune_length_t ch_length; // TODO: This can be made entirely branchless using 32-bit SWAR. if (leading_byte < 0x80) { // Single-byte rune (0xxxxxxx) ch = leading_byte; ch_length = sz_utf8_rune_1byte_k; } else if ((leading_byte & 0xE0) == 0xC0) { // Two-byte rune (110xxxxx 10xxxxxx) ch = (leading_byte & 0x1F) << 6; ch |= (*current++ & 0x3F); ch_length = sz_utf8_rune_2bytes_k; } else if ((leading_byte & 0xF0) == 0xE0) { // Three-byte rune (1110xxxx 10xxxxxx 10xxxxxx) ch = (leading_byte & 0x0F) << 12; ch |= (*current++ & 0x3F) << 6; ch |= (*current++ & 0x3F); ch_length = sz_utf8_rune_3bytes_k; } else if ((leading_byte & 0xF8) == 0xF0) { // Four-byte rune (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) ch = (leading_byte & 0x07) << 18; ch |= (*current++ & 0x3F) << 12; ch |= (*current++ & 0x3F) << 6; ch |= (*current++ & 0x3F); ch_length = sz_utf8_rune_4bytes_k; } else { // Invalid UTF8 rune. ch = 0; ch_length = sz_utf8_invalid_k; } *code = ch; *code_length = ch_length; } /** * @brief Exports a UTF8 string into a UTF32 buffer. * ! The result is undefined id the UTF8 string is corrupted. * @return The length in the number of codepoints. */ SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_length, sz_rune_t *utf32) { sz_cptr_t const end = utf8 + utf8_length; sz_size_t count = 0; sz_rune_length_t rune_length; for (; utf8 != end; utf8 += rune_length, utf32++, count++) _sz_extract_utf8_rune(utf8, utf32, &rune_length); return count; } /** * @brief Compute the Levenshtein distance between two strings using the Wagner-Fisher algorithm. * Stores only 2 rows of the Levenshtein matrix, but uses 64-bit integers for the distance values, * and upcasts UTF8 variable-length codepoints to 64-bit integers for faster addressing. * * ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in extra: * + 2 rows * 100 slots * 8 bytes/slot = 1600 bytes of memory for the two rows of the Levenshtein matrix rows. * + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer. * = 2400 bytes of memory or @b 12x memory amplification! */ SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // sz_cptr_t longer, sz_size_t longer_length, // sz_cptr_t shorter, sz_size_t shorter_length, // sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc) { // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. sz_memory_allocator_t global_alloc; if (!alloc) { sz_memory_allocator_init_default(&global_alloc); alloc = &global_alloc; } // A good idea may be to dispatch different kernels for different string lengths. // Like using `uint8_t` counters for strings under 255 characters long. // Good in theory, this results in frequent upcasts and downcasts in serial code. // On strings over 20 bytes, using `uint8` over `uint64` on 64-bit x86 CPU doubles the execution time. // So one must be very cautious with such optimizations. typedef sz_size_t _distance_t; // Compute the number of columns in our Levenshtein matrix. sz_size_t const n = shorter_length + 1; // If a buffering memory-allocator is provided, this operation is practically free, // and cheaper than allocating even 512 bytes (for small distance matrices) on stack. sz_size_t buffer_length = sizeof(_distance_t) * (n * 2); // If the strings contain Unicode characters, let's estimate the max character width, // and use it to allocate a larger buffer to decode UTF8. if ((can_be_unicode == sz_true_k) && (sz_isascii(longer, longer_length) == sz_false_k || sz_isascii(shorter, shorter_length) == sz_false_k)) { buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t); } else { can_be_unicode = sz_false_k; } // If the allocation fails, return the maximum distance. sz_ptr_t const buffer = (sz_ptr_t)alloc->allocate(buffer_length, alloc->handle); if (!buffer) return SZ_SIZE_MAX; // Let's export the UTF8 sequence into the newly allocated buffer at the end. if (can_be_unicode == sz_true_k) { sz_rune_t *const longer_utf32 = (sz_rune_t *)(buffer + sizeof(_distance_t) * (n * 2)); sz_rune_t *const shorter_utf32 = longer_utf32 + longer_length; // Export the UTF8 sequences into the newly allocated buffer. longer_length = _sz_export_utf8_to_utf32(longer, longer_length, longer_utf32); shorter_length = _sz_export_utf8_to_utf32(shorter, shorter_length, shorter_utf32); longer = (sz_cptr_t)longer_utf32; shorter = (sz_cptr_t)shorter_utf32; } // Let's parameterize the core logic for different character types and distance types. #define _wagner_fisher_unbounded(_distance_t, _char_t) \ /* Now let's cast our pointer to avoid it in subsequent sections. */ \ _char_t const *const longer_chars = (_char_t const *)longer; \ _char_t const *const shorter_chars = (_char_t const *)shorter; \ _distance_t *previous_distances = (_distance_t *)buffer; \ _distance_t *current_distances = previous_distances + n; \ /* Initialize the first row of the Levenshtein matrix with `iota`-style arithmetic progression. */ \ for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ /* The main loop of the algorithm with quadratic complexity. */ \ for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ _char_t const longer_char = longer_chars[idx_longer]; \ /* Using pure pointer arithmetic is faster than iterating with an index. */ \ _char_t const *shorter_ptr = shorter_chars; \ _distance_t const *previous_ptr = previous_distances; \ _distance_t *current_ptr = current_distances; \ _distance_t *const current_end = current_ptr + shorter_length; \ current_ptr[0] = idx_longer + 1; \ for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ /* We can avoid `+1` for costs here, shifting it to post-minimum computation, */ \ /* saving one increment operation. */ \ _distance_t cost_deletion = previous_ptr[1]; \ _distance_t cost_insertion = current_ptr[0]; \ /* ? It might be a good idea to enforce branchless execution here. */ \ /* ? The caveat being that the benchmarks on longer sequences backfire and more research is needed. */ \ current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ } \ /* Swap `previous_distances` and `current_distances` pointers. */ \ _distance_t *temporary = previous_distances; \ previous_distances = current_distances; \ current_distances = temporary; \ } \ /* Cache scalar before `free` call. */ \ sz_size_t result = previous_distances[shorter_length]; \ alloc->free(buffer, buffer_length, alloc->handle); \ return result; // Let's define a separate variant for bounded distance computation. // Practically the same as unbounded, but also collecting the running minimum within each row for early exit. #define _wagner_fisher_bounded(_distance_t, _char_t) \ _char_t const *const longer_chars = (_char_t const *)longer; \ _char_t const *const shorter_chars = (_char_t const *)shorter; \ _distance_t *previous_distances = (_distance_t *)buffer; \ _distance_t *current_distances = previous_distances + n; \ for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \ for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \ _char_t const longer_char = longer_chars[idx_longer]; \ _char_t const *shorter_ptr = shorter_chars; \ _distance_t const *previous_ptr = previous_distances; \ _distance_t *current_ptr = current_distances; \ _distance_t *const current_end = current_ptr + shorter_length; \ current_ptr[0] = idx_longer + 1; \ /* Initialize min_distance with a value greater than bound */ \ _distance_t min_distance = bound - 1; \ for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \ _distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \ _distance_t cost_deletion = previous_ptr[1]; \ _distance_t cost_insertion = current_ptr[0]; \ current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \ /* Keep track of the minimum distance seen so far in this row */ \ min_distance = sz_min_of_two(current_ptr[1], min_distance); \ } \ /* If the minimum distance in this row exceeded the bound, return early */ \ if (min_distance >= bound) { \ alloc->free(buffer, buffer_length, alloc->handle); \ return bound; \ } \ _distance_t *temporary = previous_distances; \ previous_distances = current_distances; \ current_distances = temporary; \ } \ sz_size_t result = previous_distances[shorter_length]; \ alloc->free(buffer, buffer_length, alloc->handle); \ return sz_min_of_two(result, bound); // Dispatch the actual computation. if (!bound) { if (can_be_unicode == sz_true_k) { _wagner_fisher_unbounded(sz_size_t, sz_rune_t); } else { _wagner_fisher_unbounded(sz_size_t, sz_u8_t); } } else { if (can_be_unicode == sz_true_k) { _wagner_fisher_bounded(sz_size_t, sz_rune_t); } else { _wagner_fisher_bounded(sz_size_t, sz_u8_t); } } } SZ_PUBLIC sz_size_t sz_edit_distance_serial( // sz_cptr_t longer, sz_size_t longer_length, // sz_cptr_t shorter, sz_size_t shorter_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { // Let's make sure that we use the amount proportional to the // number of elements in the shorter string, not the larger. if (shorter_length > longer_length) { sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); sz_pointer_swap((void **)&longer, (void **)&shorter); } // Skip the matching prefixes and suffixes, they won't affect the distance. for (sz_cptr_t a_end = longer + longer_length, b_end = shorter + shorter_length; longer != a_end && shorter != b_end && *longer == *shorter; ++longer, ++shorter, --longer_length, --shorter_length); for (; longer_length && shorter_length && longer[longer_length - 1] == shorter[shorter_length - 1]; --longer_length, --shorter_length); // Bounded computations may exit early. if (bound) { // If one of the strings is empty - the edit distance is equal to the length of the other one. if (longer_length == 0) return sz_min_of_two(shorter_length, bound); if (shorter_length == 0) return sz_min_of_two(longer_length, bound); // If the difference in length is beyond the `bound`, there is no need to check at all. if (longer_length - shorter_length > bound) return bound; } if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. if (shorter_length == longer_length && !bound) return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); return _sz_edit_distance_wagner_fisher_serial(longer, longer_length, shorter, shorter_length, bound, sz_false_k, alloc); } SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // sz_cptr_t longer, sz_size_t longer_length, // sz_cptr_t shorter, sz_size_t shorter_length, // sz_error_cost_t const *subs, sz_error_cost_t gap, // sz_memory_allocator_t *alloc) { // If one of the strings is empty - the edit distance is equal to the length of the other one if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; // Let's make sure that we use the amount proportional to the // number of elements in the shorter string, not the larger. if (shorter_length > longer_length) { sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); sz_pointer_swap((void **)&longer, (void **)&shorter); } // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. sz_memory_allocator_t global_alloc; if (!alloc) { sz_memory_allocator_init_default(&global_alloc); alloc = &global_alloc; } sz_size_t n = shorter_length + 1; sz_size_t buffer_length = sizeof(sz_ssize_t) * n * 2; sz_ssize_t *distances = (sz_ssize_t *)alloc->allocate(buffer_length, alloc->handle); sz_ssize_t *previous_distances = distances; sz_ssize_t *current_distances = previous_distances + n; for (sz_size_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = (sz_ssize_t)idx_shorter * gap; sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; sz_u8_t const *longer_unsigned = (sz_u8_t const *)longer; for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { current_distances[0] = ((sz_ssize_t)idx_longer + 1) * gap; // Initialize min_distance with a value greater than bound sz_error_cost_t const *a_subs = subs + longer_unsigned[idx_longer] * 256ul; for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { sz_ssize_t cost_deletion = previous_distances[idx_shorter + 1] + gap; sz_ssize_t cost_insertion = current_distances[idx_shorter] + gap; sz_ssize_t cost_substitution = previous_distances[idx_shorter] + a_subs[shorter_unsigned[idx_shorter]]; current_distances[idx_shorter + 1] = sz_max_of_three(cost_deletion, cost_insertion, cost_substitution); } // Swap previous_distances and current_distances pointers sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); } // Cache scalar before `free` call. sz_ssize_t result = previous_distances[shorter_length]; alloc->free(distances, buffer_length, alloc->handle); return result; } SZ_PUBLIC sz_size_t sz_hamming_distance_serial( // sz_cptr_t a, sz_size_t a_length, // sz_cptr_t b, sz_size_t b_length, // sz_size_t bound) { sz_size_t const min_length = sz_min_of_two(a_length, b_length); sz_size_t const max_length = sz_max_of_two(a_length, b_length); sz_cptr_t const a_end = a + min_length; bound = bound == 0 ? max_length : bound; // Walk through both strings using SWAR and counting the number of differing characters. sz_size_t distance = max_length - min_length; #if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN if (min_length >= SZ_SWAR_THRESHOLD) { sz_u64_vec_t a_vec, b_vec, match_vec; for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) { a_vec.u64 = sz_u64_load(a).u64; b_vec.u64 = sz_u64_load(b).u64; match_vec = _sz_u64_each_byte_equal(a_vec, b_vec); distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull); } } #endif for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); } return sz_min_of_two(distance, bound); } SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( // sz_cptr_t a, sz_size_t a_length, // sz_cptr_t b, sz_size_t b_length, // sz_size_t bound) { sz_cptr_t const a_end = a + a_length; sz_cptr_t const b_end = b + b_length; sz_size_t distance = 0; sz_rune_t a_rune, b_rune; sz_rune_length_t a_rune_length, b_rune_length; if (bound) { for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) { _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); distance += (a_rune != b_rune); } // If one string has more runes, we need to go through the tail. if (distance < bound) { for (; a < a_end && distance < bound; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); for (; b < b_end && distance < bound; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); } } else { for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) { _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); distance += (a_rune != b_rune); } // If one string has more runes, we need to go through the tail. for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length); for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length); } return distance; } SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) { sz_u64_t checksum = 0; sz_u8_t const *text_u8 = (sz_u8_t const *)text; sz_u8_t const *text_end = text_u8 + length; for (; text_u8 != text_end; ++text_u8) checksum += *text_u8; return checksum; } /** * @brief Largest prime number that fits into 31 bits. * @see https://mersenneforum.org/showthread.php?t=3471 */ #define SZ_U32_MAX_PRIME (2147483647u) /** * @brief Largest prime number that fits into 64 bits. * @see https://mersenneforum.org/showthread.php?t=3471 * * 2^64 = 18,446,744,073,709,551,616 * this = 18,446,744,073,709,551,557 * diff = 59 */ #define SZ_U64_MAX_PRIME (18446744073709551557ull) /* * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. * Using a Boost-like mixer works very poorly in such case: * * hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2)); * * Let's stick to the Fibonacci hash trick using the golden ratio. * https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ */ #define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull)) #define _sz_shift_low(x) (x) #define _sz_shift_high(x) ((x + 77ull) & 0xFFull) #define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { sz_u64_t hash_low = 0; sz_u64_t hash_high = 0; sz_u8_t const *text = (sz_u8_t const *)start; sz_u8_t const *text_end = text + length; switch (length) { case 0: return 0; // Texts under 7 bytes long are definitely below the largest prime. case 1: hash_low = _sz_shift_low(text[0]); hash_high = _sz_shift_high(text[0]); break; case 2: hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); break; case 3: hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // _sz_shift_low(text[1]) * 31ull + // _sz_shift_low(text[2]); hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // _sz_shift_high(text[1]) * 257ull + // _sz_shift_high(text[2]); break; case 4: hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // _sz_shift_low(text[1]) * 31ull * 31ull + // _sz_shift_low(text[2]) * 31ull + // _sz_shift_low(text[3]); hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // _sz_shift_high(text[1]) * 257ull * 257ull + // _sz_shift_high(text[2]) * 257ull + // _sz_shift_high(text[3]); break; case 5: hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // _sz_shift_low(text[2]) * 31ull * 31ull + // _sz_shift_low(text[3]) * 31ull + // _sz_shift_low(text[4]); hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // _sz_shift_high(text[2]) * 257ull * 257ull + // _sz_shift_high(text[3]) * 257ull + // _sz_shift_high(text[4]); break; case 6: hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // _sz_shift_low(text[3]) * 31ull * 31ull + // _sz_shift_low(text[4]) * 31ull + // _sz_shift_low(text[5]); hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // _sz_shift_high(text[3]) * 257ull * 257ull + // _sz_shift_high(text[4]) * 257ull + // _sz_shift_high(text[5]); break; case 7: hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // _sz_shift_low(text[4]) * 31ull * 31ull + // _sz_shift_low(text[5]) * 31ull + // _sz_shift_low(text[6]); hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // _sz_shift_high(text[4]) * 257ull * 257ull + // _sz_shift_high(text[5]) * 257ull + // _sz_shift_high(text[6]); break; default: // Unroll the first seven cycles: hash_low = hash_low * 31ull + _sz_shift_low(text[0]); hash_high = hash_high * 257ull + _sz_shift_high(text[0]); hash_low = hash_low * 31ull + _sz_shift_low(text[1]); hash_high = hash_high * 257ull + _sz_shift_high(text[1]); hash_low = hash_low * 31ull + _sz_shift_low(text[2]); hash_high = hash_high * 257ull + _sz_shift_high(text[2]); hash_low = hash_low * 31ull + _sz_shift_low(text[3]); hash_high = hash_high * 257ull + _sz_shift_high(text[3]); hash_low = hash_low * 31ull + _sz_shift_low(text[4]); hash_high = hash_high * 257ull + _sz_shift_high(text[4]); hash_low = hash_low * 31ull + _sz_shift_low(text[5]); hash_high = hash_high * 257ull + _sz_shift_high(text[5]); hash_low = hash_low * 31ull + _sz_shift_low(text[6]); hash_high = hash_high * 257ull + _sz_shift_high(text[6]); text += 7; // Iterate throw the rest with the modulus: for (; text != text_end; ++text) { hash_low = hash_low * 31ull + _sz_shift_low(text[0]); hash_high = hash_high * 257ull + _sz_shift_high(text[0]); // Wrap the hashes around: hash_low = _sz_prime_mod(hash_low); hash_high = _sz_prime_mod(hash_high); } break; } return _sz_hash_mix(hash_low, hash_high); } SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // sz_hash_callback_t callback, void *callback_handle) { if (length < window_length || !window_length) return; sz_u8_t const *text = (sz_u8_t const *)start; sz_u8_t const *text_end = text + length; // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. sz_u64_t prime_power_low = 1, prime_power_high = 1; for (sz_size_t i = 0; i + 1 < window_length; ++i) prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; // Compute the initial hash value for the first window. sz_u64_t hash_low = 0, hash_high = 0, hash_mix; for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text) hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME, hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME; // In most cases the fingerprint length will be a power of two. hash_mix = _sz_hash_mix(hash_low, hash_high); callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); // Compute the hash value for every window, exporting into the fingerprint, // using the expensive modulo operation. sz_size_t cycles = 1; sz_size_t const step_mask = step - 1; for (; text < text_end; ++text, ++cycles) { // Discard one character: hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low; hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high; // And add a new one: hash_low = 31ull * hash_low + _sz_shift_low(*text); hash_high = 257ull * hash_high + _sz_shift_high(*text); // Wrap the hashes around: hash_low = _sz_prime_mod(hash_low); hash_high = _sz_prime_mod(hash_high); // Mix only if we've skipped enough hashes. if ((cycles & step_mask) == 0) { hash_mix = _sz_hash_mix(hash_low, hash_high); callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); } } } #undef _sz_shift_low #undef _sz_shift_high #undef _sz_hash_mix #undef _sz_prime_mod /** * @brief Uses a small lookup-table to convert a lowercase character to uppercase. */ SZ_INTERNAL sz_u8_t sz_u8_tolower(sz_u8_t c) { static sz_u8_t const lowered[256] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, // 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // }; return lowered[c]; } /** * @brief Uses a small lookup-table to convert an uppercase character to lowercase. */ SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { static sz_u8_t const upped[256] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // 96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, // 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, // 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // }; return upped[c]; } /** * @brief Uses two small lookup tables (768 bytes total) to accelerate division by a small * unsigned integer. Performs two lookups, one multiplication, two shifts, and two accumulations. * * @param divisor Integral value @b larger than one. * @param number Integral value to divide. */ SZ_INTERNAL sz_u8_t sz_u8_divide(sz_u8_t number, sz_u8_t divisor) { sz_assert(divisor > 1); static sz_u16_t const multipliers[256] = { 0, 0, 0, 21846, 0, 39322, 21846, 9363, 0, 50973, 39322, 29790, 21846, 15124, 9363, 4370, 0, 57826, 50973, 44841, 39322, 34329, 29790, 25645, 21846, 18351, 15124, 12137, 9363, 6780, 4370, 2115, 0, 61565, 57826, 54302, 50973, 47824, 44841, 42011, 39322, 36765, 34329, 32006, 29790, 27671, 25645, 23705, 21846, 20063, 18351, 16706, 15124, 13602, 12137, 10725, 9363, 8049, 6780, 5554, 4370, 3224, 2115, 1041, 0, 63520, 61565, 59668, 57826, 56039, 54302, 52614, 50973, 49377, 47824, 46313, 44841, 43407, 42011, 40649, 39322, 38028, 36765, 35532, 34329, 33154, 32006, 30885, 29790, 28719, 27671, 26647, 25645, 24665, 23705, 22766, 21846, 20945, 20063, 19198, 18351, 17520, 16706, 15907, 15124, 14356, 13602, 12863, 12137, 11424, 10725, 10038, 9363, 8700, 8049, 7409, 6780, 6162, 5554, 4957, 4370, 3792, 3224, 2665, 2115, 1573, 1041, 517, 0, 64520, 63520, 62535, 61565, 60609, 59668, 58740, 57826, 56926, 56039, 55164, 54302, 53452, 52614, 51788, 50973, 50169, 49377, 48595, 47824, 47063, 46313, 45572, 44841, 44120, 43407, 42705, 42011, 41326, 40649, 39982, 39322, 38671, 38028, 37392, 36765, 36145, 35532, 34927, 34329, 33738, 33154, 32577, 32006, 31443, 30885, 30334, 29790, 29251, 28719, 28192, 27671, 27156, 26647, 26143, 25645, 25152, 24665, 24182, 23705, 23233, 22766, 22303, 21846, 21393, 20945, 20502, 20063, 19628, 19198, 18772, 18351, 17933, 17520, 17111, 16706, 16305, 15907, 15514, 15124, 14738, 14356, 13977, 13602, 13231, 12863, 12498, 12137, 11779, 11424, 11073, 10725, 10380, 10038, 9699, 9363, 9030, 8700, 8373, 8049, 7727, 7409, 7093, 6780, 6470, 6162, 5857, 5554, 5254, 4957, 4662, 4370, 4080, 3792, 3507, 3224, 2943, 2665, 2388, 2115, 1843, 1573, 1306, 1041, 778, 517, 258, }; // This table can be avoided using a single addition and counting trailing zeros. static sz_u8_t const shifts[256] = { 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, // 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, // 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // }; sz_u32_t multiplier = multipliers[divisor]; sz_u8_t shift = shifts[divisor]; sz_u16_t q = (sz_u16_t)((multiplier * number) >> 16); sz_u16_t t = ((number - q) >> 1) + q; return (sz_u8_t)(t >> shift); } SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result) { sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut; sz_u8_t const *unsigned_text = (sz_u8_t const *)text; sz_u8_t *unsigned_result = (sz_u8_t *)result; sz_u8_t const *end = unsigned_text + length; for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = unsigned_lut[*unsigned_text]; } SZ_PUBLIC void sz_tolower_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { sz_u8_t *unsigned_result = (sz_u8_t *)result; sz_u8_t const *unsigned_text = (sz_u8_t const *)text; sz_u8_t const *end = unsigned_text + length; for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_tolower(*unsigned_text); } SZ_PUBLIC void sz_toupper_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { sz_u8_t *unsigned_result = (sz_u8_t *)result; sz_u8_t const *unsigned_text = (sz_u8_t const *)text; sz_u8_t const *end = unsigned_text + length; for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_toupper(*unsigned_text); } SZ_PUBLIC void sz_toascii_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { sz_u8_t *unsigned_result = (sz_u8_t *)result; sz_u8_t const *unsigned_text = (sz_u8_t const *)text; sz_u8_t const *end = unsigned_text + length; for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = *unsigned_text & 0x7F; } /** * @brief Check if there is a byte in this buffer, that exceeds 127 and can't be an ASCII character. * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. */ SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) { if (!length) return sz_true_k; sz_u8_t const *h = (sz_u8_t const *)text; sz_u8_t const *const h_end = h + length; #if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads. for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) if (*h & 0x80ull) return sz_false_k; #endif // Validate eight bytes at once using SWAR. sz_u64_vec_t text_vec; for (; h + 8 <= h_end; h += 8) { text_vec.u64 = *(sz_u64_t const *)h; if (text_vec.u64 & 0x8080808080808080ull) return sz_false_k; } // Handle the misaligned tail. for (; h < h_end; ++h) if (*h & 0x80ull) return sz_false_k; return sz_true_k; } SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, sz_random_generator_t generator, void *generator_user_data) { sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size"); if (alphabet_size == 1) sz_fill(result, result_length, *alphabet); else { sz_assert(generator && "Expects a valid random generator"); sz_u8_t divisor = (sz_u8_t)alphabet_size; for (sz_cptr_t end = result + result_length; result != end; ++result) { sz_u8_t random = generator(generator_user_data) & 0xFF; sz_u8_t quotient = sz_u8_divide(random, divisor); *result = alphabet[random - quotient * divisor]; } } } #pragma endregion /* * Serial implementation of string class operations. */ #pragma region Serial Implementation for the String Class SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string) { // It doesn't matter if it's on stack or heap, the pointer location is the same. return (sz_bool_t)((sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]); } SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length) { sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; sz_size_t is_big_mask = is_small - 1ull; *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. *length = string->external.length & (0x00000000000000FFull | is_big_mask); } SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, sz_bool_t *is_external) { sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; sz_size_t is_big_mask = is_small - 1ull; *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. *length = string->external.length & (0x00000000000000FFull | is_big_mask); // In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull. *space = sz_u64_blend(SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); *is_external = (sz_bool_t)!is_small; } SZ_PUBLIC sz_bool_t sz_string_equal(sz_string_t const *a, sz_string_t const *b) { // Tempting to say that the external.length is bitwise the same even if it includes // some bytes of the on-stack payload, but we don't at this writing maintain that invariant. // (An on-stack string includes noise bytes in the high-order bits of external.length. So do this // the hard/correct way. #if SZ_USE_MISALIGNED_LOADS // Dealing with StringZilla strings, we know that the `start` pointer always points // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. #endif // Alternatively, fall back to byte-by-byte comparison. sz_ptr_t a_start, b_start; sz_size_t a_length, b_length; sz_string_range(a, &a_start, &a_length); sz_string_range(b, &b_start, &b_length); return (sz_bool_t)(a_length == b_length && sz_equal(a_start, b_start, b_length)); } SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const *b) { #if SZ_USE_MISALIGNED_LOADS // Dealing with StringZilla strings, we know that the `start` pointer always points // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. #endif // Alternatively, fall back to byte-by-byte comparison. sz_ptr_t a_start, b_start; sz_size_t a_length, b_length; sz_string_range(a, &a_start, &a_length); sz_string_range(b, &b_start, &b_length); return sz_order(a_start, a_length, b_start, b_length); } SZ_PUBLIC void sz_string_init(sz_string_t *string) { sz_assert(string && "String can't be SZ_NULL."); // Only 8 + 1 + 1 need to be initialized. string->internal.start = &string->internal.chars[0]; // But for safety let's initialize the entire structure to zeros. // string->internal.chars[0] = 0; // string->internal.length = 0; string->words[1] = 0; string->words[2] = 0; string->words[3] = 0; } SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) { sz_size_t space_needed = length + 1; // space for trailing \0 sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); // Initialize the string to zeros for safety. string->words[1] = 0; string->words[2] = 0; string->words[3] = 0; // If we are lucky, no memory allocations will be needed. if (space_needed <= SZ_STRING_INTERNAL_SPACE) { string->internal.start = &string->internal.chars[0]; string->internal.length = (sz_u8_t)length; } else { // If we are not lucky, we need to allocate memory. string->external.start = (sz_ptr_t)allocator->allocate(space_needed, allocator->handle); if (!string->external.start) return SZ_NULL_CHAR; string->external.length = length; string->external.space = space_needed; } sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); string->external.start[length] = 0; return string->external.start; } SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) { sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); sz_size_t new_space = new_capacity + 1; if (new_space <= SZ_STRING_INTERNAL_SPACE) return string->external.start; sz_ptr_t string_start; sz_size_t string_length; sz_size_t string_space; sz_bool_t string_is_external; sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); sz_assert(new_space > string_space && "New space must be larger than current."); sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); if (!new_start) return SZ_NULL_CHAR; sz_copy(new_start, string_start, string_length); string->external.start = new_start; string->external.space = new_space; string->external.padding = 0; string->external.length = string_length; // Deallocate the old string. if (string_is_external) allocator->free(string_start, string_space, allocator->handle); return string->external.start; } SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator) { sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL."); sz_ptr_t string_start; sz_size_t string_length; sz_size_t string_space; sz_bool_t string_is_external; sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); // We may already be space-optimal, and in that case we don't need to do anything. sz_size_t new_space = string_length + 1; if (string_space == new_space || !string_is_external) return string->external.start; sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); if (!new_start) return SZ_NULL_CHAR; sz_copy(new_start, string_start, string_length); string->external.start = new_start; string->external.space = new_space; string->external.padding = 0; string->external.length = string_length; // Deallocate the old string. if (string_is_external) allocator->free(string_start, string_space, allocator->handle); return string->external.start; } SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, sz_memory_allocator_t *allocator) { sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); sz_ptr_t string_start; sz_size_t string_length; sz_size_t string_space; sz_bool_t string_is_external; sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); // The user intended to extend the string. offset = sz_min_of_two(offset, string_length); // If we are lucky, no memory allocations will be needed. if (string_length + added_length < string_space) { sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); string_start[string_length + added_length] = 0; // Even if the string is on the stack, the `+=` won't affect the tail of the string. string->external.length += added_length; } // If we are not lucky, we need to allocate more memory. else { sz_size_t next_planned_size = sz_max_of_two(SZ_CACHE_LINE_WIDTH, string_space * 2ull); sz_size_t min_needed_space = sz_size_bit_ceil(offset + string_length + added_length + 1); sz_size_t new_space = sz_max_of_two(min_needed_space, next_planned_size); string_start = sz_string_reserve(string, new_space - 1, allocator); if (!string_start) return SZ_NULL_CHAR; // Copy into the new buffer. sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); string_start[string_length + added_length] = 0; string->external.length = string_length + added_length; } return string_start; } SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) { sz_assert(string && "String can't be SZ_NULL."); sz_ptr_t string_start; sz_size_t string_length; sz_size_t string_space; sz_bool_t string_is_external; sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); // Normalize the offset, it can't be larger than the length. offset = sz_min_of_two(offset, string_length); // We shouldn't normalize the length, to avoid overflowing on `offset + length >= string_length`, // if receiving `length == SZ_SIZE_MAX`. After following expression the `length` will contain // exactly the delta between original and final length of this `string`. length = sz_min_of_two(length, string_length - offset); // There are 2 common cases, that wouldn't even require a `memmove`: // 1. Erasing the entire contents of the string. // In that case `length` argument will be equal or greater than `length` member. // 2. Removing the tail of the string with something like `string.pop_back()` in C++. // // In both of those, regardless of the location of the string - stack or heap, // the erasing is as easy as setting the length to the offset. // In every other case, we must `memmove` the tail of the string to the left. if (offset + length < string_length) sz_move(string_start + offset, string_start + offset + length, string_length - offset - length); // The `string->external.length = offset` assignment would discard last characters // of the on-the-stack string, but inplace subtraction would work. string->external.length -= length; string_start[string_length - length] = 0; return length; } SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator) { if (!sz_string_is_on_stack(string)) allocator->free(string->external.start, string->external.space, allocator->handle); sz_string_init(string); } // When overriding libc, disable optimizations for this function because MSVC will optimize the loops into a memset. // Which then causes a stack overflow due to infinite recursion (memset -> sz_fill_serial -> memset). #if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC #pragma optimize("", off) #endif SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value) { // Dealing with short strings, a single sequential pass would be faster. // If the size is larger than 2 words, then at least 1 of them will be aligned. // But just one aligned word may not be worth SWAR. if (length < SZ_SWAR_THRESHOLD) while (length--) *(target++) = value; // In case of long strings, skip unaligned bytes, and then fill the rest in 64-bit chunks. else { sz_u64_t value64 = (sz_u64_t)value * 0x0101010101010101ull; while ((sz_size_t)target & 7ull) *(target++) = value, length--; while (length >= 8) *(sz_u64_t *)target = value64, target += 8, length -= 8; while (length--) *(target++) = value; } } #if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC #pragma optimize("", on) #endif SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { // The most typical implementation of `memcpy` suffers from Undefined Behavior: // // for (char const *end = source + length; source < end; source++) *target++ = *source; // // As NULL pointer arithmetic is undefined for calls like `memcpy(NULL, NULL, 0)`. // That's mitigated in C2y with the N3322 proposal, but our solution uses a design, that has no such issues. // https://developers.redhat.com/articles/2024/12/11/making-memcpynull-null-0-well-defined #if SZ_USE_MISALIGNED_LOADS while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)source, target += 8, source += 8, length -= 8; #endif while (length--) *(target++) = *(source++); } SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { // Implementing `memmove` is trickier, than `memcpy`, as the ranges may overlap. // Existing implementations often have two passes, in normal and reversed order, // depending on the relation of `target` and `source` addresses. // https://student.cs.uwaterloo.ca/~cs350/common/os161-src-html/doxygen/html/memmove_8c_source.html // https://marmota.medium.com/c-language-making-memmove-def8792bb8d5 // // We can use the `memcpy` like left-to-right pass if we know that the `target` is before `source`. // Or if we know that they don't intersect! In that case the traversal order is irrelevant, // but older CPUs may predict and fetch forward-passes better. if (target < source || target >= source + length) { #if SZ_USE_MISALIGNED_LOADS while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)(source), target += 8, source += 8, length -= 8; #endif while (length--) *(target++) = *(source++); } else { // Jump to the end and walk backwards. target += length, source += length; #if SZ_USE_MISALIGNED_LOADS while (length >= 8) *(sz_u64_t *)(target -= 8) = *(sz_u64_t const *)(source -= 8), length -= 8; #endif while (length--) *(--target) = *(--source); } } #pragma endregion /* * @brief Serial implementation for strings sequence processing. */ #pragma region Serial Implementation for Sequences SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) { sz_size_t matches = 0; while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches; for (sz_size_t i = matches + 1; i < sequence->count; ++i) if (predicate(sequence, sequence->order[i])) sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches; return matches; } SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) { sz_size_t start_b = partition + 1; // If the direct merge is already sorted if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return; sz_size_t start_a = 0; while (start_a <= partition && start_b <= sequence->count) { // If element 1 is in right place if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; } else { sz_size_t value = sequence->order[start_b]; sz_size_t index = start_b; // Shift all the elements between element 1 // element 2, right by 1. while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; } sequence->order[start_a] = value; // Update all the pointers start_a++; partition++; start_b++; } } } SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) { sz_u64_t *keys = sequence->order; sz_size_t keys_count = sequence->count; for (sz_size_t i = 1; i < keys_count; i++) { sz_u64_t i_key = keys[i]; sz_size_t j = i; for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1]; keys[j] = i_key; } } SZ_INTERNAL void _sz_sift_down(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start, sz_size_t end) { sz_size_t root = start; while (2 * root + 1 <= end) { sz_size_t child = 2 * root + 1; if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; } if (!less(sequence, order[root], order[child])) { return; } sz_u64_swap(order + root, order + child); root = child; } } SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) { sz_size_t start = (count - 2) / 2; while (1) { _sz_sift_down(sequence, less, order, start, count - 1); if (start == 0) return; start--; } } SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) { sz_u64_t *order = sequence->order; sz_size_t count = last - first; _sz_heapify(sequence, less, order + first, count); sz_size_t end = count - 1; while (end > 0) { sz_u64_swap(order + first, order + first + end); end--; _sz_sift_down(sequence, less, order + first, 0, end); } } SZ_PUBLIC void sz_sort_introsort_recursion(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last, sz_size_t depth) { sz_size_t length = last - first; switch (length) { case 0: case 1: return; case 2: if (less(sequence, sequence->order[first + 1], sequence->order[first])) sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]); return; case 3: { sz_u64_t a = sequence->order[first]; sz_u64_t b = sequence->order[first + 1]; sz_u64_t c = sequence->order[first + 2]; if (less(sequence, b, a)) sz_u64_swap(&a, &b); if (less(sequence, c, b)) sz_u64_swap(&c, &b); if (less(sequence, b, a)) sz_u64_swap(&a, &b); sequence->order[first] = a; sequence->order[first + 1] = b; sequence->order[first + 2] = c; return; } } // Until a certain length, the quadratic-complexity insertion-sort is fine if (length <= 16) { sz_sequence_t sub_seq = *sequence; sub_seq.order += first; sub_seq.count = length; sz_sort_insertion(&sub_seq, less); return; } // Fallback to N-logN-complexity heap-sort if (depth == 0) { _sz_heapsort(sequence, less, first, last); return; } --depth; // Median-of-three logic to choose pivot sz_size_t median = first + length / 2; if (less(sequence, sequence->order[median], sequence->order[first])) sz_u64_swap(&sequence->order[first], &sequence->order[median]); if (less(sequence, sequence->order[last - 1], sequence->order[first])) sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]); if (less(sequence, sequence->order[median], sequence->order[last - 1])) sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]); // Partition using the median-of-three as the pivot sz_u64_t pivot = sequence->order[median]; sz_size_t left = first; sz_size_t right = last - 1; while (1) { while (less(sequence, sequence->order[left], pivot)) left++; while (less(sequence, pivot, sequence->order[right])) right--; if (left >= right) break; sz_u64_swap(&sequence->order[left], &sequence->order[right]); left++; right--; } // Recursively sort the partitions sz_sort_introsort_recursion(sequence, less, first, left, depth); sz_sort_introsort_recursion(sequence, less, right + 1, last, depth); } SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) { if (sequence->count == 0) return; sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0; sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two; sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit); } SZ_PUBLIC void sz_sort_recursion( // sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator, sz_size_t partial_order_length) { if (!sequence->count) return; // Array of size one doesn't need sorting - only needs the prefix to be discarded. if (sequence->count == 1) { sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; order_half_words[1] = 0; return; } // Partition a range of integers according to a specific bit value sz_size_t split = 0; sz_u64_t mask = (1ull << 63) >> bit_idx; // The clean approach would be to perform a single pass over the sequence. // // while (split != sequence->count && !(sequence->order[split] & mask)) ++split; // for (sz_size_t i = split + 1; i < sequence->count; ++i) // if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split; // // This, however, doesn't take into account the high relative cost of writes and swaps. // To circumvent that, we can first count the total number entries to be mapped into either part. // And then walk through both parts, swapping the entries that are in the wrong part. // This would often lead to ~15% performance gain. sz_size_t count_with_bit_set = 0; for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0; split = sequence->count - count_with_bit_set; // It's possible that the sequence is already partitioned. if (split != 0 && split != sequence->count) { // Use two pointers to efficiently reposition elements. // On pointer walks left-to-right from the start, and the other walks right-to-left from the end. sz_size_t left = 0; sz_size_t right = sequence->count - 1; while (1) { // Find the next element with the bit set on the left side. while (left < split && !(sequence->order[left] & mask)) ++left; // Find the next element without the bit set on the right side. while (right >= split && (sequence->order[right] & mask)) --right; // Swap the mispositioned elements. if (left < split && right >= split) { sz_u64_swap(sequence->order + left, sequence->order + right); ++left; --right; } else { break; } } } // Go down recursively. if (bit_idx < bit_max) { sz_sequence_t a = *sequence; a.count = split; sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length); sz_sequence_t b = *sequence; b.order += split; b.count -= split; sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length); } // Reached the end of recursion. else { // Discard the prefixes. sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; } sz_sequence_t a = *sequence; a.count = split; sz_sort_introsort(&a, comparator); sz_sequence_t b = *sequence; b.order += split; b.count -= split; sz_sort_introsort(&b, comparator); } } SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) { sz_cptr_t i_str = sequence->get_start(sequence, i_key); sz_cptr_t j_str = sequence->get_start(sequence, j_key); sz_size_t i_len = sequence->get_length(sequence, i_key); sz_size_t j_len = sequence->get_length(sequence, j_key); return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k); } SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { #if SZ_DETECT_BIG_ENDIAN // TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing. sz_unused(partial_order_length); sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); #else // Export up to 4 bytes into the `sequence` bits themselves for (sz_size_t i = 0; i != sequence->count; ++i) { sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]); sz_size_t length = sequence->get_length(sequence, sequence->order[i]); length = length > 4u ? 4u : length; sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i]; for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j]; } // Perform optionally-parallel radix sort on them sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length); #endif } SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { #if SZ_DETECT_BIG_ENDIAN sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less); #else sz_sort_partial(sequence, sequence->count); #endif } #pragma endregion /* * @brief AVX2 implementation of the string search algorithms. * Very minimalistic, but still faster than the serial implementation. */ #pragma region AVX2 Implementation #if SZ_USE_X86_AVX2 #pragma GCC push_options #pragma GCC target("avx2") #pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) #include /** * @brief Helper structure to simplify work with 256-bit registers. */ typedef union sz_u256_vec_t { __m256i ymm; __m128i xmms[2]; sz_u64_t u64s[4]; sz_u32_t u32s[8]; sz_u16_t u16s[16]; sz_u8_t u8s[32]; } sz_u256_vec_t; SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations return sz_order_serial(a, a_length, b, b_length); } SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { sz_u256_vec_t a_vec, b_vec; while (length >= 32) { a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a); b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b); // One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`. int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm)); if (difference_mask == 0) { a += 32, b += 32, length -= 32; } else { return sz_false_k; } } if (length) return sz_equal_serial(a, b, length); return sz_true_k; } SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value) { char value_char = *(char *)&value; __m256i value_vec = _mm256_set1_epi8(value_char); // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "stores". // // for (; length >= 32; target += 32, length -= 32) _mm256_storeu_si256(target, value_vec); // sz_fill_serial(target, length, value); // // When the buffer is small, there isn't much to innovate. if (length <= 32) sz_fill_serial(target, length, value); // When the buffer is aligned, we can avoid any split-stores. else { sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. sz_u16_t value16 = (sz_u16_t)value * 0x0101u; sz_u32_t value32 = (sz_u32_t)value16 * 0x00010001u; sz_u64_t value64 = (sz_u64_t)value32 * 0x0000000100000001ull; // Fill the head of the buffer. This part is much cleaner with AVX-512. if (head_length & 1) *(sz_u8_t *)target = value, target++, head_length--; if (head_length & 2) *(sz_u16_t *)target = value16, target += 2, head_length -= 2; if (head_length & 4) *(sz_u32_t *)target = value32, target += 4, head_length -= 4; if (head_length & 8) *(sz_u64_t *)target = value64, target += 8, head_length -= 8; if (head_length & 16) _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, head_length -= 16; sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); // Fill the aligned body of the buffer. for (; body_length >= 32; target += 32, body_length -= 32) _mm256_store_si256((__m256i *)target, value_vec); // Fill the tail of the buffer. This part is much cleaner with AVX-512. sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); if (tail_length & 16) _mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, tail_length -= 16; if (tail_length & 8) *(sz_u64_t *)target = value64, target += 8, tail_length -= 8; if (tail_length & 4) *(sz_u32_t *)target = value32, target += 4, tail_length -= 4; if (tail_length & 2) *(sz_u16_t *)target = value16, target += 2, tail_length -= 2; if (tail_length & 1) *(sz_u8_t *)target = value, target++, tail_length--; } } SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "stores" and "loads". // // for (; length >= 32; target += 32, source += 32, length -= 32) // _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); // sz_copy_serial(target, source, length); // // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. // For now, let's avoid the cases beyond the L2 size. int is_huge = length > 1ull * 1024ull * 1024ull; if (length <= 32) { sz_copy_serial(target, source, length); } // When dealing with larger arrays, the optimization is not as simple as with the `sz_fill_avx2` function, // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, // we can use aligned loads and stores, and the performance will be great. else if ((sz_size_t)target % 32 == 0 && (sz_size_t)source % 32 == 0 && !is_huge) { for (; length >= 32; target += 32, source += 32, length -= 32) _mm256_store_si256((__m256i *)target, _mm256_load_si256((__m256i const *)source)); if (length) sz_copy_serial(target, source, length); } // The trickiest case is when both `source` and `target` are not aligned. // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, // and then combine unaligned loads with aligned stores. else { sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less. sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less. sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. // Fill the head of the buffer. This part is much cleaner with AVX-512. if (head_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, head_length--; if (head_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, head_length -= 2; if (head_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, head_length -= 4; if (head_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, head_length -= 8; if (head_length & 16) _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, head_length -= 16; sz_assert(head_length == 0 && "The head length should be zero after the head copy."); sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); // Fill the aligned body of the buffer. if (!is_huge) { for (; body_length >= 32; target += 32, source += 32, body_length -= 32) _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); } // When the buffer is huge, we can traverse it in 2 directions. else { size_t tails_bytes_skipped = 0; for (; body_length >= 64; target += 32, source += 32, body_length -= 64, tails_bytes_skipped += 32) { _mm256_store_si256((__m256i *)(target), _mm256_lddqu_si256((__m256i const *)(source))); _mm256_store_si256((__m256i *)(target + body_length - 32), _mm256_lddqu_si256((__m256i const *)(source + body_length - 32))); } if (body_length) { sz_assert(body_length == 32 && "The only remaining body length should be 32 bytes."); _mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); target += 32, source += 32, body_length -= 32; } target += tails_bytes_skipped; source += tails_bytes_skipped; } // Fill the tail of the buffer. This part is much cleaner with AVX-512. sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size."); if (tail_length & 16) _mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16, tail_length -= 16; if (tail_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, tail_length -= 8; if (tail_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, tail_length -= 4; if (tail_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, tail_length -= 2; if (tail_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, tail_length--; sz_assert(tail_length == 0 && "The tail length should be zero after the tail copy."); } } SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { if (target < source || target >= source + length) { for (; length >= 32; target += 32, source += 32, length -= 32) _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); while (length--) *(target++) = *(source++); } else { // Jump to the end and walk backwards. for (target += length, source += length; length >= 32; length -= 32) _mm256_storeu_si256((__m256i *)(target -= 32), _mm256_lddqu_si256((__m256i const *)(source -= 32))); while (length--) *(--target) = *(--source); } } SZ_PUBLIC sz_u64_t sz_checksum_avx2(sz_cptr_t text, sz_size_t length) { // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "loads". // // A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core, // 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer. // For now, let's avoid the cases beyond the L2 size. int is_huge = length > 1ull * 1024ull * 1024ull; // When the buffer is small, there isn't much to innovate. if (length <= 32) { return sz_checksum_serial(text, length); } else if (!is_huge) { sz_u256_vec_t text_vec, sums_vec; sums_vec.ymm = _mm256_setzero_si256(); for (; length >= 32; text += 32, length -= 32) { text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); } // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); sz_u64_t result = low + high; if (length) result += sz_checksum_serial(text, length); return result; } // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. // Most notably, we can avoid populating the cache with the entire buffer, and instead traverse it in 2 directions. else { sz_size_t head_length = (32 - ((sz_size_t)text % 32)) % 32; // 31 or less. sz_size_t tail_length = (sz_size_t)(text + length) % 32; // 31 or less. sz_size_t body_length = length - head_length - tail_length; // Multiple of 32. sz_u64_t result = 0; // Handle the head while (head_length--) result += *text++; sz_u256_vec_t text_vec, sums_vec; sums_vec.ymm = _mm256_setzero_si256(); // Fill the aligned body of the buffer. if (!is_huge) { for (; body_length >= 32; text += 32, body_length -= 32) { text_vec.ymm = _mm256_stream_load_si256((__m256i const *)text); sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); } } // When the buffer is huge, we can traverse it in 2 directions. else { sz_u256_vec_t text_reversed_vec, sums_reversed_vec; sums_reversed_vec.ymm = _mm256_setzero_si256(); for (; body_length >= 64; text += 64, body_length -= 64) { text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64)); sums_reversed_vec.ymm = _mm256_add_epi64( sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256())); } if (body_length >= 32) { text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text)); sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256())); } sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm); } // Handle the tail while (tail_length--) result += *text++; // Accumulating 256 bits is harder, as we need to extract the 128-bit sums first. __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm); __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1); __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); result += low + high; return result; } } SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { // If the input is tiny (especially smaller than the look-up table itself), we may end up paying // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. // But if at least 3 cache lines are touched, the AVX-2 implementation should be faster. if (length <= 128) { sz_look_up_transform_serial(source, length, lut, target); return; } // We need to pull the lookup table into 8x YMM registers. // The biggest issue is reorganizing the data in the lookup table, as AVX2 doesn't have 256-bit shuffle, // it only has 128-bit "within-lane" shuffle. Still, it's wiser to use full YMM registers, instead of XMM, // so that we can at least compensate high latency with twice larger window and one more level of lookup. sz_u256_vec_t lut_0_to_15_vec, lut_16_to_31_vec, lut_32_to_47_vec, lut_48_to_63_vec, // lut_64_to_79_vec, lut_80_to_95_vec, lut_96_to_111_vec, lut_112_to_127_vec, // lut_128_to_143_vec, lut_144_to_159_vec, lut_160_to_175_vec, lut_176_to_191_vec, // lut_192_to_207_vec, lut_208_to_223_vec, lut_224_to_239_vec, lut_240_to_255_vec; lut_0_to_15_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut))); lut_16_to_31_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 16))); lut_32_to_47_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 32))); lut_48_to_63_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 48))); lut_64_to_79_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 64))); lut_80_to_95_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 80))); lut_96_to_111_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 96))); lut_112_to_127_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 112))); lut_128_to_143_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 128))); lut_144_to_159_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 144))); lut_160_to_175_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 160))); lut_176_to_191_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 176))); lut_192_to_207_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 192))); lut_208_to_223_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 208))); lut_224_to_239_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 224))); lut_240_to_255_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 240))); // Assuming each lookup is performed within 16 elements of 256, we need to reduce the scope by 16x = 2^4. sz_u256_vec_t not_first_bit_vec, not_second_bit_vec, not_third_bit_vec, not_fourth_bit_vec; /// Top and bottom nibbles of the source are used separately. sz_u256_vec_t source_vec, source_bot_vec; sz_u256_vec_t blended_0_to_31_vec, blended_32_to_63_vec, blended_64_to_95_vec, blended_96_to_127_vec, blended_128_to_159_vec, blended_160_to_191_vec, blended_192_to_223_vec, blended_224_to_255_vec; // Handling the head. while (length >= 32) { // Load and separate the nibbles of each byte in the source. source_vec.ymm = _mm256_lddqu_si256((__m256i const *)source); source_bot_vec.ymm = _mm256_and_si256(source_vec.ymm, _mm256_set1_epi8((char)0x0F)); // In the first round, we select using the 4th bit. not_fourth_bit_vec.ymm = _mm256_cmpeq_epi8( // _mm256_and_si256(_mm256_set1_epi8((char)0x10), source_vec.ymm), _mm256_setzero_si256()); blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // _mm256_shuffle_epi8(lut_16_to_31_vec.ymm, source_bot_vec.ymm), // _mm256_shuffle_epi8(lut_0_to_15_vec.ymm, source_bot_vec.ymm), // not_fourth_bit_vec.ymm); blended_32_to_63_vec.ymm = _mm256_blendv_epi8( // _mm256_shuffle_epi8(lut_48_to_63_vec.ymm, source_bot_vec.ymm), // _mm256_shuffle_epi8(lut_32_to_47_vec.ymm, source_bot_vec.ymm), // not_fourth_bit_vec.ymm); blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // _mm256_shuffle_epi8(lut_80_to_95_vec.ymm, source_bot_vec.ymm), // _mm256_shuffle_epi8(lut_64_to_79_vec.ymm, source_bot_vec.ymm), // not_fourth_bit_vec.ymm); blended_96_to_127_vec.ymm = _mm256_blendv_epi8( // _mm256_shuffle_epi8(lut_112_to_127_vec.ymm, source_bot_vec.ymm), // _mm256_shuffle_epi8(lut_96_to_111_vec.ymm, source_bot_vec.ymm), // not_fourth_bit_vec.ymm); blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // _mm256_shuffle_epi8(lut_144_to_159_vec.ymm, source_bot_vec.ymm), // _mm256_shuffle_epi8(lut_128_to_143_vec.ymm, source_bot_vec.ymm), // not_fourth_bit_vec.ymm); blended_160_to_191_vec.ymm = _mm256_blendv_epi8( // _mm256_shuffle_epi8(lut_176_to_191_vec.ymm, source_bot_vec.ymm), // _mm256_shuffle_epi8(lut_160_to_175_vec.ymm, source_bot_vec.ymm), // not_fourth_bit_vec.ymm); blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // _mm256_shuffle_epi8(lut_208_to_223_vec.ymm, source_bot_vec.ymm), // _mm256_shuffle_epi8(lut_192_to_207_vec.ymm, source_bot_vec.ymm), // not_fourth_bit_vec.ymm); blended_224_to_255_vec.ymm = _mm256_blendv_epi8( // _mm256_shuffle_epi8(lut_240_to_255_vec.ymm, source_bot_vec.ymm), // _mm256_shuffle_epi8(lut_224_to_239_vec.ymm, source_bot_vec.ymm), // not_fourth_bit_vec.ymm); // Perform a tree-like reduction of the 8x "blended" YMM registers, depending on the "source" content. // The first round selects using the 3rd bit. not_third_bit_vec.ymm = _mm256_cmpeq_epi8( // _mm256_and_si256(_mm256_set1_epi8((char)0x20), source_vec.ymm), _mm256_setzero_si256()); blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // blended_32_to_63_vec.ymm, // blended_0_to_31_vec.ymm, // not_third_bit_vec.ymm); blended_64_to_95_vec.ymm = _mm256_blendv_epi8( // blended_96_to_127_vec.ymm, // blended_64_to_95_vec.ymm, // not_third_bit_vec.ymm); blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // blended_160_to_191_vec.ymm, // blended_128_to_159_vec.ymm, // not_third_bit_vec.ymm); blended_192_to_223_vec.ymm = _mm256_blendv_epi8( // blended_224_to_255_vec.ymm, // blended_192_to_223_vec.ymm, // not_third_bit_vec.ymm); // The second round selects using the 2nd bit. not_second_bit_vec.ymm = _mm256_cmpeq_epi8( // _mm256_and_si256(_mm256_set1_epi8((char)0x40), source_vec.ymm), _mm256_setzero_si256()); blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // blended_64_to_95_vec.ymm, // blended_0_to_31_vec.ymm, // not_second_bit_vec.ymm); blended_128_to_159_vec.ymm = _mm256_blendv_epi8( // blended_192_to_223_vec.ymm, // blended_128_to_159_vec.ymm, // not_second_bit_vec.ymm); // The third round selects using the 1st bit. not_first_bit_vec.ymm = _mm256_cmpeq_epi8( // _mm256_and_si256(_mm256_set1_epi8((char)0x80), source_vec.ymm), _mm256_setzero_si256()); blended_0_to_31_vec.ymm = _mm256_blendv_epi8( // blended_128_to_159_vec.ymm, // blended_0_to_31_vec.ymm, // not_first_bit_vec.ymm); // And dump the result into the target. _mm256_storeu_si256((__m256i *)target, blended_0_to_31_vec.ymm); source += 32, target += 32, length -= 32; } // Handle the tail. if (length) sz_look_up_transform_serial(source, length, lut, target); } SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { int mask; sz_u256_vec_t h_vec, n_vec; n_vec.ymm = _mm256_set1_epi8(n[0]); while (h_length >= 32) { h_vec.ymm = _mm256_lddqu_si256((__m256i const *)h); mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); if (mask) return h + sz_u32_ctz(mask); h += 32, h_length -= 32; } return sz_find_byte_serial(h, h_length, n); } SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { int mask; sz_u256_vec_t h_vec, n_vec; n_vec.ymm = _mm256_set1_epi8(n[0]); while (h_length >= 32) { h_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + h_length - 32)); mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); if (mask) return h + h_length - 1 - sz_u32_clz(mask); h_length -= 32; } return sz_rfind_byte_serial(h, h_length, n); } SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; if (n_length == 1) return sz_find_byte_avx2(h, h_length, n); // Pick the parts of the needle that are worth comparing. sz_size_t offset_first, offset_mid, offset_last; _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); // Broadcast those characters into YMM registers. int matches; sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); // Scan through the string. for (; h_length >= n_length + 32; h += 32, h_length -= 32) { h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_first)); h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_mid)); h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_last)); matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); while (matches) { int potential_offset = sz_u32_ctz(matches); if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; matches &= matches - 1; } } return sz_find_serial(h, h_length, n, n_length); } SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; if (n_length == 1) return sz_rfind_byte_avx2(h, h_length, n); // Pick the parts of the needle that are worth comparing. sz_size_t offset_first, offset_mid, offset_last; _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); // Broadcast those characters into YMM registers. int matches; sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); // Scan through the string. sz_cptr_t h_reversed; for (; h_length >= n_length + 32; h_length -= 32) { h_reversed = h + h_length - n_length - 32 + 1; h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_first)); h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_mid)); h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_last)); matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); while (matches) { int potential_offset = sz_u32_clz(matches); if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) return h + h_length - n_length - potential_offset; matches &= ~(1 << (31 - potential_offset)); } } return sz_rfind_serial(h, h_length, n, n_length); } SZ_PUBLIC sz_cptr_t sz_find_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. // That way when we invoke `_mm256_shuffle_epi8` we can use the same mask for both lanes. sz_u256_vec_t filter_even_vec, filter_odd_vec; for (sz_size_t i = 0; i != 16; ++i) filter_even_vec.u8s[i] = filter->_u8s[i * 2], filter_odd_vec.u8s[i] = filter->_u8s[i * 2 + 1]; filter_even_vec.xmms[1] = filter_even_vec.xmms[0]; filter_odd_vec.xmms[1] = filter_odd_vec.xmms[0]; sz_u256_vec_t text_vec; sz_u256_vec_t matches_vec; sz_u256_vec_t lower_nibbles_vec, higher_nibbles_vec; sz_u256_vec_t bitset_even_vec, bitset_odd_vec; sz_u256_vec_t bitmask_vec, bitmask_lookup_vec; bitmask_lookup_vec.ymm = _mm256_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); while (length >= 32) { // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so // StrinZilla uses a somewhat different approach. // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new // // sz_u8_t input = *(sz_u8_t const *)text; // sz_u8_t lo_nibble = input & 0x0f; // sz_u8_t hi_nibble = input >> 4; // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; // if ((bitset & bitmask) != 0) return text; // else { length--, text++; } // // The nice part about this, loading the strided data is vey easy with Arm NEON, // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text); lower_nibbles_vec.ymm = _mm256_and_si256(text_vec.ymm, _mm256_set1_epi8(0x0f)); bitmask_vec.ymm = _mm256_shuffle_epi8(bitmask_lookup_vec.ymm, lower_nibbles_vec.ymm); // // At this point we can validate the `bitmask_vec` contents like this: // // for (sz_size_t i = 0; i != 32; ++i) { // sz_u8_t input = *(sz_u8_t const *)(text + i); // sz_u8_t lo_nibble = input & 0x0f; // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); // sz_assert(bitmask_vec.u8s[i] == bitmask); // } // // Shift right every byte by 4 bits. // There is no `_mm256_srli_epi8` intrinsic, so we have to use `_mm256_srli_epi16` // and combine it with a mask to clear the higher bits. higher_nibbles_vec.ymm = _mm256_and_si256(_mm256_srli_epi16(text_vec.ymm, 4), _mm256_set1_epi8(0x0f)); bitset_even_vec.ymm = _mm256_shuffle_epi8(filter_even_vec.ymm, higher_nibbles_vec.ymm); bitset_odd_vec.ymm = _mm256_shuffle_epi8(filter_odd_vec.ymm, higher_nibbles_vec.ymm); // // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: // // for (sz_size_t i = 0; i != 32; ++i) { // sz_u8_t input = *(sz_u8_t const *)(text + i); // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; // sz_u8_t hi_nibble = input >> 4; // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; // sz_assert(bitset_even_vec.u8s[i] == bitset_even); // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); // } // __m256i take_first = _mm256_cmpgt_epi8(_mm256_set1_epi8(8), lower_nibbles_vec.ymm); bitset_even_vec.ymm = _mm256_blendv_epi8(bitset_odd_vec.ymm, bitset_even_vec.ymm, take_first); // It would have been great to have an instruction that tests the bits and then broadcasts // the matching bit into all bits in that byte. But we don't have that, so we have to // `and`, `cmpeq`, `movemask`, and then invert at the end... matches_vec.ymm = _mm256_and_si256(bitset_even_vec.ymm, bitmask_vec.ymm); matches_vec.ymm = _mm256_cmpeq_epi8(matches_vec.ymm, _mm256_setzero_si256()); int matches_mask = ~_mm256_movemask_epi8(matches_vec.ymm); if (matches_mask) { int offset = sz_u32_ctz(matches_mask); return text + offset; } else { text += 32, length -= 32; } } return sz_find_charset_serial(text, length, filter); } SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { return sz_rfind_charset_serial(text, length, filter); } /** * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. * This implementation is coming from Agner Fog's Vector Class Library. */ SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); __m256i prodlh = _mm256_mullo_epi32(a, bswap); __m256i zero = _mm256_setzero_si256(); __m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero); __m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73); __m256i prodll = _mm256_mul_epu32(a, b); __m256i prod = _mm256_add_epi64(prodll, prodlh3); return prod; } SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // sz_hash_callback_t callback, void *callback_handle) { if (length < window_length || !window_length) return; if (length < 4 * window_length) { sz_hashes_serial(start, length, window_length, step, callback, callback_handle); return; } // Using AVX2, we can perform 4 long integer multiplications and additions within one register. // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. sz_size_t const max_hashes = length - window_length + 1; sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. sz_u8_t const *text_first = (sz_u8_t const *)start; sz_u8_t const *text_second = text_first + min_hashes_per_thread; sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; sz_u8_t const *text_end = text_first + length; // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. sz_u64_t prime_power_low = 1, prime_power_high = 1; for (sz_size_t i = 0; i + 1 < window_length; ++i) prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; // Broadcast the constants into the registers. sz_u256_vec_t prime_vec, golden_ratio_vec; sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec; base_low_vec.ymm = _mm256_set1_epi64x(31ull); base_high_vec.ymm = _mm256_set1_epi64x(257ull); shift_high_vec.ymm = _mm256_set1_epi64x(77ull); prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME); golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull); prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low); prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high); // Compute the initial hash values for every one of the four windows. sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec; hash_low_vec.ymm = _mm256_setzero_si256(); hash_high_vec.ymm = _mm256_setzero_si256(); for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; ++text_first, ++text_second, ++text_third, ++text_fourth) { // 1. Multiply the hashes by the base. hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); // 3. Add the incoming characters. hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); // 4. Compute the modulo. Assuming there are only 59 values between our prime // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); } // 5. Compute the hash mix, that will be used to index into the fingerprint. // This includes a serial step at the end. hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); // Now repeat that operation for the remaining characters, discarding older characters. sz_size_t cycle = 1; sz_size_t const step_mask = step - 1; for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { // 0. Load again the four characters we are dropping, shift them, and subtract. chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[-window_length], text_third[-window_length], text_second[-window_length], text_first[-window_length]); chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); hash_low_vec.ymm = _mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm)); hash_high_vec.ymm = _mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm)); // 1. Multiply the hashes by the base. hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); // 3. Add the incoming characters. hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); // 4. Compute the modulo. Assuming there are only 59 values between our prime // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); // 5. Compute the hash mix, that will be used to index into the fingerprint. // This includes a serial step at the end. hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); if ((cycle & step_mask) == 0) { callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); } } } #pragma clang attribute pop #pragma GCC pop_options #endif #pragma endregion /* * @brief AVX-512 implementation of the string search algorithms. * * Different subsets of AVX-512 were introduced in different years: * * 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW * * 2018 CannonLake: IFMA, VBMI * * 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES * * 2020 TigerLake: VP2INTERSECT */ #pragma region AVX - 512 Implementation #if SZ_USE_X86_AVX512 #pragma GCC push_options #pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) #include /** * @brief Helper structure to simplify work with 512-bit registers. */ typedef union sz_u512_vec_t { __m512i zmm; __m256i ymms[2]; __m128i xmms[4]; sz_u64_t u64s[8]; sz_u32_t u32s[16]; sz_u16_t u16s[32]; sz_u8_t u8s[64]; sz_i64_t i64s[8]; sz_i32_t i32s[16]; } sz_u512_vec_t; SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) { // The simplest approach to compute this if we know that `n` is blow or equal 64: // return (1ull << n) - 1; // A slightly more complex approach, if we don't know that `n` is under 64: return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? (sz_u32_t)n : 64); } SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) { // The simplest approach to compute this if we know that `n` is blow or equal 32: // return (1ull << n) - 1; // A slightly more complex approach, if we don't know that `n` is under 32: return _bzhi_u32(0xFFFFFFFF, n < 32 ? (sz_u32_t)n : 32); } SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) { // The simplest approach to compute this if we know that `n` is blow or equal 16: // return (1ull << n) - 1; // A slightly more complex approach, if we don't know that `n` is under 16: return _bzhi_u32(0xFFFFFFFF, n < 16 ? (sz_u32_t)n : 16); } SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) { // The simplest approach to compute this if we know that `n` is blow or equal 16: // return (1ull << n) - 1; // A slightly more complex approach, if we don't know that `n` is under 16: return (__mmask16)_bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); } SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) { // The simplest approach to compute this if we know that `n` is blow or equal 32: // return (1ull << n) - 1; // A slightly more complex approach, if we don't know that `n` is under 32: return _bzhi_u32(0xFFFFFFFF, (sz_u32_t)n); } SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { // The simplest approach to compute this if we know that `n` is blow or equal 64: // return (1ull << n) - 1; // A slightly more complex approach, if we don't know that `n` is under 64: return _bzhi_u64(0xFFFFFFFFFFFFFFFF, (sz_u32_t)n); } SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { sz_u512_vec_t a_vec, b_vec; // Pointer arithmetic is cheap, fetching memory is not! // So we can use the masked loads to fetch at most one cache-line for each string, // compare the prefixes, and only then move forward. sz_size_t a_head_length = 64 - ((sz_size_t)a % 64); // 63 or less. sz_size_t b_head_length = 64 - ((sz_size_t)b % 64); // 63 or less. a_head_length = a_head_length < a_length ? a_head_length : a_length; b_head_length = b_head_length < b_length ? b_head_length : b_length; sz_size_t head_length = a_head_length < b_head_length ? a_head_length : b_head_length; __mmask64 head_mask = _sz_u64_mask_until(head_length); a_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, a); b_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, b); __mmask64 mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); if (mask_not_equal != 0) { sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); char a_char = a_vec.u8s[first_diff]; char b_char = b_vec.u8s[first_diff]; return _sz_order_scalars(a_char, b_char); } else if (head_length == a_length && head_length == b_length) { return sz_equal_k; } else { a += head_length, b += head_length, a_length -= head_length, b_length -= head_length; } // The rare case, when both string are very long. __mmask64 a_mask, b_mask; while ((a_length >= 64) & (b_length >= 64)) { a_vec.zmm = _mm512_loadu_si512(a); b_vec.zmm = _mm512_loadu_si512(b); mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); if (mask_not_equal != 0) { sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); char a_char = a_vec.u8s[first_diff]; char b_char = b_vec.u8s[first_diff]; return _sz_order_scalars(a_char, b_char); } a += 64, b += 64, a_length -= 64, b_length -= 64; } // In most common scenarios at least one of the strings is under 64 bytes. if (a_length | b_length) { a_mask = _sz_u64_clamp_mask_until(a_length); b_mask = _sz_u64_clamp_mask_until(b_length); a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a); b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b); // The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments. // They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have // been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards. mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); if (mask_not_equal != 0) { sz_u64_t first_diff = _tzcnt_u64(mask_not_equal); char a_char = a_vec.u8s[first_diff]; char b_char = b_vec.u8s[first_diff]; return _sz_order_scalars(a_char, b_char); } // From logic perspective, the hardest cases are "abc\0" and "abc". // The result must be `sz_greater_k`, as the latter is shorter. else { return _sz_order_scalars(a_length, b_length); } } return sz_equal_k; } SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { __mmask64 mask; sz_u512_vec_t a_vec, b_vec; while (length >= 64) { a_vec.zmm = _mm512_loadu_si512(a); b_vec.zmm = _mm512_loadu_si512(b); mask = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); if (mask != 0) return sz_false_k; a += 64, b += 64, length -= 64; } if (length) { mask = _sz_u64_mask_until(length); a_vec.zmm = _mm512_maskz_loadu_epi8(mask, a); b_vec.zmm = _mm512_maskz_loadu_epi8(mask, b); // Reuse the same `mask` variable to find the bit that doesn't match mask = _mm512_mask_cmpneq_epi8_mask(mask, a_vec.zmm, b_vec.zmm); return (sz_bool_t)(mask == 0); } return sz_true_k; } SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value) { __m512i value_vec = _mm512_set1_epi8(value); // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "stores". // // for (; length >= 64; target += 64, length -= 64) _mm512_storeu_si512(target, value_vec); // _mm512_mask_storeu_epi8(target, _sz_u64_mask_until(length), value_vec); // // When the buffer is small, there isn't much to innovate. if (length <= 64) { __mmask64 mask = _sz_u64_mask_until(length); _mm512_mask_storeu_epi8(target, mask, value_vec); } // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked // for the body. else { sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. __mmask64 head_mask = _sz_u64_mask_until(head_length); __mmask64 tail_mask = _sz_u64_mask_until(tail_length); _mm512_mask_storeu_epi8(target, head_mask, value_vec); for (target += head_length; body_length >= 64; target += 64, body_length -= 64) _mm512_store_si512(target, value_vec); _mm512_mask_storeu_epi8(target, tail_mask, value_vec); } } SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "stores" and "loads". // // for (; length >= 64; target += 64, source += 64, length -= 64) // _mm512_storeu_si512(target, _mm512_loadu_si512(source)); // __mmask64 mask = _sz_u64_mask_until(length); // _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); // // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. int const is_huge = length >= 1ull * 1024ull * 1024ull; // When the buffer is small, there isn't much to innovate. if (length <= 64) { __mmask64 mask = _sz_u64_mask_until(length); _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); } // When dealing with larger arrays, the optimization is not as simple as with the `sz_fill_avx512` function, // as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer, // we can use aligned loads and stores, and the performance will be great. else if ((sz_size_t)target % 64 == 0 && (sz_size_t)source % 64 == 0 && !is_huge) { for (; length >= 64; target += 64, source += 64, length -= 64) _mm512_store_si512(target, _mm512_load_si512(source)); // At this point the length is guaranteed to be under 64. __mmask64 mask = _sz_u64_mask_until(length); // Aligned load and stores would work too, but it's not defined. _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); } // The trickiest case is when both `source` and `target` are not aligned. // In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary, // and then combine unaligned loads with aligned stores. else if (!is_huge) { sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. __mmask64 head_mask = _sz_u64_mask_until(head_length); __mmask64 tail_mask = _sz_u64_mask_until(tail_length); _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); for (target += head_length, source += head_length; body_length >= 64; target += 64, source += 64, body_length -= 64) _mm512_store_si512(target, _mm512_loadu_si512(source)); // Unaligned load, but aligned store! _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); } // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. // // 1. Moving in both directions to maximize the throughput, when fetching from multiple // memory pages. Also helps with cache set-associativity issues, as we won't always // be fetching the same entries in the lookup table. // 2. Using non-temporal stores to avoid polluting the cache. // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless // for predictable patterns, so disregard this advice. // // Bidirectional traversal adds about 10%, accelerating from 11 GB/s to 12 GB/s. // Using "streaming stores" boosts us from 12 GB/s to 19 GB/s. else { sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; sz_size_t tail_length = (sz_size_t)(target + length) % 64; sz_size_t body_length = length - head_length - tail_length; __mmask64 head_mask = _sz_u64_mask_until(head_length); __mmask64 tail_mask = _sz_u64_mask_until(tail_length); _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length)); // Now in the main loop, we can use non-temporal loads and stores, // performing the operation in both directions. for (target += head_length, source += head_length; // body_length >= 128; // target += 64, source += 64, body_length -= 128) { _mm512_stream_si512((__m512i *)(target), _mm512_loadu_si512(source)); _mm512_stream_si512((__m512i *)(target + body_length - 64), _mm512_loadu_si512(source + body_length - 64)); } if (body_length >= 64) _mm512_stream_si512((__m512i *)target, _mm512_loadu_si512(source)); } } SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { if (target == source) return; // Don't be silly, don't move the data if it's already there. // On very short buffers, that are one cache line in width or less, we don't need any loops. // We can also avoid any data-dependencies between iterations, assuming we have 32 registers // to pre-load the data, before writing it back. if (length <= 64) { __mmask64 mask = _sz_u64_mask_until(length); _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); } else if (length <= 128) { sz_size_t last_length = length - 64; __mmask64 mask = _sz_u64_mask_until(last_length); __m512i source0 = _mm512_loadu_epi8(source); __m512i source1 = _mm512_maskz_loadu_epi8(mask, source + 64); _mm512_storeu_epi8(target, source0); _mm512_mask_storeu_epi8(target + 64, mask, source1); } else if (length <= 192) { sz_size_t last_length = length - 128; __mmask64 mask = _sz_u64_mask_until(last_length); __m512i source0 = _mm512_loadu_epi8(source); __m512i source1 = _mm512_loadu_epi8(source + 64); __m512i source2 = _mm512_maskz_loadu_epi8(mask, source + 128); _mm512_storeu_epi8(target, source0); _mm512_storeu_epi8(target + 64, source1); _mm512_mask_storeu_epi8(target + 128, mask, source2); } else if (length <= 256) { sz_size_t last_length = length - 192; __mmask64 mask = _sz_u64_mask_until(last_length); __m512i source0 = _mm512_loadu_epi8(source); __m512i source1 = _mm512_loadu_epi8(source + 64); __m512i source2 = _mm512_loadu_epi8(source + 128); __m512i source3 = _mm512_maskz_loadu_epi8(mask, source + 192); _mm512_storeu_epi8(target, source0); _mm512_storeu_epi8(target + 64, source1); _mm512_storeu_epi8(target + 128, source2); _mm512_mask_storeu_epi8(target + 192, mask, source3); } // If the regions don't overlap at all, just use "copy" and save some brain cells thinking about corner cases. else if (target + length < source || target >= source + length) { sz_copy_avx512(target, source, length); } // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked // for the body. else { sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. __mmask64 head_mask = _sz_u64_mask_until(head_length); __mmask64 tail_mask = _sz_u64_mask_until(tail_length); // The absolute most common case of using "moves" is shifting the data within a continuous buffer // when adding a removing some values in it. In such cases, a typical shift is by 1, 2, 4, 8, 16, // or 32 bytes, rarely larger. For small shifts, under the size of the ZMM register, we can use shuffles. // // Remember: // - if we are shifting data left, that we are traversing to the right. // - if we are shifting data right, that we are traversing to the left. int const left_to_right_traversal = source > target; // Now we guarantee, that the relative shift within registers is from 1 to 63 bytes and the output is aligned. // Hopefully, we need to shift more than two ZMM registers, so we could consider `valignr` instruction. // Sadly, using `_mm512_alignr_epi8` doesn't make sense, as it operates at a 128-bit granularity. // // - `_mm256_alignr_epi8` shifts entire 256-bit register, but we need many of them. // - `_mm512_alignr_epi32` shifts 512-bit chunks, but only if the `shift` is a multiple of 4 bytes. // - `_mm512_alignr_epi64` shifts 512-bit chunks by 8 bytes. // // All of those have a latency of 1 cycle, and the shift amount must be an immediate value! // For 1-byte-shift granularity, the `_mm512_permutex2var_epi8` has a latency of 6 and needs VBMI! // The most efficient and broadly compatible alternative could be to use a combination of align and shuffle. // A similar approach was outlined in "Byte-wise alignr in AVX512F" by Wojciech Muła. // http://0x80.pl/notesen/2016-10-16-avx512-byte-alignr.html // // That solution, is extremely mouthful, assuming we need compile time constants for the shift amount. // A cleaner one, with a latency of 3 cycles, is to use `_mm512_permutexvar_epi8` or // `_mm512_mask_permutexvar_epi8`, which can be seen as combination of a cross-register shuffle and blend, // and is available with VBMI. That solution is still noticeably slower than AVX2. // // The GLibC implementation also uses non-temporal stores for larger buffers, we don't. // https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memmove-avx512-no-vzeroupper.S.html if (left_to_right_traversal) { // Head, body, and tail. _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); for (target += head_length, source += head_length; body_length >= 64; target += 64, source += 64, body_length -= 64) _mm512_store_si512(target, _mm512_loadu_si512(source)); _mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source)); } else { // Tail, body, and head. _mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length)); for (; body_length >= 64; body_length -= 64) _mm512_store_si512(target + head_length + body_length - 64, _mm512_loadu_si512(source + head_length + body_length - 64)); _mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source)); } } } SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { __mmask64 mask; sz_u512_vec_t h_vec, n_vec; n_vec.zmm = _mm512_set1_epi8(n[0]); while (h_length >= 64) { h_vec.zmm = _mm512_loadu_si512(h); mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); if (mask) return h + sz_u64_ctz(mask); h += 64, h_length -= 64; } if (h_length) { mask = _sz_u64_mask_until(h_length); h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); // Reuse the same `mask` variable to find the bit that doesn't match mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); if (mask) return h + sz_u64_ctz(mask); } return SZ_NULL_CHAR; } SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; if (n_length == 1) return sz_find_byte_avx512(h, h_length, n); // Pick the parts of the needle that are worth comparing. sz_size_t offset_first, offset_mid, offset_last; _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); // Broadcast those characters into ZMM registers. __mmask64 matches; __mmask64 mask; sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); // Scan through the string. // We have several optimized versions of the lagorithm for shorter strings, // but they all mimic the default case for unbounded length needles if (n_length >= 64) { for (; h_length >= n_length + 64; h += 64, h_length -= 64) { h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); matches = _kand_mask64(_kand_mask64( // Intersect the masks _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); while (matches) { int potential_offset = sz_u64_ctz(matches); if (sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; matches &= matches - 1; } // TODO: If the last character contains a bad byte, we can reposition the start of the next iteration. // This will be very helpful for very long needles. } } // If there are only 2 or 3 characters in the needle, we don't even need the nested loop. else if (n_length <= 3) { for (; h_length >= n_length + 64; h += 64, h_length -= 64) { h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); matches = _kand_mask64(_kand_mask64( // Intersect the masks _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); if (matches) return h + sz_u64_ctz(matches); } } // If the needle is smaller than the size of the ZMM register, we can use masked comparisons // to avoid the the inner-most nested loop and compare the entire needle against a haystack // slice in 3 CPU cycles. else { __mmask64 n_mask = _sz_u64_mask_until(n_length); sz_u512_vec_t n_full_vec, h_full_vec; n_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, n); for (; h_length >= n_length + 64; h += 64, h_length -= 64) { h_first_vec.zmm = _mm512_loadu_si512(h + offset_first); h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid); h_last_vec.zmm = _mm512_loadu_si512(h + offset_last); matches = _kand_mask64(_kand_mask64( // Intersect the masks _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); while (matches) { int potential_offset = sz_u64_ctz(matches); h_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, h + potential_offset); if (_mm512_mask_cmpneq_epi8_mask(n_mask, h_full_vec.zmm, n_full_vec.zmm) == 0) return h + potential_offset; matches &= matches - 1; } } } // The "tail" of the function uses masked loads to process the remaining bytes. { mask = _sz_u64_mask_until(h_length - n_length + 1); h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); matches = _kand_mask64(_kand_mask64( // Intersect the masks _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); while (matches) { int potential_offset = sz_u64_ctz(matches); if (n_length <= 3 || sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; matches &= matches - 1; } } return SZ_NULL_CHAR; } SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { __mmask64 mask; sz_u512_vec_t h_vec, n_vec; n_vec.zmm = _mm512_set1_epi8(n[0]); while (h_length >= 64) { h_vec.zmm = _mm512_loadu_si512(h + h_length - 64); mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); if (mask) return h + h_length - 1 - sz_u64_clz(mask); h_length -= 64; } if (h_length) { mask = _sz_u64_mask_until(h_length); h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); // Reuse the same `mask` variable to find the bit that doesn't match mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); if (mask) return h + 64 - sz_u64_clz(mask) - 1; } return SZ_NULL_CHAR; } SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n); // Pick the parts of the needle that are worth comparing. sz_size_t offset_first, offset_mid, offset_last; _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); // Broadcast those characters into ZMM registers. __mmask64 mask; __mmask64 matches; sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); // Scan through the string. sz_cptr_t h_reversed; for (; h_length >= n_length + 64; h_length -= 64) { h_reversed = h + h_length - n_length - 64 + 1; h_first_vec.zmm = _mm512_loadu_si512(h_reversed + offset_first); h_mid_vec.zmm = _mm512_loadu_si512(h_reversed + offset_mid); h_last_vec.zmm = _mm512_loadu_si512(h_reversed + offset_last); matches = _kand_mask64(_kand_mask64( // Intersect the masks _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); while (matches) { int potential_offset = sz_u64_clz(matches); if (n_length <= 3 || sz_equal_avx512(h + h_length - n_length - potential_offset, n, n_length)) return h + h_length - n_length - potential_offset; sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && "The bit must be set before we squash it"); matches &= ~((sz_u64_t)1 << (63 - potential_offset)); } } // The "tail" of the function uses masked loads to process the remaining bytes. { mask = _sz_u64_mask_until(h_length - n_length + 1); h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); matches = _kand_mask64(_kand_mask64( // Intersect the masks _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); while (matches) { int potential_offset = sz_u64_clz(matches); if (n_length <= 3 || sz_equal_avx512(h + 64 - potential_offset - 1, n, n_length)) return h + 64 - potential_offset - 1; sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && "The bit must be set before we squash it"); matches &= ~((sz_u64_t)1 << (63 - potential_offset)); } } return SZ_NULL_CHAR; } SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // sz_cptr_t shorter, sz_size_t shorter_length, // sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. sz_memory_allocator_t global_alloc; if (!alloc) { sz_memory_allocator_init_default(&global_alloc); alloc = &global_alloc; } // TODO: Generalize! sz_size_t max_length = 256u * 256u; sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); sz_unused(longer_length && bound && max_length); // We are going to store 3 diagonals of the matrix. // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. sz_size_t n = shorter_length + 1; // Unlike the serial version, we also want to avoid reverse-order iteration over teh shorter string. // So let's allocate a bit more memory and reverse-export our shorter string into that buffer. sz_size_t buffer_length = sizeof(sz_u16_t) * n * 3 + shorter_length; sz_u16_t *distances = (sz_u16_t *)alloc->allocate(buffer_length, alloc->handle); if (!distances) return SZ_SIZE_MAX; sz_u16_t *previous_distances = distances; sz_u16_t *current_distances = previous_distances + n; sz_u16_t *next_distances = current_distances + n; sz_ptr_t shorter_reversed = (sz_ptr_t)(next_distances + n); // Export the reversed string into the buffer. for (sz_size_t i = 0; i != shorter_length; ++i) shorter_reversed[i] = shorter[shorter_length - 1 - i]; // Initialize the first two diagonals: previous_distances[0] = 0; current_distances[0] = current_distances[1] = 1; // Using ZMM registers, we can process 32x 16-bit values at once, // storing 16 bytes of each string in YMM registers. sz_u512_vec_t insertions_vec, deletions_vec, substitutions_vec, next_vec; sz_u512_vec_t ones_u16_vec; ones_u16_vec.zmm = _mm512_set1_epi16(1); // This is a mixed-precision implementation, using 8-bit representations for part of the operations. // Even there, in case `SZ_USE_X86_AVX2=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. sz_u512_vec_t shorter_vec, longer_vec; sz_u512_vec_t ones_u8_vec; ones_u8_vec.ymms[0] = _mm256_set1_epi8(1); // Progress through the upper triangle of the Levenshtein matrix. sz_size_t next_skew_diagonal_index = 2; for (; next_skew_diagonal_index != n; ++next_skew_diagonal_index) { sz_size_t const next_skew_diagonal_length = next_skew_diagonal_index + 1; for (sz_size_t i = 0; i + 2 < next_skew_diagonal_length;) { sz_u32_t remaining_length = (sz_u32_t)(next_skew_diagonal_length - i - 2); sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + i); // Our original code addressed the shorter string `[next_skew_diagonal_index - i - 2]` for growing `i`. // If the `shorter` string was reversed, the `[next_skew_diagonal_index - i - 2]` would // be equal to `[shorter_length - 1 - next_skew_diagonal_index + i + 2]`. // Which simplified would be equal to `[shorter_length - next_skew_diagonal_index + i + 1]`. shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8( remaining_length_mask, shorter_reversed + shorter_length - next_skew_diagonal_index + i + 1); // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 // to get the result as a vector, instead of a bitmask. Adding 1 to every scalar we can overflow // transforming from {0xFF, 0} values to {0, 1} values - exactly what we need. Then - upcast to 16-bit. substitutions_vec.zmm = _mm512_cvtepi8_epi16( // _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); substitutions_vec.zmm = _mm512_add_epi16( // substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i)); // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, // than rotate the bytes in the ZMM register. insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i); deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1); // First get the minimum of insertions and deletions. next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); _mm512_mask_storeu_epi16(next_distances + i + 1, remaining_length_mask, next_vec.zmm); i += register_length; } // Don't forget to populate the first row and the fiest column of the Levenshtein matrix. next_distances[0] = next_distances[next_skew_diagonal_length - 1] = (sz_u16_t)next_skew_diagonal_index; // Perform a circular rotation of those buffers, to reuse the memory. sz_u16_t *temporary = previous_distances; previous_distances = current_distances; current_distances = next_distances; next_distances = temporary; } // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal // index on either side, we will be cropping those values out. sz_size_t total_diagonals = n + n - 1; for (; next_skew_diagonal_index != total_diagonals; ++next_skew_diagonal_index) { sz_size_t const next_skew_diagonal_length = total_diagonals - next_skew_diagonal_index; for (sz_size_t i = 0; i != next_skew_diagonal_length;) { sz_u32_t remaining_length = (sz_u32_t)(next_skew_diagonal_length - i); sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32; sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_skew_diagonal_index - n + i); // Our original code addressed the shorter string `[shorter_length - 1 - i]` for growing `i`. // If the `shorter` string was reversed, the `[shorter_length - 1 - i]` would // be equal to `[shorter_length - 1 - shorter_length + 1 + i]`. // Which simplified would be equal to just `[i]`. Beautiful! shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, shorter_reversed + i); // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 // to get the result as a vector, instead of a bitmask. The compare it against the accumulated // substitution costs. substitutions_vec.zmm = _mm512_cvtepi8_epi16( // _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); substitutions_vec.zmm = _mm512_add_epi16( // substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i)); // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, // than rotate the bytes in the ZMM register. insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i); deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1); // First get the minimum of insertions and deletions. next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); _mm512_mask_storeu_epi16(next_distances + i, remaining_length_mask, next_vec.zmm); i += register_length; } // Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift, // dropping the first element in the current array. sz_u16_t *temporary = previous_distances; previous_distances = current_distances + 1; current_distances = next_distances; next_distances = temporary; } // Cache scalar before `free` call. sz_size_t result = current_distances[0]; alloc->free(distances, buffer_length, alloc->handle); return result; } SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // sz_cptr_t shorter, sz_size_t shorter_length, // sz_cptr_t longer, sz_size_t longer_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { if (shorter_length == longer_length && !bound && shorter_length && shorter_length < 256u * 256u) return _sz_edit_distance_skewed_diagonals_upto65k_avx512(shorter, shorter_length, longer, longer_length, bound, alloc); else return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); } #pragma clang attribute pop #pragma GCC pop_options #pragma GCC push_options #pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "bmi", "bmi2") #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,bmi,bmi2"))), \ apply_to = function) SZ_PUBLIC sz_u64_t sz_checksum_avx512(sz_cptr_t text, sz_size_t length) { // The naive implementation of this function is very simple. // It assumes the CPU is great at handling unaligned "loads". // // A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core, // 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache. // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. int const is_huge = length >= 1ull * 1024ull * 1024ull; sz_u512_vec_t text_vec, sums_vec; // When the buffer is small, there isn't much to innovate. if (length <= 16) { __mmask16 mask = _sz_u16_mask_until(length); text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text); sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128()); sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]); sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1); return low + high; } else if (length <= 32) { __mmask32 mask = _sz_u32_mask_until(length); text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text); sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256()); // Accumulating 256 bits is harders, as we need to extract the 128-bit sums first. __m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]); __m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1); __m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm); sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm); sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1); return low + high; } else if (length <= 64) { __mmask64 mask = _sz_u64_mask_until(length); text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text); sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); return _mm512_reduce_add_epi64(sums_vec.zmm); } else if (!is_huge) { sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less. sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less. sz_size_t body_length = length - head_length - tail_length; // Multiple of 64. __mmask64 head_mask = _sz_u64_mask_until(head_length); __mmask64 tail_mask = _sz_u64_mask_until(tail_length); text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); for (text += head_length; body_length >= 64; text += 64, body_length -= 64) { text_vec.zmm = _mm512_load_si512((__m512i const *)text); sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); } text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text); sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); return _mm512_reduce_add_epi64(sums_vec.zmm); } // For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use. // // 1. Moving in both directions to maximize the throughput, when fetching from multiple // memory pages. Also helps with cache set-associativity issues, as we won't always // be fetching the same entries in the lookup table. // 2. Using non-temporal stores to avoid polluting the cache. // 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless // for predictable patterns, so disregard this advice. // // Bidirectional traversal generally adds about 10% to such algorithms. else { sz_u512_vec_t text_reversed_vec, sums_reversed_vec; sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; sz_size_t tail_length = (sz_size_t)(text + length) % 64; sz_size_t body_length = length - head_length - tail_length; __mmask64 head_mask = _sz_u64_mask_until(head_length); __mmask64 tail_mask = _sz_u64_mask_until(tail_length); text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text); sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()); text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length); sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()); // Now in the main loop, we can use non-temporal loads and stores, // performing the operation in both directions. for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) { text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64)); sums_reversed_vec.zmm = _mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512())); } if (body_length >= 64) { text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text)); sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512())); } return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm)); } } SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // sz_hash_callback_t callback, void *callback_handle) { if (length < window_length || !window_length) return; if (length < 4 * window_length) { sz_hashes_serial(start, length, window_length, step, callback, callback_handle); return; } // Using AVX2, we can perform 4 long integer multiplications and additions within one register. // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. sz_size_t const max_hashes = length - window_length + 1; sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. sz_u8_t const *text_first = (sz_u8_t const *)start; sz_u8_t const *text_second = text_first + min_hashes_per_thread; sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; sz_u8_t const *text_end = text_first + length; // Broadcast the global constants into the registers. // Both high and low hashes will work with the same prime and golden ratio. sz_u512_vec_t prime_vec, golden_ratio_vec; prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME); golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull); // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. sz_u64_t prime_power_low = 1, prime_power_high = 1; for (sz_size_t i = 0; i + 1 < window_length; ++i) prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; // We will be evaluating 4 offsets at a time with 2 different hash functions. // We can fit all those 8 state variables in each of the following ZMM registers. sz_u512_vec_t base_vec, prime_power_vec, shift_vec; base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull); shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull); prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low, prime_power_high, prime_power_high, prime_power_high, prime_power_high); // Compute the initial hash values for every one of the four windows. sz_u512_vec_t hash_vec, chars_vec; hash_vec.zmm = _mm512_setzero_si512(); for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; ++text_first, ++text_second, ++text_third, ++text_fourth) { // 1. Multiply the hashes by the base. hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`... chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // text_fourth[0], text_third[0], text_second[0], text_first[0]); chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); // 3. Add the incoming characters. hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); // 4. Compute the modulo. Assuming there are only 59 values between our prime // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); } // 5. Compute the hash mix, that will be used to index into the fingerprint. // This includes a serial step at the end. sz_u512_vec_t hash_mix_vec; hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // _mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0)); callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); // Now repeat that operation for the remaining characters, discarding older characters. sz_size_t cycle = 1; sz_size_t step_mask = step - 1; for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { // 0. Load again the four characters we are dropping, shift them, and subtract. chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length], text_second[-window_length], text_first[-window_length], // text_fourth[-window_length], text_third[-window_length], text_second[-window_length], text_first[-window_length]); chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm)); // 1. Multiply the hashes by the base. hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // text_fourth[0], text_third[0], text_second[0], text_first[0]); chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); // ... and prefetch the next four characters into Level 2 or higher. _mm_prefetch((sz_cptr_t)text_fourth + 1, _MM_HINT_T1); _mm_prefetch((sz_cptr_t)text_third + 1, _MM_HINT_T1); _mm_prefetch((sz_cptr_t)text_second + 1, _MM_HINT_T1); _mm_prefetch((sz_cptr_t)text_first + 1, _MM_HINT_T1); // 3. Add the incoming characters. hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); // 4. Compute the modulo. Assuming there are only 59 values between our prime // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); // 5. Compute the hash mix, that will be used to index into the fingerprint. // This includes a serial step at the end. hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // _mm512_castsi512_si256(hash_mix_vec.zmm)); if ((cycle & step_mask) == 0) { callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); } } } #pragma clang attribute pop #pragma GCC pop_options #pragma GCC push_options #pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2") #pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \ apply_to = function) SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { // If the input is tiny (especially smaller than the look-up table itself), we may end up paying // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. // But if at least 3 cache lines are touched, the AVX-512 implementation should be faster. if (length <= 128) { sz_look_up_transform_serial(source, length, lut, target); return; } // When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail, // and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores // by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked // for the body. sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. __mmask64 head_mask = _sz_u64_mask_until(head_length); __mmask64 tail_mask = _sz_u64_mask_until(tail_length); // We need to pull the lookup table into 4x ZMM registers. // We can use `vpermi2b` instruction to perform the look in two ZMM registers with `_mm512_permutex2var_epi8` // intrinsics, but it has a 6-cycle latency on Sapphire Rapids and requires AVX512-VBMI. Assuming we need to // operate on 4 registers, it might be cleaner to use 2x separate `_mm512_permutexvar_epi8` calls. // Combining the results with 2x `_mm512_test_epi8_mask` and 3x blends afterwards. // // - 4x `_mm512_permutexvar_epi8` maps to "VPERMB (ZMM, ZMM, ZMM)": // - On Ice Lake: 3 cycles latency, ports: 1*p5 // - On Genoa: 6 cycles latency, ports: 1*FP12 // - 3x `_mm512_mask_blend_epi8` maps to "VPBLENDMB_Z (ZMM, K, ZMM, ZMM)": // - On Ice Lake: 3 cycles latency, ports: 1*p05 // - On Genoa: 1 cycle latency, ports: 1*FP0123 // - 2x `_mm512_test_epi8_mask` maps to "VPTESTMB (K, ZMM, ZMM)": // - On Ice Lake: 3 cycles latency, ports: 1*p5 // - On Genoa: 4 cycles latency, ports: 1*FP01 // sz_u512_vec_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; lut_0_to_63_vec.zmm = _mm512_loadu_si512((lut)); lut_64_to_127_vec.zmm = _mm512_loadu_si512((lut + 64)); lut_128_to_191_vec.zmm = _mm512_loadu_si512((lut + 128)); lut_192_to_255_vec.zmm = _mm512_loadu_si512((lut + 192)); sz_u512_vec_t first_bit_vec, second_bit_vec; first_bit_vec.zmm = _mm512_set1_epi8((char)0x80); second_bit_vec.zmm = _mm512_set1_epi8((char)0x40); __mmask64 first_bit_mask, second_bit_mask; sz_u512_vec_t source_vec; // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. sz_u512_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; sz_u512_vec_t blended_0_to_127_vec, blended_128_to_255_vec, blended_0_to_255_vec; // Handling the head. if (head_length) { source_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, source); lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); blended_0_to_127_vec.zmm = _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); blended_128_to_255_vec.zmm = _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); blended_0_to_255_vec.zmm = _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); _mm512_mask_storeu_epi8(target, head_mask, blended_0_to_255_vec.zmm); source += head_length, target += head_length, length -= head_length; } // Handling the body in 64-byte chunks aligned to cache-line boundaries with respect to `target`. while (length >= 64) { source_vec.zmm = _mm512_loadu_si512(source); lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); blended_0_to_127_vec.zmm = _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); blended_128_to_255_vec.zmm = _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); blended_0_to_255_vec.zmm = _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); _mm512_store_si512(target, blended_0_to_255_vec.zmm); //! Aligned store, our main weapon! source += 64, target += 64, length -= 64; } // Handling the tail. if (tail_length) { source_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, source); lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm); lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm); lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm); lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm); first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm); second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm); blended_0_to_127_vec.zmm = _mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm); blended_128_to_255_vec.zmm = _mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm); blended_0_to_255_vec.zmm = _mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm); _mm512_mask_storeu_epi8(target, tail_mask, blended_0_to_255_vec.zmm); source += tail_length, target += tail_length, length -= tail_length; } } SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { // Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes. // In practice, that only hurts, even when we have matches every 5-ish bytes. // // if (length < SZ_SWAR_THRESHOLD) return sz_find_charset_serial(text, length, filter); // sz_cptr_t early_result = sz_find_charset_serial(text, SZ_SWAR_THRESHOLD, filter); // if (early_result) return early_result; // text += SZ_SWAR_THRESHOLD; // length -= SZ_SWAR_THRESHOLD; // // Let's unzip even and odd elements and replicate them into both lanes of the YMM register. // That way when we invoke `_mm512_shuffle_epi8` we can use the same mask for both lanes. sz_u512_vec_t filter_even_vec, filter_odd_vec; __m256i filter_ymm = _mm256_lddqu_si256((__m256i const *)filter); // There are a few way to initialize filters without having native strided loads. // In the cronological order of experiments: // - serial code initializing 128 bytes of odd and even mask // - using several shuffles // - using `_mm512_permutexvar_epi8` // - using `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0x55555555, filter_ymm)))` // and `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm)))` filter_even_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i _mm256_maskz_compress_epi8(0x55555555, filter_ymm))); filter_odd_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i _mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm))); // After the unzipping operation, we can validate the contents of the vectors like this: // // for (sz_size_t i = 0; i != 16; ++i) { // sz_assert(filter_even_vec.u8s[i] == filter->_u8s[i * 2]); // sz_assert(filter_odd_vec.u8s[i] == filter->_u8s[i * 2 + 1]); // sz_assert(filter_even_vec.u8s[i + 16] == filter->_u8s[i * 2]); // sz_assert(filter_odd_vec.u8s[i + 16] == filter->_u8s[i * 2 + 1]); // sz_assert(filter_even_vec.u8s[i + 32] == filter->_u8s[i * 2]); // sz_assert(filter_odd_vec.u8s[i + 32] == filter->_u8s[i * 2 + 1]); // sz_assert(filter_even_vec.u8s[i + 48] == filter->_u8s[i * 2]); // sz_assert(filter_odd_vec.u8s[i + 48] == filter->_u8s[i * 2 + 1]); // } // sz_u512_vec_t text_vec; sz_u512_vec_t lower_nibbles_vec, higher_nibbles_vec; sz_u512_vec_t bitset_even_vec, bitset_odd_vec; sz_u512_vec_t bitmask_vec, bitmask_lookup_vec; bitmask_lookup_vec.zmm = _mm512_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, // -128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1); while (length) { // The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set" // solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so // StrinZilla uses a somewhat different approach. // http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new // // sz_u8_t input = *(sz_u8_t const *)text; // sz_u8_t lo_nibble = input & 0x0f; // sz_u8_t hi_nibble = input >> 4; // sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble]; // sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble]; // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); // sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd; // if ((bitset & bitmask) != 0) return text; // else { length--, text++; } // // The nice part about this, loading the strided data is vey easy with Arm NEON, // while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either. sz_size_t load_length = sz_min_of_two(length, 64); __mmask64 load_mask = _sz_u64_mask_until(load_length); text_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, text); lower_nibbles_vec.zmm = _mm512_and_si512(text_vec.zmm, _mm512_set1_epi8(0x0f)); bitmask_vec.zmm = _mm512_shuffle_epi8(bitmask_lookup_vec.zmm, lower_nibbles_vec.zmm); // // At this point we can validate the `bitmask_vec` contents like this: // // for (sz_size_t i = 0; i != load_length; ++i) { // sz_u8_t input = *(sz_u8_t const *)(text + i); // sz_u8_t lo_nibble = input & 0x0f; // sz_u8_t bitmask = (1 << (lo_nibble & 0x7)); // sz_assert(bitmask_vec.u8s[i] == bitmask); // } // // Shift right every byte by 4 bits. // There is no `_mm512_srli_epi8` intrinsic, so we have to use `_mm512_srli_epi16` // and combine it with a mask to clear the higher bits. higher_nibbles_vec.zmm = _mm512_and_si512(_mm512_srli_epi16(text_vec.zmm, 4), _mm512_set1_epi8(0x0f)); bitset_even_vec.zmm = _mm512_shuffle_epi8(filter_even_vec.zmm, higher_nibbles_vec.zmm); bitset_odd_vec.zmm = _mm512_shuffle_epi8(filter_odd_vec.zmm, higher_nibbles_vec.zmm); // // At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this: // // for (sz_size_t i = 0; i != load_length; ++i) { // sz_u8_t input = *(sz_u8_t const *)(text + i); // sz_u8_t const *bitset_ptr = &filter->_u8s[0]; // sz_u8_t hi_nibble = input >> 4; // sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2]; // sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1]; // sz_assert(bitset_even_vec.u8s[i] == bitset_even); // sz_assert(bitset_odd_vec.u8s[i] == bitset_odd); // } // // TODO: Is this a good place for ternary logic? __mmask64 take_first = _mm512_cmplt_epi8_mask(lower_nibbles_vec.zmm, _mm512_set1_epi8(8)); bitset_even_vec.zmm = _mm512_mask_blend_epi8(take_first, bitset_odd_vec.zmm, bitset_even_vec.zmm); __mmask64 matches_mask = _mm512_mask_test_epi8_mask(load_mask, bitset_even_vec.zmm, bitmask_vec.zmm); if (matches_mask) { int offset = sz_u64_ctz(matches_mask); return text + offset; } else { text += load_length, length -= load_length; } } return SZ_NULL_CHAR; } SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { return sz_rfind_charset_serial(text, length, filter); } /** * Computes the Needleman Wunsch alignment score between two strings. * The method uses 32-bit integers to accumulate the running score for every cell in the matrix. * Assuming the costs of substitutions can be arbitrary signed 8-bit integers, the method is expected to be used * on strings not exceeding 2^24 length or 16.7 million characters. * * Unlike the `_sz_edit_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store * the accumulated score. Moreover, it's primary bottleneck is the latency of gathering the substitution costs * from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string with * a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character against * a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of * a 256 x 256 matrix, but from a single row! */ SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( // sz_cptr_t shorter, sz_size_t shorter_length, // sz_cptr_t longer, sz_size_t longer_length, // sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { // If one of the strings is empty - the edit distance is equal to the length of the other one if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; // Let's make sure that we use the amount proportional to the // number of elements in the shorter string, not the larger. if (shorter_length > longer_length) { sz_pointer_swap((void **)&longer_length, (void **)&shorter_length); sz_pointer_swap((void **)&longer, (void **)&shorter); } // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. sz_memory_allocator_t global_alloc; if (!alloc) { sz_memory_allocator_init_default(&global_alloc); alloc = &global_alloc; } sz_size_t const max_length = 256ull * 256ull * 256ull; sz_size_t const n = longer_length + 1; sz_assert(n < max_length && "The length must fit into 24-bit integer. Otherwise use serial variant."); sz_unused(longer_length && max_length); sz_size_t buffer_length = sizeof(sz_i32_t) * n * 2; sz_i32_t *distances = (sz_i32_t *)alloc->allocate(buffer_length, alloc->handle); sz_i32_t *previous_distances = distances; sz_i32_t *current_distances = previous_distances + n; // Intialize the first row of the Levenshtein matrix with `iota`. for (sz_size_t idx_longer = 0; idx_longer != n; ++idx_longer) previous_distances[idx_longer] = (sz_i32_t)idx_longer * gap; /// Contains up to 16 consecutive characters from the longer string. sz_u512_vec_t longer_vec; sz_u512_vec_t cost_deletion_vec, cost_substitution_vec, lookup_substitution_vec, current_vec; sz_u512_vec_t row_first_subs_vec, row_second_subs_vec, row_third_subs_vec, row_fourth_subs_vec; sz_u512_vec_t shuffled_first_subs_vec, shuffled_second_subs_vec, shuffled_third_subs_vec, shuffled_fourth_subs_vec; // Prepare constants and masks. sz_u512_vec_t is_third_or_fourth_vec, is_second_or_fourth_vec, gap_vec; { char is_third_or_fourth_check, is_second_or_fourth_check; *(sz_u8_t *)&is_third_or_fourth_check = 0x80, *(sz_u8_t *)&is_second_or_fourth_check = 0x40; is_third_or_fourth_vec.zmm = _mm512_set1_epi8(is_third_or_fourth_check); is_second_or_fourth_vec.zmm = _mm512_set1_epi8(is_second_or_fourth_check); gap_vec.zmm = _mm512_set1_epi32(gap); } sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { sz_i32_t last_in_row = current_distances[0] = (sz_i32_t)(idx_shorter + 1) * gap; // Load one row of the substitution matrix into four ZMM registers. sz_error_cost_t const *row_subs = subs + shorter_unsigned[idx_shorter] * 256u; row_first_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 0); row_second_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 1); row_third_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 2); row_fourth_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 3); // In the serial version we have one forward pass, that computes the deletion, // insertion, and substitution costs at once. // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { // sz_ssize_t cost_deletion = previous_distances[idx_longer + 1] + gap; // sz_ssize_t cost_insertion = current_distances[idx_longer] + gap; // sz_ssize_t cost_substitution = previous_distances[idx_longer] + row_subs[longer_unsigned[idx_longer]]; // current_distances[idx_longer + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); // } // // Given the complexity of handling the data-dependency between consecutive insertion cost computations // within a Levenshtein matrix, the simplest design would be to vectorize every kind of cost computation // separately. // 1. Compute substitution costs for up to 64 characters at once, upcasting from 8-bit integers to 32. // 2. Compute the pairwise minimum with deletion costs. // 3. Inclusive prefix minimum computation to combine with addition costs. // Proceeding with substitutions: for (sz_size_t idx_longer = 0; idx_longer < longer_length; idx_longer += 64) { sz_size_t register_length = sz_min_of_two(longer_length - idx_longer, 64); __mmask64 mask = _sz_u64_mask_until(register_length); longer_vec.zmm = _mm512_maskz_loadu_epi8(mask, longer + idx_longer); // Blend the `row_(first|second|third|fourth)_subs_vec` into `current_vec`, picking the right source // for every character in `longer_vec`. Before that, we need to permute the subsititution vectors. // Only the bottom 6 bits of a byte are used in VPERB, so we don't even need to mask. shuffled_first_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_first_subs_vec.zmm); shuffled_second_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_second_subs_vec.zmm); shuffled_third_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_third_subs_vec.zmm); shuffled_fourth_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_fourth_subs_vec.zmm); // To blend we can invoke three `_mm512_cmplt_epu8_mask`, but we can also achieve the same using // the AND logical operation, checking the top two bits of every byte. // Continuing this thought, we can use the VPTESTMB instruction to output the mask after the AND. __mmask64 is_third_or_fourth = _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_third_or_fourth_vec.zmm); __mmask64 is_second_or_fourth = _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_second_or_fourth_vec.zmm); lookup_substitution_vec.zmm = _mm512_mask_blend_epi8( is_third_or_fourth, // Choose between the first and the second. _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_first_subs_vec.zmm, shuffled_second_subs_vec.zmm), // Choose between the third and the fourth. _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_third_subs_vec.zmm, shuffled_fourth_subs_vec.zmm)); // First, sign-extend lower and upper 16 bytes to 16-bit integers. __m512i current_0_31_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 0)); __m512i current_32_63_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 1)); // Now extend those 16-bit integers to 32-bit. // This isn't free, same as the subsequent store, so we only want to do that for the populated lanes. // To minimize the number of loads and stores, we can combine our substitution costs with the previous // distances, containing the deletion costs. { cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer); cost_substitution_vec.zmm = _mm512_add_epi32( cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 0))); cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer); cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); // Inclusive prefix minimum computation to combine with insertion costs. // Simply disabling this operation results in 5x performance improvement, meaning // that this operation is responsible for 80% of the total runtime. // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { // current_distances[idx_longer + 1] = // sz_max_of_two(current_distances[idx_longer] + gap, current_distances[idx_longer + 1]); // } // // To perform the same operation in vectorized form, we need to perform a tree-like reduction, // that will involve multiple steps. It's quite expensive and should be first tested in the // "experimental" section. // // Another approach might be loop unrolling: // current_vec.i32s[0] = last_in_row = sz_i32_max_of_two(current_vec.i32s[0], last_in_row + gap); // current_vec.i32s[1] = last_in_row = sz_i32_max_of_two(current_vec.i32s[1], last_in_row + gap); // current_vec.i32s[2] = last_in_row = sz_i32_max_of_two(current_vec.i32s[2], last_in_row + gap); // ... yet this approach is also quite expensive. for (int i = 0; i != 16; ++i) current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); _mm512_mask_storeu_epi32(current_distances + idx_longer + 1, (__mmask16)mask, current_vec.zmm); } // Export the values from 16 to 31. if (register_length > 16) { mask = _kshiftri_mask64(mask, 16); cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 16); cost_substitution_vec.zmm = _mm512_add_epi32( cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 1))); cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 16); cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); // Aggregate running insertion costs within the register. for (int i = 0; i != 16; ++i) current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 16, (__mmask16)mask, current_vec.zmm); } // Export the values from 32 to 47. if (register_length > 32) { mask = _kshiftri_mask64(mask, 16); cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 32); cost_substitution_vec.zmm = _mm512_add_epi32( cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 0))); cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 32); cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); // Aggregate running insertion costs within the register. for (int i = 0; i != 16; ++i) current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 32, (__mmask16)mask, current_vec.zmm); } // Export the values from 32 to 47. if (register_length > 48) { mask = _kshiftri_mask64(mask, 16); cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 48); cost_substitution_vec.zmm = _mm512_add_epi32( cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 1))); cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 48); cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); // Aggregate running insertion costs within the register. for (int i = 0; i != 16; ++i) current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 48, (__mmask16)mask, current_vec.zmm); } } // Swap previous_distances and current_distances pointers sz_pointer_swap((void **)&previous_distances, (void **)¤t_distances); } // Cache scalar before `free` call. sz_ssize_t result = previous_distances[longer_length]; alloc->free(distances, buffer_length, alloc->handle); return result; } SZ_INTERNAL sz_ssize_t sz_alignment_score_avx512( // sz_cptr_t shorter, sz_size_t shorter_length, // sz_cptr_t longer, sz_size_t longer_length, // sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { if (sz_max_of_two(shorter_length, longer_length) < (256ull * 256ull * 256ull)) return _sz_alignment_score_wagner_fisher_upto17m_avx512(shorter, shorter_length, longer, longer_length, subs, gap, alloc); else return sz_alignment_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc); } enum sz_encoding_t { sz_encoding_unknown_k = 0, sz_encoding_ascii_k = 1, sz_encoding_utf8_k = 2, sz_encoding_utf16_k = 3, sz_encoding_utf32_k = 4, sz_jwt_k, sz_base64_k, // Low priority encodings: sz_encoding_utf8bom_k = 5, sz_encoding_utf16le_k = 6, sz_encoding_utf16be_k = 7, sz_encoding_utf32le_k = 8, sz_encoding_utf32be_k = 9, }; // Character Set Detection is one of the most commonly performed operations in data processing with // [Chardet](https://github.com/chardet/chardet), [Charset Normalizer](https://github.com/jawah/charset_normalizer), // [cChardet](https://github.com/PyYoshi/cChardet) being the most commonly used options in the Python ecosystem. // All of them are notoriously slow. // // Moreover, as of October 2024, UTF-8 is the dominant character encoding on the web, used by 98.4% of websites. // Other have minimal usage, according to [W3Techs](https://w3techs.com/technologies/overview/character_encoding): // - ISO-8859-1: 1.2% // - Windows-1252: 0.3% // - Windows-1251: 0.2% // - EUC-JP: 0.1% // - Shift JIS: 0.1% // - EUC-KR: 0.1% // - GB2312: 0.1% // - Windows-1250: 0.1% // Within programming language implementations and database management systems, 16-bit and 32-bit fixed-width encodings // are also very popular and we need a way to efficienly differentiate between the most common UTF flavors, ASCII, and // the rest. // // One good solution is the [simdutf](https://github.com/simdutf/simdutf) library, but it depends on the C++ runtime // and focuses more on incremental validation & transcoding, rather than detection. // // So we need a very fast and efficient way of determining SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) { // https://github.com/simdutf/simdutf/blob/master/src/icelake/icelake_utf8_validation.inl.cpp // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_from_utf8.inl.cpp#L81 // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L661 // https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L788 // We can implement this operation simpler & differently, assuming most of the time continuous chunks of memory // have identical encoding. With Russian and many European languages, we generally deal with 2-byte codepoints // with occasional 1-byte punctuation marks. In the case of Chinese, Japanese, and Korean, we deal with 3-byte // codepoints. In the case of emojis, we deal with 4-byte codepoints. // We can also use the idea, that misaligned reads are quite cheap on modern CPUs. int can_be_ascii = 1, can_be_utf8 = 1, can_be_utf16 = 1, can_be_utf32 = 1; sz_unused(can_be_ascii + can_be_utf8 + can_be_utf16 + can_be_utf32); sz_unused(text && length); return sz_false_k; } #pragma clang attribute pop #pragma GCC pop_options #endif #pragma endregion /* @brief Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit * Arm processors. Implements: {substring search, character search, character set search} x {forward, reverse}. */ #pragma region ARM NEON #if SZ_USE_ARM_NEON #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+simd") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function) /** * @brief Helper structure to simplify work with 64-bit words. */ typedef union sz_u128_vec_t { uint8x16_t u8x16; uint16x8_t u16x8; uint32x4_t u32x4; uint64x2_t u64x2; sz_u64_t u64s[2]; sz_u32_t u32s[4]; sz_u16_t u16s[8]; sz_u8_t u8s[16]; } sz_u128_vec_t; SZ_INTERNAL sz_u64_t _sz_vreinterpretq_u8_u4(uint8x16_t vec) { // Use `vshrn` to produce a bitmask, similar to `movemask` in SSE. // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull; } SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { //! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide: //! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations return sz_order_serial(a, a_length, b, b_length); } SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { sz_u128_vec_t a_vec, b_vec; for (; length >= 16; a += 16, b += 16, length -= 16) { a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a); b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b); uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16); if (vminvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match } // Handle remaining bytes if (length) return sz_equal_serial(a, b, length); return sz_true_k; } SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) { uint64x2_t sum_vec = vdupq_n_u64(0); // Process 16 bytes (128 bits) at a time for (; length >= 16; text += 16, length -= 16) { uint8x16_t vec = vld1q_u8((sz_u8_t const *)text); // Load 16 bytes uint16x8_t pairwise_sum1 = vpaddlq_u8(vec); // Pairwise add lower and upper 8 bits uint32x4_t pairwise_sum2 = vpaddlq_u16(pairwise_sum1); // Pairwise add 16-bit results uint64x2_t pairwise_sum3 = vpaddlq_u32(pairwise_sum2); // Pairwise add 32-bit results sum_vec = vaddq_u64(sum_vec, pairwise_sum3); // Accumulate the sum } // Final reduction of `sum_vec` to a single scalar sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1); if (length) sum += sz_checksum_serial(text, length); return sum; } SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { // In most cases the `source` and the `target` are not aligned, but we should // at least make sure that writes don't touch many cache lines. // NEON has an instruction to load and write 64 bytes at once. // // sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less. // sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less. // for (; head_length; target += 1, source += 1, head_length -= 1) *target = *source; // length -= head_length; // for (; length >= 64; target += 64, source += 64, length -= 64) // vst4q_u8((sz_u8_t *)target, vld1q_u8_x4((sz_u8_t const *)source)); // for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = *source; // // Sadly, those instructions end up being 20% slower than the code processing 16 bytes at a time: for (; length >= 16; target += 16, source += 16, length -= 16) vst1q_u8((sz_u8_t *)target, vld1q_u8((sz_u8_t const *)source)); if (length) sz_copy_serial(target, source, length); } SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { // When moving small buffers, using a small buffer on stack as a temporary storage is faster. if (target < source || target >= source + length) { // Non-overlapping, proceed forward sz_copy_neon(target, source, length); } else { // Overlapping, proceed backward target += length; source += length; sz_u128_vec_t src_vec; while (length >= 16) { target -= 16, source -= 16, length -= 16; src_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); vst1q_u8((sz_u8_t *)target, src_vec.u8x16); } while (length) { target -= 1, source -= 1, length -= 1; *target = *source; } } } SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) { uint8x16_t fill_vec = vdupq_n_u8(value); // Broadcast the value across the register while (length >= 16) { vst1q_u8((sz_u8_t *)target, fill_vec); target += 16; length -= 16; } // Handle remaining bytes if (length) sz_fill_serial(target, length, value); } SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { // If the input is tiny (especially smaller than the look-up table itself), we may end up paying // more for organizing the SIMD registers and changing the CPU state, than for the actual computation. if (length <= 128) { sz_look_up_transform_serial(source, length, lut, target); return; } sz_size_t head_length = (16 - ((sz_size_t)target % 16)) % 16; // 15 or less. sz_size_t tail_length = (sz_size_t)(target + length) % 16; // 15 or less. // We need to pull the lookup table into 16x NEON registers. We have a total of 32 such registers. // According to the Neoverse V2 manual, the 4-table lookup has a latency of 6 cycles, and 4x throughput. uint8x16x4_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec; lut_0_to_63_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 0)); lut_64_to_127_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 64)); lut_128_to_191_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 128)); lut_192_to_255_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 192)); sz_u128_vec_t source_vec; // If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or // `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`. sz_u128_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec; sz_u128_vec_t blended_0_to_255_vec; // Process the head with serial code for (; head_length; target += 1, source += 1, head_length -= 1) *target = lut[*(sz_u8_t const *)source]; // Table lookups on Arm are much simpler to use than on x86, as we can use the `vqtbl4q_u8` instruction // to perform a 4-table lookup in a single instruction. The XORs are used to adjust the lookup position // within each 64-byte range of the table. // Details on the 4-table lookup: https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/ length -= head_length; length -= tail_length; for (; length >= 16; source += 16, target += 16, length -= 16) { source_vec.u8x16 = vld1q_u8((sz_u8_t const *)source); lookup_0_to_63_vec.u8x16 = vqtbl4q_u8(lut_0_to_63_vec, source_vec.u8x16); lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40))); lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80))); lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0))); blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16), vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16)); vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16); } // Process the tail with serial code for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source]; } SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { sz_u64_t matches; sz_u128_vec_t h_vec, n_vec, matches_vec; n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); while (h_length >= 16) { h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h); matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); // In Arm NEON we don't have a `movemask` to combine it with `ctz` and get the offset of the match. // But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting) // the vector with a relative offsets array. matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); if (matches) return h + sz_u64_ctz(matches) / 4; h += 16, h_length -= 16; } return sz_find_byte_serial(h, h_length, n); } SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { sz_u64_t matches; sz_u128_vec_t h_vec, n_vec, matches_vec; n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); while (h_length >= 16) { h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h + h_length - 16); matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; h_length -= 16; } return sz_rfind_byte_serial(h, h_length, n); } SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register(sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, uint8x16_t set_bottom_vec_u8x16) { // Once we've read the characters in the haystack, we want to // compare them against our bitset. The serial version of that code // would look like: `(set_->_u8s[c >> 3] & (1u << (c & 7u))) != 0`. uint8x16_t byte_index_vec = vshrq_n_u8(h_vec.u8x16, 3); uint8x16_t byte_mask_vec = vshlq_u8(vdupq_n_u8(1), vreinterpretq_s8_u8(vandq_u8(h_vec.u8x16, vdupq_n_u8(7)))); uint8x16_t matches_top_vec = vqtbl1q_u8(set_top_vec_u8x16, byte_index_vec); // The table lookup instruction in NEON replies to out-of-bound requests with zeros. // The values in `byte_index_vec` all fall in [0; 32). So for values under 16, substracting 16 will underflow // and map into interval [240, 256). Meaning that those will be populated with zeros and we can safely // merge `matches_top_vec` and `matches_bottom_vec` with a bitwise OR. uint8x16_t matches_bottom_vec = vqtbl1q_u8(set_bottom_vec_u8x16, vsubq_u8(byte_index_vec, vdupq_n_u8(16))); uint8x16_t matches_vec = vorrq_u8(matches_top_vec, matches_bottom_vec); // Istead of pure `vandq_u8`, we can immediately broadcast a match presence across each 8-bit word. matches_vec = vtstq_u8(matches_vec, byte_mask_vec); return _sz_vreinterpretq_u8_u4(matches_vec); } SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; if (n_length == 1) return sz_find_byte_neon(h, h_length, n); // Scan through the string. // Assuming how tiny the Arm NEON registers are, we should avoid internal branches at all costs. // That's why, for smaller needles, we use different loops. if (n_length == 2) { // Broadcast needle characters into SIMD registers. sz_u64_t matches; sz_u128_vec_t h_first_vec, h_last_vec, n_first_vec, n_last_vec, matches_vec; // Dealing with 16-bit values, we can load 2 registers at a time and compare 31 possible offsets // in a single loop iteration. n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); for (; h_length >= 17; h += 16, h_length -= 16) { h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); matches_vec.u8x16 = vandq_u8(vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); if (matches) return h + sz_u64_ctz(matches) / 4; } } else if (n_length == 3) { // Broadcast needle characters into SIMD registers. sz_u64_t matches; sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; // Comparing 24-bit values is a bumer. Being lazy, I went with the same approach // as when searching for string over 4 characters long. I only avoid the last comparison. n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]); n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]); n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[2]); for (; h_length >= 18; h += 16, h_length -= 16) { h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0)); h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1)); h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 2)); matches_vec.u8x16 = vandq_u8( // vandq_u8( // vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); if (matches) return h + sz_u64_ctz(matches) / 4; } } else { // Pick the parts of the needle that are worth comparing. sz_size_t offset_first, offset_mid, offset_last; _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); // Broadcast those characters into SIMD registers. sz_u64_t matches; sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); // Walk through the string. for (; h_length >= n_length + 16; h += 16, h_length -= 16) { h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_first)); h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_mid)); h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_last)); matches_vec.u8x16 = vandq_u8( // vandq_u8( // vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); while (matches) { int potential_offset = sz_u64_ctz(matches) / 4; if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; matches &= matches - 1; } } } return sz_find_serial(h, h_length, n, n_length); } SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { // This almost never fires, but it's better to be safe than sorry. if (h_length < n_length || !n_length) return SZ_NULL_CHAR; if (n_length == 1) return sz_rfind_byte_neon(h, h_length, n); // Pick the parts of the needle that are worth comparing. sz_size_t offset_first, offset_mid, offset_last; _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); // Will contain 4 bits per character. sz_u64_t matches; sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); sz_cptr_t h_reversed; for (; h_length >= n_length + 16; h_length -= 16) { h_reversed = h + h_length - n_length - 16 + 1; h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_first)); h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_mid)); h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_last)); matches_vec.u8x16 = vandq_u8( // vandq_u8( // vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16); while (matches) { int potential_offset = sz_u64_clz(matches) / 4; if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) return h + h_length - n_length - potential_offset; sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && "The bit must be set before we squash it"); matches &= ~(1ull << (63 - potential_offset * 4)); } } return sz_rfind_serial(h, h_length, n, n_length); } SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { sz_u64_t matches; sz_u128_vec_t h_vec; uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); for (; h_length >= 16; h += 16, h_length -= 16) { h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h)); matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); if (matches) return h + sz_u64_ctz(matches) / 4; } return sz_find_charset_serial(h, h_length, set); } SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { sz_u64_t matches; sz_u128_vec_t h_vec; uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); // Check `sz_find_charset_neon` for explanations. for (; h_length >= 16; h_length -= 16) { h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h) + h_length - 16); matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; } return sz_rfind_charset_serial(h, h_length, set); } #pragma clang attribute pop #pragma GCC pop_options #endif // Arm Neon #pragma endregion /* @brief Implementation of the string search algorithms using the Arm SVE variable-length registers, available * in Arm v9 processors. * * Implements: * - memory: {copy, move, fill} * - comparisons: {equal, order} * - search: {substring, character, character set} x {forward, reverse}. */ #pragma region ARM SVE #if SZ_USE_ARM_SVE #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+sve") #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function) SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value) { svuint8_t value_vec = svdup_u8(value); sz_size_t vec_len = svcntb(); // Vector length in bytes (scalable) if (length <= vec_len) { // Small buffer case: use mask to handle small writes svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); svst1_u8(mask, (unsigned char *)target, value_vec); } else { // Calculate head, body, and tail sizes sz_size_t head_length = vec_len - ((sz_size_t)target % vec_len); sz_size_t tail_length = (sz_size_t)(target + length) % vec_len; sz_size_t body_length = length - head_length - tail_length; // Handle unaligned head svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); svst1_u8(head_mask, (unsigned char *)target, value_vec); target += head_length; // Aligned body loop for (; body_length >= vec_len; target += vec_len, body_length -= vec_len) { svst1_u8(svptrue_b8(), (unsigned char *)target, value_vec); } // Handle unaligned tail svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); svst1_u8(tail_mask, (unsigned char *)target, value_vec); } } SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { sz_size_t vec_len = svcntb(); // Vector length in bytes // Arm Neoverse V2 cores in Graviton 4, for example, come with 256 KB of L1 data cache per core, // and 8 MB of L2 cache per core. Moreover, the L1 cache is fully associative. // With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length. // // int is_huge = length >= 4ull * 1024ull * 1024ull; // // When the buffer is small, there isn't much to innovate. if (length <= vec_len) { // Small buffer case: use mask to handle small writes svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length); svuint8_t data = svld1_u8(mask, (unsigned char *)source); svst1_u8(mask, (unsigned char *)target, data); } // When dealing with larger buffers, similar to AVX-512, we want minimize unaligned operations // and handle the head, body, and tail separately. We can also traverse the buffer in both directions // as Arm generally supports more simultaneous stores than x86 CPUs. // // For gigantic datasets, similar to AVX-512, non-temporal "loads" and "stores" can be used. // Sadly, if the register size (16 byte or larger) is smaller than a cache-line (64 bytes) // we will pay a huge penalty on loads, fetching the same content many times. // It may be better to allow caching (and subsequent eviction), in favor of using four-element // tuples, wich will be guaranteed to be a multiple of a cache line. // // Another approach is to use the `LD4B` instructions, which will populate four registers at once. // This however, further decreases the performance from LibC-like 29 GB/s to 20 GB/s. else { // Calculating head, body, and tail sizes depends on the `vec_len`, // but it's runtime constant, and the modulo operation is expensive! // Instead we use the fact, that it's always a multiple of 128 bits or 16 bytes. sz_size_t head_length = 16 - ((sz_size_t)target % 16); sz_size_t tail_length = (sz_size_t)(target + length) % 16; sz_size_t body_length = length - head_length - tail_length; // Handle unaligned parts svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length); svuint8_t head_data = svld1_u8(head_mask, (unsigned char *)source); svst1_u8(head_mask, (unsigned char *)target, head_data); svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length); svuint8_t tail_data = svld1_u8(tail_mask, (unsigned char *)source + head_length + body_length); svst1_u8(tail_mask, (unsigned char *)target + head_length + body_length, tail_data); target += head_length; source += head_length; // Aligned body loop, walking in two directions for (; body_length >= vec_len * 2; target += vec_len, source += vec_len, body_length -= vec_len * 2) { svuint8_t forward_data = svld1_u8(svptrue_b8(), (unsigned char *)source); svuint8_t backward_data = svld1_u8(svptrue_b8(), (unsigned char *)source + body_length - vec_len); svst1_u8(svptrue_b8(), (unsigned char *)target, forward_data); svst1_u8(svptrue_b8(), (unsigned char *)target + body_length - vec_len, backward_data); } // Up to (vec_len * 2 - 1) bytes of data may be left in the body, // so we can unroll the last two optional loop iterations. if (body_length > vec_len) { svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); svuint8_t data = svld1_u8(mask, (unsigned char *)source); svst1_u8(mask, (unsigned char *)target, data); body_length -= vec_len; source += body_length; target += body_length; } if (body_length) { svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length); svuint8_t data = svld1_u8(mask, (unsigned char *)source); svst1_u8(mask, (unsigned char *)target, data); } } } #pragma clang attribute pop #pragma GCC pop_options #endif // Arm SVE #pragma endregion /* * @brief Pick the right implementation for the string search algorithms. */ #pragma region Compile - Time Dispatching SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t ins, sz_size_t length) { return sz_hash_serial(ins, length); } SZ_PUBLIC void sz_tolower(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_tolower_serial(ins, length, outs); } SZ_PUBLIC void sz_toupper(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toupper_serial(ins, length, outs); } SZ_PUBLIC void sz_toascii(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toascii_serial(ins, length, outs); } SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t ins, sz_size_t length) { return sz_isascii_serial(ins, length); } SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_ptr_t fingerprint, sz_size_t fingerprint_bytes) { sz_bool_t fingerprint_length_is_power_of_two = (sz_bool_t)((fingerprint_bytes & (fingerprint_bytes - 1)) == 0); sz_string_view_t fingerprint_buffer = {fingerprint, fingerprint_bytes}; // There are several issues related to the fingerprinting algorithm. // First, the memory traversal order is important. // https://blog.stuffedcow.net/2015/08/pagewalk-coherence/ // In most cases the fingerprint length will be a power of two. if (fingerprint_length_is_power_of_two == sz_false_k) sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_non_pow2_callback, &fingerprint_buffer); else sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_pow2_callback, &fingerprint_buffer); } #if !SZ_DYNAMIC_DISPATCH SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) { #if SZ_USE_X86_AVX512 return sz_checksum_avx512(text, length); #elif SZ_USE_X86_AVX2 return sz_checksum_avx2(text, length); #elif SZ_USE_ARM_NEON return sz_checksum_neon(text, length); #else return sz_checksum_serial(text, length); #endif } SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { #if SZ_USE_X86_AVX512 return sz_equal_avx512(a, b, length); #elif SZ_USE_X86_AVX2 return sz_equal_avx2(a, b, length); #elif SZ_USE_ARM_NEON return sz_equal_neon(a, b, length); #else return sz_equal_serial(a, b, length); #endif } SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { #if SZ_USE_X86_AVX512 return sz_order_avx512(a, a_length, b, b_length); #elif SZ_USE_X86_AVX2 return sz_order_avx2(a, a_length, b, b_length); #elif SZ_USE_ARM_NEON return sz_order_neon(a, a_length, b, b_length); #else return sz_order_serial(a, a_length, b, b_length); #endif } SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { #if SZ_USE_X86_AVX512 sz_copy_avx512(target, source, length); #elif SZ_USE_X86_AVX2 sz_copy_avx2(target, source, length); #elif SZ_USE_ARM_NEON sz_copy_neon(target, source, length); #else sz_copy_serial(target, source, length); #endif } SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { #if SZ_USE_X86_AVX512 sz_move_avx512(target, source, length); #elif SZ_USE_X86_AVX2 sz_move_avx2(target, source, length); #elif SZ_USE_ARM_NEON sz_move_neon(target, source, length); #else sz_move_serial(target, source, length); #endif } SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { #if SZ_USE_X86_AVX512 sz_fill_avx512(target, length, value); #elif SZ_USE_X86_AVX2 sz_fill_avx2(target, length, value); #elif SZ_USE_ARM_NEON sz_fill_neon(target, length, value); #else sz_fill_serial(target, length, value); #endif } SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) { #if SZ_USE_X86_AVX512 sz_look_up_transform_avx512(source, length, lut, target); #elif SZ_USE_X86_AVX2 sz_look_up_transform_avx2(source, length, lut, target); #elif SZ_USE_ARM_NEON sz_look_up_transform_neon(source, length, lut, target); #else sz_look_up_transform_serial(source, length, lut, target); #endif } SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { #if SZ_USE_X86_AVX512 return sz_find_byte_avx512(haystack, h_length, needle); #elif SZ_USE_X86_AVX2 return sz_find_byte_avx2(haystack, h_length, needle); #elif SZ_USE_ARM_NEON return sz_find_byte_neon(haystack, h_length, needle); #else return sz_find_byte_serial(haystack, h_length, needle); #endif } SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { #if SZ_USE_X86_AVX512 return sz_rfind_byte_avx512(haystack, h_length, needle); #elif SZ_USE_X86_AVX2 return sz_rfind_byte_avx2(haystack, h_length, needle); #elif SZ_USE_ARM_NEON return sz_rfind_byte_neon(haystack, h_length, needle); #else return sz_rfind_byte_serial(haystack, h_length, needle); #endif } SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { #if SZ_USE_X86_AVX512 return sz_find_avx512(haystack, h_length, needle, n_length); #elif SZ_USE_X86_AVX2 return sz_find_avx2(haystack, h_length, needle, n_length); #elif SZ_USE_ARM_NEON return sz_find_neon(haystack, h_length, needle, n_length); #else return sz_find_serial(haystack, h_length, needle, n_length); #endif } SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { #if SZ_USE_X86_AVX512 return sz_rfind_avx512(haystack, h_length, needle, n_length); #elif SZ_USE_X86_AVX2 return sz_rfind_avx2(haystack, h_length, needle, n_length); #elif SZ_USE_ARM_NEON return sz_rfind_neon(haystack, h_length, needle, n_length); #else return sz_rfind_serial(haystack, h_length, needle, n_length); #endif } SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { #if SZ_USE_X86_AVX512 return sz_find_charset_avx512(text, length, set); #elif SZ_USE_X86_AVX2 return sz_find_charset_avx2(text, length, set); #elif SZ_USE_ARM_NEON return sz_find_charset_neon(text, length, set); #else return sz_find_charset_serial(text, length, set); #endif } SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { #if SZ_USE_X86_AVX512 return sz_rfind_charset_avx512(text, length, set); #elif SZ_USE_X86_AVX2 return sz_rfind_charset_avx2(text, length, set); #elif SZ_USE_ARM_NEON return sz_rfind_charset_neon(text, length, set); #else return sz_rfind_charset_serial(text, length, set); #endif } SZ_DYNAMIC sz_size_t sz_hamming_distance( // sz_cptr_t a, sz_size_t a_length, // sz_cptr_t b, sz_size_t b_length, // sz_size_t bound) { return sz_hamming_distance_serial(a, a_length, b, b_length, bound); } SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( // sz_cptr_t a, sz_size_t a_length, // sz_cptr_t b, sz_size_t b_length, // sz_size_t bound) { return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound); } SZ_DYNAMIC sz_size_t sz_edit_distance( // sz_cptr_t a, sz_size_t a_length, // sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { #if SZ_USE_X86_AVX512 return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc); #else return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); #endif } SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( // sz_cptr_t a, sz_size_t a_length, // sz_cptr_t b, sz_size_t b_length, // sz_size_t bound, sz_memory_allocator_t *alloc) { return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc); } SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { #if SZ_USE_X86_AVX512 return sz_alignment_score_avx512(a, a_length, b, b_length, subs, gap, alloc); #else return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc); #endif } SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // sz_hash_callback_t callback, void *callback_handle) { #if SZ_USE_X86_AVX512 sz_hashes_avx512(text, length, window_length, window_step, callback, callback_handle); #elif SZ_USE_X86_AVX2 sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle); #else sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); #endif } SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { sz_charset_t set; sz_charset_init(&set); for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); return sz_find_charset(h, h_length, &set); } SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { sz_charset_t set; sz_charset_init(&set); for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); sz_charset_invert(&set); return sz_find_charset(h, h_length, &set); } SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { sz_charset_t set; sz_charset_init(&set); for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); return sz_rfind_charset(h, h_length, &set); } SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { sz_charset_t set; sz_charset_init(&set); for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); sz_charset_invert(&set); return sz_rfind_charset(h, h_length, &set); } SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, sz_random_generator_t generator, void *generator_user_data) { sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data); } #endif #pragma endregion #ifdef __cplusplus #pragma GCC diagnostic pop } #endif // __cplusplus #endif // STRINGZILLA_H_