Files
wg_cpso/CppLibrary/stringzilla.h
2026-03-25 18:20:24 +08:00

6867 lines
335 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* @brief StringZilla is a collection of simple string algorithms, designed to be used in Big Data applications.
* It may be slower than LibC, but has a broader & cleaner interface, and a very short implementation
* targeting modern x86 CPUs with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization.
*
* Consider overriding the following macros to customize the library:
*
* - `SZ_DEBUG=0` - whether to enable debug assertions and logging.
* - `SZ_DYNAMIC_DISPATCH=0` - whether to use runtime dispatching of the most advanced SIMD backend.
* - `SZ_USE_MISALIGNED_LOADS=0` - whether to use misaligned loads on platforms that support them.
* - `SZ_SWAR_THRESHOLD=24` - threshold for switching to SWAR backend over serial byte-level for-loops.
* - `SZ_USE_X86_AVX512=?` - whether to use AVX-512 instructions on x86_64.
* - `SZ_USE_X86_AVX2=?` - whether to use AVX2 instructions on x86_64.
* - `SZ_USE_ARM_NEON=?` - whether to use NEON instructions on ARM.
* - `SZ_USE_ARM_SVE=?` - whether to use SVE instructions on ARM.
*
* @see StringZilla: https://github.com/ashvardanian/StringZilla/blob/main/README.md
* @see LibC String: https://pubs.opengroup.org/onlinepubs/009695399/basedefs/string.h.html
*
* @file stringzilla.h
* @author Ash Vardanian
*/
#ifndef STRINGZILLA_H_
#define STRINGZILLA_H_
#define STRINGZILLA_VERSION_MAJOR 3
#define STRINGZILLA_VERSION_MINOR 12
#define STRINGZILLA_VERSION_PATCH 6
/**
* @brief When set to 1, the library will include the following LibC headers: <stddef.h> and <stdint.h>.
* In debug builds (SZ_DEBUG=1), the library will also include <stdio.h> and <stdlib.h>.
*
* You may want to disable this compiling for use in the kernel, or in embedded systems.
* You may also avoid them, if you are very sensitive to compilation time and avoid pre-compiled headers.
* https://artificial-mind.net/projects/compile-health/
*/
#ifndef SZ_AVOID_LIBC
#define SZ_AVOID_LIBC (0) // true or false
#endif
/**
* @brief A misaligned load can be - trying to fetch eight consecutive bytes from an address
* that is not divisible by eight. On x86 enabled by default. On ARM it's not.
*
* Most platforms support it, but there is no industry standard way to check for those.
* This value will mostly affect the performance of the serial (SWAR) backend.
*/
#ifndef SZ_USE_MISALIGNED_LOADS
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86)
#define SZ_USE_MISALIGNED_LOADS (1) // true or false
#else
#define SZ_USE_MISALIGNED_LOADS (0) // true or false
#endif
#endif
/**
* @brief Removes compile-time dispatching, and replaces it with runtime dispatching.
* So the `sz_find` function will invoke the most advanced backend supported by the CPU,
* that runs the program, rather than the most advanced backend supported by the CPU
* used to compile the library or the downstream application.
*/
#ifndef SZ_DYNAMIC_DISPATCH
#define SZ_DYNAMIC_DISPATCH (0) // true or false
#endif
/**
* @brief Analogous to `size_t` and `std::size_t`, unsigned integer, identical to pointer size.
* 64-bit on most platforms where pointers are 64-bit.
* 32-bit on platforms where pointers are 32-bit.
*/
#if defined(__LP64__) || defined(_LP64) || defined(__x86_64__) || defined(_WIN64)
#define SZ_DETECT_64_BIT (1)
#define SZ_SIZE_MAX (0xFFFFFFFFFFFFFFFFull) // Largest unsigned integer that fits into 64 bits.
#define SZ_SSIZE_MAX (0x7FFFFFFFFFFFFFFFull) // Largest signed integer that fits into 64 bits.
#else
#define SZ_DETECT_64_BIT (0)
#define SZ_SIZE_MAX (0xFFFFFFFFu) // Largest unsigned integer that fits into 32 bits.
#define SZ_SSIZE_MAX (0x7FFFFFFFu) // Largest signed integer that fits into 32 bits.
#endif
/**
* @brief On Big-Endian machines StringZilla will work in compatibility mode.
* This disables SWAR hacks to minimize code duplication, assuming practically
* all modern popular platforms are Little-Endian.
*
* This variable is hard to infer from macros reliably. It's best to set it manually.
* For that CMake provides the `TestBigEndian` and `CMAKE_<LANG>_BYTE_ORDER` (from 3.20 onwards).
* In Python one can check `sys.byteorder == 'big'` in the `setup.py` script and pass the appropriate macro.
* https://stackoverflow.com/a/27054190
*/
#ifndef SZ_DETECT_BIG_ENDIAN
#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || defined(__BIG_ENDIAN__) || defined(__ARMEB__) || \
defined(__THUMBEB__) || defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__)
#define SZ_DETECT_BIG_ENDIAN (1) //< It's a big-endian target architecture
#else
#define SZ_DETECT_BIG_ENDIAN (0) //< It's a little-endian target architecture
#endif
#endif
/*
* Debugging and testing.
*/
#ifndef SZ_DEBUG
#if defined(DEBUG) || defined(_DEBUG) // This means "Not using DEBUG information".
#define SZ_DEBUG (1)
#else
#define SZ_DEBUG (0)
#endif
#endif
/**
* @brief Threshold for switching to SWAR (8-bytes at a time) backend over serial byte-level for-loops.
* On very short strings, under 16 bytes long, at most a single word will be processed with SWAR.
* Assuming potentially misaligned loads, SWAR makes sense only after ~24 bytes.
*/
#ifndef SZ_SWAR_THRESHOLD
#if SZ_DEBUG
#define SZ_SWAR_THRESHOLD (8u) // 8 bytes in debug builds
#else
#define SZ_SWAR_THRESHOLD (24u) // 24 bytes in release builds
#endif
#endif
/* Annotation for the public API symbols:
*
* - `SZ_PUBLIC` is used for functions that are part of the public API.
* - `SZ_INTERNAL` is used for internal helper functions with unstable APIs.
* - `SZ_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime.
*/
#ifndef SZ_DYNAMIC
#if SZ_DYNAMIC_DISPATCH
#if defined(_WIN32) || defined(__CYGWIN__)
#define SZ_DYNAMIC __declspec(dllexport)
#define SZ_EXTERNAL __declspec(dllimport)
#define SZ_PUBLIC inline static
#define SZ_INTERNAL inline static
#else
#define SZ_DYNAMIC __attribute__((visibility("default")))
#define SZ_EXTERNAL extern
#define SZ_PUBLIC __attribute__((unused)) inline static
#define SZ_INTERNAL __attribute__((always_inline)) inline static
#endif // _WIN32 || __CYGWIN__
#else
#define SZ_DYNAMIC inline static
#define SZ_EXTERNAL extern
#define SZ_PUBLIC inline static
#define SZ_INTERNAL inline static
#endif // SZ_DYNAMIC_DISPATCH
#endif // SZ_DYNAMIC
/**
* @brief Alignment macro for 64-byte alignment.
*/
#if defined(_MSC_VER)
#define SZ_ALIGN64 __declspec(align(64))
#elif defined(__GNUC__) || defined(__clang__)
#define SZ_ALIGN64 __attribute__((aligned(64)))
#else
#define SZ_ALIGN64
#endif
#ifdef __cplusplus
extern "C" {
#endif
/*
* Let's infer the integer types or pull them from LibC,
* if that is allowed by the user.
*/
#if !SZ_AVOID_LIBC
#include <stddef.h> // `size_t`
#include <stdint.h> // `uint8_t`
typedef int8_t sz_i8_t; // Always 8 bits
typedef uint8_t sz_u8_t; // Always 8 bits
typedef uint16_t sz_u16_t; // Always 16 bits
typedef int32_t sz_i32_t; // Always 32 bits
typedef uint32_t sz_u32_t; // Always 32 bits
typedef uint64_t sz_u64_t; // Always 64 bits
typedef int64_t sz_i64_t; // Always 64 bits
typedef size_t sz_size_t; // Pointer-sized unsigned integer, 32 or 64 bits
typedef ptrdiff_t sz_ssize_t; // Signed version of `sz_size_t`, 32 or 64 bits
#else // if SZ_AVOID_LIBC:
// ! The C standard doesn't specify the signedness of char.
// ! On x86 char is signed by default while on Arm it is unsigned by default.
// ! That's why we don't define `sz_char_t` and generally use explicit `sz_i8_t` and `sz_u8_t`.
typedef signed char sz_i8_t; // Always 8 bits
typedef unsigned char sz_u8_t; // Always 8 bits
typedef unsigned short sz_u16_t; // Always 16 bits
typedef int sz_i32_t; // Always 32 bits
typedef unsigned int sz_u32_t; // Always 32 bits
typedef long long sz_i64_t; // Always 64 bits
typedef unsigned long long sz_u64_t; // Always 64 bits
// Now we need to redefine the `size_t`.
// Microsoft Visual C++ (MSVC) typically follows LLP64 data model on 64-bit platforms,
// where integers, pointers, and long types have different sizes:
//
// > `int` is 32 bits
// > `long` is 32 bits
// > `long long` is 64 bits
// > pointer (thus, `size_t`) is 64 bits
//
// In contrast, GCC and Clang on 64-bit Unix-like systems typically follow the LP64 model, where:
//
// > `int` is 32 bits
// > `long` and pointer (thus, `size_t`) are 64 bits
// > `long long` is also 64 bits
//
// Source: https://learn.microsoft.com/en-us/windows/win32/winprog64/abstract-data-models
#if SZ_DETECT_64_BIT
typedef unsigned long long sz_size_t; // 64-bit.
typedef long long sz_ssize_t; // 64-bit.
#else
typedef unsigned sz_size_t; // 32-bit.
typedef unsigned sz_ssize_t; // 32-bit.
#endif // SZ_DETECT_64_BIT
#endif // SZ_AVOID_LIBC
/**
* @brief Compile-time assert macro similar to `static_assert` in C++.
*/
#define sz_static_assert(condition, name) \
typedef struct { \
int static_assert_##name : (condition) ? 1 : -1; \
} sz_static_assert_##name##_t
sz_static_assert(sizeof(sz_size_t) == sizeof(void *), sz_size_t_must_be_pointer_size);
sz_static_assert(sizeof(sz_ssize_t) == sizeof(void *), sz_ssize_t_must_be_pointer_size);
#pragma region Public API
typedef char *sz_ptr_t; // A type alias for `char *`
typedef char const *sz_cptr_t; // A type alias for `char const *`
typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions
typedef sz_u64_t sz_sorted_idx_t; // Index of a sorted string in a list of strings
typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit
typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=>
/**
* @brief Tiny string-view structure. It's POD type, unlike the `std::string_view`.
*/
typedef struct sz_string_view_t {
sz_cptr_t start;
sz_size_t length;
} sz_string_view_t;
/**
* @brief Enumeration of SIMD capabilities of the target architecture.
* Used to introspect the supported functionality of the dynamic library.
*/
typedef enum sz_capability_t {
sz_cap_serial_k = 1, /// Serial (non-SIMD) capability
sz_cap_any_k = 0x7FFFFFFF, /// Mask representing any capability
sz_cap_arm_neon_k = 1 << 10, /// ARM NEON capability
sz_cap_arm_sve_k = 1 << 11, /// ARM SVE capability TODO: Not yet supported or used
sz_cap_x86_avx2_k = 1 << 20, /// x86 AVX2 capability
sz_cap_x86_avx512f_k = 1 << 21, /// x86 AVX512 F capability
sz_cap_x86_avx512bw_k = 1 << 22, /// x86 AVX512 BW instruction capability
sz_cap_x86_avx512vl_k = 1 << 23, /// x86 AVX512 VL instruction capability
sz_cap_x86_avx512vbmi_k = 1 << 24, /// x86 AVX512 VBMI instruction capability
sz_cap_x86_gfni_k = 1 << 25, /// x86 AVX512 GFNI instruction capability
sz_cap_x86_avx512vbmi2_k = 1 << 26, /// x86 AVX512 VBMI 2 instruction capability
} sz_capability_t;
/**
* @brief Function to determine the SIMD capabilities of the current machine @b only at @b runtime.
* @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value.
*/
SZ_DYNAMIC sz_capability_t sz_capabilities(void);
/**
* @brief Bit-set structure for 256 possible byte values. Useful for filtering and search.
* @see sz_charset_init, sz_charset_add, sz_charset_contains, sz_charset_invert
*/
typedef union sz_charset_t {
sz_u64_t _u64s[4];
sz_u32_t _u32s[8];
sz_u16_t _u16s[16];
sz_u8_t _u8s[32];
} sz_charset_t;
/** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */
SZ_PUBLIC void sz_charset_init(sz_charset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; }
/** @brief Adds a character to the set and accepts @b unsigned integers. */
SZ_PUBLIC void sz_charset_add_u8(sz_charset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); }
/** @brief Adds a character to the set. Consider @b sz_charset_add_u8. */
SZ_PUBLIC void sz_charset_add(sz_charset_t *s, char c) { sz_charset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast
/** @brief Checks if the set contains a given character and accepts @b unsigned integers. */
SZ_PUBLIC sz_bool_t sz_charset_contains_u8(sz_charset_t const *s, sz_u8_t c) {
// Checking the bit can be done in different ways:
// - (s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0
// - (s->_u32s[c >> 5] & (1u << (c & 31u))) != 0
// - (s->_u16s[c >> 4] & (1u << (c & 15u))) != 0
// - (s->_u8s[c >> 3] & (1u << (c & 7u))) != 0
return (sz_bool_t)((s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0);
}
/** @brief Checks if the set contains a given character. Consider @b sz_charset_contains_u8. */
SZ_PUBLIC sz_bool_t sz_charset_contains(sz_charset_t const *s, char c) {
return sz_charset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast
}
/** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */
SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) {
s->_u64s[0] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[1] ^= 0xFFFFFFFFFFFFFFFFull, //
s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull;
}
typedef void *(*sz_memory_allocate_t)(sz_size_t, void *);
typedef void (*sz_memory_free_t)(void *, sz_size_t, void *);
typedef sz_u64_t (*sz_random_generator_t)(void *);
/**
* @brief Some complex pattern matching algorithms may require memory allocations.
* This structure is used to pass the memory allocator to those functions.
* @see sz_memory_allocator_init_fixed
*/
typedef struct sz_memory_allocator_t {
sz_memory_allocate_t allocate;
sz_memory_free_t free;
void *handle;
} sz_memory_allocator_t;
/**
* @brief Initializes a memory allocator to use the system default `malloc` and `free`.
* ! The function is not available if the library was compiled with `SZ_AVOID_LIBC`.
*
* @param alloc Memory allocator to initialize.
*/
SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc);
/**
* @brief Initializes a memory allocator to use a static-capacity buffer.
* No dynamic allocations will be performed.
*
* @param alloc Memory allocator to initialize.
* @param buffer Buffer to use for allocations.
* @param length Length of the buffer. @b Must be greater than 8 bytes. Different values would be optimal for
* different algorithms and input lengths, but 4096 bytes (one RAM page) is a good default.
*/
SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length);
/**
* @brief The number of bytes a stack-allocated string can hold, including the SZ_NULL termination character.
* ! This can't be changed from outside. Don't use the `#error` as it may already be included and set.
*/
#ifdef SZ_STRING_INTERNAL_SPACE
#undef SZ_STRING_INTERNAL_SPACE
#endif
#define SZ_STRING_INTERNAL_SPACE (sizeof(sz_size_t) * 3 - 1) // 3 pointers minus one byte for an 8-bit length
/**
* @brief Tiny memory-owning string structure with a Small String Optimization (SSO).
* Differs in layout from Folly, Clang, GCC, and probably most other implementations.
* It's designed to avoid any branches on read-only operations, and can store up
* to 22 characters on stack on 64-bit machines, followed by the SZ_NULL-termination character.
*
* @section Changing Length
*
* One nice thing about this design, is that you can, in many cases, change the length of the string
* without any branches, invoking a `+=` or `-=` on the 64-bit `length` field. If the string is on heap,
* the solution is obvious. If it's on stack, inplace decrement wouldn't affect the top bytes of the string,
* only changing the last byte containing the length.
*/
typedef union sz_string_t {
#if !SZ_DETECT_BIG_ENDIAN
struct external {
sz_ptr_t start;
sz_size_t length;
sz_size_t space;
sz_size_t padding;
} external;
struct internal {
sz_ptr_t start;
sz_u8_t length;
char chars[SZ_STRING_INTERNAL_SPACE];
} internal;
#else
struct external {
sz_ptr_t start;
sz_size_t space;
sz_size_t padding;
sz_size_t length;
} external;
struct internal {
sz_ptr_t start;
char chars[SZ_STRING_INTERNAL_SPACE];
sz_u8_t length;
} internal;
#endif
sz_size_t words[4];
} sz_string_t;
typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t);
typedef sz_u64_t (*sz_checksum_t)(sz_cptr_t, sz_size_t);
typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t);
typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t);
typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t);
/**
* @brief Computes the 64-bit check-sum of bytes in a string.
* Similar to `std::ranges::accumulate`.
*
* @param text String to aggregate.
* @param length Number of bytes in the text.
* @return 64-bit unsigned value.
*/
SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length);
/** @copydoc sz_checksum */
SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length);
/**
* @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings,
* simple implementation, and supports rolling computation, reused in other APIs.
* Similar to `std::hash` in C++.
*
* @param text String to hash.
* @param length Number of bytes in the text.
* @return 64-bit hash value.
*
* @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection
*/
SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length);
/** @copydoc sz_hash */
SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length);
/**
* @brief Checks if two string are equal.
* Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL.
*
* The implementation of this function is very similar to `sz_order`, but the usage patterns are different.
* This function is more often used in parsing, while `sz_order` is often used in sorting.
* It works best on platforms with cheap
*
* @param a First string to compare.
* @param b Second string to compare.
* @param length Number of bytes in both strings.
* @return 1 if strings match, 0 otherwise.
*/
SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length);
/** @copydoc sz_equal */
SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length);
/**
* @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC.
* Can be used on different length strings.
*
* @param a First string to compare.
* @param a_length Number of bytes in the first string.
* @param b Second string to compare.
* @param b_length Number of bytes in the second string.
* @return Negative if (a < b), positive if (a > b), zero if they are equal.
*/
SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length);
/** @copydoc sz_order */
SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length);
/**
* @brief Look Up Table @b (LUT) transformation of a string. Equivalent to `for (char & c : text) c = lut[c]`.
*
* Can be used to implement some form of string normalization, partially masking punctuation marks,
* or converting between different character sets, like uppercase or lowercase. Surprisingly, also has
* broad implications in image processing, where image channel transformations are often done using LUTs.
*
* @param text String to be normalized.
* @param length Number of bytes in the string.
* @param lut Look Up Table to apply. Must be exactly @b 256 bytes long.
* @param result Output string, can point to the same address as ::text.
*/
SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result);
typedef void (*sz_look_up_transform_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_ptr_t);
/** @copydoc sz_look_up_transform */
SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result);
/**
* @brief Equivalent to `for (char & c : text) c = tolower(c)`.
*
* ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122].
* So there are 26 english letters, shifted by 32 values, meaning that a conversion
* can be done by flipping the 5th bit each inappropriate character byte. This, however,
* breaks for extended ASCII, so a different solution is needed.
* http://0x80.pl/notesen/2016-01-06-swar-swap-case.html
*
* @param text String to be normalized.
* @param length Number of bytes in the string.
* @param result Output string, can point to the same address as ::text.
*/
SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result);
/**
* @brief Equivalent to `for (char & c : text) c = toupper(c)`.
*
* ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122].
* So there are 26 english letters, shifted by 32 values, meaning that a conversion
* can be done by flipping the 5th bit each inappropriate character byte. This, however,
* breaks for extended ASCII, so a different solution is needed.
* http://0x80.pl/notesen/2016-01-06-swar-swap-case.html
*
* @param text String to be normalized.
* @param length Number of bytes in the string.
* @param result Output string, can point to the same address as ::text.
*/
SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result);
/**
* @brief Equivalent to `for (char & c : text) c = toascii(c)`.
*
* @param text String to be normalized.
* @param length Number of bytes in the string.
* @param result Output string, can point to the same address as ::text.
*/
SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result);
/**
* @brief Checks if all characters in the range are valid ASCII characters.
*
* @param text String to be analyzed.
* @param length Number of bytes in the string.
* @return Whether all characters are valid ASCII characters.
*/
SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t text, sz_size_t length);
/**
* @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations.
* Similar to `text[i] = alphabet[rand() % cardinality]`.
*
* The modulo operation is expensive, and should be avoided in performance-critical code.
* We avoid it using small lookup tables and replacing it with a multiplication and shifts, similar to `libdivide`.
* Alternative algorithms would include:
* - Montgomery form: https://en.algorithmica.org/hpc/number-theory/montgomery/
* - Barret reduction: https://www.nayuki.io/page/barrett-reduction-algorithm
* - Lemire's trick: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
*
* @param alphabet Set of characters to sample from.
* @param cardinality Number of characters to sample from.
* @param text Output string, can point to the same address as ::text.
* @param generate Callback producing random numbers given the generator state.
* @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator.
*/
SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length,
sz_random_generator_t generate, void *generator);
/** @copydoc sz_generate */
SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length,
sz_random_generator_t generate, void *generator);
/**
* @brief Similar to `memcpy`, copies contents of one string into another.
* The behavior is undefined if the strings overlap.
*
* @param target String to copy into.
* @param length Number of bytes to copy.
* @param source String to copy from.
*/
SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
/** @copydoc sz_copy */
SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
/**
* @brief Similar to `memmove`, copies (moves) contents of one string into another.
* Unlike `sz_copy`, allows overlapping strings as arguments.
*
* @param target String to copy into.
* @param length Number of bytes to copy.
* @param source String to copy from.
*/
SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
/** @copydoc sz_move */
SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t);
/**
* @brief Similar to `memset`, fills a string with a given value.
*
* @param target String to fill.
* @param length Number of bytes to fill.
* @param value Value to fill with.
*/
SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value);
/** @copydoc sz_fill */
SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value);
typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t);
/**
* @brief Initializes a string class instance to an empty value.
*/
SZ_PUBLIC void sz_string_init(sz_string_t *string);
/**
* @brief Convenience function checking if the provided string is stored inside of the ::string instance itself,
* alternative being - allocated in a remote region of the heap.
*/
SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string);
/**
* @brief Unpacks the opaque instance of a string class into its components.
* Recommended to use only in read-only operations.
*
* @param string String to unpack.
* @param start Pointer to the start of the string.
* @param length Number of bytes in the string, before the SZ_NULL character.
* @param space Number of bytes allocated for the string (heap or stack), including the SZ_NULL character.
* @param is_external Whether the string is allocated on the heap externally, or fits withing ::string instance.
*/
SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space,
sz_bool_t *is_external);
/**
* @brief Unpacks only the start and length of the string.
* Recommended to use only in read-only operations.
*
* @param string String to unpack.
* @param start Pointer to the start of the string.
* @param length Number of bytes in the string, before the SZ_NULL character.
*/
SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length);
/**
* @brief Constructs a string of a given ::length with noisy contents.
* Use the returned character pointer to populate the string.
*
* @param string String to initialize.
* @param length Number of bytes in the string, before the SZ_NULL character.
* @param allocator Memory allocator to use for the allocation.
* @return SZ_NULL if the operation failed, pointer to the start of the string otherwise.
*/
SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator);
/**
* @brief Doesn't change the contents or the length of the string, but grows the available memory capacity.
* This is beneficial, if several insertions are expected, and we want to minimize allocations.
*
* @param string String to grow.
* @param new_capacity The number of characters to reserve space for, including existing ones.
* @param allocator Memory allocator to use for the allocation.
* @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise.
*/
SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator);
/**
* @brief Grows the string by adding an uninitialized region of ::added_length at the given ::offset.
* Would often be used in conjunction with one or more `sz_copy` calls to populate the allocated region.
* Similar to `sz_string_reserve`, but changes the length of the ::string.
*
* @param string String to grow.
* @param offset Offset of the first byte to reserve space for.
* If provided offset is larger than the length, it will be capped.
* @param added_length The number of new characters to reserve space for.
* @param allocator Memory allocator to use for the allocation.
* @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise.
*/
SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length,
sz_memory_allocator_t *allocator);
/**
* @brief Removes a range from a string. Changes the length, but not the capacity.
* Performs no allocations or deallocations and can't fail.
*
* @param string String to clean.
* @param offset Offset of the first byte to remove.
* @param length Number of bytes to remove. Out-of-bound ranges will be capped.
* @return Number of bytes removed.
*/
SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length);
/**
* @brief Shrinks the string to fit the current length, if it's allocated on the heap.
* It's the reverse operation of ::sz_string_reserve.
*
* @param string String to shrink.
* @param allocator Memory allocator to use for the allocation.
* @return Whether the operation was successful. The only failures can come from the allocator.
* On failure, the string will remain unchanged.
*/
SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator);
/**
* @brief Frees the string, if it's allocated on the heap.
* If the string is on the stack, the function clears/resets the state.
*/
SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator);
#pragma endregion
#pragma region Fast Substring Search API
typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t);
typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t);
typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *);
/**
* @brief Locates first matching byte in a string. Equivalent to `memchr(haystack, *needle, h_length)` in LibC.
*
* X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memchr.S
* Aarch64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/aarch64/memchr.S
*
* @param haystack Haystack - the string to search in.
* @param h_length Number of bytes in the haystack.
* @param needle Needle - single-byte substring to find.
* @return Address of the first match.
*/
SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/** @copydoc sz_find_byte */
SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/**
* @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC.
*
* X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memrchr.S
* Aarch64 implementation: missing
*
* @param haystack Haystack - the string to search in.
* @param h_length Number of bytes in the haystack.
* @param needle Needle - single-byte substring to find.
* @return Address of the last match.
*/
SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/** @copydoc sz_rfind_byte */
SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/**
* @brief Locates first matching substring.
* Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC.
* Similar to `strstr(haystack, needle)` in LibC, but requires known length.
*
* @param haystack Haystack - the string to search in.
* @param h_length Number of bytes in the haystack.
* @param needle Needle - substring to find.
* @param n_length Number of bytes in the needle.
* @return Address of the first match.
*/
SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length);
/** @copydoc sz_find */
SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length);
/**
* @brief Locates the last matching substring.
*
* @param haystack Haystack - the string to search in.
* @param h_length Number of bytes in the haystack.
* @param needle Needle - substring to find.
* @param n_length Number of bytes in the needle.
* @return Address of the last match.
*/
SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length);
/** @copydoc sz_rfind */
SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length);
/**
* @brief Finds the first character present from the ::set, present in ::text.
* Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC.
* May have identical implementation and performance to ::sz_rfind_charset.
*
* Useful for parsing, when we want to skip a set of characters. Examples:
* * 6 whitespaces: " \t\n\r\v\f".
* * 16 digits forming a float number: "0123456789,.eE+-".
* * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing.
* * 2 JSON string special characters useful to locate the end of the string: "\"\\".
*
* @param text String to be scanned.
* @param set Set of relevant characters.
* @return Pointer to the first matching character from ::set.
*/
SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set);
/** @copydoc sz_find_charset */
SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set);
/**
* @brief Finds the last character present from the ::set, present in ::text.
* Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC.
* May have identical implementation and performance to ::sz_find_charset.
*
* Useful for parsing, when we want to skip a set of characters. Examples:
* * 6 whitespaces: " \t\n\r\v\f".
* * 16 digits forming a float number: "0123456789,.eE+-".
* * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing.
* * 2 JSON string special characters useful to locate the end of the string: "\"\\".
*
* @param text String to be scanned.
* @param set Set of relevant characters.
* @return Pointer to the last matching character from ::set.
*/
SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set);
/** @copydoc sz_rfind_charset */
SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set);
#pragma endregion
#pragma region String Similarity Measures API
/**
* @brief Computes the Hamming distance between two strings - number of not matching characters.
* Difference in length is is counted as a mismatch.
*
* @param a First string to compare.
* @param a_length Number of bytes in the first string.
* @param b Second string to compare.
* @param b_length Number of bytes in the second string.
*
* @param bound Upper bound on the distance, that allows us to exit early.
* If zero is passed, the maximum possible distance will be equal to the length of the longer input.
* @return Unsigned integer for the distance, the `bound` if was exceeded.
*
* @see sz_hamming_distance_utf8
* @see https://en.wikipedia.org/wiki/Hamming_distance
*/
SZ_DYNAMIC sz_size_t sz_hamming_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length,
sz_size_t bound);
/** @copydoc sz_hamming_distance */
SZ_PUBLIC sz_size_t sz_hamming_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length,
sz_size_t bound);
/**
* @brief Computes the Hamming distance between two @b UTF8 strings - number of not matching characters.
* Difference in length is is counted as a mismatch.
*
* @param a First string to compare.
* @param a_length Number of bytes in the first string.
* @param b Second string to compare.
* @param b_length Number of bytes in the second string.
*
* @param bound Upper bound on the distance, that allows us to exit early.
* If zero is passed, the maximum possible distance will be equal to the length of the longer input.
* @return Unsigned integer for the distance, the `bound` if was exceeded.
*
* @see sz_hamming_distance
* @see https://en.wikipedia.org/wiki/Hamming_distance
*/
SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length,
sz_size_t bound);
/** @copydoc sz_hamming_distance_utf8 */
SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length,
sz_size_t bound);
typedef sz_size_t (*sz_hamming_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t);
/**
* @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm.
* Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching.
*
* @param a First string to compare.
* @param a_length Number of bytes in the first string.
* @param b Second string to compare.
* @param b_length Number of bytes in the second string.
*
* @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated,
* so the memory usage is linear in relation to ::a_length and ::b_length.
* If SZ_NULL is passed, will initialize to the systems default `malloc`.
* @param bound Upper bound on the distance, that allows us to exit early.
* If zero is passed, the maximum possible distance will be equal to the length of the longer input.
* @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX`
* if the memory allocation failed.
*
* @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default
* @see https://en.wikipedia.org/wiki/Levenshtein_distance
*/
SZ_DYNAMIC sz_size_t sz_edit_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, //
sz_size_t bound, sz_memory_allocator_t *alloc);
/** @copydoc sz_edit_distance */
SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, //
sz_size_t bound, sz_memory_allocator_t *alloc);
/**
* @brief Computes the Levenshtein edit-distance between two @b UTF8 strings.
* Unlike `sz_edit_distance`, reports the distance in Unicode codepoints, and not in bytes.
*
* @param a First string to compare.
* @param a_length Number of bytes in the first string.
* @param b Second string to compare.
* @param b_length Number of bytes in the second string.
*
* @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated,
* so the memory usage is linear in relation to ::a_length and ::b_length.
* If SZ_NULL is passed, will initialize to the systems default `malloc`.
* @param bound Upper bound on the distance, that allows us to exit early.
* If zero is passed, the maximum possible distance will be equal to the length of the longer input.
* @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX`
* if the memory allocation failed.
*
* @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default, sz_edit_distance
* @see https://en.wikipedia.org/wiki/Levenshtein_distance
*/
SZ_DYNAMIC sz_size_t sz_edit_distance_utf8(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, //
sz_size_t bound, sz_memory_allocator_t *alloc);
typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *);
/** @copydoc sz_edit_distance_utf8 */
SZ_PUBLIC sz_size_t sz_edit_distance_utf8_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, //
sz_size_t bound, sz_memory_allocator_t *alloc);
/**
* @brief Computes NeedlemanWunsch alignment score for two string. Often used in bioinformatics and cheminformatics.
* Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties.
*
* Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may
* not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric.
* Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`.
*
* @param a First string to compare.
* @param a_length Number of bytes in the first string.
* @param b Second string to compare.
* @param b_length Number of bytes in the second string.
* @param gap Penalty cost for gaps - insertions and removals.
* @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters.
*
* @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated,
* so the memory usage is linear in relation to ::a_length and ::b_length.
* If SZ_NULL is passed, will initialize to the systems default `malloc`.
* @return Signed similarity score. Can be negative, depending on the substitution costs.
* If the memory allocation fails, the function returns `SZ_SSIZE_MAX`.
*
* @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default
* @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm
*/
SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, //
sz_error_cost_t const *subs, sz_error_cost_t gap, //
sz_memory_allocator_t *alloc);
/** @copydoc sz_alignment_score */
SZ_PUBLIC sz_ssize_t sz_alignment_score_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, //
sz_error_cost_t const *subs, sz_error_cost_t gap, //
sz_memory_allocator_t *alloc);
typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *,
sz_error_cost_t, sz_memory_allocator_t *);
typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user);
/**
* @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`.
* Can be used for similarity scores, search, ranking, etc.
*
* Rabin-Karp-like rolling hashes can have very high-level of collisions and depend
* on the choice of bases and the prime number. That's why, often two hashes from the same
* family are used with different bases.
*
* 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet.
* 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function.
*
* Choosing the right ::window_length is task- and domain-dependant. For example, most English words are
* between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences,
* the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long.
* With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed.
* For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs.
*
* @param text String to hash.
* @param length Number of bytes in the string.
* @param window_length Length of the rolling window in bytes.
* @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`.
* @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`.
* @param callback_handle Optional user-provided pointer to be passed to the `callback`.
* @see sz_hashes_fingerprint, sz_hashes_intersection
*/
SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, //
sz_hash_callback_t callback, void *callback_handle);
/** @copydoc sz_hashes */
SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, //
sz_hash_callback_t callback, void *callback_handle);
typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *);
/**
* @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint.
* Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity.
*
* The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times
* to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint.
* It can also be reused to produce multi-resolution fingerprints by changing the ::window_length
* and calling the same function multiple times for the same input ::text.
*
* Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer,
* avoiding cache-coherency penalties of remote on-heap buffers.
*
* @param text String to hash.
* @param length Number of bytes in the string.
* @param fingerprint Output fingerprint buffer.
* @param fingerprint_bytes Number of bytes in the fingerprint buffer.
* @param window_length Length of the rolling window in bytes.
* @see sz_hashes, sz_hashes_intersection
*/
SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t text, sz_size_t length, sz_size_t window_length, //
sz_ptr_t fingerprint, sz_size_t fingerprint_bytes);
typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t);
/**
* @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes
* of the incoming document. Can be used for document scoring and search.
*
* Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer,
* avoiding cache-coherency penalties of remote on-heap buffers.
*
* @param text Input document.
* @param length Number of bytes in the input document.
* @param fingerprint Reference document fingerprint.
* @param fingerprint_bytes Number of bytes in the reference documents fingerprint.
* @param window_length Length of the rolling window in bytes.
* @see sz_hashes, sz_hashes_fingerprint
*/
SZ_PUBLIC sz_size_t sz_hashes_intersection(sz_cptr_t text, sz_size_t length, sz_size_t window_length, //
sz_cptr_t fingerprint, sz_size_t fingerprint_bytes);
typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t);
#pragma endregion
#pragma region Convenience API
/**
* @brief Finds the first character in the haystack, that is present in the needle.
* Convenience function, reused across different language bindings.
* @see sz_find_charset
*/
SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length);
/**
* @brief Finds the first character in the haystack, that is @b not present in the needle.
* Convenience function, reused across different language bindings.
* @see sz_find_charset
*/
SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length);
/**
* @brief Finds the last character in the haystack, that is present in the needle.
* Convenience function, reused across different language bindings.
* @see sz_find_charset
*/
SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length);
/**
* @brief Finds the last character in the haystack, that is @b not present in the needle.
* Convenience function, reused across different language bindings.
* @see sz_find_charset
*/
SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length);
#pragma endregion
#pragma region String Sequences API
struct sz_sequence_t;
typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t);
typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t);
typedef sz_bool_t (*sz_sequence_predicate_t)(struct sz_sequence_t const *, sz_size_t);
typedef sz_bool_t (*sz_sequence_comparator_t)(struct sz_sequence_t const *, sz_size_t, sz_size_t);
typedef sz_bool_t (*sz_string_is_less_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t);
typedef struct sz_sequence_t {
sz_sorted_idx_t *order;
sz_size_t count;
sz_sequence_member_start_t get_start;
sz_sequence_member_length_t get_length;
void const *handle;
} sz_sequence_t;
/**
* @brief Initiates the sequence structure from a tape layout, used by Apache Arrow.
* Expects ::offsets to contains `count + 1` entries, the last pointing at the end
* of the last string, indicating the total length of the ::tape.
*/
SZ_PUBLIC void sz_sequence_from_u32tape(sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count,
sz_sequence_t *sequence);
/**
* @brief Initiates the sequence structure from a tape layout, used by Apache Arrow.
* Expects ::offsets to contains `count + 1` entries, the last pointing at the end
* of the last string, indicating the total length of the ::tape.
*/
SZ_PUBLIC void sz_sequence_from_u64tape(sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count,
sz_sequence_t *sequence);
/**
* @brief Similar to `std::partition`, given a predicate splits the sequence into two parts.
* The algorithm is unstable, meaning that elements may change relative order, as long
* as they are in the right partition. This is the simpler algorithm for partitioning.
*/
SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate);
/**
* @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`.
*
* @param partition The number of elements in the first sub-sequence in `sequence`.
* @param less Comparison function, to determine the lexicographic ordering.
*/
SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less);
/**
* @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word
* and a follow-up by a more conventional sorting procedure on equally prefixed parts.
*/
SZ_PUBLIC void sz_sort(sz_sequence_t *sequence);
/**
* @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word
* and a follow-up by a more conventional sorting procedure on equally prefixed parts.
*/
SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n);
/**
* @brief Intro-Sort algorithm that supports custom comparators.
*/
SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less);
#pragma endregion
/*
* Hardware feature detection.
* All of those can be controlled by the user.
*/
#ifndef SZ_USE_X86_AVX512
#ifdef __AVX512BW__
#define SZ_USE_X86_AVX512 1
#else
#define SZ_USE_X86_AVX512 0
#endif
#endif
#ifndef SZ_USE_X86_AVX2
#ifdef __AVX2__
#define SZ_USE_X86_AVX2 1
#else
#define SZ_USE_X86_AVX2 0
#endif
#endif
#ifndef SZ_USE_ARM_NEON
#ifdef __ARM_NEON
#define SZ_USE_ARM_NEON 1
#else
#define SZ_USE_ARM_NEON 0
#endif
#endif
#ifndef SZ_USE_ARM_SVE
#ifdef __ARM_FEATURE_SVE
#define SZ_USE_ARM_SVE 1
#else
#define SZ_USE_ARM_SVE 0
#endif
#endif
/*
* Include hardware-specific headers.
*/
#if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2
#include <immintrin.h>
#endif // SZ_USE_X86...
#if SZ_USE_ARM_NEON
#if !defined(_MSC_VER)
#include <arm_acle.h>
#endif
#include <arm_neon.h>
#endif // SZ_USE_ARM_NEON
#if SZ_USE_ARM_SVE
#if !defined(_MSC_VER)
#include <arm_sve.h>
#endif
#endif // SZ_USE_ARM_SVE
#pragma region Hardware - Specific API
#if SZ_USE_X86_AVX512
/** @copydoc sz_equal */
SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length);
/** @copydoc sz_order */
SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length);
/** @copydoc sz_copy */
SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
/** @copydoc sz_move */
SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
/** @copydoc sz_fill */
SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value);
/** @copydoc sz_look_up_transform */
SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target);
/** @copydoc sz_find_byte */
SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/** @copydoc sz_rfind_byte */
SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/** @copydoc sz_find */
SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length);
/** @copydoc sz_rfind */
SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length);
/** @copydoc sz_find_charset */
SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set);
/** @copydoc sz_rfind_charset */
SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set);
/** @copydoc sz_edit_distance */
SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, //
sz_size_t bound, sz_memory_allocator_t *alloc);
/** @copydoc sz_alignment_score */
SZ_PUBLIC sz_ssize_t sz_alignment_score_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, //
sz_error_cost_t const *subs, sz_error_cost_t gap, //
sz_memory_allocator_t *alloc);
/** @copydoc sz_hashes */
SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, //
sz_hash_callback_t callback, void *callback_handle);
#endif
#if SZ_USE_X86_AVX2
/** @copydoc sz_equal */
SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length);
/** @copydoc sz_order */
SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length);
/** @copydoc sz_copy */
SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
/** @copydoc sz_move */
SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
/** @copydoc sz_fill */
SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value);
/** @copydoc sz_look_up_transform */
SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target);
/** @copydoc sz_find_byte */
SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/** @copydoc sz_rfind_byte */
SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/** @copydoc sz_find */
SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length);
/** @copydoc sz_rfind */
SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length);
/** @copydoc sz_hashes */
SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, //
sz_hash_callback_t callback, void *callback_handle);
#endif
#if SZ_USE_ARM_NEON
/** @copydoc sz_equal */
SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length);
/** @copydoc sz_order */
SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length);
/** @copydoc sz_copy */
SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
/** @copydoc sz_move */
SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
/** @copydoc sz_fill */
SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value);
/** @copydoc sz_look_up_transform */
SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t table, sz_ptr_t target);
/** @copydoc sz_find_byte */
SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/** @copydoc sz_rfind_byte */
SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/** @copydoc sz_find */
SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length);
/** @copydoc sz_rfind */
SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length);
/** @copydoc sz_find_charset */
SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set);
/** @copydoc sz_rfind_charset */
SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set);
#endif
#if SZ_USE_ARM_SVE
/** @copydoc sz_equal */
SZ_PUBLIC sz_bool_t sz_equal_sve(sz_cptr_t a, sz_cptr_t b, sz_size_t length);
/** @copydoc sz_order */
SZ_PUBLIC sz_ordering_t sz_order_sve(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length);
/** @copydoc sz_copy */
SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
/** @copydoc sz_move */
SZ_PUBLIC void sz_move_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length);
/** @copydoc sz_fill */
SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value);
/** @copydoc sz_find_byte */
SZ_PUBLIC sz_cptr_t sz_find_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/** @copydoc sz_rfind_byte */
SZ_PUBLIC sz_cptr_t sz_rfind_byte_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);
/** @copydoc sz_find */
SZ_PUBLIC sz_cptr_t sz_find_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length);
/** @copydoc sz_rfind */
SZ_PUBLIC sz_cptr_t sz_rfind_sve(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length);
/** @copydoc sz_find_charset */
SZ_PUBLIC sz_cptr_t sz_find_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set);
/** @copydoc sz_rfind_charset */
SZ_PUBLIC sz_cptr_t sz_rfind_charset_sve(sz_cptr_t text, sz_size_t length, sz_charset_t const *set);
#endif
#pragma endregion
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wconversion"
/*
**********************************************************************************************************************
**********************************************************************************************************************
**********************************************************************************************************************
*
* This is where we the actual implementation begins.
* The rest of the file is hidden from the public API.
*
**********************************************************************************************************************
**********************************************************************************************************************
**********************************************************************************************************************
*/
#pragma region Compiler Extensions and Helper Functions
#pragma GCC visibility push(hidden)
/**
* @brief Helper-macro to mark potentially unused variables.
*/
#define sz_unused(x) ((void)(x))
/**
* @brief Helper-macro casting a variable to another type of the same size.
*/
#define sz_bitcast(type, value) (*((type *)&(value)))
/**
* @brief Defines `SZ_NULL`, analogous to `NULL`.
* The default often comes from locale.h, stddef.h,
* stdio.h, stdlib.h, string.h, time.h, or wchar.h.
*/
#ifdef __GNUG__
#define SZ_NULL __null
#define SZ_NULL_CHAR __null
#else
#define SZ_NULL ((void *)0)
#define SZ_NULL_CHAR ((char *)0)
#endif
/**
* @brief Cache-line width, that will affect the execution of some algorithms,
* like equality checks and relative order computing.
*/
#define SZ_CACHE_LINE_WIDTH (64) // bytes
/**
* @brief Similar to `assert`, the `sz_assert` is used in the SZ_DEBUG mode
* to check the invariants of the library. It's a no-op in the SZ_RELEASE mode.
* @note If you want to catch it, put a breakpoint at @b `__GI_exit`
*/
#if SZ_DEBUG && defined(SZ_AVOID_LIBC) && !SZ_AVOID_LIBC && !defined(SZ_PIC)
#include <stdio.h> // `fprintf`
#include <stdlib.h> // `EXIT_FAILURE`
SZ_PUBLIC void _sz_assert_failure(char const *condition, char const *file, int line) {
fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", condition, file, line);
exit(EXIT_FAILURE);
}
#define sz_assert(condition) \
do { \
if (!(condition)) { _sz_assert_failure(#condition, __FILE__, __LINE__); } \
} while (0)
#else
#define sz_assert(condition) ((void)(condition))
#endif
/* Intrinsics aliases for MSVC, GCC, Clang, and Clang-Cl.
* The following section of compiler intrinsics comes in 2 flavors.
*/
#if defined(_MSC_VER) && !defined(__clang__) // On Clang-CL
#include <intrin.h>
// Sadly, when building Win32 images, we can't use the `_tzcnt_u64`, `_lzcnt_u64`,
// `_BitScanForward64`, or `_BitScanReverse64` intrinsics. For now it's a simple `for`-loop.
// TODO: In the future we can switch to a more efficient De Bruijn's algorithm.
// https://www.chessprogramming.org/BitScan
// https://www.chessprogramming.org/De_Bruijn_Sequence
// https://gist.github.com/resilar/e722d4600dbec9752771ab4c9d47044f
//
// Use the serial version on 32-bit x86 and on Arm.
#if (defined(_WIN32) && !defined(_WIN64)) || defined(_M_ARM) || defined(_M_ARM64)
SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) {
sz_assert(x != 0);
int n = 0;
while ((x & 1) == 0) { n++, x >>= 1; }
return n;
}
SZ_INTERNAL int sz_u64_clz(sz_u64_t x) {
sz_assert(x != 0);
int n = 0;
while ((x & 0x8000000000000000ull) == 0) { n++, x <<= 1; }
return n;
}
SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) {
x = x - ((x >> 1) & 0x5555555555555555ull);
x = (x & 0x3333333333333333ull) + ((x >> 2) & 0x3333333333333333ull);
return (((x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Full) * 0x0101010101010101ull) >> 56;
}
SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) {
sz_assert(x != 0);
int n = 0;
while ((x & 1) == 0) { n++, x >>= 1; }
return n;
}
SZ_INTERNAL int sz_u32_clz(sz_u32_t x) {
sz_assert(x != 0);
int n = 0;
while ((x & 0x80000000u) == 0) { n++, x <<= 1; }
return n;
}
SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) {
x = x - ((x >> 1) & 0x55555555);
x = (x & 0x33333333) + ((x >> 2) & 0x33333333);
return (((x + (x >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
}
#else
SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return (int)_tzcnt_u64(x); }
SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return (int)_lzcnt_u64(x); }
SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return (int)__popcnt64(x); }
SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return (int)_tzcnt_u32(x); }
SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return (int)_lzcnt_u32(x); }
SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return (int)__popcnt(x); }
#endif
// Force the byteswap functions to be intrinsics, because when /Oi- is given, these will turn into CRT function calls,
// which breaks when `SZ_AVOID_LIBC` is given
#pragma intrinsic(_byteswap_uint64)
SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return _byteswap_uint64(val); }
#pragma intrinsic(_byteswap_ulong)
SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return _byteswap_ulong(val); }
#else
SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return __builtin_popcountll(x); }
SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return __builtin_popcount(x); }
SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return __builtin_ctzll(x); }
SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return __builtin_clzll(x); }
SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return __builtin_ctz(x); } // ! Undefined if `x == 0`
SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return __builtin_clz(x); } // ! Undefined if `x == 0`
SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return __builtin_bswap64(val); }
SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap32(val); }
#endif
SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); }
/**
* @brief Select bits from either ::a or ::b depending on the value of ::mask bits.
*
* Similar to `_mm_blend_epi16` intrinsic on x86.
* Described in the "Bit Twiddling Hacks" by Sean Eron Anderson.
* https://graphics.stanford.edu/~seander/bithacks.html#ConditionalSetOrClearBitsWithoutBranching
*/
SZ_INTERNAL sz_u64_t sz_u64_blend(sz_u64_t a, sz_u64_t b, sz_u64_t mask) { return a ^ ((a ^ b) & mask); }
/*
* Efficiently computing the minimum and maximum of two or three values can be tricky.
* The simple branching baseline would be:
*
* x < y ? x : y // can replace with 1 conditional move
*
* Branchless approach is well known for signed integers, but it doesn't apply to unsigned ones.
* https://stackoverflow.com/questions/514435/templatized-branchless-int-max-min-function
* https://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax
* Using only bit-shifts for singed integers it would be:
*
* y + ((x - y) & (x - y) >> 31) // 4 unique operations
*
* Alternatively, for any integers using multiplication:
*
* (x > y) * y + (x <= y) * x // 5 operations
*
* Alternatively, to avoid multiplication:
*
* x & ~((x < y) - 1) + y & ((x < y) - 1) // 6 unique operations
*/
#define sz_min_of_two(x, y) (x < y ? x : y)
#define sz_max_of_two(x, y) (x < y ? y : x)
#define sz_min_of_three(x, y, z) sz_min_of_two(x, sz_min_of_two(y, z))
#define sz_max_of_three(x, y, z) sz_max_of_two(x, sz_max_of_two(y, z))
/** @brief Branchless minimum function for two signed 32-bit integers. */
SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); }
/** @brief Branchless minimum function for two signed 32-bit integers. */
SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); }
/**
* @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing.
*/
SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end,
sz_size_t *normalized_offset, sz_size_t *normalized_length) {
// TODO: Remove branches.
// Normalize negative indices
if (start < 0) start += length;
if (end < 0) end += length;
// Clamp indices to a valid range
if (start < 0) start = 0;
if (end < 0) end = 0;
if (start > (sz_ssize_t)length) start = length;
if (end > (sz_ssize_t)length) end = length;
// Ensure start <= end
if (start > end) start = end;
*normalized_offset = start;
*normalized_length = end - start;
}
/**
* @brief Compute the logarithm base 2 of a positive integer, rounding down.
*/
SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) {
sz_assert(x > 0 && "Non-positive numbers have no defined logarithm");
sz_size_t leading_zeros = sz_u64_clz(x);
return 63 - leading_zeros;
}
/**
* @brief Compute the smallest power of two greater than or equal to ::x.
*/
SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) {
// Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`.
// https://stackoverflow.com/a/10143264
x--;
x |= x >> 1;
x |= x >> 2;
x |= x >> 4;
x |= x >> 8;
x |= x >> 16;
#if SZ_DETECT_64_BIT
x |= x >> 32;
#endif
x++;
return x;
}
/**
* @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`.
*
* There is a well known SWAR sequence for that known to chess programmers,
* willing to flip a bit-matrix of pieces along the main A1-H8 diagonal.
* https://www.chessprogramming.org/Flipping_Mirroring_and_Rotating
* https://lukas-prokop.at/articles/2021-07-23-transpose
*/
SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) {
sz_u64_t t;
t = x ^ (x << 36);
x ^= 0xf0f0f0f00f0f0f0full & (t ^ (x >> 36));
t = 0xcccc0000cccc0000ull & (x ^ (x << 18));
x ^= t ^ (t >> 18);
t = 0xaa00aa00aa00aa00ull & (x ^ (x << 9));
x ^= t ^ (t >> 9);
return x;
}
/**
* @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence.
*/
SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) {
sz_u64_t t = *a;
*a = *b;
*b = t;
}
/**
* @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence.
*/
SZ_INTERNAL void sz_pointer_swap(void **a, void **b) {
void *t = *a;
*a = *b;
*b = t;
}
/**
* @brief Helper structure to simplify work with 16-bit words.
* @see sz_u16_load
*/
typedef union sz_u16_vec_t {
sz_u16_t u16;
sz_u8_t u8s[2];
} sz_u16_vec_t;
/**
* @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms.
*/
SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) {
#if !SZ_USE_MISALIGNED_LOADS
sz_u16_vec_t result;
result.u8s[0] = ptr[0];
result.u8s[1] = ptr[1];
return result;
#elif defined(_MSC_VER) && !defined(__clang__)
#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform.
return *((sz_u16_vec_t *)ptr);
#else
return *((__unaligned sz_u16_vec_t *)ptr);
#endif
#else
__attribute__((aligned(1))) sz_u16_vec_t const *result = (sz_u16_vec_t const *)ptr;
return *result;
#endif
}
/**
* @brief Helper structure to simplify work with 32-bit words.
* @see sz_u32_load
*/
typedef union sz_u32_vec_t {
sz_u32_t u32;
sz_u16_t u16s[2];
sz_u8_t u8s[4];
} sz_u32_vec_t;
/**
* @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms.
*/
SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) {
#if !SZ_USE_MISALIGNED_LOADS
sz_u32_vec_t result;
result.u8s[0] = ptr[0];
result.u8s[1] = ptr[1];
result.u8s[2] = ptr[2];
result.u8s[3] = ptr[3];
return result;
#elif defined(_MSC_VER) && !defined(__clang__)
#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform.
return *((sz_u32_vec_t *)ptr);
#else
return *((__unaligned sz_u32_vec_t *)ptr);
#endif
#else
__attribute__((aligned(1))) sz_u32_vec_t const *result = (sz_u32_vec_t const *)ptr;
return *result;
#endif
}
/**
* @brief Helper structure to simplify work with 64-bit words.
* @see sz_u64_load
*/
typedef union sz_u64_vec_t {
sz_u64_t u64;
sz_u32_t u32s[2];
sz_u16_t u16s[4];
sz_u8_t u8s[8];
} sz_u64_vec_t;
/**
* @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms.
*/
SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) {
#if !SZ_USE_MISALIGNED_LOADS
sz_u64_vec_t result;
result.u8s[0] = ptr[0];
result.u8s[1] = ptr[1];
result.u8s[2] = ptr[2];
result.u8s[3] = ptr[3];
result.u8s[4] = ptr[4];
result.u8s[5] = ptr[5];
result.u8s[6] = ptr[6];
result.u8s[7] = ptr[7];
return result;
#elif defined(_MSC_VER) && !defined(__clang__)
#if defined(_M_IX86) //< The __unaligned modifier isn't valid for the x86 platform.
return *((sz_u64_vec_t *)ptr);
#else
return *((__unaligned sz_u64_vec_t *)ptr);
#endif
#else
__attribute__((aligned(1))) sz_u64_vec_t const *result = (sz_u64_vec_t const *)ptr;
return *result;
#endif
}
/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */
SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) {
sz_size_t capacity;
sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t));
sz_size_t consumed_capacity = sizeof(sz_size_t);
if (consumed_capacity + length > capacity) return SZ_NULL_CHAR;
return (sz_ptr_t)handle + consumed_capacity;
}
/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */
SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) {
sz_unused(start && length && handle);
}
/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */
SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) {
sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle;
sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start;
sz_size_t fingerprint_bytes = fingerprint_buffer->length;
fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7));
sz_unused(start && length);
}
/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */
SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash,
void *handle) {
sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle;
sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start;
sz_size_t fingerprint_bytes = fingerprint_buffer->length;
fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7));
sz_unused(start && length);
}
/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */
SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash,
void *scalar_handle) {
sz_unused(start && length && hash && scalar_handle);
sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle;
*scalar_ptr ^= hash;
}
/**
* @brief Chooses the offsets of the most interesting characters in a search needle.
*
* Search throughput can significantly deteriorate if we are matching the wrong characters.
* Say the needle is "aXaYa", and we are comparing the first, second, and last character.
* If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste.
*
* Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information.
* Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and
* into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the
* bytes will carry absolutely no value and will be equal to 0x04.
*/
SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, //
sz_size_t *first, sz_size_t *second, sz_size_t *third) {
*first = 0;
*second = length / 2;
*third = length - 1;
//
int has_duplicates = //
start[*first] == start[*second] || //
start[*first] == start[*third] || //
start[*second] == start[*third];
// Loop through letters to find non-colliding variants.
if (length > 3 && has_duplicates) {
// Pivot the middle point right, until we find a character different from the first one.
for (; start[*second] == start[*first] && *second + 1 < *third; ++(*second)) {}
// Pivot the third (last) point left, until we find a different character.
for (; (start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1);
--(*third)) {}
}
// TODO: Investigate alternative strategies for long needles.
// On very long needles we have the luxury to choose!
// Often dealing with UTF8, we will likely benefit from shifting the first and second characters
// further to the right, to achieve not only uniqueness within the needle, but also avoid common
// rune prefixes of 2-, 3-, and 4-byte codes.
if (length > 8) {
// Pivot the first and second points right, until we find a character, that:
// > is different from others.
// > doesn't start with 0b'110x'xxxx - only 5 bits of relevant info.
// > doesn't start with 0b'1110'xxxx - only 4 bits of relevant info.
// > doesn't start with 0b'1111'0xxx - only 3 bits of relevant info.
//
// So we are practically searching for byte values that start with 0b0xxx'xxxx or 0b'10xx'xxxx.
// Meaning they fall in the range [0, 127] and [128, 191], in other words any unsigned int up to 191.
sz_u8_t const *start_u8 = (sz_u8_t const *)start;
sz_size_t vibrant_first = *first, vibrant_second = *second, vibrant_third = *third;
// Let's begin with the seccond character, as the termination criterea there is more obvious
// and we may end up with more variants to check for the first candidate.
for (; (start_u8[vibrant_second] > 191 || start_u8[vibrant_second] == start_u8[vibrant_third]) &&
(vibrant_second + 1 < vibrant_third);
++vibrant_second) {}
// Now check if we've indeed found a good candidate or should revert the `vibrant_second` to `second`.
if (start_u8[vibrant_second] < 191) { *second = vibrant_second; }
else { vibrant_second = *second; }
// Now check the first character.
for (; (start_u8[vibrant_first] > 191 || start_u8[vibrant_first] == start_u8[vibrant_second] ||
start_u8[vibrant_first] == start_u8[vibrant_third]) &&
(vibrant_first + 1 < vibrant_second);
++vibrant_first) {}
// Now check if we've indeed found a good candidate or should revert the `vibrant_first` to `first`.
// We don't need to shift the third one when dealing with texts as the last byte of the text is
// also the last byte of a rune and contains the most information.
if (start_u8[vibrant_first] < 191) { *first = vibrant_first; }
}
}
#pragma GCC visibility pop
#pragma endregion
#pragma region Serial Implementation
#if !SZ_AVOID_LIBC
#include <stdio.h> // `fprintf`
#include <stdlib.h> // `malloc`, `EXIT_FAILURE`
SZ_PUBLIC void *_sz_memory_allocate_default(sz_size_t length, void *handle) {
sz_unused(handle);
return malloc(length);
}
SZ_PUBLIC void _sz_memory_free_default(sz_ptr_t start, sz_size_t length, void *handle) {
sz_unused(handle && length);
free(start);
}
#endif
SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc) {
#if !SZ_AVOID_LIBC
alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_default;
alloc->free = (sz_memory_free_t)_sz_memory_free_default;
#else
alloc->allocate = (sz_memory_allocate_t)SZ_NULL;
alloc->free = (sz_memory_free_t)SZ_NULL;
#endif
alloc->handle = SZ_NULL;
}
SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length) {
// The logic here is simple - put the buffer length in the first slots of the buffer.
// Later use it for bounds checking.
alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed;
alloc->free = (sz_memory_free_t)_sz_memory_free_fixed;
alloc->handle = &buffer;
sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t));
}
/**
* @brief Byte-level equality comparison between two strings.
* If unaligned loads are allowed, uses a switch-table to avoid loops on short strings.
*/
SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) {
sz_cptr_t const a_end = a + length;
#if SZ_USE_MISALIGNED_LOADS
if (length >= SZ_SWAR_THRESHOLD) {
sz_u64_vec_t a_vec, b_vec;
for (; a + 8 <= a_end; a += 8, b += 8) {
a_vec = sz_u64_load(a);
b_vec = sz_u64_load(b);
if (a_vec.u64 != b_vec.u64) return sz_false_k;
}
}
#endif
while (a != a_end && *a == *b) a++, b++;
return (sz_bool_t)(a_end == a);
}
SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) {
for (sz_cptr_t const end = text + length; text != end; ++text)
if (sz_charset_contains(set, *text)) return text;
return SZ_NULL_CHAR;
}
SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) {
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Warray-bounds"
sz_cptr_t const end = text;
for (text += length; text != end;)
if (sz_charset_contains(set, *(text -= 1))) return text;
return SZ_NULL_CHAR;
#pragma GCC diagnostic pop
}
/**
* One option to avoid branching is to use conditional moves and lookup the comparison result in a table:
* sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k};
* for (; a != min_end; ++a, ++b)
* if (*a != *b) return ordering_lookup[*a < *b];
* That, however, introduces a data-dependency.
* A cleaner option is to perform two comparisons and a subtraction.
* One instruction more, but no data-dependency.
*/
#define _sz_order_scalars(a, b) ((sz_ordering_t)((a > b) - (a < b)))
SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) {
sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length);
sz_size_t min_length = a_shorter ? a_length : b_length;
sz_cptr_t min_end = a + min_length;
#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN
for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) {
a_vec = sz_u64_load(a);
b_vec = sz_u64_load(b);
if (a_vec.u64 != b_vec.u64)
return _sz_order_scalars(sz_u64_bytes_reverse(a_vec.u64), sz_u64_bytes_reverse(b_vec.u64));
}
#endif
for (; a != min_end; ++a, ++b)
if (*a != *b) return _sz_order_scalars(*a, *b);
// If the strings are equal up to `min_end`, then the shorter string is smaller
return _sz_order_scalars(a_length, b_length);
}
/**
* @brief Byte-level equality comparison between two 64-bit integers.
* @return 64-bit integer, where every top bit in each byte signifies a match.
*/
SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) {
sz_u64_vec_t vec;
vec.u64 = ~(a.u64 ^ b.u64);
// The match is valid, if every bit within each byte is set.
// For that take the bottom 7 bits of each byte, add one to them,
// and if this sets the top bit to one, then all the 7 bits are ones as well.
vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull));
return vec;
}
/**
* @brief Find the first occurrence of a @b single-character needle in an arbitrary length haystack.
* This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time.
* Identical to `memchr(haystack, needle[0], haystack_length)`.
*/
SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
if (!h_length) return SZ_NULL_CHAR;
sz_cptr_t const h_end = h + h_length;
#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevity.
#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads.
for (; ((sz_size_t)h & 7ull) && h < h_end; ++h)
if (*h == *n) return h;
#endif
// Broadcast the n into every byte of a 64-bit integer to use SWAR
// techniques and process eight characters at a time.
sz_u64_vec_t h_vec, n_vec, match_vec;
match_vec.u64 = 0;
n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull;
for (; h + 8 <= h_end; h += 8) {
h_vec.u64 = *(sz_u64_t const *)h;
match_vec = _sz_u64_each_byte_equal(h_vec, n_vec);
if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8;
}
#endif
// Handle the misaligned tail.
for (; h < h_end; ++h)
if (*h == *n) return h;
return SZ_NULL_CHAR;
}
/**
* @brief Find the last occurrence of a @b single-character needle in an arbitrary length haystack.
* This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time.
* Identical to `memrchr(haystack, needle[0], haystack_length)`.
*/
sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
if (!h_length) return SZ_NULL_CHAR;
sz_cptr_t const h_start = h;
// Reposition the `h` pointer to the end, as we will be walking backwards.
h = h + h_length - 1;
#if !SZ_DETECT_BIG_ENDIAN // Use SWAR only on little-endian platforms for brevity.
#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads.
for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h)
if (*h == *n) return h;
#endif
// Broadcast the n into every byte of a 64-bit integer to use SWAR
// techniques and process eight characters at a time.
sz_u64_vec_t h_vec, n_vec, match_vec;
n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull;
for (; h >= h_start + 7; h -= 8) {
h_vec.u64 = *(sz_u64_t const *)(h - 7);
match_vec = _sz_u64_each_byte_equal(h_vec, n_vec);
if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8;
}
#endif
for (; h >= h_start; --h)
if (*h == *n) return h;
return SZ_NULL_CHAR;
}
/**
* @brief 2Byte-level equality comparison between two 64-bit integers.
* @return 64-bit integer, where every top bit in each 2byte signifies a match.
*/
SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) {
sz_u64_vec_t vec;
vec.u64 = ~(a.u64 ^ b.u64);
// The match is valid, if every bit within each 2byte is set.
// For that take the bottom 15 bits of each 2byte, add one to them,
// and if this sets the top bit to one, then all the 15 bits are ones as well.
vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull));
return vec;
}
/**
* @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack.
* This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time.
*/
SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
// This is an internal method, and the haystack is guaranteed to be at least 2 bytes long.
sz_assert(h_length >= 2 && "The haystack is too short.");
sz_unused(n_length);
sz_cptr_t const h_end = h + h_length;
#if !SZ_USE_MISALIGNED_LOADS
// Process the misaligned head, to void UB on unaligned 64-bit loads.
for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h)
if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h;
#endif
sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec;
n_vec.u64 = 0;
n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1];
n_vec.u64 *= 0x0001000100010001ull; // broadcast
// This code simulates hyper-scalar execution, analyzing 8 offsets at a time.
for (; h + 9 <= h_end; h += 8) {
h_even_vec.u64 = *(sz_u64_t *)h;
h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56);
matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec);
matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec);
matches_even_vec.u64 >>= 8;
if (matches_even_vec.u64 + matches_odd_vec.u64) {
sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64;
return h + sz_u64_ctz(match_indicators) / 8;
}
}
for (; h + 2 <= h_end; ++h)
if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h;
return SZ_NULL_CHAR;
}
/**
* @brief 4Byte-level equality comparison between two 64-bit integers.
* @return 64-bit integer, where every top bit in each 4byte signifies a match.
*/
SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) {
sz_u64_vec_t vec;
vec.u64 = ~(a.u64 ^ b.u64);
// The match is valid, if every bit within each 4byte is set.
// For that take the bottom 31 bits of each 4byte, add one to them,
// and if this sets the top bit to one, then all the 31 bits are ones as well.
vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull));
return vec;
}
/**
* @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack.
* This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time.
*/
SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
// This is an internal method, and the haystack is guaranteed to be at least 4 bytes long.
sz_assert(h_length >= 4 && "The haystack is too short.");
sz_unused(n_length);
sz_cptr_t const h_end = h + h_length;
#if !SZ_USE_MISALIGNED_LOADS
// Process the misaligned head, to void UB on unaligned 64-bit loads.
for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h)
if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h;
#endif
sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec;
n_vec.u64 = 0;
n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3];
n_vec.u64 *= 0x0000000100000001ull; // broadcast
// This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words.
// We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :)
sz_u64_t h_page_current, h_page_next;
for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) {
h_page_current = *(sz_u64_t *)h;
h_page_next = *(sz_u32_t *)(h + 8);
h0_vec.u64 = (h_page_current);
h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56);
h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48);
h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40);
matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec);
matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec);
matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec);
matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec);
if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) {
matches0_vec.u64 >>= 24;
matches1_vec.u64 >>= 16;
matches2_vec.u64 >>= 8;
sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64;
return h + sz_u64_ctz(match_indicators) / 8;
}
}
for (; h + 4 <= h_end; ++h)
if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h;
return SZ_NULL_CHAR;
}
/**
* @brief 3Byte-level equality comparison between two 64-bit integers.
* @return 64-bit integer, where every top bit in each 3byte signifies a match.
*/
SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) {
sz_u64_vec_t vec;
vec.u64 = ~(a.u64 ^ b.u64);
// The match is valid, if every bit within each 4byte is set.
// For that take the bottom 31 bits of each 4byte, add one to them,
// and if this sets the top bit to one, then all the 31 bits are ones as well.
vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull));
return vec;
}
/**
* @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack.
* This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time.
*/
SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
// This is an internal method, and the haystack is guaranteed to be at least 4 bytes long.
sz_assert(h_length >= 3 && "The haystack is too short.");
sz_unused(n_length);
sz_cptr_t const h_end = h + h_length;
#if !SZ_USE_MISALIGNED_LOADS
// Process the misaligned head, to void UB on unaligned 64-bit loads.
for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h)
if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h;
#endif
// We fetch 12
sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec;
sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec;
sz_u64_vec_t n_vec;
n_vec.u64 = 0;
n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2];
n_vec.u64 *= 0x0000000001000001ull; // broadcast
// This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words.
// We load the subsequent two-byte word as well.
sz_u64_t h_page_current, h_page_next;
for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) {
h_page_current = *(sz_u64_t *)h;
h_page_next = *(sz_u16_t *)(h + 8);
h0_vec.u64 = (h_page_current);
h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56);
h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48);
h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40);
h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32);
matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec);
matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec);
matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec);
matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec);
matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec);
if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) {
matches0_vec.u64 >>= 16;
matches1_vec.u64 >>= 8;
matches3_vec.u64 <<= 8;
matches4_vec.u64 <<= 16;
sz_u64_t match_indicators =
matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64;
return h + sz_u64_ctz(match_indicators) / 8;
}
}
for (; h + 3 <= h_end; ++h)
if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h;
return SZ_NULL_CHAR;
}
/**
* @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long.
* Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern.
*/
SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, //
sz_cptr_t n_chars, sz_size_t n_length) {
sz_assert(n_length <= 256 && "The pattern is too long.");
// Several popular string matching algorithms are using a bad-character shift table.
// Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html
// Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html
// Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html
union {
sz_u8_t jumps[256];
sz_u64_vec_t vecs[64];
} bad_shift_table;
// Let's initialize the table using SWAR to the total length of the string.
sz_u8_t const *h = (sz_u8_t const *)h_chars;
sz_u8_t const *n = (sz_u8_t const *)n_chars;
{
sz_u64_vec_t n_length_vec;
n_length_vec.u64 = ((sz_u8_t)(n_length - 1)) * 0x0101010101010101ull; // broadcast
for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64;
for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1);
}
// Another common heuristic is to match a few characters from different parts of a string.
// Raita suggests to use the first two, the last, and the middle character of the pattern.
sz_u32_vec_t h_vec, n_vec;
// Pick the parts of the needle that are worth comparing.
sz_size_t offset_first, offset_mid, offset_last;
_sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last);
// Broadcast those characters into an unsigned integer.
n_vec.u8s[0] = n[offset_first];
n_vec.u8s[1] = n[offset_first + 1];
n_vec.u8s[2] = n[offset_mid];
n_vec.u8s[3] = n[offset_last];
// Scan through the whole haystack, skipping the last `n_length - 1` bytes.
for (sz_size_t i = 0; i <= h_length - n_length;) {
h_vec.u8s[0] = h[i + offset_first];
h_vec.u8s[1] = h[i + offset_first + 1];
h_vec.u8s[2] = h[i + offset_mid];
h_vec.u8s[3] = h[i + offset_last];
if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i;
i += bad_shift_table.jumps[h[i + n_length - 1]];
}
return SZ_NULL_CHAR;
}
/**
* @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long.
* Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern.
*/
SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, //
sz_cptr_t n_chars, sz_size_t n_length) {
sz_assert(n_length <= 256 && "The pattern is too long.");
union {
sz_u8_t jumps[256];
sz_u64_vec_t vecs[64];
} bad_shift_table;
// Let's initialize the table using SWAR to the total length of the string.
sz_u8_t const *h = (sz_u8_t const *)h_chars;
sz_u8_t const *n = (sz_u8_t const *)n_chars;
{
sz_u64_vec_t n_length_vec;
n_length_vec.u64 = ((sz_u8_t)(n_length - 1)) * 0x0101010101010101ull; // broadcast
for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64;
for (sz_size_t i = 0; i + 1 < n_length; ++i)
bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1);
}
// Another common heuristic is to match a few characters from different parts of a string.
// Raita suggests to use the first two, the last, and the middle character of the pattern.
sz_u32_vec_t h_vec, n_vec;
// Pick the parts of the needle that are worth comparing.
sz_size_t offset_first, offset_mid, offset_last;
_sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last);
// Broadcast those characters into an unsigned integer.
n_vec.u8s[0] = n[offset_first];
n_vec.u8s[1] = n[offset_first + 1];
n_vec.u8s[2] = n[offset_mid];
n_vec.u8s[3] = n[offset_last];
// Scan through the whole haystack, skipping the first `n_length - 1` bytes.
for (sz_size_t j = 0; j <= h_length - n_length;) {
sz_size_t i = h_length - n_length - j;
h_vec.u8s[0] = h[i + offset_first];
h_vec.u8s[1] = h[i + offset_first + 1];
h_vec.u8s[2] = h[i + offset_mid];
h_vec.u8s[3] = h[i + offset_last];
if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i;
j += bad_shift_table.jumps[h[i]];
}
return SZ_NULL_CHAR;
}
/**
* @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle
* using a given search function, and then verifies the remaining part of the needle.
*/
SZ_INTERNAL sz_cptr_t _sz_find_with_prefix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length,
sz_find_t find_prefix, sz_size_t prefix_length) {
sz_size_t suffix_length = n_length - prefix_length;
while (1) {
sz_cptr_t found = find_prefix(h, h_length, n, prefix_length);
if (!found) return SZ_NULL_CHAR;
// Verify the remaining part of the needle
sz_size_t remaining = h_length - (found - h);
if (remaining < n_length) return SZ_NULL_CHAR;
if (sz_equal(found + prefix_length, n + prefix_length, suffix_length)) return found;
// Adjust the position.
h = found + 1;
h_length = remaining - 1;
}
// Unreachable, but helps silence compiler warnings:
return SZ_NULL_CHAR;
}
/**
* @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the
* needle using a given search function, and then verifies the remaining part of the needle.
*/
SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length,
sz_find_t find_suffix, sz_size_t suffix_length) {
sz_size_t prefix_length = n_length - suffix_length;
while (1) {
sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length);
if (!found) return SZ_NULL_CHAR;
// Verify the remaining part of the needle
sz_size_t remaining = found - h;
if (remaining < prefix_length) return SZ_NULL_CHAR;
if (sz_equal(found - prefix_length, n, prefix_length)) return found - prefix_length;
// Adjust the position.
h_length = remaining - 1;
}
// Unreachable, but helps silence compiler warnings:
return SZ_NULL_CHAR;
}
SZ_INTERNAL sz_cptr_t _sz_find_byte_prefix_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
sz_unused(n_length);
return sz_find_byte_serial(h, h_length, n);
}
SZ_INTERNAL sz_cptr_t _sz_rfind_byte_prefix_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
sz_unused(n_length);
return sz_rfind_byte_serial(h, h_length, n);
}
SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_4byte_serial, 4);
}
SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n,
sz_size_t n_length) {
return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256);
}
SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n,
sz_size_t n_length) {
return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256);
}
SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
// This almost never fires, but it's better to be safe than sorry.
if (h_length < n_length || !n_length) return SZ_NULL_CHAR;
#if SZ_DETECT_BIG_ENDIAN
sz_find_t backends[] = {
_sz_find_byte_prefix_serial,
_sz_find_horspool_upto_256bytes_serial,
_sz_find_horspool_over_256bytes_serial,
};
return backends[(n_length > 1) + (n_length > 256)](h, h_length, n, n_length);
#else
sz_find_t backends[] = {
// For very short strings brute-force SWAR makes sense.
_sz_find_byte_prefix_serial,
_sz_find_2byte_serial,
_sz_find_3byte_serial,
_sz_find_4byte_serial,
// To avoid constructing the skip-table, let's use the prefixed approach.
_sz_find_over_4bytes_serial,
// For longer needles - use skip tables.
_sz_find_horspool_upto_256bytes_serial,
_sz_find_horspool_over_256bytes_serial,
};
return backends[
// For very short strings brute-force SWAR makes sense.
(n_length > 1) + (n_length > 2) + (n_length > 3) +
// To avoid constructing the skip-table, let's use the prefixed approach.
(n_length > 4) +
// For longer needles - use skip tables.
(n_length > 8) + (n_length > 256)](h, h_length, n, n_length);
#endif
}
SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
// This almost never fires, but it's better to be safe than sorry.
if (h_length < n_length || !n_length) return SZ_NULL_CHAR;
sz_find_t backends[] = {
// For very short strings brute-force SWAR makes sense.
_sz_rfind_byte_prefix_serial,
// TODO: implement reverse-order SWAR for 2/3/4 byte variants.
// TODO: _sz_rfind_2byte_serial,
// TODO: _sz_rfind_3byte_serial,
// TODO: _sz_rfind_4byte_serial,
// To avoid constructing the skip-table, let's use the prefixed approach.
// _sz_rfind_over_4bytes_serial,
// For longer needles - use skip tables.
_sz_rfind_horspool_upto_256bytes_serial,
_sz_rfind_horspool_over_256bytes_serial,
};
return backends[
// For very short strings brute-force SWAR makes sense.
0 +
// To avoid constructing the skip-table, let's use the prefixed approach.
(n_length > 1) +
// For longer needles - use skip tables.
(n_length > 256)](h, h_length, n, n_length);
}
SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( //
sz_cptr_t shorter, sz_size_t shorter_length, //
sz_cptr_t longer, sz_size_t longer_length, //
sz_size_t bound, sz_memory_allocator_t *alloc) {
// Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome.
sz_memory_allocator_t global_alloc;
if (!alloc) {
sz_memory_allocator_init_default(&global_alloc);
alloc = &global_alloc;
}
// TODO: Generalize to remove the following asserts!
sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix.");
sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet.");
sz_unused(longer_length && bound);
// We are going to store 3 diagonals of the matrix.
// The length of the longest (main) diagonal would be `n = (shorter_length + 1)`.
sz_size_t n = shorter_length + 1;
sz_size_t buffer_length = sizeof(sz_size_t) * n * 3;
sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle);
if (!distances) return SZ_SIZE_MAX;
sz_size_t *previous_distances = distances;
sz_size_t *current_distances = previous_distances + n;
sz_size_t *next_distances = previous_distances + n * 2;
// Initialize the first two diagonals:
previous_distances[0] = 0;
current_distances[0] = current_distances[1] = 1;
// Progress through the upper triangle of the Levenshtein matrix.
sz_size_t next_skew_diagonal_index = 2;
for (; next_skew_diagonal_index != n; ++next_skew_diagonal_index) {
sz_size_t const next_skew_diagonal_length = next_skew_diagonal_index + 1;
for (sz_size_t i = 0; i + 2 < next_skew_diagonal_length; ++i) {
sz_size_t cost_of_substitution = shorter[next_skew_diagonal_index - i - 2] != longer[i];
sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution;
sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1;
next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution);
}
// Don't forget to populate the first row and the first column of the Levenshtein matrix.
next_distances[0] = next_distances[next_skew_diagonal_length - 1] = next_skew_diagonal_index;
// Perform a circular rotation of those buffers, to reuse the memory.
sz_size_t *temporary = previous_distances;
previous_distances = current_distances;
current_distances = next_distances;
next_distances = temporary;
}
// By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a
// larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal
// index on either side, we will be cropping those values out.
sz_size_t total_diagonals = n + n - 1;
for (; next_skew_diagonal_index != total_diagonals; ++next_skew_diagonal_index) {
sz_size_t const next_skew_diagonal_length = total_diagonals - next_skew_diagonal_index;
for (sz_size_t i = 0; i != next_skew_diagonal_length; ++i) {
sz_size_t cost_of_substitution =
shorter[shorter_length - 1 - i] != longer[next_skew_diagonal_index - n + i];
sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution;
sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1;
next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution);
}
// Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift,
// dropping the first element in the current array.
sz_size_t *temporary = previous_distances;
previous_distances = current_distances + 1;
current_distances = next_distances;
next_distances = temporary;
}
// Cache scalar before `free` call.
sz_size_t result = current_distances[0];
alloc->free(distances, buffer_length, alloc->handle);
return result;
}
/**
* @brief Describes the length of a UTF8 character / codepoint / rune in bytes.
*/
typedef enum {
sz_utf8_invalid_k = 0, //!< Invalid UTF8 character.
sz_utf8_rune_1byte_k = 1, //!< 1-byte UTF8 character.
sz_utf8_rune_2bytes_k = 2, //!< 2-byte UTF8 character.
sz_utf8_rune_3bytes_k = 3, //!< 3-byte UTF8 character.
sz_utf8_rune_4bytes_k = 4, //!< 4-byte UTF8 character.
} sz_rune_length_t;
typedef sz_u32_t sz_rune_t;
/**
* @brief Extracts just one UTF8 codepoint from a UTF8 string into a 32-bit unsigned integer.
*/
SZ_INTERNAL void _sz_extract_utf8_rune(sz_cptr_t utf8, sz_rune_t *code, sz_rune_length_t *code_length) {
sz_u8_t const *current = (sz_u8_t const *)utf8;
sz_u8_t leading_byte = *current++;
sz_rune_t ch;
sz_rune_length_t ch_length;
// TODO: This can be made entirely branchless using 32-bit SWAR.
if (leading_byte < 0x80) {
// Single-byte rune (0xxxxxxx)
ch = leading_byte;
ch_length = sz_utf8_rune_1byte_k;
}
else if ((leading_byte & 0xE0) == 0xC0) {
// Two-byte rune (110xxxxx 10xxxxxx)
ch = (leading_byte & 0x1F) << 6;
ch |= (*current++ & 0x3F);
ch_length = sz_utf8_rune_2bytes_k;
}
else if ((leading_byte & 0xF0) == 0xE0) {
// Three-byte rune (1110xxxx 10xxxxxx 10xxxxxx)
ch = (leading_byte & 0x0F) << 12;
ch |= (*current++ & 0x3F) << 6;
ch |= (*current++ & 0x3F);
ch_length = sz_utf8_rune_3bytes_k;
}
else if ((leading_byte & 0xF8) == 0xF0) {
// Four-byte rune (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx)
ch = (leading_byte & 0x07) << 18;
ch |= (*current++ & 0x3F) << 12;
ch |= (*current++ & 0x3F) << 6;
ch |= (*current++ & 0x3F);
ch_length = sz_utf8_rune_4bytes_k;
}
else {
// Invalid UTF8 rune.
ch = 0;
ch_length = sz_utf8_invalid_k;
}
*code = ch;
*code_length = ch_length;
}
/**
* @brief Exports a UTF8 string into a UTF32 buffer.
* ! The result is undefined id the UTF8 string is corrupted.
* @return The length in the number of codepoints.
*/
SZ_INTERNAL sz_size_t _sz_export_utf8_to_utf32(sz_cptr_t utf8, sz_size_t utf8_length, sz_rune_t *utf32) {
sz_cptr_t const end = utf8 + utf8_length;
sz_size_t count = 0;
sz_rune_length_t rune_length;
for (; utf8 != end; utf8 += rune_length, utf32++, count++) _sz_extract_utf8_rune(utf8, utf32, &rune_length);
return count;
}
/**
* @brief Compute the Levenshtein distance between two strings using the Wagner-Fisher algorithm.
* Stores only 2 rows of the Levenshtein matrix, but uses 64-bit integers for the distance values,
* and upcasts UTF8 variable-length codepoints to 64-bit integers for faster addressing.
*
* ! In the worst case for 2 strings of length 100, that contain just one 16-bit codepoint this will result in extra:
* + 2 rows * 100 slots * 8 bytes/slot = 1600 bytes of memory for the two rows of the Levenshtein matrix rows.
* + 100 codepoints * 2 strings * 4 bytes/codepoint = 800 bytes of memory for the UTF8 buffer.
* = 2400 bytes of memory or @b 12x memory amplification!
*/
SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( //
sz_cptr_t longer, sz_size_t longer_length, //
sz_cptr_t shorter, sz_size_t shorter_length, //
sz_size_t bound, sz_bool_t can_be_unicode, sz_memory_allocator_t *alloc) {
// Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome.
sz_memory_allocator_t global_alloc;
if (!alloc) {
sz_memory_allocator_init_default(&global_alloc);
alloc = &global_alloc;
}
// A good idea may be to dispatch different kernels for different string lengths.
// Like using `uint8_t` counters for strings under 255 characters long.
// Good in theory, this results in frequent upcasts and downcasts in serial code.
// On strings over 20 bytes, using `uint8` over `uint64` on 64-bit x86 CPU doubles the execution time.
// So one must be very cautious with such optimizations.
typedef sz_size_t _distance_t;
// Compute the number of columns in our Levenshtein matrix.
sz_size_t const n = shorter_length + 1;
// If a buffering memory-allocator is provided, this operation is practically free,
// and cheaper than allocating even 512 bytes (for small distance matrices) on stack.
sz_size_t buffer_length = sizeof(_distance_t) * (n * 2);
// If the strings contain Unicode characters, let's estimate the max character width,
// and use it to allocate a larger buffer to decode UTF8.
if ((can_be_unicode == sz_true_k) &&
(sz_isascii(longer, longer_length) == sz_false_k || sz_isascii(shorter, shorter_length) == sz_false_k)) {
buffer_length += (shorter_length + longer_length) * sizeof(sz_rune_t);
}
else { can_be_unicode = sz_false_k; }
// If the allocation fails, return the maximum distance.
sz_ptr_t const buffer = (sz_ptr_t)alloc->allocate(buffer_length, alloc->handle);
if (!buffer) return SZ_SIZE_MAX;
// Let's export the UTF8 sequence into the newly allocated buffer at the end.
if (can_be_unicode == sz_true_k) {
sz_rune_t *const longer_utf32 = (sz_rune_t *)(buffer + sizeof(_distance_t) * (n * 2));
sz_rune_t *const shorter_utf32 = longer_utf32 + longer_length;
// Export the UTF8 sequences into the newly allocated buffer.
longer_length = _sz_export_utf8_to_utf32(longer, longer_length, longer_utf32);
shorter_length = _sz_export_utf8_to_utf32(shorter, shorter_length, shorter_utf32);
longer = (sz_cptr_t)longer_utf32;
shorter = (sz_cptr_t)shorter_utf32;
}
// Let's parameterize the core logic for different character types and distance types.
#define _wagner_fisher_unbounded(_distance_t, _char_t) \
/* Now let's cast our pointer to avoid it in subsequent sections. */ \
_char_t const *const longer_chars = (_char_t const *)longer; \
_char_t const *const shorter_chars = (_char_t const *)shorter; \
_distance_t *previous_distances = (_distance_t *)buffer; \
_distance_t *current_distances = previous_distances + n; \
/* Initialize the first row of the Levenshtein matrix with `iota`-style arithmetic progression. */ \
for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \
/* The main loop of the algorithm with quadratic complexity. */ \
for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \
_char_t const longer_char = longer_chars[idx_longer]; \
/* Using pure pointer arithmetic is faster than iterating with an index. */ \
_char_t const *shorter_ptr = shorter_chars; \
_distance_t const *previous_ptr = previous_distances; \
_distance_t *current_ptr = current_distances; \
_distance_t *const current_end = current_ptr + shorter_length; \
current_ptr[0] = idx_longer + 1; \
for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \
_distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \
/* We can avoid `+1` for costs here, shifting it to post-minimum computation, */ \
/* saving one increment operation. */ \
_distance_t cost_deletion = previous_ptr[1]; \
_distance_t cost_insertion = current_ptr[0]; \
/* ? It might be a good idea to enforce branchless execution here. */ \
/* ? The caveat being that the benchmarks on longer sequences backfire and more research is needed. */ \
current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \
} \
/* Swap `previous_distances` and `current_distances` pointers. */ \
_distance_t *temporary = previous_distances; \
previous_distances = current_distances; \
current_distances = temporary; \
} \
/* Cache scalar before `free` call. */ \
sz_size_t result = previous_distances[shorter_length]; \
alloc->free(buffer, buffer_length, alloc->handle); \
return result;
// Let's define a separate variant for bounded distance computation.
// Practically the same as unbounded, but also collecting the running minimum within each row for early exit.
#define _wagner_fisher_bounded(_distance_t, _char_t) \
_char_t const *const longer_chars = (_char_t const *)longer; \
_char_t const *const shorter_chars = (_char_t const *)shorter; \
_distance_t *previous_distances = (_distance_t *)buffer; \
_distance_t *current_distances = previous_distances + n; \
for (_distance_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) previous_distances[idx_shorter] = idx_shorter; \
for (_distance_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { \
_char_t const longer_char = longer_chars[idx_longer]; \
_char_t const *shorter_ptr = shorter_chars; \
_distance_t const *previous_ptr = previous_distances; \
_distance_t *current_ptr = current_distances; \
_distance_t *const current_end = current_ptr + shorter_length; \
current_ptr[0] = idx_longer + 1; \
/* Initialize min_distance with a value greater than bound */ \
_distance_t min_distance = bound - 1; \
for (; current_ptr != current_end; ++previous_ptr, ++current_ptr, ++shorter_ptr) { \
_distance_t cost_substitution = previous_ptr[0] + (_distance_t)(longer_char != shorter_ptr[0]); \
_distance_t cost_deletion = previous_ptr[1]; \
_distance_t cost_insertion = current_ptr[0]; \
current_ptr[1] = sz_min_of_two(cost_substitution, sz_min_of_two(cost_deletion, cost_insertion) + 1); \
/* Keep track of the minimum distance seen so far in this row */ \
min_distance = sz_min_of_two(current_ptr[1], min_distance); \
} \
/* If the minimum distance in this row exceeded the bound, return early */ \
if (min_distance >= bound) { \
alloc->free(buffer, buffer_length, alloc->handle); \
return bound; \
} \
_distance_t *temporary = previous_distances; \
previous_distances = current_distances; \
current_distances = temporary; \
} \
sz_size_t result = previous_distances[shorter_length]; \
alloc->free(buffer, buffer_length, alloc->handle); \
return sz_min_of_two(result, bound);
// Dispatch the actual computation.
if (!bound) {
if (can_be_unicode == sz_true_k) { _wagner_fisher_unbounded(sz_size_t, sz_rune_t); }
else { _wagner_fisher_unbounded(sz_size_t, sz_u8_t); }
}
else {
if (can_be_unicode == sz_true_k) { _wagner_fisher_bounded(sz_size_t, sz_rune_t); }
else { _wagner_fisher_bounded(sz_size_t, sz_u8_t); }
}
}
SZ_PUBLIC sz_size_t sz_edit_distance_serial( //
sz_cptr_t longer, sz_size_t longer_length, //
sz_cptr_t shorter, sz_size_t shorter_length, //
sz_size_t bound, sz_memory_allocator_t *alloc) {
// Let's make sure that we use the amount proportional to the
// number of elements in the shorter string, not the larger.
if (shorter_length > longer_length) {
sz_pointer_swap((void **)&longer_length, (void **)&shorter_length);
sz_pointer_swap((void **)&longer, (void **)&shorter);
}
// Skip the matching prefixes and suffixes, they won't affect the distance.
for (sz_cptr_t a_end = longer + longer_length, b_end = shorter + shorter_length;
longer != a_end && shorter != b_end && *longer == *shorter;
++longer, ++shorter, --longer_length, --shorter_length);
for (; longer_length && shorter_length && longer[longer_length - 1] == shorter[shorter_length - 1];
--longer_length, --shorter_length);
// Bounded computations may exit early.
if (bound) {
// If one of the strings is empty - the edit distance is equal to the length of the other one.
if (longer_length == 0) return sz_min_of_two(shorter_length, bound);
if (shorter_length == 0) return sz_min_of_two(longer_length, bound);
// If the difference in length is beyond the `bound`, there is no need to check at all.
if (longer_length - shorter_length > bound) return bound;
}
if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero.
if (shorter_length == longer_length && !bound)
return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc);
return _sz_edit_distance_wagner_fisher_serial(longer, longer_length, shorter, shorter_length, bound, sz_false_k,
alloc);
}
SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( //
sz_cptr_t longer, sz_size_t longer_length, //
sz_cptr_t shorter, sz_size_t shorter_length, //
sz_error_cost_t const *subs, sz_error_cost_t gap, //
sz_memory_allocator_t *alloc) {
// If one of the strings is empty - the edit distance is equal to the length of the other one
if (longer_length == 0) return (sz_ssize_t)shorter_length * gap;
if (shorter_length == 0) return (sz_ssize_t)longer_length * gap;
// Let's make sure that we use the amount proportional to the
// number of elements in the shorter string, not the larger.
if (shorter_length > longer_length) {
sz_pointer_swap((void **)&longer_length, (void **)&shorter_length);
sz_pointer_swap((void **)&longer, (void **)&shorter);
}
// Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome.
sz_memory_allocator_t global_alloc;
if (!alloc) {
sz_memory_allocator_init_default(&global_alloc);
alloc = &global_alloc;
}
sz_size_t n = shorter_length + 1;
sz_size_t buffer_length = sizeof(sz_ssize_t) * n * 2;
sz_ssize_t *distances = (sz_ssize_t *)alloc->allocate(buffer_length, alloc->handle);
sz_ssize_t *previous_distances = distances;
sz_ssize_t *current_distances = previous_distances + n;
for (sz_size_t idx_shorter = 0; idx_shorter != n; ++idx_shorter)
previous_distances[idx_shorter] = (sz_ssize_t)idx_shorter * gap;
sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter;
sz_u8_t const *longer_unsigned = (sz_u8_t const *)longer;
for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) {
current_distances[0] = ((sz_ssize_t)idx_longer + 1) * gap;
// Initialize min_distance with a value greater than bound
sz_error_cost_t const *a_subs = subs + longer_unsigned[idx_longer] * 256ul;
for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) {
sz_ssize_t cost_deletion = previous_distances[idx_shorter + 1] + gap;
sz_ssize_t cost_insertion = current_distances[idx_shorter] + gap;
sz_ssize_t cost_substitution = previous_distances[idx_shorter] + a_subs[shorter_unsigned[idx_shorter]];
current_distances[idx_shorter + 1] = sz_max_of_three(cost_deletion, cost_insertion, cost_substitution);
}
// Swap previous_distances and current_distances pointers
sz_pointer_swap((void **)&previous_distances, (void **)&current_distances);
}
// Cache scalar before `free` call.
sz_ssize_t result = previous_distances[shorter_length];
alloc->free(distances, buffer_length, alloc->handle);
return result;
}
SZ_PUBLIC sz_size_t sz_hamming_distance_serial( //
sz_cptr_t a, sz_size_t a_length, //
sz_cptr_t b, sz_size_t b_length, //
sz_size_t bound) {
sz_size_t const min_length = sz_min_of_two(a_length, b_length);
sz_size_t const max_length = sz_max_of_two(a_length, b_length);
sz_cptr_t const a_end = a + min_length;
bound = bound == 0 ? max_length : bound;
// Walk through both strings using SWAR and counting the number of differing characters.
sz_size_t distance = max_length - min_length;
#if SZ_USE_MISALIGNED_LOADS && !SZ_DETECT_BIG_ENDIAN
if (min_length >= SZ_SWAR_THRESHOLD) {
sz_u64_vec_t a_vec, b_vec, match_vec;
for (; a + 8 <= a_end && distance < bound; a += 8, b += 8) {
a_vec.u64 = sz_u64_load(a).u64;
b_vec.u64 = sz_u64_load(b).u64;
match_vec = _sz_u64_each_byte_equal(a_vec, b_vec);
distance += sz_u64_popcount((~match_vec.u64) & 0x8080808080808080ull);
}
}
#endif
for (; a != a_end && distance < bound; ++a, ++b) { distance += (*a != *b); }
return sz_min_of_two(distance, bound);
}
SZ_PUBLIC sz_size_t sz_hamming_distance_utf8_serial( //
sz_cptr_t a, sz_size_t a_length, //
sz_cptr_t b, sz_size_t b_length, //
sz_size_t bound) {
sz_cptr_t const a_end = a + a_length;
sz_cptr_t const b_end = b + b_length;
sz_size_t distance = 0;
sz_rune_t a_rune, b_rune;
sz_rune_length_t a_rune_length, b_rune_length;
if (bound) {
for (; a < a_end && b < b_end && distance < bound; a += a_rune_length, b += b_rune_length) {
_sz_extract_utf8_rune(a, &a_rune, &a_rune_length);
_sz_extract_utf8_rune(b, &b_rune, &b_rune_length);
distance += (a_rune != b_rune);
}
// If one string has more runes, we need to go through the tail.
if (distance < bound) {
for (; a < a_end && distance < bound; a += a_rune_length, ++distance)
_sz_extract_utf8_rune(a, &a_rune, &a_rune_length);
for (; b < b_end && distance < bound; b += b_rune_length, ++distance)
_sz_extract_utf8_rune(b, &b_rune, &b_rune_length);
}
}
else {
for (; a < a_end && b < b_end; a += a_rune_length, b += b_rune_length) {
_sz_extract_utf8_rune(a, &a_rune, &a_rune_length);
_sz_extract_utf8_rune(b, &b_rune, &b_rune_length);
distance += (a_rune != b_rune);
}
// If one string has more runes, we need to go through the tail.
for (; a < a_end; a += a_rune_length, ++distance) _sz_extract_utf8_rune(a, &a_rune, &a_rune_length);
for (; b < b_end; b += b_rune_length, ++distance) _sz_extract_utf8_rune(b, &b_rune, &b_rune_length);
}
return distance;
}
SZ_PUBLIC sz_u64_t sz_checksum_serial(sz_cptr_t text, sz_size_t length) {
sz_u64_t checksum = 0;
sz_u8_t const *text_u8 = (sz_u8_t const *)text;
sz_u8_t const *text_end = text_u8 + length;
for (; text_u8 != text_end; ++text_u8) checksum += *text_u8;
return checksum;
}
/**
* @brief Largest prime number that fits into 31 bits.
* @see https://mersenneforum.org/showthread.php?t=3471
*/
#define SZ_U32_MAX_PRIME (2147483647u)
/**
* @brief Largest prime number that fits into 64 bits.
* @see https://mersenneforum.org/showthread.php?t=3471
*
* 2^64 = 18,446,744,073,709,551,616
* this = 18,446,744,073,709,551,557
* diff = 59
*/
#define SZ_U64_MAX_PRIME (18446744073709551557ull)
/*
* One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values.
* Using a Boost-like mixer works very poorly in such case:
*
* hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2));
*
* Let's stick to the Fibonacci hash trick using the golden ratio.
* https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/
*/
#define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull))
#define _sz_shift_low(x) (x)
#define _sz_shift_high(x) ((x + 77ull) & 0xFFull)
#define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME)
SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) {
sz_u64_t hash_low = 0;
sz_u64_t hash_high = 0;
sz_u8_t const *text = (sz_u8_t const *)start;
sz_u8_t const *text_end = text + length;
switch (length) {
case 0: return 0;
// Texts under 7 bytes long are definitely below the largest prime.
case 1:
hash_low = _sz_shift_low(text[0]);
hash_high = _sz_shift_high(text[0]);
break;
case 2:
hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]);
hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]);
break;
case 3:
hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + //
_sz_shift_low(text[1]) * 31ull + //
_sz_shift_low(text[2]);
hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + //
_sz_shift_high(text[1]) * 257ull + //
_sz_shift_high(text[2]);
break;
case 4:
hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + //
_sz_shift_low(text[1]) * 31ull * 31ull + //
_sz_shift_low(text[2]) * 31ull + //
_sz_shift_low(text[3]);
hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + //
_sz_shift_high(text[1]) * 257ull * 257ull + //
_sz_shift_high(text[2]) * 257ull + //
_sz_shift_high(text[3]);
break;
case 5:
hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + //
_sz_shift_low(text[1]) * 31ull * 31ull * 31ull + //
_sz_shift_low(text[2]) * 31ull * 31ull + //
_sz_shift_low(text[3]) * 31ull + //
_sz_shift_low(text[4]);
hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + //
_sz_shift_high(text[1]) * 257ull * 257ull * 257ull + //
_sz_shift_high(text[2]) * 257ull * 257ull + //
_sz_shift_high(text[3]) * 257ull + //
_sz_shift_high(text[4]);
break;
case 6:
hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + //
_sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + //
_sz_shift_low(text[2]) * 31ull * 31ull * 31ull + //
_sz_shift_low(text[3]) * 31ull * 31ull + //
_sz_shift_low(text[4]) * 31ull + //
_sz_shift_low(text[5]);
hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + //
_sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + //
_sz_shift_high(text[2]) * 257ull * 257ull * 257ull + //
_sz_shift_high(text[3]) * 257ull * 257ull + //
_sz_shift_high(text[4]) * 257ull + //
_sz_shift_high(text[5]);
break;
case 7:
hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + //
_sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + //
_sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + //
_sz_shift_low(text[3]) * 31ull * 31ull * 31ull + //
_sz_shift_low(text[4]) * 31ull * 31ull + //
_sz_shift_low(text[5]) * 31ull + //
_sz_shift_low(text[6]);
hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + //
_sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + //
_sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + //
_sz_shift_high(text[3]) * 257ull * 257ull * 257ull + //
_sz_shift_high(text[4]) * 257ull * 257ull + //
_sz_shift_high(text[5]) * 257ull + //
_sz_shift_high(text[6]);
break;
default:
// Unroll the first seven cycles:
hash_low = hash_low * 31ull + _sz_shift_low(text[0]);
hash_high = hash_high * 257ull + _sz_shift_high(text[0]);
hash_low = hash_low * 31ull + _sz_shift_low(text[1]);
hash_high = hash_high * 257ull + _sz_shift_high(text[1]);
hash_low = hash_low * 31ull + _sz_shift_low(text[2]);
hash_high = hash_high * 257ull + _sz_shift_high(text[2]);
hash_low = hash_low * 31ull + _sz_shift_low(text[3]);
hash_high = hash_high * 257ull + _sz_shift_high(text[3]);
hash_low = hash_low * 31ull + _sz_shift_low(text[4]);
hash_high = hash_high * 257ull + _sz_shift_high(text[4]);
hash_low = hash_low * 31ull + _sz_shift_low(text[5]);
hash_high = hash_high * 257ull + _sz_shift_high(text[5]);
hash_low = hash_low * 31ull + _sz_shift_low(text[6]);
hash_high = hash_high * 257ull + _sz_shift_high(text[6]);
text += 7;
// Iterate throw the rest with the modulus:
for (; text != text_end; ++text) {
hash_low = hash_low * 31ull + _sz_shift_low(text[0]);
hash_high = hash_high * 257ull + _sz_shift_high(text[0]);
// Wrap the hashes around:
hash_low = _sz_prime_mod(hash_low);
hash_high = _sz_prime_mod(hash_high);
}
break;
}
return _sz_hash_mix(hash_low, hash_high);
}
SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, //
sz_hash_callback_t callback, void *callback_handle) {
if (length < window_length || !window_length) return;
sz_u8_t const *text = (sz_u8_t const *)start;
sz_u8_t const *text_end = text + length;
// Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic.
sz_u64_t prime_power_low = 1, prime_power_high = 1;
for (sz_size_t i = 0; i + 1 < window_length; ++i)
prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME,
prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME;
// Compute the initial hash value for the first window.
sz_u64_t hash_low = 0, hash_high = 0, hash_mix;
for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text)
hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME,
hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME;
// In most cases the fingerprint length will be a power of two.
hash_mix = _sz_hash_mix(hash_low, hash_high);
callback((sz_cptr_t)text, window_length, hash_mix, callback_handle);
// Compute the hash value for every window, exporting into the fingerprint,
// using the expensive modulo operation.
sz_size_t cycles = 1;
sz_size_t const step_mask = step - 1;
for (; text < text_end; ++text, ++cycles) {
// Discard one character:
hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low;
hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high;
// And add a new one:
hash_low = 31ull * hash_low + _sz_shift_low(*text);
hash_high = 257ull * hash_high + _sz_shift_high(*text);
// Wrap the hashes around:
hash_low = _sz_prime_mod(hash_low);
hash_high = _sz_prime_mod(hash_high);
// Mix only if we've skipped enough hashes.
if ((cycles & step_mask) == 0) {
hash_mix = _sz_hash_mix(hash_low, hash_high);
callback((sz_cptr_t)text, window_length, hash_mix, callback_handle);
}
}
}
#undef _sz_shift_low
#undef _sz_shift_high
#undef _sz_hash_mix
#undef _sz_prime_mod
/**
* @brief Uses a small lookup-table to convert a lowercase character to uppercase.
*/
SZ_INTERNAL sz_u8_t sz_u8_tolower(sz_u8_t c) {
static sz_u8_t const lowered[256] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, //
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, //
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, //
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, //
64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, //
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, //
96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, //
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, //
128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, //
144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, //
160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, //
176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, //
224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, //
240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, //
224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, //
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, //
};
return lowered[c];
}
/**
* @brief Uses a small lookup-table to convert an uppercase character to lowercase.
*/
SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) {
static sz_u8_t const upped[256] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, //
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, //
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, //
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, //
64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, //
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, //
96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, //
80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, //
128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, //
144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, //
160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, //
176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, //
224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, //
240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, //
224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, //
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, //
};
return upped[c];
}
/**
* @brief Uses two small lookup tables (768 bytes total) to accelerate division by a small
* unsigned integer. Performs two lookups, one multiplication, two shifts, and two accumulations.
*
* @param divisor Integral value @b larger than one.
* @param number Integral value to divide.
*/
SZ_INTERNAL sz_u8_t sz_u8_divide(sz_u8_t number, sz_u8_t divisor) {
sz_assert(divisor > 1);
static sz_u16_t const multipliers[256] = {
0, 0, 0, 21846, 0, 39322, 21846, 9363, 0, 50973, 39322, 29790, 21846, 15124, 9363, 4370,
0, 57826, 50973, 44841, 39322, 34329, 29790, 25645, 21846, 18351, 15124, 12137, 9363, 6780, 4370, 2115,
0, 61565, 57826, 54302, 50973, 47824, 44841, 42011, 39322, 36765, 34329, 32006, 29790, 27671, 25645, 23705,
21846, 20063, 18351, 16706, 15124, 13602, 12137, 10725, 9363, 8049, 6780, 5554, 4370, 3224, 2115, 1041,
0, 63520, 61565, 59668, 57826, 56039, 54302, 52614, 50973, 49377, 47824, 46313, 44841, 43407, 42011, 40649,
39322, 38028, 36765, 35532, 34329, 33154, 32006, 30885, 29790, 28719, 27671, 26647, 25645, 24665, 23705, 22766,
21846, 20945, 20063, 19198, 18351, 17520, 16706, 15907, 15124, 14356, 13602, 12863, 12137, 11424, 10725, 10038,
9363, 8700, 8049, 7409, 6780, 6162, 5554, 4957, 4370, 3792, 3224, 2665, 2115, 1573, 1041, 517,
0, 64520, 63520, 62535, 61565, 60609, 59668, 58740, 57826, 56926, 56039, 55164, 54302, 53452, 52614, 51788,
50973, 50169, 49377, 48595, 47824, 47063, 46313, 45572, 44841, 44120, 43407, 42705, 42011, 41326, 40649, 39982,
39322, 38671, 38028, 37392, 36765, 36145, 35532, 34927, 34329, 33738, 33154, 32577, 32006, 31443, 30885, 30334,
29790, 29251, 28719, 28192, 27671, 27156, 26647, 26143, 25645, 25152, 24665, 24182, 23705, 23233, 22766, 22303,
21846, 21393, 20945, 20502, 20063, 19628, 19198, 18772, 18351, 17933, 17520, 17111, 16706, 16305, 15907, 15514,
15124, 14738, 14356, 13977, 13602, 13231, 12863, 12498, 12137, 11779, 11424, 11073, 10725, 10380, 10038, 9699,
9363, 9030, 8700, 8373, 8049, 7727, 7409, 7093, 6780, 6470, 6162, 5857, 5554, 5254, 4957, 4662,
4370, 4080, 3792, 3507, 3224, 2943, 2665, 2388, 2115, 1843, 1573, 1306, 1041, 778, 517, 258,
};
// This table can be avoided using a single addition and counting trailing zeros.
static sz_u8_t const shifts[256] = {
0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, //
4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, //
5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, //
6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, //
6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, //
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, //
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, //
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, //
};
sz_u32_t multiplier = multipliers[divisor];
sz_u8_t shift = shifts[divisor];
sz_u16_t q = (sz_u16_t)((multiplier * number) >> 16);
sz_u16_t t = ((number - q) >> 1) + q;
return (sz_u8_t)(t >> shift);
}
SZ_PUBLIC void sz_look_up_transform_serial(sz_cptr_t text, sz_size_t length, sz_cptr_t lut, sz_ptr_t result) {
sz_u8_t const *unsigned_lut = (sz_u8_t const *)lut;
sz_u8_t const *unsigned_text = (sz_u8_t const *)text;
sz_u8_t *unsigned_result = (sz_u8_t *)result;
sz_u8_t const *end = unsigned_text + length;
for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = unsigned_lut[*unsigned_text];
}
SZ_PUBLIC void sz_tolower_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) {
sz_u8_t *unsigned_result = (sz_u8_t *)result;
sz_u8_t const *unsigned_text = (sz_u8_t const *)text;
sz_u8_t const *end = unsigned_text + length;
for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_tolower(*unsigned_text);
}
SZ_PUBLIC void sz_toupper_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) {
sz_u8_t *unsigned_result = (sz_u8_t *)result;
sz_u8_t const *unsigned_text = (sz_u8_t const *)text;
sz_u8_t const *end = unsigned_text + length;
for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_toupper(*unsigned_text);
}
SZ_PUBLIC void sz_toascii_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) {
sz_u8_t *unsigned_result = (sz_u8_t *)result;
sz_u8_t const *unsigned_text = (sz_u8_t const *)text;
sz_u8_t const *end = unsigned_text + length;
for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = *unsigned_text & 0x7F;
}
/**
* @brief Check if there is a byte in this buffer, that exceeds 127 and can't be an ASCII character.
* This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time.
*/
SZ_PUBLIC sz_bool_t sz_isascii_serial(sz_cptr_t text, sz_size_t length) {
if (!length) return sz_true_k;
sz_u8_t const *h = (sz_u8_t const *)text;
sz_u8_t const *const h_end = h + length;
#if !SZ_USE_MISALIGNED_LOADS
// Process the misaligned head, to void UB on unaligned 64-bit loads.
for (; ((sz_size_t)h & 7ull) && h < h_end; ++h)
if (*h & 0x80ull) return sz_false_k;
#endif
// Validate eight bytes at once using SWAR.
sz_u64_vec_t text_vec;
for (; h + 8 <= h_end; h += 8) {
text_vec.u64 = *(sz_u64_t const *)h;
if (text_vec.u64 & 0x8080808080808080ull) return sz_false_k;
}
// Handle the misaligned tail.
for (; h < h_end; ++h)
if (*h & 0x80ull) return sz_false_k;
return sz_true_k;
}
SZ_PUBLIC void sz_generate_serial(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length,
sz_random_generator_t generator, void *generator_user_data) {
sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size");
if (alphabet_size == 1) sz_fill(result, result_length, *alphabet);
else {
sz_assert(generator && "Expects a valid random generator");
sz_u8_t divisor = (sz_u8_t)alphabet_size;
for (sz_cptr_t end = result + result_length; result != end; ++result) {
sz_u8_t random = generator(generator_user_data) & 0xFF;
sz_u8_t quotient = sz_u8_divide(random, divisor);
*result = alphabet[random - quotient * divisor];
}
}
}
#pragma endregion
/*
* Serial implementation of string class operations.
*/
#pragma region Serial Implementation for the String Class
SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string) {
// It doesn't matter if it's on stack or heap, the pointer location is the same.
return (sz_bool_t)((sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]);
}
SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length) {
sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0];
sz_size_t is_big_mask = is_small - 1ull;
*start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same.
// If the string is small, use branch-less approach to mask-out the top 7 bytes of the length.
*length = string->external.length & (0x00000000000000FFull | is_big_mask);
}
SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space,
sz_bool_t *is_external) {
sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0];
sz_size_t is_big_mask = is_small - 1ull;
*start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same.
// If the string is small, use branch-less approach to mask-out the top 7 bytes of the length.
*length = string->external.length & (0x00000000000000FFull | is_big_mask);
// In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull.
*space = sz_u64_blend(SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask);
*is_external = (sz_bool_t)!is_small;
}
SZ_PUBLIC sz_bool_t sz_string_equal(sz_string_t const *a, sz_string_t const *b) {
// Tempting to say that the external.length is bitwise the same even if it includes
// some bytes of the on-stack payload, but we don't at this writing maintain that invariant.
// (An on-stack string includes noise bytes in the high-order bits of external.length. So do this
// the hard/correct way.
#if SZ_USE_MISALIGNED_LOADS
// Dealing with StringZilla strings, we know that the `start` pointer always points
// to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once.
#endif
// Alternatively, fall back to byte-by-byte comparison.
sz_ptr_t a_start, b_start;
sz_size_t a_length, b_length;
sz_string_range(a, &a_start, &a_length);
sz_string_range(b, &b_start, &b_length);
return (sz_bool_t)(a_length == b_length && sz_equal(a_start, b_start, b_length));
}
SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const *b) {
#if SZ_USE_MISALIGNED_LOADS
// Dealing with StringZilla strings, we know that the `start` pointer always points
// to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once.
#endif
// Alternatively, fall back to byte-by-byte comparison.
sz_ptr_t a_start, b_start;
sz_size_t a_length, b_length;
sz_string_range(a, &a_start, &a_length);
sz_string_range(b, &b_start, &b_length);
return sz_order(a_start, a_length, b_start, b_length);
}
SZ_PUBLIC void sz_string_init(sz_string_t *string) {
sz_assert(string && "String can't be SZ_NULL.");
// Only 8 + 1 + 1 need to be initialized.
string->internal.start = &string->internal.chars[0];
// But for safety let's initialize the entire structure to zeros.
// string->internal.chars[0] = 0;
// string->internal.length = 0;
string->words[1] = 0;
string->words[2] = 0;
string->words[3] = 0;
}
SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) {
sz_size_t space_needed = length + 1; // space for trailing \0
sz_assert(string && allocator && "String and allocator can't be SZ_NULL.");
// Initialize the string to zeros for safety.
string->words[1] = 0;
string->words[2] = 0;
string->words[3] = 0;
// If we are lucky, no memory allocations will be needed.
if (space_needed <= SZ_STRING_INTERNAL_SPACE) {
string->internal.start = &string->internal.chars[0];
string->internal.length = (sz_u8_t)length;
}
else {
// If we are not lucky, we need to allocate memory.
string->external.start = (sz_ptr_t)allocator->allocate(space_needed, allocator->handle);
if (!string->external.start) return SZ_NULL_CHAR;
string->external.length = length;
string->external.space = space_needed;
}
sz_assert(&string->internal.start == &string->external.start && "Alignment confusion");
string->external.start[length] = 0;
return string->external.start;
}
SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) {
sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL.");
sz_size_t new_space = new_capacity + 1;
if (new_space <= SZ_STRING_INTERNAL_SPACE) return string->external.start;
sz_ptr_t string_start;
sz_size_t string_length;
sz_size_t string_space;
sz_bool_t string_is_external;
sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external);
sz_assert(new_space > string_space && "New space must be larger than current.");
sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle);
if (!new_start) return SZ_NULL_CHAR;
sz_copy(new_start, string_start, string_length);
string->external.start = new_start;
string->external.space = new_space;
string->external.padding = 0;
string->external.length = string_length;
// Deallocate the old string.
if (string_is_external) allocator->free(string_start, string_space, allocator->handle);
return string->external.start;
}
SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator) {
sz_assert(string && allocator && "Strings and allocators can't be SZ_NULL.");
sz_ptr_t string_start;
sz_size_t string_length;
sz_size_t string_space;
sz_bool_t string_is_external;
sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external);
// We may already be space-optimal, and in that case we don't need to do anything.
sz_size_t new_space = string_length + 1;
if (string_space == new_space || !string_is_external) return string->external.start;
sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle);
if (!new_start) return SZ_NULL_CHAR;
sz_copy(new_start, string_start, string_length);
string->external.start = new_start;
string->external.space = new_space;
string->external.padding = 0;
string->external.length = string_length;
// Deallocate the old string.
if (string_is_external) allocator->free(string_start, string_space, allocator->handle);
return string->external.start;
}
SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length,
sz_memory_allocator_t *allocator) {
sz_assert(string && allocator && "String and allocator can't be SZ_NULL.");
sz_ptr_t string_start;
sz_size_t string_length;
sz_size_t string_space;
sz_bool_t string_is_external;
sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external);
// The user intended to extend the string.
offset = sz_min_of_two(offset, string_length);
// If we are lucky, no memory allocations will be needed.
if (string_length + added_length < string_space) {
sz_move(string_start + offset + added_length, string_start + offset, string_length - offset);
string_start[string_length + added_length] = 0;
// Even if the string is on the stack, the `+=` won't affect the tail of the string.
string->external.length += added_length;
}
// If we are not lucky, we need to allocate more memory.
else {
sz_size_t next_planned_size = sz_max_of_two(SZ_CACHE_LINE_WIDTH, string_space * 2ull);
sz_size_t min_needed_space = sz_size_bit_ceil(offset + string_length + added_length + 1);
sz_size_t new_space = sz_max_of_two(min_needed_space, next_planned_size);
string_start = sz_string_reserve(string, new_space - 1, allocator);
if (!string_start) return SZ_NULL_CHAR;
// Copy into the new buffer.
sz_move(string_start + offset + added_length, string_start + offset, string_length - offset);
string_start[string_length + added_length] = 0;
string->external.length = string_length + added_length;
}
return string_start;
}
SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) {
sz_assert(string && "String can't be SZ_NULL.");
sz_ptr_t string_start;
sz_size_t string_length;
sz_size_t string_space;
sz_bool_t string_is_external;
sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external);
// Normalize the offset, it can't be larger than the length.
offset = sz_min_of_two(offset, string_length);
// We shouldn't normalize the length, to avoid overflowing on `offset + length >= string_length`,
// if receiving `length == SZ_SIZE_MAX`. After following expression the `length` will contain
// exactly the delta between original and final length of this `string`.
length = sz_min_of_two(length, string_length - offset);
// There are 2 common cases, that wouldn't even require a `memmove`:
// 1. Erasing the entire contents of the string.
// In that case `length` argument will be equal or greater than `length` member.
// 2. Removing the tail of the string with something like `string.pop_back()` in C++.
//
// In both of those, regardless of the location of the string - stack or heap,
// the erasing is as easy as setting the length to the offset.
// In every other case, we must `memmove` the tail of the string to the left.
if (offset + length < string_length)
sz_move(string_start + offset, string_start + offset + length, string_length - offset - length);
// The `string->external.length = offset` assignment would discard last characters
// of the on-the-stack string, but inplace subtraction would work.
string->external.length -= length;
string_start[string_length - length] = 0;
return length;
}
SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator) {
if (!sz_string_is_on_stack(string))
allocator->free(string->external.start, string->external.space, allocator->handle);
sz_string_init(string);
}
// When overriding libc, disable optimizations for this function because MSVC will optimize the loops into a memset.
// Which then causes a stack overflow due to infinite recursion (memset -> sz_fill_serial -> memset).
#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC
#pragma optimize("", off)
#endif
SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value) {
// Dealing with short strings, a single sequential pass would be faster.
// If the size is larger than 2 words, then at least 1 of them will be aligned.
// But just one aligned word may not be worth SWAR.
if (length < SZ_SWAR_THRESHOLD)
while (length--) *(target++) = value;
// In case of long strings, skip unaligned bytes, and then fill the rest in 64-bit chunks.
else {
sz_u64_t value64 = (sz_u64_t)value * 0x0101010101010101ull;
while ((sz_size_t)target & 7ull) *(target++) = value, length--;
while (length >= 8) *(sz_u64_t *)target = value64, target += 8, length -= 8;
while (length--) *(target++) = value;
}
}
#if defined(_MSC_VER) && defined(SZ_OVERRIDE_LIBC) && SZ_OVERRIDE_LIBC
#pragma optimize("", on)
#endif
SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
// The most typical implementation of `memcpy` suffers from Undefined Behavior:
//
// for (char const *end = source + length; source < end; source++) *target++ = *source;
//
// As NULL pointer arithmetic is undefined for calls like `memcpy(NULL, NULL, 0)`.
// That's mitigated in C2y with the N3322 proposal, but our solution uses a design, that has no such issues.
// https://developers.redhat.com/articles/2024/12/11/making-memcpynull-null-0-well-defined
#if SZ_USE_MISALIGNED_LOADS
while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)source, target += 8, source += 8, length -= 8;
#endif
while (length--) *(target++) = *(source++);
}
SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
// Implementing `memmove` is trickier, than `memcpy`, as the ranges may overlap.
// Existing implementations often have two passes, in normal and reversed order,
// depending on the relation of `target` and `source` addresses.
// https://student.cs.uwaterloo.ca/~cs350/common/os161-src-html/doxygen/html/memmove_8c_source.html
// https://marmota.medium.com/c-language-making-memmove-def8792bb8d5
//
// We can use the `memcpy` like left-to-right pass if we know that the `target` is before `source`.
// Or if we know that they don't intersect! In that case the traversal order is irrelevant,
// but older CPUs may predict and fetch forward-passes better.
if (target < source || target >= source + length) {
#if SZ_USE_MISALIGNED_LOADS
while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t const *)(source), target += 8, source += 8, length -= 8;
#endif
while (length--) *(target++) = *(source++);
}
else {
// Jump to the end and walk backwards.
target += length, source += length;
#if SZ_USE_MISALIGNED_LOADS
while (length >= 8) *(sz_u64_t *)(target -= 8) = *(sz_u64_t const *)(source -= 8), length -= 8;
#endif
while (length--) *(--target) = *(--source);
}
}
#pragma endregion
/*
* @brief Serial implementation for strings sequence processing.
*/
#pragma region Serial Implementation for Sequences
SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) {
sz_size_t matches = 0;
while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches;
for (sz_size_t i = matches + 1; i < sequence->count; ++i)
if (predicate(sequence, sequence->order[i]))
sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches;
return matches;
}
SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) {
sz_size_t start_b = partition + 1;
// If the direct merge is already sorted
if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return;
sz_size_t start_a = 0;
while (start_a <= partition && start_b <= sequence->count) {
// If element 1 is in right place
if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; }
else {
sz_size_t value = sequence->order[start_b];
sz_size_t index = start_b;
// Shift all the elements between element 1
// element 2, right by 1.
while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; }
sequence->order[start_a] = value;
// Update all the pointers
start_a++;
partition++;
start_b++;
}
}
}
SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) {
sz_u64_t *keys = sequence->order;
sz_size_t keys_count = sequence->count;
for (sz_size_t i = 1; i < keys_count; i++) {
sz_u64_t i_key = keys[i];
sz_size_t j = i;
for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1];
keys[j] = i_key;
}
}
SZ_INTERNAL void _sz_sift_down(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start,
sz_size_t end) {
sz_size_t root = start;
while (2 * root + 1 <= end) {
sz_size_t child = 2 * root + 1;
if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; }
if (!less(sequence, order[root], order[child])) { return; }
sz_u64_swap(order + root, order + child);
root = child;
}
}
SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) {
sz_size_t start = (count - 2) / 2;
while (1) {
_sz_sift_down(sequence, less, order, start, count - 1);
if (start == 0) return;
start--;
}
}
SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) {
sz_u64_t *order = sequence->order;
sz_size_t count = last - first;
_sz_heapify(sequence, less, order + first, count);
sz_size_t end = count - 1;
while (end > 0) {
sz_u64_swap(order + first, order + first + end);
end--;
_sz_sift_down(sequence, less, order + first, 0, end);
}
}
SZ_PUBLIC void sz_sort_introsort_recursion(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first,
sz_size_t last, sz_size_t depth) {
sz_size_t length = last - first;
switch (length) {
case 0:
case 1: return;
case 2:
if (less(sequence, sequence->order[first + 1], sequence->order[first]))
sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]);
return;
case 3: {
sz_u64_t a = sequence->order[first];
sz_u64_t b = sequence->order[first + 1];
sz_u64_t c = sequence->order[first + 2];
if (less(sequence, b, a)) sz_u64_swap(&a, &b);
if (less(sequence, c, b)) sz_u64_swap(&c, &b);
if (less(sequence, b, a)) sz_u64_swap(&a, &b);
sequence->order[first] = a;
sequence->order[first + 1] = b;
sequence->order[first + 2] = c;
return;
}
}
// Until a certain length, the quadratic-complexity insertion-sort is fine
if (length <= 16) {
sz_sequence_t sub_seq = *sequence;
sub_seq.order += first;
sub_seq.count = length;
sz_sort_insertion(&sub_seq, less);
return;
}
// Fallback to N-logN-complexity heap-sort
if (depth == 0) {
_sz_heapsort(sequence, less, first, last);
return;
}
--depth;
// Median-of-three logic to choose pivot
sz_size_t median = first + length / 2;
if (less(sequence, sequence->order[median], sequence->order[first]))
sz_u64_swap(&sequence->order[first], &sequence->order[median]);
if (less(sequence, sequence->order[last - 1], sequence->order[first]))
sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]);
if (less(sequence, sequence->order[median], sequence->order[last - 1]))
sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]);
// Partition using the median-of-three as the pivot
sz_u64_t pivot = sequence->order[median];
sz_size_t left = first;
sz_size_t right = last - 1;
while (1) {
while (less(sequence, sequence->order[left], pivot)) left++;
while (less(sequence, pivot, sequence->order[right])) right--;
if (left >= right) break;
sz_u64_swap(&sequence->order[left], &sequence->order[right]);
left++;
right--;
}
// Recursively sort the partitions
sz_sort_introsort_recursion(sequence, less, first, left, depth);
sz_sort_introsort_recursion(sequence, less, right + 1, last, depth);
}
SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) {
if (sequence->count == 0) return;
sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0;
sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two;
sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit);
}
SZ_PUBLIC void sz_sort_recursion( //
sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator,
sz_size_t partial_order_length) {
if (!sequence->count) return;
// Array of size one doesn't need sorting - only needs the prefix to be discarded.
if (sequence->count == 1) {
sz_u32_t *order_half_words = (sz_u32_t *)sequence->order;
order_half_words[1] = 0;
return;
}
// Partition a range of integers according to a specific bit value
sz_size_t split = 0;
sz_u64_t mask = (1ull << 63) >> bit_idx;
// The clean approach would be to perform a single pass over the sequence.
//
// while (split != sequence->count && !(sequence->order[split] & mask)) ++split;
// for (sz_size_t i = split + 1; i < sequence->count; ++i)
// if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split;
//
// This, however, doesn't take into account the high relative cost of writes and swaps.
// To circumvent that, we can first count the total number entries to be mapped into either part.
// And then walk through both parts, swapping the entries that are in the wrong part.
// This would often lead to ~15% performance gain.
sz_size_t count_with_bit_set = 0;
for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0;
split = sequence->count - count_with_bit_set;
// It's possible that the sequence is already partitioned.
if (split != 0 && split != sequence->count) {
// Use two pointers to efficiently reposition elements.
// On pointer walks left-to-right from the start, and the other walks right-to-left from the end.
sz_size_t left = 0;
sz_size_t right = sequence->count - 1;
while (1) {
// Find the next element with the bit set on the left side.
while (left < split && !(sequence->order[left] & mask)) ++left;
// Find the next element without the bit set on the right side.
while (right >= split && (sequence->order[right] & mask)) --right;
// Swap the mispositioned elements.
if (left < split && right >= split) {
sz_u64_swap(sequence->order + left, sequence->order + right);
++left;
--right;
}
else { break; }
}
}
// Go down recursively.
if (bit_idx < bit_max) {
sz_sequence_t a = *sequence;
a.count = split;
sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length);
sz_sequence_t b = *sequence;
b.order += split;
b.count -= split;
sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length);
}
// Reached the end of recursion.
else {
// Discard the prefixes.
sz_u32_t *order_half_words = (sz_u32_t *)sequence->order;
for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; }
sz_sequence_t a = *sequence;
a.count = split;
sz_sort_introsort(&a, comparator);
sz_sequence_t b = *sequence;
b.order += split;
b.count -= split;
sz_sort_introsort(&b, comparator);
}
}
SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) {
sz_cptr_t i_str = sequence->get_start(sequence, i_key);
sz_cptr_t j_str = sequence->get_start(sequence, j_key);
sz_size_t i_len = sequence->get_length(sequence, i_key);
sz_size_t j_len = sequence->get_length(sequence, j_key);
return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k);
}
SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) {
#if SZ_DETECT_BIG_ENDIAN
// TODO: Implement partial sort for big-endian systems. For now this sorts the whole thing.
sz_unused(partial_order_length);
sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less);
#else
// Export up to 4 bytes into the `sequence` bits themselves
for (sz_size_t i = 0; i != sequence->count; ++i) {
sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]);
sz_size_t length = sequence->get_length(sequence, sequence->order[i]);
length = length > 4u ? 4u : length;
sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i];
for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j];
}
// Perform optionally-parallel radix sort on them
sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length);
#endif
}
SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) {
#if SZ_DETECT_BIG_ENDIAN
sz_sort_introsort(sequence, (sz_sequence_comparator_t)_sz_sort_is_less);
#else
sz_sort_partial(sequence, sequence->count);
#endif
}
#pragma endregion
/*
* @brief AVX2 implementation of the string search algorithms.
* Very minimalistic, but still faster than the serial implementation.
*/
#pragma region AVX2 Implementation
#if SZ_USE_X86_AVX2
#pragma GCC push_options
#pragma GCC target("avx2")
#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function)
#include <immintrin.h>
/**
* @brief Helper structure to simplify work with 256-bit registers.
*/
typedef union sz_u256_vec_t {
__m256i ymm;
__m128i xmms[2];
sz_u64_t u64s[4];
sz_u32_t u32s[8];
sz_u16_t u16s[16];
sz_u8_t u8s[32];
} sz_u256_vec_t;
SZ_PUBLIC sz_ordering_t sz_order_avx2(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) {
//! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide:
//! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations
return sz_order_serial(a, a_length, b, b_length);
}
SZ_PUBLIC sz_bool_t sz_equal_avx2(sz_cptr_t a, sz_cptr_t b, sz_size_t length) {
sz_u256_vec_t a_vec, b_vec;
while (length >= 32) {
a_vec.ymm = _mm256_lddqu_si256((__m256i const *)a);
b_vec.ymm = _mm256_lddqu_si256((__m256i const *)b);
// One approach can be to use "movemasks", but we could also use a bitwise matching like `_mm256_testnzc_si256`.
int difference_mask = ~_mm256_movemask_epi8(_mm256_cmpeq_epi8(a_vec.ymm, b_vec.ymm));
if (difference_mask == 0) { a += 32, b += 32, length -= 32; }
else { return sz_false_k; }
}
if (length) return sz_equal_serial(a, b, length);
return sz_true_k;
}
SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value) {
char value_char = *(char *)&value;
__m256i value_vec = _mm256_set1_epi8(value_char);
// The naive implementation of this function is very simple.
// It assumes the CPU is great at handling unaligned "stores".
//
// for (; length >= 32; target += 32, length -= 32) _mm256_storeu_si256(target, value_vec);
// sz_fill_serial(target, length, value);
//
// When the buffer is small, there isn't much to innovate.
if (length <= 32) sz_fill_serial(target, length, value);
// When the buffer is aligned, we can avoid any split-stores.
else {
sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less.
sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less.
sz_size_t body_length = length - head_length - tail_length; // Multiple of 32.
sz_u16_t value16 = (sz_u16_t)value * 0x0101u;
sz_u32_t value32 = (sz_u32_t)value16 * 0x00010001u;
sz_u64_t value64 = (sz_u64_t)value32 * 0x0000000100000001ull;
// Fill the head of the buffer. This part is much cleaner with AVX-512.
if (head_length & 1) *(sz_u8_t *)target = value, target++, head_length--;
if (head_length & 2) *(sz_u16_t *)target = value16, target += 2, head_length -= 2;
if (head_length & 4) *(sz_u32_t *)target = value32, target += 4, head_length -= 4;
if (head_length & 8) *(sz_u64_t *)target = value64, target += 8, head_length -= 8;
if (head_length & 16)
_mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, head_length -= 16;
sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size.");
// Fill the aligned body of the buffer.
for (; body_length >= 32; target += 32, body_length -= 32) _mm256_store_si256((__m256i *)target, value_vec);
// Fill the tail of the buffer. This part is much cleaner with AVX-512.
sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size.");
if (tail_length & 16)
_mm_store_si128((__m128i *)target, _mm_set1_epi8(value_char)), target += 16, tail_length -= 16;
if (tail_length & 8) *(sz_u64_t *)target = value64, target += 8, tail_length -= 8;
if (tail_length & 4) *(sz_u32_t *)target = value32, target += 4, tail_length -= 4;
if (tail_length & 2) *(sz_u16_t *)target = value16, target += 2, tail_length -= 2;
if (tail_length & 1) *(sz_u8_t *)target = value, target++, tail_length--;
}
}
SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
// The naive implementation of this function is very simple.
// It assumes the CPU is great at handling unaligned "stores" and "loads".
//
// for (; length >= 32; target += 32, source += 32, length -= 32)
// _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source));
// sz_copy_serial(target, source, length);
//
// A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core,
// 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer.
// For now, let's avoid the cases beyond the L2 size.
int is_huge = length > 1ull * 1024ull * 1024ull;
if (length <= 32) { sz_copy_serial(target, source, length); }
// When dealing with larger arrays, the optimization is not as simple as with the `sz_fill_avx2` function,
// as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer,
// we can use aligned loads and stores, and the performance will be great.
else if ((sz_size_t)target % 32 == 0 && (sz_size_t)source % 32 == 0 && !is_huge) {
for (; length >= 32; target += 32, source += 32, length -= 32)
_mm256_store_si256((__m256i *)target, _mm256_load_si256((__m256i const *)source));
if (length) sz_copy_serial(target, source, length);
}
// The trickiest case is when both `source` and `target` are not aligned.
// In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary,
// and then combine unaligned loads with aligned stores.
else {
sz_size_t head_length = (32 - ((sz_size_t)target % 32)) % 32; // 31 or less.
sz_size_t tail_length = (sz_size_t)(target + length) % 32; // 31 or less.
sz_size_t body_length = length - head_length - tail_length; // Multiple of 32.
// Fill the head of the buffer. This part is much cleaner with AVX-512.
if (head_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, head_length--;
if (head_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, head_length -= 2;
if (head_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, head_length -= 4;
if (head_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, head_length -= 8;
if (head_length & 16)
_mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16,
head_length -= 16;
sz_assert(head_length == 0 && "The head length should be zero after the head copy.");
sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size.");
// Fill the aligned body of the buffer.
if (!is_huge) {
for (; body_length >= 32; target += 32, source += 32, body_length -= 32)
_mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source));
}
// When the buffer is huge, we can traverse it in 2 directions.
else {
size_t tails_bytes_skipped = 0;
for (; body_length >= 64; target += 32, source += 32, body_length -= 64, tails_bytes_skipped += 32) {
_mm256_store_si256((__m256i *)(target), _mm256_lddqu_si256((__m256i const *)(source)));
_mm256_store_si256((__m256i *)(target + body_length - 32),
_mm256_lddqu_si256((__m256i const *)(source + body_length - 32)));
}
if (body_length) {
sz_assert(body_length == 32 && "The only remaining body length should be 32 bytes.");
_mm256_store_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source));
target += 32, source += 32, body_length -= 32;
}
target += tails_bytes_skipped;
source += tails_bytes_skipped;
}
// Fill the tail of the buffer. This part is much cleaner with AVX-512.
sz_assert((sz_size_t)target % 32 == 0 && "Target is supposed to be aligned to the YMM register size.");
if (tail_length & 16)
_mm_store_si128((__m128i *)target, _mm_lddqu_si128((__m128i const *)source)), target += 16, source += 16,
tail_length -= 16;
if (tail_length & 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, tail_length -= 8;
if (tail_length & 4) *(sz_u32_t *)target = *(sz_u32_t *)source, target += 4, source += 4, tail_length -= 4;
if (tail_length & 2) *(sz_u16_t *)target = *(sz_u16_t *)source, target += 2, source += 2, tail_length -= 2;
if (tail_length & 1) *(sz_u8_t *)target = *(sz_u8_t *)source, target++, source++, tail_length--;
sz_assert(tail_length == 0 && "The tail length should be zero after the tail copy.");
}
}
SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
if (target < source || target >= source + length) {
for (; length >= 32; target += 32, source += 32, length -= 32)
_mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source));
while (length--) *(target++) = *(source++);
}
else {
// Jump to the end and walk backwards.
for (target += length, source += length; length >= 32; length -= 32)
_mm256_storeu_si256((__m256i *)(target -= 32), _mm256_lddqu_si256((__m256i const *)(source -= 32)));
while (length--) *(--target) = *(--source);
}
}
SZ_PUBLIC sz_u64_t sz_checksum_avx2(sz_cptr_t text, sz_size_t length) {
// The naive implementation of this function is very simple.
// It assumes the CPU is great at handling unaligned "loads".
//
// A typical AWS Skylake instance can have 32 KB x 2 blocks of L1 data cache per core,
// 1 MB x 2 blocks of L2 cache per core, and one shared L3 cache buffer.
// For now, let's avoid the cases beyond the L2 size.
int is_huge = length > 1ull * 1024ull * 1024ull;
// When the buffer is small, there isn't much to innovate.
if (length <= 32) { return sz_checksum_serial(text, length); }
else if (!is_huge) {
sz_u256_vec_t text_vec, sums_vec;
sums_vec.ymm = _mm256_setzero_si256();
for (; length >= 32; text += 32, length -= 32) {
text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text);
sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256()));
}
// Accumulating 256 bits is harders, as we need to extract the 128-bit sums first.
__m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm);
__m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1);
__m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm);
sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm);
sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1);
sz_u64_t result = low + high;
if (length) result += sz_checksum_serial(text, length);
return result;
}
// For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use.
// Most notably, we can avoid populating the cache with the entire buffer, and instead traverse it in 2 directions.
else {
sz_size_t head_length = (32 - ((sz_size_t)text % 32)) % 32; // 31 or less.
sz_size_t tail_length = (sz_size_t)(text + length) % 32; // 31 or less.
sz_size_t body_length = length - head_length - tail_length; // Multiple of 32.
sz_u64_t result = 0;
// Handle the head
while (head_length--) result += *text++;
sz_u256_vec_t text_vec, sums_vec;
sums_vec.ymm = _mm256_setzero_si256();
// Fill the aligned body of the buffer.
if (!is_huge) {
for (; body_length >= 32; text += 32, body_length -= 32) {
text_vec.ymm = _mm256_stream_load_si256((__m256i const *)text);
sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256()));
}
}
// When the buffer is huge, we can traverse it in 2 directions.
else {
sz_u256_vec_t text_reversed_vec, sums_reversed_vec;
sums_reversed_vec.ymm = _mm256_setzero_si256();
for (; body_length >= 64; text += 64, body_length -= 64) {
text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text));
sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256()));
text_reversed_vec.ymm = _mm256_stream_load_si256((__m256i *)(text + body_length - 64));
sums_reversed_vec.ymm = _mm256_add_epi64(
sums_reversed_vec.ymm, _mm256_sad_epu8(text_reversed_vec.ymm, _mm256_setzero_si256()));
}
if (body_length >= 32) {
text_vec.ymm = _mm256_stream_load_si256((__m256i *)(text));
sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, _mm256_sad_epu8(text_vec.ymm, _mm256_setzero_si256()));
}
sums_vec.ymm = _mm256_add_epi64(sums_vec.ymm, sums_reversed_vec.ymm);
}
// Handle the tail
while (tail_length--) result += *text++;
// Accumulating 256 bits is harder, as we need to extract the 128-bit sums first.
__m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymm);
__m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymm, 1);
__m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm);
sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm);
sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1);
result += low + high;
return result;
}
}
SZ_PUBLIC void sz_look_up_transform_avx2(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) {
// If the input is tiny (especially smaller than the look-up table itself), we may end up paying
// more for organizing the SIMD registers and changing the CPU state, than for the actual computation.
// But if at least 3 cache lines are touched, the AVX-2 implementation should be faster.
if (length <= 128) {
sz_look_up_transform_serial(source, length, lut, target);
return;
}
// We need to pull the lookup table into 8x YMM registers.
// The biggest issue is reorganizing the data in the lookup table, as AVX2 doesn't have 256-bit shuffle,
// it only has 128-bit "within-lane" shuffle. Still, it's wiser to use full YMM registers, instead of XMM,
// so that we can at least compensate high latency with twice larger window and one more level of lookup.
sz_u256_vec_t lut_0_to_15_vec, lut_16_to_31_vec, lut_32_to_47_vec, lut_48_to_63_vec, //
lut_64_to_79_vec, lut_80_to_95_vec, lut_96_to_111_vec, lut_112_to_127_vec, //
lut_128_to_143_vec, lut_144_to_159_vec, lut_160_to_175_vec, lut_176_to_191_vec, //
lut_192_to_207_vec, lut_208_to_223_vec, lut_224_to_239_vec, lut_240_to_255_vec;
lut_0_to_15_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut)));
lut_16_to_31_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 16)));
lut_32_to_47_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 32)));
lut_48_to_63_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 48)));
lut_64_to_79_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 64)));
lut_80_to_95_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 80)));
lut_96_to_111_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 96)));
lut_112_to_127_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 112)));
lut_128_to_143_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 128)));
lut_144_to_159_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 144)));
lut_160_to_175_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 160)));
lut_176_to_191_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 176)));
lut_192_to_207_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 192)));
lut_208_to_223_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 208)));
lut_224_to_239_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 224)));
lut_240_to_255_vec.ymm = _mm256_broadcastsi128_si256(_mm_lddqu_si128((__m128i const *)(lut + 240)));
// Assuming each lookup is performed within 16 elements of 256, we need to reduce the scope by 16x = 2^4.
sz_u256_vec_t not_first_bit_vec, not_second_bit_vec, not_third_bit_vec, not_fourth_bit_vec;
/// Top and bottom nibbles of the source are used separately.
sz_u256_vec_t source_vec, source_bot_vec;
sz_u256_vec_t blended_0_to_31_vec, blended_32_to_63_vec, blended_64_to_95_vec, blended_96_to_127_vec,
blended_128_to_159_vec, blended_160_to_191_vec, blended_192_to_223_vec, blended_224_to_255_vec;
// Handling the head.
while (length >= 32) {
// Load and separate the nibbles of each byte in the source.
source_vec.ymm = _mm256_lddqu_si256((__m256i const *)source);
source_bot_vec.ymm = _mm256_and_si256(source_vec.ymm, _mm256_set1_epi8((char)0x0F));
// In the first round, we select using the 4th bit.
not_fourth_bit_vec.ymm = _mm256_cmpeq_epi8( //
_mm256_and_si256(_mm256_set1_epi8((char)0x10), source_vec.ymm), _mm256_setzero_si256());
blended_0_to_31_vec.ymm = _mm256_blendv_epi8( //
_mm256_shuffle_epi8(lut_16_to_31_vec.ymm, source_bot_vec.ymm), //
_mm256_shuffle_epi8(lut_0_to_15_vec.ymm, source_bot_vec.ymm), //
not_fourth_bit_vec.ymm);
blended_32_to_63_vec.ymm = _mm256_blendv_epi8( //
_mm256_shuffle_epi8(lut_48_to_63_vec.ymm, source_bot_vec.ymm), //
_mm256_shuffle_epi8(lut_32_to_47_vec.ymm, source_bot_vec.ymm), //
not_fourth_bit_vec.ymm);
blended_64_to_95_vec.ymm = _mm256_blendv_epi8( //
_mm256_shuffle_epi8(lut_80_to_95_vec.ymm, source_bot_vec.ymm), //
_mm256_shuffle_epi8(lut_64_to_79_vec.ymm, source_bot_vec.ymm), //
not_fourth_bit_vec.ymm);
blended_96_to_127_vec.ymm = _mm256_blendv_epi8( //
_mm256_shuffle_epi8(lut_112_to_127_vec.ymm, source_bot_vec.ymm), //
_mm256_shuffle_epi8(lut_96_to_111_vec.ymm, source_bot_vec.ymm), //
not_fourth_bit_vec.ymm);
blended_128_to_159_vec.ymm = _mm256_blendv_epi8( //
_mm256_shuffle_epi8(lut_144_to_159_vec.ymm, source_bot_vec.ymm), //
_mm256_shuffle_epi8(lut_128_to_143_vec.ymm, source_bot_vec.ymm), //
not_fourth_bit_vec.ymm);
blended_160_to_191_vec.ymm = _mm256_blendv_epi8( //
_mm256_shuffle_epi8(lut_176_to_191_vec.ymm, source_bot_vec.ymm), //
_mm256_shuffle_epi8(lut_160_to_175_vec.ymm, source_bot_vec.ymm), //
not_fourth_bit_vec.ymm);
blended_192_to_223_vec.ymm = _mm256_blendv_epi8( //
_mm256_shuffle_epi8(lut_208_to_223_vec.ymm, source_bot_vec.ymm), //
_mm256_shuffle_epi8(lut_192_to_207_vec.ymm, source_bot_vec.ymm), //
not_fourth_bit_vec.ymm);
blended_224_to_255_vec.ymm = _mm256_blendv_epi8( //
_mm256_shuffle_epi8(lut_240_to_255_vec.ymm, source_bot_vec.ymm), //
_mm256_shuffle_epi8(lut_224_to_239_vec.ymm, source_bot_vec.ymm), //
not_fourth_bit_vec.ymm);
// Perform a tree-like reduction of the 8x "blended" YMM registers, depending on the "source" content.
// The first round selects using the 3rd bit.
not_third_bit_vec.ymm = _mm256_cmpeq_epi8( //
_mm256_and_si256(_mm256_set1_epi8((char)0x20), source_vec.ymm), _mm256_setzero_si256());
blended_0_to_31_vec.ymm = _mm256_blendv_epi8( //
blended_32_to_63_vec.ymm, //
blended_0_to_31_vec.ymm, //
not_third_bit_vec.ymm);
blended_64_to_95_vec.ymm = _mm256_blendv_epi8( //
blended_96_to_127_vec.ymm, //
blended_64_to_95_vec.ymm, //
not_third_bit_vec.ymm);
blended_128_to_159_vec.ymm = _mm256_blendv_epi8( //
blended_160_to_191_vec.ymm, //
blended_128_to_159_vec.ymm, //
not_third_bit_vec.ymm);
blended_192_to_223_vec.ymm = _mm256_blendv_epi8( //
blended_224_to_255_vec.ymm, //
blended_192_to_223_vec.ymm, //
not_third_bit_vec.ymm);
// The second round selects using the 2nd bit.
not_second_bit_vec.ymm = _mm256_cmpeq_epi8( //
_mm256_and_si256(_mm256_set1_epi8((char)0x40), source_vec.ymm), _mm256_setzero_si256());
blended_0_to_31_vec.ymm = _mm256_blendv_epi8( //
blended_64_to_95_vec.ymm, //
blended_0_to_31_vec.ymm, //
not_second_bit_vec.ymm);
blended_128_to_159_vec.ymm = _mm256_blendv_epi8( //
blended_192_to_223_vec.ymm, //
blended_128_to_159_vec.ymm, //
not_second_bit_vec.ymm);
// The third round selects using the 1st bit.
not_first_bit_vec.ymm = _mm256_cmpeq_epi8( //
_mm256_and_si256(_mm256_set1_epi8((char)0x80), source_vec.ymm), _mm256_setzero_si256());
blended_0_to_31_vec.ymm = _mm256_blendv_epi8( //
blended_128_to_159_vec.ymm, //
blended_0_to_31_vec.ymm, //
not_first_bit_vec.ymm);
// And dump the result into the target.
_mm256_storeu_si256((__m256i *)target, blended_0_to_31_vec.ymm);
source += 32, target += 32, length -= 32;
}
// Handle the tail.
if (length) sz_look_up_transform_serial(source, length, lut, target);
}
SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
int mask;
sz_u256_vec_t h_vec, n_vec;
n_vec.ymm = _mm256_set1_epi8(n[0]);
while (h_length >= 32) {
h_vec.ymm = _mm256_lddqu_si256((__m256i const *)h);
mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm));
if (mask) return h + sz_u32_ctz(mask);
h += 32, h_length -= 32;
}
return sz_find_byte_serial(h, h_length, n);
}
SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
int mask;
sz_u256_vec_t h_vec, n_vec;
n_vec.ymm = _mm256_set1_epi8(n[0]);
while (h_length >= 32) {
h_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + h_length - 32));
mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm));
if (mask) return h + h_length - 1 - sz_u32_clz(mask);
h_length -= 32;
}
return sz_rfind_byte_serial(h, h_length, n);
}
SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
// This almost never fires, but it's better to be safe than sorry.
if (h_length < n_length || !n_length) return SZ_NULL_CHAR;
if (n_length == 1) return sz_find_byte_avx2(h, h_length, n);
// Pick the parts of the needle that are worth comparing.
sz_size_t offset_first, offset_mid, offset_last;
_sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last);
// Broadcast those characters into YMM registers.
int matches;
sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec;
n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]);
n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]);
n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]);
// Scan through the string.
for (; h_length >= n_length + 32; h += 32, h_length -= 32) {
h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_first));
h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_mid));
h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_last));
matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) &
_mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) &
_mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm));
while (matches) {
int potential_offset = sz_u32_ctz(matches);
if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset;
matches &= matches - 1;
}
}
return sz_find_serial(h, h_length, n, n_length);
}
SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
// This almost never fires, but it's better to be safe than sorry.
if (h_length < n_length || !n_length) return SZ_NULL_CHAR;
if (n_length == 1) return sz_rfind_byte_avx2(h, h_length, n);
// Pick the parts of the needle that are worth comparing.
sz_size_t offset_first, offset_mid, offset_last;
_sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last);
// Broadcast those characters into YMM registers.
int matches;
sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec;
n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]);
n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]);
n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]);
// Scan through the string.
sz_cptr_t h_reversed;
for (; h_length >= n_length + 32; h_length -= 32) {
h_reversed = h + h_length - n_length - 32 + 1;
h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_first));
h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_mid));
h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_last));
matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) &
_mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) &
_mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm));
while (matches) {
int potential_offset = sz_u32_clz(matches);
if (sz_equal(h + h_length - n_length - potential_offset, n, n_length))
return h + h_length - n_length - potential_offset;
matches &= ~(1 << (31 - potential_offset));
}
}
return sz_rfind_serial(h, h_length, n, n_length);
}
SZ_PUBLIC sz_cptr_t sz_find_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) {
// Let's unzip even and odd elements and replicate them into both lanes of the YMM register.
// That way when we invoke `_mm256_shuffle_epi8` we can use the same mask for both lanes.
sz_u256_vec_t filter_even_vec, filter_odd_vec;
for (sz_size_t i = 0; i != 16; ++i)
filter_even_vec.u8s[i] = filter->_u8s[i * 2], filter_odd_vec.u8s[i] = filter->_u8s[i * 2 + 1];
filter_even_vec.xmms[1] = filter_even_vec.xmms[0];
filter_odd_vec.xmms[1] = filter_odd_vec.xmms[0];
sz_u256_vec_t text_vec;
sz_u256_vec_t matches_vec;
sz_u256_vec_t lower_nibbles_vec, higher_nibbles_vec;
sz_u256_vec_t bitset_even_vec, bitset_odd_vec;
sz_u256_vec_t bitmask_vec, bitmask_lookup_vec;
bitmask_lookup_vec.ymm = _mm256_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, //
-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1);
while (length >= 32) {
// The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set"
// solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so
// StrinZilla uses a somewhat different approach.
// http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new
//
// sz_u8_t input = *(sz_u8_t const *)text;
// sz_u8_t lo_nibble = input & 0x0f;
// sz_u8_t hi_nibble = input >> 4;
// sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble];
// sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble];
// sz_u8_t bitmask = (1 << (lo_nibble & 0x7));
// sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd;
// if ((bitset & bitmask) != 0) return text;
// else { length--, text++; }
//
// The nice part about this, loading the strided data is vey easy with Arm NEON,
// while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either.
text_vec.ymm = _mm256_lddqu_si256((__m256i const *)text);
lower_nibbles_vec.ymm = _mm256_and_si256(text_vec.ymm, _mm256_set1_epi8(0x0f));
bitmask_vec.ymm = _mm256_shuffle_epi8(bitmask_lookup_vec.ymm, lower_nibbles_vec.ymm);
//
// At this point we can validate the `bitmask_vec` contents like this:
//
// for (sz_size_t i = 0; i != 32; ++i) {
// sz_u8_t input = *(sz_u8_t const *)(text + i);
// sz_u8_t lo_nibble = input & 0x0f;
// sz_u8_t bitmask = (1 << (lo_nibble & 0x7));
// sz_assert(bitmask_vec.u8s[i] == bitmask);
// }
//
// Shift right every byte by 4 bits.
// There is no `_mm256_srli_epi8` intrinsic, so we have to use `_mm256_srli_epi16`
// and combine it with a mask to clear the higher bits.
higher_nibbles_vec.ymm = _mm256_and_si256(_mm256_srli_epi16(text_vec.ymm, 4), _mm256_set1_epi8(0x0f));
bitset_even_vec.ymm = _mm256_shuffle_epi8(filter_even_vec.ymm, higher_nibbles_vec.ymm);
bitset_odd_vec.ymm = _mm256_shuffle_epi8(filter_odd_vec.ymm, higher_nibbles_vec.ymm);
//
// At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this:
//
// for (sz_size_t i = 0; i != 32; ++i) {
// sz_u8_t input = *(sz_u8_t const *)(text + i);
// sz_u8_t const *bitset_ptr = &filter->_u8s[0];
// sz_u8_t hi_nibble = input >> 4;
// sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2];
// sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1];
// sz_assert(bitset_even_vec.u8s[i] == bitset_even);
// sz_assert(bitset_odd_vec.u8s[i] == bitset_odd);
// }
//
__m256i take_first = _mm256_cmpgt_epi8(_mm256_set1_epi8(8), lower_nibbles_vec.ymm);
bitset_even_vec.ymm = _mm256_blendv_epi8(bitset_odd_vec.ymm, bitset_even_vec.ymm, take_first);
// It would have been great to have an instruction that tests the bits and then broadcasts
// the matching bit into all bits in that byte. But we don't have that, so we have to
// `and`, `cmpeq`, `movemask`, and then invert at the end...
matches_vec.ymm = _mm256_and_si256(bitset_even_vec.ymm, bitmask_vec.ymm);
matches_vec.ymm = _mm256_cmpeq_epi8(matches_vec.ymm, _mm256_setzero_si256());
int matches_mask = ~_mm256_movemask_epi8(matches_vec.ymm);
if (matches_mask) {
int offset = sz_u32_ctz(matches_mask);
return text + offset;
}
else { text += 32, length -= 32; }
}
return sz_find_charset_serial(text, length, filter);
}
SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx2(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) {
return sz_rfind_charset_serial(text, length, filter);
}
/**
* @brief There is no AVX2 instruction for fast multiplication of 64-bit integers.
* This implementation is coming from Agner Fog's Vector Class Library.
*/
SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) {
__m256i bswap = _mm256_shuffle_epi32(b, 0xB1);
__m256i prodlh = _mm256_mullo_epi32(a, bswap);
__m256i zero = _mm256_setzero_si256();
__m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero);
__m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73);
__m256i prodll = _mm256_mul_epu32(a, b);
__m256i prod = _mm256_add_epi64(prodll, prodlh3);
return prod;
}
SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, //
sz_hash_callback_t callback, void *callback_handle) {
if (length < window_length || !window_length) return;
if (length < 4 * window_length) {
sz_hashes_serial(start, length, window_length, step, callback, callback_handle);
return;
}
// Using AVX2, we can perform 4 long integer multiplications and additions within one register.
// So let's slice the entire string into 4 overlapping windows, to slide over them in parallel.
sz_size_t const max_hashes = length - window_length + 1;
sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads.
sz_u8_t const *text_first = (sz_u8_t const *)start;
sz_u8_t const *text_second = text_first + min_hashes_per_thread;
sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2;
sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3;
sz_u8_t const *text_end = text_first + length;
// Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic.
sz_u64_t prime_power_low = 1, prime_power_high = 1;
for (sz_size_t i = 0; i + 1 < window_length; ++i)
prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME,
prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME;
// Broadcast the constants into the registers.
sz_u256_vec_t prime_vec, golden_ratio_vec;
sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec;
base_low_vec.ymm = _mm256_set1_epi64x(31ull);
base_high_vec.ymm = _mm256_set1_epi64x(257ull);
shift_high_vec.ymm = _mm256_set1_epi64x(77ull);
prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME);
golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull);
prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low);
prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high);
// Compute the initial hash values for every one of the four windows.
sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec;
hash_low_vec.ymm = _mm256_setzero_si256();
hash_high_vec.ymm = _mm256_setzero_si256();
for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end;
++text_first, ++text_second, ++text_third, ++text_fourth) {
// 1. Multiply the hashes by the base.
hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm);
hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm);
// 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`,
// `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`.
chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]);
chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm);
// 3. Add the incoming characters.
hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm);
hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm);
// 4. Compute the modulo. Assuming there are only 59 values between our prime
// and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime.
hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm),
_mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm));
hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm),
_mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm));
}
// 5. Compute the hash mix, that will be used to index into the fingerprint.
// This includes a serial step at the end.
hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm);
hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm);
hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm);
callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle);
callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle);
callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle);
callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle);
// Now repeat that operation for the remaining characters, discarding older characters.
sz_size_t cycle = 1;
sz_size_t const step_mask = step - 1;
for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) {
// 0. Load again the four characters we are dropping, shift them, and subtract.
chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[-window_length], text_third[-window_length],
text_second[-window_length], text_first[-window_length]);
chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm);
hash_low_vec.ymm =
_mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm));
hash_high_vec.ymm =
_mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm));
// 1. Multiply the hashes by the base.
hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm);
hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm);
// 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`,
// `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`.
chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]);
chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm);
// 3. Add the incoming characters.
hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm);
hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm);
// 4. Compute the modulo. Assuming there are only 59 values between our prime
// and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime.
hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm),
_mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm));
hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm),
_mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm));
// 5. Compute the hash mix, that will be used to index into the fingerprint.
// This includes a serial step at the end.
hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm);
hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm);
hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm);
if ((cycle & step_mask) == 0) {
callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle);
callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle);
callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle);
callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle);
}
}
}
#pragma clang attribute pop
#pragma GCC pop_options
#endif
#pragma endregion
/*
* @brief AVX-512 implementation of the string search algorithms.
*
* Different subsets of AVX-512 were introduced in different years:
* * 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW
* * 2018 CannonLake: IFMA, VBMI
* * 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES
* * 2020 TigerLake: VP2INTERSECT
*/
#pragma region AVX - 512 Implementation
#if SZ_USE_X86_AVX512
#pragma GCC push_options
#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2")
#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function)
#include <immintrin.h>
/**
* @brief Helper structure to simplify work with 512-bit registers.
*/
typedef union sz_u512_vec_t {
__m512i zmm;
__m256i ymms[2];
__m128i xmms[4];
sz_u64_t u64s[8];
sz_u32_t u32s[16];
sz_u16_t u16s[32];
sz_u8_t u8s[64];
sz_i64_t i64s[8];
sz_i32_t i32s[16];
} sz_u512_vec_t;
SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) {
// The simplest approach to compute this if we know that `n` is blow or equal 64:
// return (1ull << n) - 1;
// A slightly more complex approach, if we don't know that `n` is under 64:
return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? (sz_u32_t)n : 64);
}
SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) {
// The simplest approach to compute this if we know that `n` is blow or equal 32:
// return (1ull << n) - 1;
// A slightly more complex approach, if we don't know that `n` is under 32:
return _bzhi_u32(0xFFFFFFFF, n < 32 ? (sz_u32_t)n : 32);
}
SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) {
// The simplest approach to compute this if we know that `n` is blow or equal 16:
// return (1ull << n) - 1;
// A slightly more complex approach, if we don't know that `n` is under 16:
return _bzhi_u32(0xFFFFFFFF, n < 16 ? (sz_u32_t)n : 16);
}
SZ_INTERNAL __mmask16 _sz_u16_mask_until(sz_size_t n) {
// The simplest approach to compute this if we know that `n` is blow or equal 16:
// return (1ull << n) - 1;
// A slightly more complex approach, if we don't know that `n` is under 16:
return (__mmask16)_bzhi_u32(0xFFFFFFFF, (sz_u32_t)n);
}
SZ_INTERNAL __mmask32 _sz_u32_mask_until(sz_size_t n) {
// The simplest approach to compute this if we know that `n` is blow or equal 32:
// return (1ull << n) - 1;
// A slightly more complex approach, if we don't know that `n` is under 32:
return _bzhi_u32(0xFFFFFFFF, (sz_u32_t)n);
}
SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) {
// The simplest approach to compute this if we know that `n` is blow or equal 64:
// return (1ull << n) - 1;
// A slightly more complex approach, if we don't know that `n` is under 64:
return _bzhi_u64(0xFFFFFFFFFFFFFFFF, (sz_u32_t)n);
}
SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) {
sz_u512_vec_t a_vec, b_vec;
// Pointer arithmetic is cheap, fetching memory is not!
// So we can use the masked loads to fetch at most one cache-line for each string,
// compare the prefixes, and only then move forward.
sz_size_t a_head_length = 64 - ((sz_size_t)a % 64); // 63 or less.
sz_size_t b_head_length = 64 - ((sz_size_t)b % 64); // 63 or less.
a_head_length = a_head_length < a_length ? a_head_length : a_length;
b_head_length = b_head_length < b_length ? b_head_length : b_length;
sz_size_t head_length = a_head_length < b_head_length ? a_head_length : b_head_length;
__mmask64 head_mask = _sz_u64_mask_until(head_length);
a_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, a);
b_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, b);
__mmask64 mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm);
if (mask_not_equal != 0) {
sz_u64_t first_diff = _tzcnt_u64(mask_not_equal);
char a_char = a_vec.u8s[first_diff];
char b_char = b_vec.u8s[first_diff];
return _sz_order_scalars(a_char, b_char);
}
else if (head_length == a_length && head_length == b_length) { return sz_equal_k; }
else { a += head_length, b += head_length, a_length -= head_length, b_length -= head_length; }
// The rare case, when both string are very long.
__mmask64 a_mask, b_mask;
while ((a_length >= 64) & (b_length >= 64)) {
a_vec.zmm = _mm512_loadu_si512(a);
b_vec.zmm = _mm512_loadu_si512(b);
mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm);
if (mask_not_equal != 0) {
sz_u64_t first_diff = _tzcnt_u64(mask_not_equal);
char a_char = a_vec.u8s[first_diff];
char b_char = b_vec.u8s[first_diff];
return _sz_order_scalars(a_char, b_char);
}
a += 64, b += 64, a_length -= 64, b_length -= 64;
}
// In most common scenarios at least one of the strings is under 64 bytes.
if (a_length | b_length) {
a_mask = _sz_u64_clamp_mask_until(a_length);
b_mask = _sz_u64_clamp_mask_until(b_length);
a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a);
b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b);
// The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments.
// They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have
// been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards.
mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm);
if (mask_not_equal != 0) {
sz_u64_t first_diff = _tzcnt_u64(mask_not_equal);
char a_char = a_vec.u8s[first_diff];
char b_char = b_vec.u8s[first_diff];
return _sz_order_scalars(a_char, b_char);
}
// From logic perspective, the hardest cases are "abc\0" and "abc".
// The result must be `sz_greater_k`, as the latter is shorter.
else { return _sz_order_scalars(a_length, b_length); }
}
return sz_equal_k;
}
SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) {
__mmask64 mask;
sz_u512_vec_t a_vec, b_vec;
while (length >= 64) {
a_vec.zmm = _mm512_loadu_si512(a);
b_vec.zmm = _mm512_loadu_si512(b);
mask = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm);
if (mask != 0) return sz_false_k;
a += 64, b += 64, length -= 64;
}
if (length) {
mask = _sz_u64_mask_until(length);
a_vec.zmm = _mm512_maskz_loadu_epi8(mask, a);
b_vec.zmm = _mm512_maskz_loadu_epi8(mask, b);
// Reuse the same `mask` variable to find the bit that doesn't match
mask = _mm512_mask_cmpneq_epi8_mask(mask, a_vec.zmm, b_vec.zmm);
return (sz_bool_t)(mask == 0);
}
return sz_true_k;
}
SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value) {
__m512i value_vec = _mm512_set1_epi8(value);
// The naive implementation of this function is very simple.
// It assumes the CPU is great at handling unaligned "stores".
//
// for (; length >= 64; target += 64, length -= 64) _mm512_storeu_si512(target, value_vec);
// _mm512_mask_storeu_epi8(target, _sz_u64_mask_until(length), value_vec);
//
// When the buffer is small, there isn't much to innovate.
if (length <= 64) {
__mmask64 mask = _sz_u64_mask_until(length);
_mm512_mask_storeu_epi8(target, mask, value_vec);
}
// When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail,
// and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores
// by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked
// for the body.
else {
sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less.
sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less.
sz_size_t body_length = length - head_length - tail_length; // Multiple of 64.
__mmask64 head_mask = _sz_u64_mask_until(head_length);
__mmask64 tail_mask = _sz_u64_mask_until(tail_length);
_mm512_mask_storeu_epi8(target, head_mask, value_vec);
for (target += head_length; body_length >= 64; target += 64, body_length -= 64)
_mm512_store_si512(target, value_vec);
_mm512_mask_storeu_epi8(target, tail_mask, value_vec);
}
}
SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
// The naive implementation of this function is very simple.
// It assumes the CPU is great at handling unaligned "stores" and "loads".
//
// for (; length >= 64; target += 64, source += 64, length -= 64)
// _mm512_storeu_si512(target, _mm512_loadu_si512(source));
// __mmask64 mask = _sz_u64_mask_until(length);
// _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source));
//
// A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core,
// 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache.
// With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length.
int const is_huge = length >= 1ull * 1024ull * 1024ull;
// When the buffer is small, there isn't much to innovate.
if (length <= 64) {
__mmask64 mask = _sz_u64_mask_until(length);
_mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source));
}
// When dealing with larger arrays, the optimization is not as simple as with the `sz_fill_avx512` function,
// as both buffers may be unaligned. If we are lucky and the requested operation is some huge page transfer,
// we can use aligned loads and stores, and the performance will be great.
else if ((sz_size_t)target % 64 == 0 && (sz_size_t)source % 64 == 0 && !is_huge) {
for (; length >= 64; target += 64, source += 64, length -= 64)
_mm512_store_si512(target, _mm512_load_si512(source));
// At this point the length is guaranteed to be under 64.
__mmask64 mask = _sz_u64_mask_until(length);
// Aligned load and stores would work too, but it's not defined.
_mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source));
}
// The trickiest case is when both `source` and `target` are not aligned.
// In such and simpler cases we can copy enough bytes into `target` to reach its cacheline boundary,
// and then combine unaligned loads with aligned stores.
else if (!is_huge) {
sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less.
sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less.
sz_size_t body_length = length - head_length - tail_length; // Multiple of 64.
__mmask64 head_mask = _sz_u64_mask_until(head_length);
__mmask64 tail_mask = _sz_u64_mask_until(tail_length);
_mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source));
for (target += head_length, source += head_length; body_length >= 64;
target += 64, source += 64, body_length -= 64)
_mm512_store_si512(target, _mm512_loadu_si512(source)); // Unaligned load, but aligned store!
_mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source));
}
// For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use.
//
// 1. Moving in both directions to maximize the throughput, when fetching from multiple
// memory pages. Also helps with cache set-associativity issues, as we won't always
// be fetching the same entries in the lookup table.
// 2. Using non-temporal stores to avoid polluting the cache.
// 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless
// for predictable patterns, so disregard this advice.
//
// Bidirectional traversal adds about 10%, accelerating from 11 GB/s to 12 GB/s.
// Using "streaming stores" boosts us from 12 GB/s to 19 GB/s.
else {
sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64;
sz_size_t tail_length = (sz_size_t)(target + length) % 64;
sz_size_t body_length = length - head_length - tail_length;
__mmask64 head_mask = _sz_u64_mask_until(head_length);
__mmask64 tail_mask = _sz_u64_mask_until(tail_length);
_mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source));
_mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask,
_mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length));
// Now in the main loop, we can use non-temporal loads and stores,
// performing the operation in both directions.
for (target += head_length, source += head_length; //
body_length >= 128; //
target += 64, source += 64, body_length -= 128) {
_mm512_stream_si512((__m512i *)(target), _mm512_loadu_si512(source));
_mm512_stream_si512((__m512i *)(target + body_length - 64), _mm512_loadu_si512(source + body_length - 64));
}
if (body_length >= 64) _mm512_stream_si512((__m512i *)target, _mm512_loadu_si512(source));
}
}
SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
if (target == source) return; // Don't be silly, don't move the data if it's already there.
// On very short buffers, that are one cache line in width or less, we don't need any loops.
// We can also avoid any data-dependencies between iterations, assuming we have 32 registers
// to pre-load the data, before writing it back.
if (length <= 64) {
__mmask64 mask = _sz_u64_mask_until(length);
_mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source));
}
else if (length <= 128) {
sz_size_t last_length = length - 64;
__mmask64 mask = _sz_u64_mask_until(last_length);
__m512i source0 = _mm512_loadu_epi8(source);
__m512i source1 = _mm512_maskz_loadu_epi8(mask, source + 64);
_mm512_storeu_epi8(target, source0);
_mm512_mask_storeu_epi8(target + 64, mask, source1);
}
else if (length <= 192) {
sz_size_t last_length = length - 128;
__mmask64 mask = _sz_u64_mask_until(last_length);
__m512i source0 = _mm512_loadu_epi8(source);
__m512i source1 = _mm512_loadu_epi8(source + 64);
__m512i source2 = _mm512_maskz_loadu_epi8(mask, source + 128);
_mm512_storeu_epi8(target, source0);
_mm512_storeu_epi8(target + 64, source1);
_mm512_mask_storeu_epi8(target + 128, mask, source2);
}
else if (length <= 256) {
sz_size_t last_length = length - 192;
__mmask64 mask = _sz_u64_mask_until(last_length);
__m512i source0 = _mm512_loadu_epi8(source);
__m512i source1 = _mm512_loadu_epi8(source + 64);
__m512i source2 = _mm512_loadu_epi8(source + 128);
__m512i source3 = _mm512_maskz_loadu_epi8(mask, source + 192);
_mm512_storeu_epi8(target, source0);
_mm512_storeu_epi8(target + 64, source1);
_mm512_storeu_epi8(target + 128, source2);
_mm512_mask_storeu_epi8(target + 192, mask, source3);
}
// If the regions don't overlap at all, just use "copy" and save some brain cells thinking about corner cases.
else if (target + length < source || target >= source + length) { sz_copy_avx512(target, source, length); }
// When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail,
// and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores
// by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked
// for the body.
else {
sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less.
sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less.
sz_size_t body_length = length - head_length - tail_length; // Multiple of 64.
__mmask64 head_mask = _sz_u64_mask_until(head_length);
__mmask64 tail_mask = _sz_u64_mask_until(tail_length);
// The absolute most common case of using "moves" is shifting the data within a continuous buffer
// when adding a removing some values in it. In such cases, a typical shift is by 1, 2, 4, 8, 16,
// or 32 bytes, rarely larger. For small shifts, under the size of the ZMM register, we can use shuffles.
//
// Remember:
// - if we are shifting data left, that we are traversing to the right.
// - if we are shifting data right, that we are traversing to the left.
int const left_to_right_traversal = source > target;
// Now we guarantee, that the relative shift within registers is from 1 to 63 bytes and the output is aligned.
// Hopefully, we need to shift more than two ZMM registers, so we could consider `valignr` instruction.
// Sadly, using `_mm512_alignr_epi8` doesn't make sense, as it operates at a 128-bit granularity.
//
// - `_mm256_alignr_epi8` shifts entire 256-bit register, but we need many of them.
// - `_mm512_alignr_epi32` shifts 512-bit chunks, but only if the `shift` is a multiple of 4 bytes.
// - `_mm512_alignr_epi64` shifts 512-bit chunks by 8 bytes.
//
// All of those have a latency of 1 cycle, and the shift amount must be an immediate value!
// For 1-byte-shift granularity, the `_mm512_permutex2var_epi8` has a latency of 6 and needs VBMI!
// The most efficient and broadly compatible alternative could be to use a combination of align and shuffle.
// A similar approach was outlined in "Byte-wise alignr in AVX512F" by Wojciech Muła.
// http://0x80.pl/notesen/2016-10-16-avx512-byte-alignr.html
//
// That solution, is extremely mouthful, assuming we need compile time constants for the shift amount.
// A cleaner one, with a latency of 3 cycles, is to use `_mm512_permutexvar_epi8` or
// `_mm512_mask_permutexvar_epi8`, which can be seen as combination of a cross-register shuffle and blend,
// and is available with VBMI. That solution is still noticeably slower than AVX2.
//
// The GLibC implementation also uses non-temporal stores for larger buffers, we don't.
// https://codebrowser.dev/glibc/glibc/sysdeps/x86_64/multiarch/memmove-avx512-no-vzeroupper.S.html
if (left_to_right_traversal) {
// Head, body, and tail.
_mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source));
for (target += head_length, source += head_length; body_length >= 64;
target += 64, source += 64, body_length -= 64)
_mm512_store_si512(target, _mm512_loadu_si512(source));
_mm512_mask_storeu_epi8(target, tail_mask, _mm512_maskz_loadu_epi8(tail_mask, source));
}
else {
// Tail, body, and head.
_mm512_mask_storeu_epi8(target + head_length + body_length, tail_mask,
_mm512_maskz_loadu_epi8(tail_mask, source + head_length + body_length));
for (; body_length >= 64; body_length -= 64)
_mm512_store_si512(target + head_length + body_length - 64,
_mm512_loadu_si512(source + head_length + body_length - 64));
_mm512_mask_storeu_epi8(target, head_mask, _mm512_maskz_loadu_epi8(head_mask, source));
}
}
}
SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
__mmask64 mask;
sz_u512_vec_t h_vec, n_vec;
n_vec.zmm = _mm512_set1_epi8(n[0]);
while (h_length >= 64) {
h_vec.zmm = _mm512_loadu_si512(h);
mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm);
if (mask) return h + sz_u64_ctz(mask);
h += 64, h_length -= 64;
}
if (h_length) {
mask = _sz_u64_mask_until(h_length);
h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h);
// Reuse the same `mask` variable to find the bit that doesn't match
mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm);
if (mask) return h + sz_u64_ctz(mask);
}
return SZ_NULL_CHAR;
}
SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
// This almost never fires, but it's better to be safe than sorry.
if (h_length < n_length || !n_length) return SZ_NULL_CHAR;
if (n_length == 1) return sz_find_byte_avx512(h, h_length, n);
// Pick the parts of the needle that are worth comparing.
sz_size_t offset_first, offset_mid, offset_last;
_sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last);
// Broadcast those characters into ZMM registers.
__mmask64 matches;
__mmask64 mask;
sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec;
n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]);
n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]);
n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]);
// Scan through the string.
// We have several optimized versions of the lagorithm for shorter strings,
// but they all mimic the default case for unbounded length needles
if (n_length >= 64) {
for (; h_length >= n_length + 64; h += 64, h_length -= 64) {
h_first_vec.zmm = _mm512_loadu_si512(h + offset_first);
h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid);
h_last_vec.zmm = _mm512_loadu_si512(h + offset_last);
matches = _kand_mask64(_kand_mask64( // Intersect the masks
_mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm),
_mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)),
_mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm));
while (matches) {
int potential_offset = sz_u64_ctz(matches);
if (sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset;
matches &= matches - 1;
}
// TODO: If the last character contains a bad byte, we can reposition the start of the next iteration.
// This will be very helpful for very long needles.
}
}
// If there are only 2 or 3 characters in the needle, we don't even need the nested loop.
else if (n_length <= 3) {
for (; h_length >= n_length + 64; h += 64, h_length -= 64) {
h_first_vec.zmm = _mm512_loadu_si512(h + offset_first);
h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid);
h_last_vec.zmm = _mm512_loadu_si512(h + offset_last);
matches = _kand_mask64(_kand_mask64( // Intersect the masks
_mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm),
_mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)),
_mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm));
if (matches) return h + sz_u64_ctz(matches);
}
}
// If the needle is smaller than the size of the ZMM register, we can use masked comparisons
// to avoid the the inner-most nested loop and compare the entire needle against a haystack
// slice in 3 CPU cycles.
else {
__mmask64 n_mask = _sz_u64_mask_until(n_length);
sz_u512_vec_t n_full_vec, h_full_vec;
n_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, n);
for (; h_length >= n_length + 64; h += 64, h_length -= 64) {
h_first_vec.zmm = _mm512_loadu_si512(h + offset_first);
h_mid_vec.zmm = _mm512_loadu_si512(h + offset_mid);
h_last_vec.zmm = _mm512_loadu_si512(h + offset_last);
matches = _kand_mask64(_kand_mask64( // Intersect the masks
_mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm),
_mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)),
_mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm));
while (matches) {
int potential_offset = sz_u64_ctz(matches);
h_full_vec.zmm = _mm512_maskz_loadu_epi8(n_mask, h + potential_offset);
if (_mm512_mask_cmpneq_epi8_mask(n_mask, h_full_vec.zmm, n_full_vec.zmm) == 0)
return h + potential_offset;
matches &= matches - 1;
}
}
}
// The "tail" of the function uses masked loads to process the remaining bytes.
{
mask = _sz_u64_mask_until(h_length - n_length + 1);
h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first);
h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid);
h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last);
matches = _kand_mask64(_kand_mask64( // Intersect the masks
_mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm),
_mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)),
_mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm));
while (matches) {
int potential_offset = sz_u64_ctz(matches);
if (n_length <= 3 || sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset;
matches &= matches - 1;
}
}
return SZ_NULL_CHAR;
}
SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
__mmask64 mask;
sz_u512_vec_t h_vec, n_vec;
n_vec.zmm = _mm512_set1_epi8(n[0]);
while (h_length >= 64) {
h_vec.zmm = _mm512_loadu_si512(h + h_length - 64);
mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm);
if (mask) return h + h_length - 1 - sz_u64_clz(mask);
h_length -= 64;
}
if (h_length) {
mask = _sz_u64_mask_until(h_length);
h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h);
// Reuse the same `mask` variable to find the bit that doesn't match
mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm);
if (mask) return h + 64 - sz_u64_clz(mask) - 1;
}
return SZ_NULL_CHAR;
}
SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
// This almost never fires, but it's better to be safe than sorry.
if (h_length < n_length || !n_length) return SZ_NULL_CHAR;
if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n);
// Pick the parts of the needle that are worth comparing.
sz_size_t offset_first, offset_mid, offset_last;
_sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last);
// Broadcast those characters into ZMM registers.
__mmask64 mask;
__mmask64 matches;
sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec;
n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]);
n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]);
n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]);
// Scan through the string.
sz_cptr_t h_reversed;
for (; h_length >= n_length + 64; h_length -= 64) {
h_reversed = h + h_length - n_length - 64 + 1;
h_first_vec.zmm = _mm512_loadu_si512(h_reversed + offset_first);
h_mid_vec.zmm = _mm512_loadu_si512(h_reversed + offset_mid);
h_last_vec.zmm = _mm512_loadu_si512(h_reversed + offset_last);
matches = _kand_mask64(_kand_mask64( // Intersect the masks
_mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm),
_mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)),
_mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm));
while (matches) {
int potential_offset = sz_u64_clz(matches);
if (n_length <= 3 || sz_equal_avx512(h + h_length - n_length - potential_offset, n, n_length))
return h + h_length - n_length - potential_offset;
sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 &&
"The bit must be set before we squash it");
matches &= ~((sz_u64_t)1 << (63 - potential_offset));
}
}
// The "tail" of the function uses masked loads to process the remaining bytes.
{
mask = _sz_u64_mask_until(h_length - n_length + 1);
h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first);
h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid);
h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last);
matches = _kand_mask64(_kand_mask64( // Intersect the masks
_mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm),
_mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)),
_mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm));
while (matches) {
int potential_offset = sz_u64_clz(matches);
if (n_length <= 3 || sz_equal_avx512(h + 64 - potential_offset - 1, n, n_length))
return h + 64 - potential_offset - 1;
sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 &&
"The bit must be set before we squash it");
matches &= ~((sz_u64_t)1 << (63 - potential_offset));
}
}
return SZ_NULL_CHAR;
}
SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( //
sz_cptr_t shorter, sz_size_t shorter_length, //
sz_cptr_t longer, sz_size_t longer_length, //
sz_size_t bound, sz_memory_allocator_t *alloc) {
// Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome.
sz_memory_allocator_t global_alloc;
if (!alloc) {
sz_memory_allocator_init_default(&global_alloc);
alloc = &global_alloc;
}
// TODO: Generalize!
sz_size_t max_length = 256u * 256u;
sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix.");
sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet.");
sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant.");
sz_unused(longer_length && bound && max_length);
// We are going to store 3 diagonals of the matrix.
// The length of the longest (main) diagonal would be `n = (shorter_length + 1)`.
sz_size_t n = shorter_length + 1;
// Unlike the serial version, we also want to avoid reverse-order iteration over teh shorter string.
// So let's allocate a bit more memory and reverse-export our shorter string into that buffer.
sz_size_t buffer_length = sizeof(sz_u16_t) * n * 3 + shorter_length;
sz_u16_t *distances = (sz_u16_t *)alloc->allocate(buffer_length, alloc->handle);
if (!distances) return SZ_SIZE_MAX;
sz_u16_t *previous_distances = distances;
sz_u16_t *current_distances = previous_distances + n;
sz_u16_t *next_distances = current_distances + n;
sz_ptr_t shorter_reversed = (sz_ptr_t)(next_distances + n);
// Export the reversed string into the buffer.
for (sz_size_t i = 0; i != shorter_length; ++i) shorter_reversed[i] = shorter[shorter_length - 1 - i];
// Initialize the first two diagonals:
previous_distances[0] = 0;
current_distances[0] = current_distances[1] = 1;
// Using ZMM registers, we can process 32x 16-bit values at once,
// storing 16 bytes of each string in YMM registers.
sz_u512_vec_t insertions_vec, deletions_vec, substitutions_vec, next_vec;
sz_u512_vec_t ones_u16_vec;
ones_u16_vec.zmm = _mm512_set1_epi16(1);
// This is a mixed-precision implementation, using 8-bit representations for part of the operations.
// Even there, in case `SZ_USE_X86_AVX2=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs.
sz_u512_vec_t shorter_vec, longer_vec;
sz_u512_vec_t ones_u8_vec;
ones_u8_vec.ymms[0] = _mm256_set1_epi8(1);
// Progress through the upper triangle of the Levenshtein matrix.
sz_size_t next_skew_diagonal_index = 2;
for (; next_skew_diagonal_index != n; ++next_skew_diagonal_index) {
sz_size_t const next_skew_diagonal_length = next_skew_diagonal_index + 1;
for (sz_size_t i = 0; i + 2 < next_skew_diagonal_length;) {
sz_u32_t remaining_length = (sz_u32_t)(next_skew_diagonal_length - i - 2);
sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32;
sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length);
longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + i);
// Our original code addressed the shorter string `[next_skew_diagonal_index - i - 2]` for growing `i`.
// If the `shorter` string was reversed, the `[next_skew_diagonal_index - i - 2]` would
// be equal to `[shorter_length - 1 - next_skew_diagonal_index + i + 2]`.
// Which simplified would be equal to `[shorter_length - next_skew_diagonal_index + i + 1]`.
shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8(
remaining_length_mask, shorter_reversed + shorter_length - next_skew_diagonal_index + i + 1);
// For substitutions, perform the equality comparison using AVX2 instead of AVX-512
// to get the result as a vector, instead of a bitmask. Adding 1 to every scalar we can overflow
// transforming from {0xFF, 0} values to {0, 1} values - exactly what we need. Then - upcast to 16-bit.
substitutions_vec.zmm = _mm512_cvtepi8_epi16( //
_mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0]));
substitutions_vec.zmm = _mm512_add_epi16( //
substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i));
// For insertions and deletions, on modern hardware, it's faster to issue two separate loads,
// than rotate the bytes in the ZMM register.
insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i);
deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1);
// First get the minimum of insertions and deletions.
next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm);
next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm);
_mm512_mask_storeu_epi16(next_distances + i + 1, remaining_length_mask, next_vec.zmm);
i += register_length;
}
// Don't forget to populate the first row and the fiest column of the Levenshtein matrix.
next_distances[0] = next_distances[next_skew_diagonal_length - 1] = (sz_u16_t)next_skew_diagonal_index;
// Perform a circular rotation of those buffers, to reuse the memory.
sz_u16_t *temporary = previous_distances;
previous_distances = current_distances;
current_distances = next_distances;
next_distances = temporary;
}
// By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a
// larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal
// index on either side, we will be cropping those values out.
sz_size_t total_diagonals = n + n - 1;
for (; next_skew_diagonal_index != total_diagonals; ++next_skew_diagonal_index) {
sz_size_t const next_skew_diagonal_length = total_diagonals - next_skew_diagonal_index;
for (sz_size_t i = 0; i != next_skew_diagonal_length;) {
sz_u32_t remaining_length = (sz_u32_t)(next_skew_diagonal_length - i);
sz_u32_t register_length = remaining_length < 32 ? remaining_length : 32;
sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length);
longer_vec.ymms[0] =
_mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_skew_diagonal_index - n + i);
// Our original code addressed the shorter string `[shorter_length - 1 - i]` for growing `i`.
// If the `shorter` string was reversed, the `[shorter_length - 1 - i]` would
// be equal to `[shorter_length - 1 - shorter_length + 1 + i]`.
// Which simplified would be equal to just `[i]`. Beautiful!
shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, shorter_reversed + i);
// For substitutions, perform the equality comparison using AVX2 instead of AVX-512
// to get the result as a vector, instead of a bitmask. The compare it against the accumulated
// substitution costs.
substitutions_vec.zmm = _mm512_cvtepi8_epi16( //
_mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0]));
substitutions_vec.zmm = _mm512_add_epi16( //
substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i));
// For insertions and deletions, on modern hardware, it's faster to issue two separate loads,
// than rotate the bytes in the ZMM register.
insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i);
deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1);
// First get the minimum of insertions and deletions.
next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm);
next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm);
_mm512_mask_storeu_epi16(next_distances + i, remaining_length_mask, next_vec.zmm);
i += register_length;
}
// Perform a circular rotation of those buffers, to reuse the memory, this time, with a shift,
// dropping the first element in the current array.
sz_u16_t *temporary = previous_distances;
previous_distances = current_distances + 1;
current_distances = next_distances;
next_distances = temporary;
}
// Cache scalar before `free` call.
sz_size_t result = current_distances[0];
alloc->free(distances, buffer_length, alloc->handle);
return result;
}
SZ_INTERNAL sz_size_t sz_edit_distance_avx512( //
sz_cptr_t shorter, sz_size_t shorter_length, //
sz_cptr_t longer, sz_size_t longer_length, //
sz_size_t bound, sz_memory_allocator_t *alloc) {
if (shorter_length == longer_length && !bound && shorter_length && shorter_length < 256u * 256u)
return _sz_edit_distance_skewed_diagonals_upto65k_avx512(shorter, shorter_length, longer, longer_length, bound,
alloc);
else
return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc);
}
#pragma clang attribute pop
#pragma GCC pop_options
#pragma GCC push_options
#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "bmi", "bmi2")
#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,bmi,bmi2"))), \
apply_to = function)
SZ_PUBLIC sz_u64_t sz_checksum_avx512(sz_cptr_t text, sz_size_t length) {
// The naive implementation of this function is very simple.
// It assumes the CPU is great at handling unaligned "loads".
//
// A typical AWS Sapphire Rapids instance can have 48 KB x 2 blocks of L1 data cache per core,
// 2 MB x 2 blocks of L2 cache per core, and one shared 60 MB buffer of L3 cache.
// With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length.
int const is_huge = length >= 1ull * 1024ull * 1024ull;
sz_u512_vec_t text_vec, sums_vec;
// When the buffer is small, there isn't much to innovate.
if (length <= 16) {
__mmask16 mask = _sz_u16_mask_until(length);
text_vec.xmms[0] = _mm_maskz_loadu_epi8(mask, text);
sums_vec.xmms[0] = _mm_sad_epu8(text_vec.xmms[0], _mm_setzero_si128());
sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_vec.xmms[0]);
sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_vec.xmms[0], 1);
return low + high;
}
else if (length <= 32) {
__mmask32 mask = _sz_u32_mask_until(length);
text_vec.ymms[0] = _mm256_maskz_loadu_epi8(mask, text);
sums_vec.ymms[0] = _mm256_sad_epu8(text_vec.ymms[0], _mm256_setzero_si256());
// Accumulating 256 bits is harders, as we need to extract the 128-bit sums first.
__m128i low_xmm = _mm256_castsi256_si128(sums_vec.ymms[0]);
__m128i high_xmm = _mm256_extracti128_si256(sums_vec.ymms[0], 1);
__m128i sums_xmm = _mm_add_epi64(low_xmm, high_xmm);
sz_u64_t low = (sz_u64_t)_mm_cvtsi128_si64(sums_xmm);
sz_u64_t high = (sz_u64_t)_mm_extract_epi64(sums_xmm, 1);
return low + high;
}
else if (length <= 64) {
__mmask64 mask = _sz_u64_mask_until(length);
text_vec.zmm = _mm512_maskz_loadu_epi8(mask, text);
sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512());
return _mm512_reduce_add_epi64(sums_vec.zmm);
}
else if (!is_huge) {
sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64; // 63 or less.
sz_size_t tail_length = (sz_size_t)(text + length) % 64; // 63 or less.
sz_size_t body_length = length - head_length - tail_length; // Multiple of 64.
__mmask64 head_mask = _sz_u64_mask_until(head_length);
__mmask64 tail_mask = _sz_u64_mask_until(tail_length);
text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text);
sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512());
for (text += head_length; body_length >= 64; text += 64, body_length -= 64) {
text_vec.zmm = _mm512_load_si512((__m512i const *)text);
sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()));
}
text_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text);
sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()));
return _mm512_reduce_add_epi64(sums_vec.zmm);
}
// For gigantic buffers, exceeding typical L1 cache sizes, there are other tricks we can use.
//
// 1. Moving in both directions to maximize the throughput, when fetching from multiple
// memory pages. Also helps with cache set-associativity issues, as we won't always
// be fetching the same entries in the lookup table.
// 2. Using non-temporal stores to avoid polluting the cache.
// 3. Prefetching the next cache line, to avoid stalling the CPU. This generally useless
// for predictable patterns, so disregard this advice.
//
// Bidirectional traversal generally adds about 10% to such algorithms.
else {
sz_u512_vec_t text_reversed_vec, sums_reversed_vec;
sz_size_t head_length = (64 - ((sz_size_t)text % 64)) % 64;
sz_size_t tail_length = (sz_size_t)(text + length) % 64;
sz_size_t body_length = length - head_length - tail_length;
__mmask64 head_mask = _sz_u64_mask_until(head_length);
__mmask64 tail_mask = _sz_u64_mask_until(tail_length);
text_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, text);
sums_vec.zmm = _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512());
text_reversed_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, text + head_length + body_length);
sums_reversed_vec.zmm = _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512());
// Now in the main loop, we can use non-temporal loads and stores,
// performing the operation in both directions.
for (text += head_length; body_length >= 128; text += 64, text += 64, body_length -= 128) {
text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text));
sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()));
text_reversed_vec.zmm = _mm512_stream_load_si512((__m512i *)(text + body_length - 64));
sums_reversed_vec.zmm =
_mm512_add_epi64(sums_reversed_vec.zmm, _mm512_sad_epu8(text_reversed_vec.zmm, _mm512_setzero_si512()));
}
if (body_length >= 64) {
text_vec.zmm = _mm512_stream_load_si512((__m512i *)(text));
sums_vec.zmm = _mm512_add_epi64(sums_vec.zmm, _mm512_sad_epu8(text_vec.zmm, _mm512_setzero_si512()));
}
return _mm512_reduce_add_epi64(_mm512_add_epi64(sums_vec.zmm, sums_reversed_vec.zmm));
}
}
SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, //
sz_hash_callback_t callback, void *callback_handle) {
if (length < window_length || !window_length) return;
if (length < 4 * window_length) {
sz_hashes_serial(start, length, window_length, step, callback, callback_handle);
return;
}
// Using AVX2, we can perform 4 long integer multiplications and additions within one register.
// So let's slice the entire string into 4 overlapping windows, to slide over them in parallel.
sz_size_t const max_hashes = length - window_length + 1;
sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads.
sz_u8_t const *text_first = (sz_u8_t const *)start;
sz_u8_t const *text_second = text_first + min_hashes_per_thread;
sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2;
sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3;
sz_u8_t const *text_end = text_first + length;
// Broadcast the global constants into the registers.
// Both high and low hashes will work with the same prime and golden ratio.
sz_u512_vec_t prime_vec, golden_ratio_vec;
prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME);
golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull);
// Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic.
sz_u64_t prime_power_low = 1, prime_power_high = 1;
for (sz_size_t i = 0; i + 1 < window_length; ++i)
prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME,
prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME;
// We will be evaluating 4 offsets at a time with 2 different hash functions.
// We can fit all those 8 state variables in each of the following ZMM registers.
sz_u512_vec_t base_vec, prime_power_vec, shift_vec;
base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull);
shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull);
prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low,
prime_power_high, prime_power_high, prime_power_high, prime_power_high);
// Compute the initial hash values for every one of the four windows.
sz_u512_vec_t hash_vec, chars_vec;
hash_vec.zmm = _mm512_setzero_si512();
for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end;
++text_first, ++text_second, ++text_third, ++text_fourth) {
// 1. Multiply the hashes by the base.
hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm);
// 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`,
// `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`...
chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], //
text_fourth[0], text_third[0], text_second[0], text_first[0]);
chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm);
// 3. Add the incoming characters.
hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm);
// 4. Compute the modulo. Assuming there are only 59 values between our prime
// and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime.
hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm,
_mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm));
}
// 5. Compute the hash mix, that will be used to index into the fingerprint.
// This includes a serial step at the end.
sz_u512_vec_t hash_mix_vec;
hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm);
hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), //
_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0));
callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle);
callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle);
callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle);
callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle);
// Now repeat that operation for the remaining characters, discarding older characters.
sz_size_t cycle = 1;
sz_size_t step_mask = step - 1;
for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) {
// 0. Load again the four characters we are dropping, shift them, and subtract.
chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length],
text_second[-window_length], text_first[-window_length], //
text_fourth[-window_length], text_third[-window_length],
text_second[-window_length], text_first[-window_length]);
chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm);
hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm));
// 1. Multiply the hashes by the base.
hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm);
// 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`,
// `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`.
chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], //
text_fourth[0], text_third[0], text_second[0], text_first[0]);
chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm);
// ... and prefetch the next four characters into Level 2 or higher.
_mm_prefetch((sz_cptr_t)text_fourth + 1, _MM_HINT_T1);
_mm_prefetch((sz_cptr_t)text_third + 1, _MM_HINT_T1);
_mm_prefetch((sz_cptr_t)text_second + 1, _MM_HINT_T1);
_mm_prefetch((sz_cptr_t)text_first + 1, _MM_HINT_T1);
// 3. Add the incoming characters.
hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm);
// 4. Compute the modulo. Assuming there are only 59 values between our prime
// and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime.
hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm,
_mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm));
// 5. Compute the hash mix, that will be used to index into the fingerprint.
// This includes a serial step at the end.
hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm);
hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), //
_mm512_castsi512_si256(hash_mix_vec.zmm));
if ((cycle & step_mask) == 0) {
callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle);
callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle);
callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle);
callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle);
}
}
}
#pragma clang attribute pop
#pragma GCC pop_options
#pragma GCC push_options
#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512vbmi", "avx512vbmi2", "bmi", "bmi2")
#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512vbmi,avx512vbmi2,bmi,bmi2"))), \
apply_to = function)
SZ_PUBLIC void sz_look_up_transform_avx512(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) {
// If the input is tiny (especially smaller than the look-up table itself), we may end up paying
// more for organizing the SIMD registers and changing the CPU state, than for the actual computation.
// But if at least 3 cache lines are touched, the AVX-512 implementation should be faster.
if (length <= 128) {
sz_look_up_transform_serial(source, length, lut, target);
return;
}
// When the buffer is over 64 bytes, it's guaranteed to touch at least two cache lines - the head and tail,
// and may include more cache-lines in-between. Knowing this, we can avoid expensive unaligned stores
// by computing 2 masks - for the head and tail, using masked stores for the head and tail, and unmasked
// for the body.
sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less.
sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less.
__mmask64 head_mask = _sz_u64_mask_until(head_length);
__mmask64 tail_mask = _sz_u64_mask_until(tail_length);
// We need to pull the lookup table into 4x ZMM registers.
// We can use `vpermi2b` instruction to perform the look in two ZMM registers with `_mm512_permutex2var_epi8`
// intrinsics, but it has a 6-cycle latency on Sapphire Rapids and requires AVX512-VBMI. Assuming we need to
// operate on 4 registers, it might be cleaner to use 2x separate `_mm512_permutexvar_epi8` calls.
// Combining the results with 2x `_mm512_test_epi8_mask` and 3x blends afterwards.
//
// - 4x `_mm512_permutexvar_epi8` maps to "VPERMB (ZMM, ZMM, ZMM)":
// - On Ice Lake: 3 cycles latency, ports: 1*p5
// - On Genoa: 6 cycles latency, ports: 1*FP12
// - 3x `_mm512_mask_blend_epi8` maps to "VPBLENDMB_Z (ZMM, K, ZMM, ZMM)":
// - On Ice Lake: 3 cycles latency, ports: 1*p05
// - On Genoa: 1 cycle latency, ports: 1*FP0123
// - 2x `_mm512_test_epi8_mask` maps to "VPTESTMB (K, ZMM, ZMM)":
// - On Ice Lake: 3 cycles latency, ports: 1*p5
// - On Genoa: 4 cycles latency, ports: 1*FP01
//
sz_u512_vec_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec;
lut_0_to_63_vec.zmm = _mm512_loadu_si512((lut));
lut_64_to_127_vec.zmm = _mm512_loadu_si512((lut + 64));
lut_128_to_191_vec.zmm = _mm512_loadu_si512((lut + 128));
lut_192_to_255_vec.zmm = _mm512_loadu_si512((lut + 192));
sz_u512_vec_t first_bit_vec, second_bit_vec;
first_bit_vec.zmm = _mm512_set1_epi8((char)0x80);
second_bit_vec.zmm = _mm512_set1_epi8((char)0x40);
__mmask64 first_bit_mask, second_bit_mask;
sz_u512_vec_t source_vec;
// If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or
// `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`.
sz_u512_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec;
sz_u512_vec_t blended_0_to_127_vec, blended_128_to_255_vec, blended_0_to_255_vec;
// Handling the head.
if (head_length) {
source_vec.zmm = _mm512_maskz_loadu_epi8(head_mask, source);
lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm);
lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm);
lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm);
lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm);
first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm);
second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm);
blended_0_to_127_vec.zmm =
_mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm);
blended_128_to_255_vec.zmm =
_mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm);
blended_0_to_255_vec.zmm =
_mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm);
_mm512_mask_storeu_epi8(target, head_mask, blended_0_to_255_vec.zmm);
source += head_length, target += head_length, length -= head_length;
}
// Handling the body in 64-byte chunks aligned to cache-line boundaries with respect to `target`.
while (length >= 64) {
source_vec.zmm = _mm512_loadu_si512(source);
lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm);
lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm);
lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm);
lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm);
first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm);
second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm);
blended_0_to_127_vec.zmm =
_mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm);
blended_128_to_255_vec.zmm =
_mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm);
blended_0_to_255_vec.zmm =
_mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm);
_mm512_store_si512(target, blended_0_to_255_vec.zmm); //! Aligned store, our main weapon!
source += 64, target += 64, length -= 64;
}
// Handling the tail.
if (tail_length) {
source_vec.zmm = _mm512_maskz_loadu_epi8(tail_mask, source);
lookup_0_to_63_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_0_to_63_vec.zmm);
lookup_64_to_127_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_64_to_127_vec.zmm);
lookup_128_to_191_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_128_to_191_vec.zmm);
lookup_192_to_255_vec.zmm = _mm512_permutexvar_epi8(source_vec.zmm, lut_192_to_255_vec.zmm);
first_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, first_bit_vec.zmm);
second_bit_mask = _mm512_test_epi8_mask(source_vec.zmm, second_bit_vec.zmm);
blended_0_to_127_vec.zmm =
_mm512_mask_blend_epi8(second_bit_mask, lookup_0_to_63_vec.zmm, lookup_64_to_127_vec.zmm);
blended_128_to_255_vec.zmm =
_mm512_mask_blend_epi8(second_bit_mask, lookup_128_to_191_vec.zmm, lookup_192_to_255_vec.zmm);
blended_0_to_255_vec.zmm =
_mm512_mask_blend_epi8(first_bit_mask, blended_0_to_127_vec.zmm, blended_128_to_255_vec.zmm);
_mm512_mask_storeu_epi8(target, tail_mask, blended_0_to_255_vec.zmm);
source += tail_length, target += tail_length, length -= tail_length;
}
}
SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) {
// Before initializing the AVX-512 vectors, we may want to run the sequential code for the first few bytes.
// In practice, that only hurts, even when we have matches every 5-ish bytes.
//
// if (length < SZ_SWAR_THRESHOLD) return sz_find_charset_serial(text, length, filter);
// sz_cptr_t early_result = sz_find_charset_serial(text, SZ_SWAR_THRESHOLD, filter);
// if (early_result) return early_result;
// text += SZ_SWAR_THRESHOLD;
// length -= SZ_SWAR_THRESHOLD;
//
// Let's unzip even and odd elements and replicate them into both lanes of the YMM register.
// That way when we invoke `_mm512_shuffle_epi8` we can use the same mask for both lanes.
sz_u512_vec_t filter_even_vec, filter_odd_vec;
__m256i filter_ymm = _mm256_lddqu_si256((__m256i const *)filter);
// There are a few way to initialize filters without having native strided loads.
// In the cronological order of experiments:
// - serial code initializing 128 bytes of odd and even mask
// - using several shuffles
// - using `_mm512_permutexvar_epi8`
// - using `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0x55555555, filter_ymm)))`
// and `_mm512_broadcast_i32x4(_mm256_castsi256_si128(_mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm)))`
filter_even_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i
_mm256_maskz_compress_epi8(0x55555555, filter_ymm)));
filter_odd_vec.zmm = _mm512_broadcast_i32x4(_mm256_castsi256_si128( // broadcast __m128i to __m512i
_mm256_maskz_compress_epi8(0xaaaaaaaa, filter_ymm)));
// After the unzipping operation, we can validate the contents of the vectors like this:
//
// for (sz_size_t i = 0; i != 16; ++i) {
// sz_assert(filter_even_vec.u8s[i] == filter->_u8s[i * 2]);
// sz_assert(filter_odd_vec.u8s[i] == filter->_u8s[i * 2 + 1]);
// sz_assert(filter_even_vec.u8s[i + 16] == filter->_u8s[i * 2]);
// sz_assert(filter_odd_vec.u8s[i + 16] == filter->_u8s[i * 2 + 1]);
// sz_assert(filter_even_vec.u8s[i + 32] == filter->_u8s[i * 2]);
// sz_assert(filter_odd_vec.u8s[i + 32] == filter->_u8s[i * 2 + 1]);
// sz_assert(filter_even_vec.u8s[i + 48] == filter->_u8s[i * 2]);
// sz_assert(filter_odd_vec.u8s[i + 48] == filter->_u8s[i * 2 + 1]);
// }
//
sz_u512_vec_t text_vec;
sz_u512_vec_t lower_nibbles_vec, higher_nibbles_vec;
sz_u512_vec_t bitset_even_vec, bitset_odd_vec;
sz_u512_vec_t bitmask_vec, bitmask_lookup_vec;
bitmask_lookup_vec.zmm = _mm512_set_epi8(-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, //
-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, //
-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1, //
-128, 64, 32, 16, 8, 4, 2, 1, -128, 64, 32, 16, 8, 4, 2, 1);
while (length) {
// The following algorithm is a transposed equivalent of the "SIMDized check which bytes are in a set"
// solutions by Wojciech Muła. We populate the bitmask differently and target newer CPUs, so
// StrinZilla uses a somewhat different approach.
// http://0x80.pl/articles/simd-byte-lookup.html#alternative-implementation-new
//
// sz_u8_t input = *(sz_u8_t const *)text;
// sz_u8_t lo_nibble = input & 0x0f;
// sz_u8_t hi_nibble = input >> 4;
// sz_u8_t bitset_even = filter_even_vec.u8s[hi_nibble];
// sz_u8_t bitset_odd = filter_odd_vec.u8s[hi_nibble];
// sz_u8_t bitmask = (1 << (lo_nibble & 0x7));
// sz_u8_t bitset = lo_nibble < 8 ? bitset_even : bitset_odd;
// if ((bitset & bitmask) != 0) return text;
// else { length--, text++; }
//
// The nice part about this, loading the strided data is vey easy with Arm NEON,
// while with x86 CPUs after AVX, shuffles within 256 bits shouldn't be an issue either.
sz_size_t load_length = sz_min_of_two(length, 64);
__mmask64 load_mask = _sz_u64_mask_until(load_length);
text_vec.zmm = _mm512_maskz_loadu_epi8(load_mask, text);
lower_nibbles_vec.zmm = _mm512_and_si512(text_vec.zmm, _mm512_set1_epi8(0x0f));
bitmask_vec.zmm = _mm512_shuffle_epi8(bitmask_lookup_vec.zmm, lower_nibbles_vec.zmm);
//
// At this point we can validate the `bitmask_vec` contents like this:
//
// for (sz_size_t i = 0; i != load_length; ++i) {
// sz_u8_t input = *(sz_u8_t const *)(text + i);
// sz_u8_t lo_nibble = input & 0x0f;
// sz_u8_t bitmask = (1 << (lo_nibble & 0x7));
// sz_assert(bitmask_vec.u8s[i] == bitmask);
// }
//
// Shift right every byte by 4 bits.
// There is no `_mm512_srli_epi8` intrinsic, so we have to use `_mm512_srli_epi16`
// and combine it with a mask to clear the higher bits.
higher_nibbles_vec.zmm = _mm512_and_si512(_mm512_srli_epi16(text_vec.zmm, 4), _mm512_set1_epi8(0x0f));
bitset_even_vec.zmm = _mm512_shuffle_epi8(filter_even_vec.zmm, higher_nibbles_vec.zmm);
bitset_odd_vec.zmm = _mm512_shuffle_epi8(filter_odd_vec.zmm, higher_nibbles_vec.zmm);
//
// At this point we can validate the `bitset_even_vec` and `bitset_odd_vec` contents like this:
//
// for (sz_size_t i = 0; i != load_length; ++i) {
// sz_u8_t input = *(sz_u8_t const *)(text + i);
// sz_u8_t const *bitset_ptr = &filter->_u8s[0];
// sz_u8_t hi_nibble = input >> 4;
// sz_u8_t bitset_even = bitset_ptr[hi_nibble * 2];
// sz_u8_t bitset_odd = bitset_ptr[hi_nibble * 2 + 1];
// sz_assert(bitset_even_vec.u8s[i] == bitset_even);
// sz_assert(bitset_odd_vec.u8s[i] == bitset_odd);
// }
//
// TODO: Is this a good place for ternary logic?
__mmask64 take_first = _mm512_cmplt_epi8_mask(lower_nibbles_vec.zmm, _mm512_set1_epi8(8));
bitset_even_vec.zmm = _mm512_mask_blend_epi8(take_first, bitset_odd_vec.zmm, bitset_even_vec.zmm);
__mmask64 matches_mask = _mm512_mask_test_epi8_mask(load_mask, bitset_even_vec.zmm, bitmask_vec.zmm);
if (matches_mask) {
int offset = sz_u64_ctz(matches_mask);
return text + offset;
}
else { text += load_length, length -= load_length; }
}
return SZ_NULL_CHAR;
}
SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) {
return sz_rfind_charset_serial(text, length, filter);
}
/**
* Computes the Needleman Wunsch alignment score between two strings.
* The method uses 32-bit integers to accumulate the running score for every cell in the matrix.
* Assuming the costs of substitutions can be arbitrary signed 8-bit integers, the method is expected to be used
* on strings not exceeding 2^24 length or 16.7 million characters.
*
* Unlike the `_sz_edit_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store
* the accumulated score. Moreover, it's primary bottleneck is the latency of gathering the substitution costs
* from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string with
* a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character against
* a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of
* a 256 x 256 matrix, but from a single row!
*/
SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( //
sz_cptr_t shorter, sz_size_t shorter_length, //
sz_cptr_t longer, sz_size_t longer_length, //
sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) {
// If one of the strings is empty - the edit distance is equal to the length of the other one
if (longer_length == 0) return (sz_ssize_t)shorter_length * gap;
if (shorter_length == 0) return (sz_ssize_t)longer_length * gap;
// Let's make sure that we use the amount proportional to the
// number of elements in the shorter string, not the larger.
if (shorter_length > longer_length) {
sz_pointer_swap((void **)&longer_length, (void **)&shorter_length);
sz_pointer_swap((void **)&longer, (void **)&shorter);
}
// Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome.
sz_memory_allocator_t global_alloc;
if (!alloc) {
sz_memory_allocator_init_default(&global_alloc);
alloc = &global_alloc;
}
sz_size_t const max_length = 256ull * 256ull * 256ull;
sz_size_t const n = longer_length + 1;
sz_assert(n < max_length && "The length must fit into 24-bit integer. Otherwise use serial variant.");
sz_unused(longer_length && max_length);
sz_size_t buffer_length = sizeof(sz_i32_t) * n * 2;
sz_i32_t *distances = (sz_i32_t *)alloc->allocate(buffer_length, alloc->handle);
sz_i32_t *previous_distances = distances;
sz_i32_t *current_distances = previous_distances + n;
// Intialize the first row of the Levenshtein matrix with `iota`.
for (sz_size_t idx_longer = 0; idx_longer != n; ++idx_longer)
previous_distances[idx_longer] = (sz_i32_t)idx_longer * gap;
/// Contains up to 16 consecutive characters from the longer string.
sz_u512_vec_t longer_vec;
sz_u512_vec_t cost_deletion_vec, cost_substitution_vec, lookup_substitution_vec, current_vec;
sz_u512_vec_t row_first_subs_vec, row_second_subs_vec, row_third_subs_vec, row_fourth_subs_vec;
sz_u512_vec_t shuffled_first_subs_vec, shuffled_second_subs_vec, shuffled_third_subs_vec, shuffled_fourth_subs_vec;
// Prepare constants and masks.
sz_u512_vec_t is_third_or_fourth_vec, is_second_or_fourth_vec, gap_vec;
{
char is_third_or_fourth_check, is_second_or_fourth_check;
*(sz_u8_t *)&is_third_or_fourth_check = 0x80, *(sz_u8_t *)&is_second_or_fourth_check = 0x40;
is_third_or_fourth_vec.zmm = _mm512_set1_epi8(is_third_or_fourth_check);
is_second_or_fourth_vec.zmm = _mm512_set1_epi8(is_second_or_fourth_check);
gap_vec.zmm = _mm512_set1_epi32(gap);
}
sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter;
for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) {
sz_i32_t last_in_row = current_distances[0] = (sz_i32_t)(idx_shorter + 1) * gap;
// Load one row of the substitution matrix into four ZMM registers.
sz_error_cost_t const *row_subs = subs + shorter_unsigned[idx_shorter] * 256u;
row_first_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 0);
row_second_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 1);
row_third_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 2);
row_fourth_subs_vec.zmm = _mm512_loadu_si512(row_subs + 64 * 3);
// In the serial version we have one forward pass, that computes the deletion,
// insertion, and substitution costs at once.
// for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) {
// sz_ssize_t cost_deletion = previous_distances[idx_longer + 1] + gap;
// sz_ssize_t cost_insertion = current_distances[idx_longer] + gap;
// sz_ssize_t cost_substitution = previous_distances[idx_longer] + row_subs[longer_unsigned[idx_longer]];
// current_distances[idx_longer + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution);
// }
//
// Given the complexity of handling the data-dependency between consecutive insertion cost computations
// within a Levenshtein matrix, the simplest design would be to vectorize every kind of cost computation
// separately.
// 1. Compute substitution costs for up to 64 characters at once, upcasting from 8-bit integers to 32.
// 2. Compute the pairwise minimum with deletion costs.
// 3. Inclusive prefix minimum computation to combine with addition costs.
// Proceeding with substitutions:
for (sz_size_t idx_longer = 0; idx_longer < longer_length; idx_longer += 64) {
sz_size_t register_length = sz_min_of_two(longer_length - idx_longer, 64);
__mmask64 mask = _sz_u64_mask_until(register_length);
longer_vec.zmm = _mm512_maskz_loadu_epi8(mask, longer + idx_longer);
// Blend the `row_(first|second|third|fourth)_subs_vec` into `current_vec`, picking the right source
// for every character in `longer_vec`. Before that, we need to permute the subsititution vectors.
// Only the bottom 6 bits of a byte are used in VPERB, so we don't even need to mask.
shuffled_first_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_first_subs_vec.zmm);
shuffled_second_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_second_subs_vec.zmm);
shuffled_third_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_third_subs_vec.zmm);
shuffled_fourth_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_fourth_subs_vec.zmm);
// To blend we can invoke three `_mm512_cmplt_epu8_mask`, but we can also achieve the same using
// the AND logical operation, checking the top two bits of every byte.
// Continuing this thought, we can use the VPTESTMB instruction to output the mask after the AND.
__mmask64 is_third_or_fourth = _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_third_or_fourth_vec.zmm);
__mmask64 is_second_or_fourth =
_mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_second_or_fourth_vec.zmm);
lookup_substitution_vec.zmm = _mm512_mask_blend_epi8(
is_third_or_fourth,
// Choose between the first and the second.
_mm512_mask_blend_epi8(is_second_or_fourth, shuffled_first_subs_vec.zmm, shuffled_second_subs_vec.zmm),
// Choose between the third and the fourth.
_mm512_mask_blend_epi8(is_second_or_fourth, shuffled_third_subs_vec.zmm, shuffled_fourth_subs_vec.zmm));
// First, sign-extend lower and upper 16 bytes to 16-bit integers.
__m512i current_0_31_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 0));
__m512i current_32_63_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 1));
// Now extend those 16-bit integers to 32-bit.
// This isn't free, same as the subsequent store, so we only want to do that for the populated lanes.
// To minimize the number of loads and stores, we can combine our substitution costs with the previous
// distances, containing the deletion costs.
{
cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer);
cost_substitution_vec.zmm = _mm512_add_epi32(
cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 0)));
cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer);
cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm);
current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm);
// Inclusive prefix minimum computation to combine with insertion costs.
// Simply disabling this operation results in 5x performance improvement, meaning
// that this operation is responsible for 80% of the total runtime.
// for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) {
// current_distances[idx_longer + 1] =
// sz_max_of_two(current_distances[idx_longer] + gap, current_distances[idx_longer + 1]);
// }
//
// To perform the same operation in vectorized form, we need to perform a tree-like reduction,
// that will involve multiple steps. It's quite expensive and should be first tested in the
// "experimental" section.
//
// Another approach might be loop unrolling:
// current_vec.i32s[0] = last_in_row = sz_i32_max_of_two(current_vec.i32s[0], last_in_row + gap);
// current_vec.i32s[1] = last_in_row = sz_i32_max_of_two(current_vec.i32s[1], last_in_row + gap);
// current_vec.i32s[2] = last_in_row = sz_i32_max_of_two(current_vec.i32s[2], last_in_row + gap);
// ... yet this approach is also quite expensive.
for (int i = 0; i != 16; ++i)
current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap);
_mm512_mask_storeu_epi32(current_distances + idx_longer + 1, (__mmask16)mask, current_vec.zmm);
}
// Export the values from 16 to 31.
if (register_length > 16) {
mask = _kshiftri_mask64(mask, 16);
cost_substitution_vec.zmm =
_mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 16);
cost_substitution_vec.zmm = _mm512_add_epi32(
cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 1)));
cost_deletion_vec.zmm =
_mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 16);
cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm);
current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm);
// Aggregate running insertion costs within the register.
for (int i = 0; i != 16; ++i)
current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap);
_mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 16, (__mmask16)mask, current_vec.zmm);
}
// Export the values from 32 to 47.
if (register_length > 32) {
mask = _kshiftri_mask64(mask, 16);
cost_substitution_vec.zmm =
_mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 32);
cost_substitution_vec.zmm = _mm512_add_epi32(
cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 0)));
cost_deletion_vec.zmm =
_mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 32);
cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm);
current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm);
// Aggregate running insertion costs within the register.
for (int i = 0; i != 16; ++i)
current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap);
_mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 32, (__mmask16)mask, current_vec.zmm);
}
// Export the values from 32 to 47.
if (register_length > 48) {
mask = _kshiftri_mask64(mask, 16);
cost_substitution_vec.zmm =
_mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + idx_longer + 48);
cost_substitution_vec.zmm = _mm512_add_epi32(
cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 1)));
cost_deletion_vec.zmm =
_mm512_maskz_loadu_epi32((__mmask16)mask, previous_distances + 1 + idx_longer + 48);
cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm);
current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm);
// Aggregate running insertion costs within the register.
for (int i = 0; i != 16; ++i)
current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap);
_mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 48, (__mmask16)mask, current_vec.zmm);
}
}
// Swap previous_distances and current_distances pointers
sz_pointer_swap((void **)&previous_distances, (void **)&current_distances);
}
// Cache scalar before `free` call.
sz_ssize_t result = previous_distances[longer_length];
alloc->free(distances, buffer_length, alloc->handle);
return result;
}
SZ_INTERNAL sz_ssize_t sz_alignment_score_avx512( //
sz_cptr_t shorter, sz_size_t shorter_length, //
sz_cptr_t longer, sz_size_t longer_length, //
sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) {
if (sz_max_of_two(shorter_length, longer_length) < (256ull * 256ull * 256ull))
return _sz_alignment_score_wagner_fisher_upto17m_avx512(shorter, shorter_length, longer, longer_length, subs,
gap, alloc);
else
return sz_alignment_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc);
}
enum sz_encoding_t {
sz_encoding_unknown_k = 0,
sz_encoding_ascii_k = 1,
sz_encoding_utf8_k = 2,
sz_encoding_utf16_k = 3,
sz_encoding_utf32_k = 4,
sz_jwt_k,
sz_base64_k,
// Low priority encodings:
sz_encoding_utf8bom_k = 5,
sz_encoding_utf16le_k = 6,
sz_encoding_utf16be_k = 7,
sz_encoding_utf32le_k = 8,
sz_encoding_utf32be_k = 9,
};
// Character Set Detection is one of the most commonly performed operations in data processing with
// [Chardet](https://github.com/chardet/chardet), [Charset Normalizer](https://github.com/jawah/charset_normalizer),
// [cChardet](https://github.com/PyYoshi/cChardet) being the most commonly used options in the Python ecosystem.
// All of them are notoriously slow.
//
// Moreover, as of October 2024, UTF-8 is the dominant character encoding on the web, used by 98.4% of websites.
// Other have minimal usage, according to [W3Techs](https://w3techs.com/technologies/overview/character_encoding):
// - ISO-8859-1: 1.2%
// - Windows-1252: 0.3%
// - Windows-1251: 0.2%
// - EUC-JP: 0.1%
// - Shift JIS: 0.1%
// - EUC-KR: 0.1%
// - GB2312: 0.1%
// - Windows-1250: 0.1%
// Within programming language implementations and database management systems, 16-bit and 32-bit fixed-width encodings
// are also very popular and we need a way to efficienly differentiate between the most common UTF flavors, ASCII, and
// the rest.
//
// One good solution is the [simdutf](https://github.com/simdutf/simdutf) library, but it depends on the C++ runtime
// and focuses more on incremental validation & transcoding, rather than detection.
//
// So we need a very fast and efficient way of determining
SZ_PUBLIC sz_bool_t sz_detect_encoding(sz_cptr_t text, sz_size_t length) {
// https://github.com/simdutf/simdutf/blob/master/src/icelake/icelake_utf8_validation.inl.cpp
// https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_from_utf8.inl.cpp#L81
// https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L661
// https://github.com/simdutf/simdutf/blob/603070affe68101e9e08ea2de19ea5f3f154cf5d/src/icelake/icelake_utf8_common.inl.cpp#L788
// We can implement this operation simpler & differently, assuming most of the time continuous chunks of memory
// have identical encoding. With Russian and many European languages, we generally deal with 2-byte codepoints
// with occasional 1-byte punctuation marks. In the case of Chinese, Japanese, and Korean, we deal with 3-byte
// codepoints. In the case of emojis, we deal with 4-byte codepoints.
// We can also use the idea, that misaligned reads are quite cheap on modern CPUs.
int can_be_ascii = 1, can_be_utf8 = 1, can_be_utf16 = 1, can_be_utf32 = 1;
sz_unused(can_be_ascii + can_be_utf8 + can_be_utf16 + can_be_utf32);
sz_unused(text && length);
return sz_false_k;
}
#pragma clang attribute pop
#pragma GCC pop_options
#endif
#pragma endregion
/* @brief Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit
* Arm processors. Implements: {substring search, character search, character set search} x {forward, reverse}.
*/
#pragma region ARM NEON
#if SZ_USE_ARM_NEON
#pragma GCC push_options
#pragma GCC target("arch=armv8.2-a+simd")
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function)
/**
* @brief Helper structure to simplify work with 64-bit words.
*/
typedef union sz_u128_vec_t {
uint8x16_t u8x16;
uint16x8_t u16x8;
uint32x4_t u32x4;
uint64x2_t u64x2;
sz_u64_t u64s[2];
sz_u32_t u32s[4];
sz_u16_t u16s[8];
sz_u8_t u8s[16];
} sz_u128_vec_t;
SZ_INTERNAL sz_u64_t _sz_vreinterpretq_u8_u4(uint8x16_t vec) {
// Use `vshrn` to produce a bitmask, similar to `movemask` in SSE.
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull;
}
SZ_PUBLIC sz_ordering_t sz_order_neon(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) {
//! Before optimizing this, read the "Operations Not Worth Optimizing" in Contributions Guide:
//! https://github.com/ashvardanian/StringZilla/blob/main/CONTRIBUTING.md#general-performance-observations
return sz_order_serial(a, a_length, b, b_length);
}
SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length) {
sz_u128_vec_t a_vec, b_vec;
for (; length >= 16; a += 16, b += 16, length -= 16) {
a_vec.u8x16 = vld1q_u8((sz_u8_t const *)a);
b_vec.u8x16 = vld1q_u8((sz_u8_t const *)b);
uint8x16_t cmp = vceqq_u8(a_vec.u8x16, b_vec.u8x16);
if (vminvq_u8(cmp) != 255) { return sz_false_k; } // Check if all bytes match
}
// Handle remaining bytes
if (length) return sz_equal_serial(a, b, length);
return sz_true_k;
}
SZ_PUBLIC sz_u64_t sz_checksum_neon(sz_cptr_t text, sz_size_t length) {
uint64x2_t sum_vec = vdupq_n_u64(0);
// Process 16 bytes (128 bits) at a time
for (; length >= 16; text += 16, length -= 16) {
uint8x16_t vec = vld1q_u8((sz_u8_t const *)text); // Load 16 bytes
uint16x8_t pairwise_sum1 = vpaddlq_u8(vec); // Pairwise add lower and upper 8 bits
uint32x4_t pairwise_sum2 = vpaddlq_u16(pairwise_sum1); // Pairwise add 16-bit results
uint64x2_t pairwise_sum3 = vpaddlq_u32(pairwise_sum2); // Pairwise add 32-bit results
sum_vec = vaddq_u64(sum_vec, pairwise_sum3); // Accumulate the sum
}
// Final reduction of `sum_vec` to a single scalar
sz_u64_t sum = vgetq_lane_u64(sum_vec, 0) + vgetq_lane_u64(sum_vec, 1);
if (length) sum += sz_checksum_serial(text, length);
return sum;
}
SZ_PUBLIC void sz_copy_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
// In most cases the `source` and the `target` are not aligned, but we should
// at least make sure that writes don't touch many cache lines.
// NEON has an instruction to load and write 64 bytes at once.
//
// sz_size_t head_length = (64 - ((sz_size_t)target % 64)) % 64; // 63 or less.
// sz_size_t tail_length = (sz_size_t)(target + length) % 64; // 63 or less.
// for (; head_length; target += 1, source += 1, head_length -= 1) *target = *source;
// length -= head_length;
// for (; length >= 64; target += 64, source += 64, length -= 64)
// vst4q_u8((sz_u8_t *)target, vld1q_u8_x4((sz_u8_t const *)source));
// for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = *source;
//
// Sadly, those instructions end up being 20% slower than the code processing 16 bytes at a time:
for (; length >= 16; target += 16, source += 16, length -= 16)
vst1q_u8((sz_u8_t *)target, vld1q_u8((sz_u8_t const *)source));
if (length) sz_copy_serial(target, source, length);
}
SZ_PUBLIC void sz_move_neon(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
// When moving small buffers, using a small buffer on stack as a temporary storage is faster.
if (target < source || target >= source + length) {
// Non-overlapping, proceed forward
sz_copy_neon(target, source, length);
}
else {
// Overlapping, proceed backward
target += length;
source += length;
sz_u128_vec_t src_vec;
while (length >= 16) {
target -= 16, source -= 16, length -= 16;
src_vec.u8x16 = vld1q_u8((sz_u8_t const *)source);
vst1q_u8((sz_u8_t *)target, src_vec.u8x16);
}
while (length) {
target -= 1, source -= 1, length -= 1;
*target = *source;
}
}
}
SZ_PUBLIC void sz_fill_neon(sz_ptr_t target, sz_size_t length, sz_u8_t value) {
uint8x16_t fill_vec = vdupq_n_u8(value); // Broadcast the value across the register
while (length >= 16) {
vst1q_u8((sz_u8_t *)target, fill_vec);
target += 16;
length -= 16;
}
// Handle remaining bytes
if (length) sz_fill_serial(target, length, value);
}
SZ_PUBLIC void sz_look_up_transform_neon(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) {
// If the input is tiny (especially smaller than the look-up table itself), we may end up paying
// more for organizing the SIMD registers and changing the CPU state, than for the actual computation.
if (length <= 128) {
sz_look_up_transform_serial(source, length, lut, target);
return;
}
sz_size_t head_length = (16 - ((sz_size_t)target % 16)) % 16; // 15 or less.
sz_size_t tail_length = (sz_size_t)(target + length) % 16; // 15 or less.
// We need to pull the lookup table into 16x NEON registers. We have a total of 32 such registers.
// According to the Neoverse V2 manual, the 4-table lookup has a latency of 6 cycles, and 4x throughput.
uint8x16x4_t lut_0_to_63_vec, lut_64_to_127_vec, lut_128_to_191_vec, lut_192_to_255_vec;
lut_0_to_63_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 0));
lut_64_to_127_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 64));
lut_128_to_191_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 128));
lut_192_to_255_vec = vld1q_u8_x4((sz_u8_t const *)(lut + 192));
sz_u128_vec_t source_vec;
// If the top bit is set in each word of `source_vec`, than we use `lookup_128_to_191_vec` or
// `lookup_192_to_255_vec`. If the second bit is set, we use `lookup_64_to_127_vec` or `lookup_192_to_255_vec`.
sz_u128_vec_t lookup_0_to_63_vec, lookup_64_to_127_vec, lookup_128_to_191_vec, lookup_192_to_255_vec;
sz_u128_vec_t blended_0_to_255_vec;
// Process the head with serial code
for (; head_length; target += 1, source += 1, head_length -= 1) *target = lut[*(sz_u8_t const *)source];
// Table lookups on Arm are much simpler to use than on x86, as we can use the `vqtbl4q_u8` instruction
// to perform a 4-table lookup in a single instruction. The XORs are used to adjust the lookup position
// within each 64-byte range of the table.
// Details on the 4-table lookup: https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/
length -= head_length;
length -= tail_length;
for (; length >= 16; source += 16, target += 16, length -= 16) {
source_vec.u8x16 = vld1q_u8((sz_u8_t const *)source);
lookup_0_to_63_vec.u8x16 = vqtbl4q_u8(lut_0_to_63_vec, source_vec.u8x16);
lookup_64_to_127_vec.u8x16 = vqtbl4q_u8(lut_64_to_127_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x40)));
lookup_128_to_191_vec.u8x16 = vqtbl4q_u8(lut_128_to_191_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0x80)));
lookup_192_to_255_vec.u8x16 = vqtbl4q_u8(lut_192_to_255_vec, veorq_u8(source_vec.u8x16, vdupq_n_u8(0xc0)));
blended_0_to_255_vec.u8x16 = vorrq_u8(vorrq_u8(lookup_0_to_63_vec.u8x16, lookup_64_to_127_vec.u8x16),
vorrq_u8(lookup_128_to_191_vec.u8x16, lookup_192_to_255_vec.u8x16));
vst1q_u8((sz_u8_t *)target, blended_0_to_255_vec.u8x16);
}
// Process the tail with serial code
for (; tail_length; target += 1, source += 1, tail_length -= 1) *target = lut[*(sz_u8_t const *)source];
}
SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
sz_u64_t matches;
sz_u128_vec_t h_vec, n_vec, matches_vec;
n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n);
while (h_length >= 16) {
h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h);
matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16);
// In Arm NEON we don't have a `movemask` to combine it with `ctz` and get the offset of the match.
// But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting)
// the vector with a relative offsets array.
matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16);
if (matches) return h + sz_u64_ctz(matches) / 4;
h += 16, h_length -= 16;
}
return sz_find_byte_serial(h, h_length, n);
}
SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
sz_u64_t matches;
sz_u128_vec_t h_vec, n_vec, matches_vec;
n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n);
while (h_length >= 16) {
h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h + h_length - 16);
matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16);
matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16);
if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4;
h_length -= 16;
}
return sz_rfind_byte_serial(h, h_length, n);
}
SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register(sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16,
uint8x16_t set_bottom_vec_u8x16) {
// Once we've read the characters in the haystack, we want to
// compare them against our bitset. The serial version of that code
// would look like: `(set_->_u8s[c >> 3] & (1u << (c & 7u))) != 0`.
uint8x16_t byte_index_vec = vshrq_n_u8(h_vec.u8x16, 3);
uint8x16_t byte_mask_vec = vshlq_u8(vdupq_n_u8(1), vreinterpretq_s8_u8(vandq_u8(h_vec.u8x16, vdupq_n_u8(7))));
uint8x16_t matches_top_vec = vqtbl1q_u8(set_top_vec_u8x16, byte_index_vec);
// The table lookup instruction in NEON replies to out-of-bound requests with zeros.
// The values in `byte_index_vec` all fall in [0; 32). So for values under 16, substracting 16 will underflow
// and map into interval [240, 256). Meaning that those will be populated with zeros and we can safely
// merge `matches_top_vec` and `matches_bottom_vec` with a bitwise OR.
uint8x16_t matches_bottom_vec = vqtbl1q_u8(set_bottom_vec_u8x16, vsubq_u8(byte_index_vec, vdupq_n_u8(16)));
uint8x16_t matches_vec = vorrq_u8(matches_top_vec, matches_bottom_vec);
// Istead of pure `vandq_u8`, we can immediately broadcast a match presence across each 8-bit word.
matches_vec = vtstq_u8(matches_vec, byte_mask_vec);
return _sz_vreinterpretq_u8_u4(matches_vec);
}
SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
// This almost never fires, but it's better to be safe than sorry.
if (h_length < n_length || !n_length) return SZ_NULL_CHAR;
if (n_length == 1) return sz_find_byte_neon(h, h_length, n);
// Scan through the string.
// Assuming how tiny the Arm NEON registers are, we should avoid internal branches at all costs.
// That's why, for smaller needles, we use different loops.
if (n_length == 2) {
// Broadcast needle characters into SIMD registers.
sz_u64_t matches;
sz_u128_vec_t h_first_vec, h_last_vec, n_first_vec, n_last_vec, matches_vec;
// Dealing with 16-bit values, we can load 2 registers at a time and compare 31 possible offsets
// in a single loop iteration.
n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]);
n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]);
for (; h_length >= 17; h += 16, h_length -= 16) {
h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0));
h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1));
matches_vec.u8x16 =
vandq_u8(vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16));
matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16);
if (matches) return h + sz_u64_ctz(matches) / 4;
}
}
else if (n_length == 3) {
// Broadcast needle characters into SIMD registers.
sz_u64_t matches;
sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec;
// Comparing 24-bit values is a bumer. Being lazy, I went with the same approach
// as when searching for string over 4 characters long. I only avoid the last comparison.
n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]);
n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[1]);
n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[2]);
for (; h_length >= 18; h += 16, h_length -= 16) {
h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 0));
h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 1));
h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + 2));
matches_vec.u8x16 = vandq_u8( //
vandq_u8( //
vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), //
vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)),
vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16));
matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16);
if (matches) return h + sz_u64_ctz(matches) / 4;
}
}
else {
// Pick the parts of the needle that are worth comparing.
sz_size_t offset_first, offset_mid, offset_last;
_sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last);
// Broadcast those characters into SIMD registers.
sz_u64_t matches;
sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec;
n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]);
n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]);
n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]);
// Walk through the string.
for (; h_length >= n_length + 16; h += 16, h_length -= 16) {
h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_first));
h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_mid));
h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_last));
matches_vec.u8x16 = vandq_u8( //
vandq_u8( //
vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), //
vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)),
vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16));
matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16);
while (matches) {
int potential_offset = sz_u64_ctz(matches) / 4;
if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset;
matches &= matches - 1;
}
}
}
return sz_find_serial(h, h_length, n, n_length);
}
SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
// This almost never fires, but it's better to be safe than sorry.
if (h_length < n_length || !n_length) return SZ_NULL_CHAR;
if (n_length == 1) return sz_rfind_byte_neon(h, h_length, n);
// Pick the parts of the needle that are worth comparing.
sz_size_t offset_first, offset_mid, offset_last;
_sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last);
// Will contain 4 bits per character.
sz_u64_t matches;
sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec;
n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]);
n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]);
n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]);
sz_cptr_t h_reversed;
for (; h_length >= n_length + 16; h_length -= 16) {
h_reversed = h + h_length - n_length - 16 + 1;
h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_first));
h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_mid));
h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_last));
matches_vec.u8x16 = vandq_u8( //
vandq_u8( //
vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), //
vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)),
vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16));
matches = _sz_vreinterpretq_u8_u4(matches_vec.u8x16);
while (matches) {
int potential_offset = sz_u64_clz(matches) / 4;
if (sz_equal(h + h_length - n_length - potential_offset, n, n_length))
return h + h_length - n_length - potential_offset;
sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 &&
"The bit must be set before we squash it");
matches &= ~(1ull << (63 - potential_offset * 4));
}
}
return sz_rfind_serial(h, h_length, n, n_length);
}
SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) {
sz_u64_t matches;
sz_u128_vec_t h_vec;
uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]);
uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]);
for (; h_length >= 16; h += 16, h_length -= 16) {
h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h));
matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16);
if (matches) return h + sz_u64_ctz(matches) / 4;
}
return sz_find_charset_serial(h, h_length, set);
}
SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) {
sz_u64_t matches;
sz_u128_vec_t h_vec;
uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]);
uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]);
// Check `sz_find_charset_neon` for explanations.
for (; h_length >= 16; h_length -= 16) {
h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h) + h_length - 16);
matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16);
if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4;
}
return sz_rfind_charset_serial(h, h_length, set);
}
#pragma clang attribute pop
#pragma GCC pop_options
#endif // Arm Neon
#pragma endregion
/* @brief Implementation of the string search algorithms using the Arm SVE variable-length registers, available
* in Arm v9 processors.
*
* Implements:
* - memory: {copy, move, fill}
* - comparisons: {equal, order}
* - search: {substring, character, character set} x {forward, reverse}.
*/
#pragma region ARM SVE
#if SZ_USE_ARM_SVE
#pragma GCC push_options
#pragma GCC target("arch=armv8.2-a+sve")
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function)
SZ_PUBLIC void sz_fill_sve(sz_ptr_t target, sz_size_t length, sz_u8_t value) {
svuint8_t value_vec = svdup_u8(value);
sz_size_t vec_len = svcntb(); // Vector length in bytes (scalable)
if (length <= vec_len) {
// Small buffer case: use mask to handle small writes
svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length);
svst1_u8(mask, (unsigned char *)target, value_vec);
}
else {
// Calculate head, body, and tail sizes
sz_size_t head_length = vec_len - ((sz_size_t)target % vec_len);
sz_size_t tail_length = (sz_size_t)(target + length) % vec_len;
sz_size_t body_length = length - head_length - tail_length;
// Handle unaligned head
svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length);
svst1_u8(head_mask, (unsigned char *)target, value_vec);
target += head_length;
// Aligned body loop
for (; body_length >= vec_len; target += vec_len, body_length -= vec_len) {
svst1_u8(svptrue_b8(), (unsigned char *)target, value_vec);
}
// Handle unaligned tail
svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length);
svst1_u8(tail_mask, (unsigned char *)target, value_vec);
}
}
SZ_PUBLIC void sz_copy_sve(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
sz_size_t vec_len = svcntb(); // Vector length in bytes
// Arm Neoverse V2 cores in Graviton 4, for example, come with 256 KB of L1 data cache per core,
// and 8 MB of L2 cache per core. Moreover, the L1 cache is fully associative.
// With two strings, we may consider the overal workload huge, if each exceeds 1 MB in length.
//
// int is_huge = length >= 4ull * 1024ull * 1024ull;
//
// When the buffer is small, there isn't much to innovate.
if (length <= vec_len) {
// Small buffer case: use mask to handle small writes
svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)length);
svuint8_t data = svld1_u8(mask, (unsigned char *)source);
svst1_u8(mask, (unsigned char *)target, data);
}
// When dealing with larger buffers, similar to AVX-512, we want minimize unaligned operations
// and handle the head, body, and tail separately. We can also traverse the buffer in both directions
// as Arm generally supports more simultaneous stores than x86 CPUs.
//
// For gigantic datasets, similar to AVX-512, non-temporal "loads" and "stores" can be used.
// Sadly, if the register size (16 byte or larger) is smaller than a cache-line (64 bytes)
// we will pay a huge penalty on loads, fetching the same content many times.
// It may be better to allow caching (and subsequent eviction), in favor of using four-element
// tuples, wich will be guaranteed to be a multiple of a cache line.
//
// Another approach is to use the `LD4B` instructions, which will populate four registers at once.
// This however, further decreases the performance from LibC-like 29 GB/s to 20 GB/s.
else {
// Calculating head, body, and tail sizes depends on the `vec_len`,
// but it's runtime constant, and the modulo operation is expensive!
// Instead we use the fact, that it's always a multiple of 128 bits or 16 bytes.
sz_size_t head_length = 16 - ((sz_size_t)target % 16);
sz_size_t tail_length = (sz_size_t)(target + length) % 16;
sz_size_t body_length = length - head_length - tail_length;
// Handle unaligned parts
svbool_t head_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)head_length);
svuint8_t head_data = svld1_u8(head_mask, (unsigned char *)source);
svst1_u8(head_mask, (unsigned char *)target, head_data);
svbool_t tail_mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)tail_length);
svuint8_t tail_data = svld1_u8(tail_mask, (unsigned char *)source + head_length + body_length);
svst1_u8(tail_mask, (unsigned char *)target + head_length + body_length, tail_data);
target += head_length;
source += head_length;
// Aligned body loop, walking in two directions
for (; body_length >= vec_len * 2; target += vec_len, source += vec_len, body_length -= vec_len * 2) {
svuint8_t forward_data = svld1_u8(svptrue_b8(), (unsigned char *)source);
svuint8_t backward_data = svld1_u8(svptrue_b8(), (unsigned char *)source + body_length - vec_len);
svst1_u8(svptrue_b8(), (unsigned char *)target, forward_data);
svst1_u8(svptrue_b8(), (unsigned char *)target + body_length - vec_len, backward_data);
}
// Up to (vec_len * 2 - 1) bytes of data may be left in the body,
// so we can unroll the last two optional loop iterations.
if (body_length > vec_len) {
svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length);
svuint8_t data = svld1_u8(mask, (unsigned char *)source);
svst1_u8(mask, (unsigned char *)target, data);
body_length -= vec_len;
source += body_length;
target += body_length;
}
if (body_length) {
svbool_t mask = svwhilelt_b8((sz_u32_t)0ull, (sz_u32_t)body_length);
svuint8_t data = svld1_u8(mask, (unsigned char *)source);
svst1_u8(mask, (unsigned char *)target, data);
}
}
}
#pragma clang attribute pop
#pragma GCC pop_options
#endif // Arm SVE
#pragma endregion
/*
* @brief Pick the right implementation for the string search algorithms.
*/
#pragma region Compile - Time Dispatching
SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t ins, sz_size_t length) { return sz_hash_serial(ins, length); }
SZ_PUBLIC void sz_tolower(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_tolower_serial(ins, length, outs); }
SZ_PUBLIC void sz_toupper(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toupper_serial(ins, length, outs); }
SZ_PUBLIC void sz_toascii(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toascii_serial(ins, length, outs); }
SZ_PUBLIC sz_bool_t sz_isascii(sz_cptr_t ins, sz_size_t length) { return sz_isascii_serial(ins, length); }
SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_ptr_t fingerprint,
sz_size_t fingerprint_bytes) {
sz_bool_t fingerprint_length_is_power_of_two = (sz_bool_t)((fingerprint_bytes & (fingerprint_bytes - 1)) == 0);
sz_string_view_t fingerprint_buffer = {fingerprint, fingerprint_bytes};
// There are several issues related to the fingerprinting algorithm.
// First, the memory traversal order is important.
// https://blog.stuffedcow.net/2015/08/pagewalk-coherence/
// In most cases the fingerprint length will be a power of two.
if (fingerprint_length_is_power_of_two == sz_false_k)
sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_non_pow2_callback, &fingerprint_buffer);
else
sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_pow2_callback, &fingerprint_buffer);
}
#if !SZ_DYNAMIC_DISPATCH
SZ_DYNAMIC sz_u64_t sz_checksum(sz_cptr_t text, sz_size_t length) {
#if SZ_USE_X86_AVX512
return sz_checksum_avx512(text, length);
#elif SZ_USE_X86_AVX2
return sz_checksum_avx2(text, length);
#elif SZ_USE_ARM_NEON
return sz_checksum_neon(text, length);
#else
return sz_checksum_serial(text, length);
#endif
}
SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) {
#if SZ_USE_X86_AVX512
return sz_equal_avx512(a, b, length);
#elif SZ_USE_X86_AVX2
return sz_equal_avx2(a, b, length);
#elif SZ_USE_ARM_NEON
return sz_equal_neon(a, b, length);
#else
return sz_equal_serial(a, b, length);
#endif
}
SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) {
#if SZ_USE_X86_AVX512
return sz_order_avx512(a, a_length, b, b_length);
#elif SZ_USE_X86_AVX2
return sz_order_avx2(a, a_length, b, b_length);
#elif SZ_USE_ARM_NEON
return sz_order_neon(a, a_length, b, b_length);
#else
return sz_order_serial(a, a_length, b, b_length);
#endif
}
SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
#if SZ_USE_X86_AVX512
sz_copy_avx512(target, source, length);
#elif SZ_USE_X86_AVX2
sz_copy_avx2(target, source, length);
#elif SZ_USE_ARM_NEON
sz_copy_neon(target, source, length);
#else
sz_copy_serial(target, source, length);
#endif
}
SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) {
#if SZ_USE_X86_AVX512
sz_move_avx512(target, source, length);
#elif SZ_USE_X86_AVX2
sz_move_avx2(target, source, length);
#elif SZ_USE_ARM_NEON
sz_move_neon(target, source, length);
#else
sz_move_serial(target, source, length);
#endif
}
SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) {
#if SZ_USE_X86_AVX512
sz_fill_avx512(target, length, value);
#elif SZ_USE_X86_AVX2
sz_fill_avx2(target, length, value);
#elif SZ_USE_ARM_NEON
sz_fill_neon(target, length, value);
#else
sz_fill_serial(target, length, value);
#endif
}
SZ_DYNAMIC void sz_look_up_transform(sz_cptr_t source, sz_size_t length, sz_cptr_t lut, sz_ptr_t target) {
#if SZ_USE_X86_AVX512
sz_look_up_transform_avx512(source, length, lut, target);
#elif SZ_USE_X86_AVX2
sz_look_up_transform_avx2(source, length, lut, target);
#elif SZ_USE_ARM_NEON
sz_look_up_transform_neon(source, length, lut, target);
#else
sz_look_up_transform_serial(source, length, lut, target);
#endif
}
SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) {
#if SZ_USE_X86_AVX512
return sz_find_byte_avx512(haystack, h_length, needle);
#elif SZ_USE_X86_AVX2
return sz_find_byte_avx2(haystack, h_length, needle);
#elif SZ_USE_ARM_NEON
return sz_find_byte_neon(haystack, h_length, needle);
#else
return sz_find_byte_serial(haystack, h_length, needle);
#endif
}
SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) {
#if SZ_USE_X86_AVX512
return sz_rfind_byte_avx512(haystack, h_length, needle);
#elif SZ_USE_X86_AVX2
return sz_rfind_byte_avx2(haystack, h_length, needle);
#elif SZ_USE_ARM_NEON
return sz_rfind_byte_neon(haystack, h_length, needle);
#else
return sz_rfind_byte_serial(haystack, h_length, needle);
#endif
}
SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) {
#if SZ_USE_X86_AVX512
return sz_find_avx512(haystack, h_length, needle, n_length);
#elif SZ_USE_X86_AVX2
return sz_find_avx2(haystack, h_length, needle, n_length);
#elif SZ_USE_ARM_NEON
return sz_find_neon(haystack, h_length, needle, n_length);
#else
return sz_find_serial(haystack, h_length, needle, n_length);
#endif
}
SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) {
#if SZ_USE_X86_AVX512
return sz_rfind_avx512(haystack, h_length, needle, n_length);
#elif SZ_USE_X86_AVX2
return sz_rfind_avx2(haystack, h_length, needle, n_length);
#elif SZ_USE_ARM_NEON
return sz_rfind_neon(haystack, h_length, needle, n_length);
#else
return sz_rfind_serial(haystack, h_length, needle, n_length);
#endif
}
SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) {
#if SZ_USE_X86_AVX512
return sz_find_charset_avx512(text, length, set);
#elif SZ_USE_X86_AVX2
return sz_find_charset_avx2(text, length, set);
#elif SZ_USE_ARM_NEON
return sz_find_charset_neon(text, length, set);
#else
return sz_find_charset_serial(text, length, set);
#endif
}
SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) {
#if SZ_USE_X86_AVX512
return sz_rfind_charset_avx512(text, length, set);
#elif SZ_USE_X86_AVX2
return sz_rfind_charset_avx2(text, length, set);
#elif SZ_USE_ARM_NEON
return sz_rfind_charset_neon(text, length, set);
#else
return sz_rfind_charset_serial(text, length, set);
#endif
}
SZ_DYNAMIC sz_size_t sz_hamming_distance( //
sz_cptr_t a, sz_size_t a_length, //
sz_cptr_t b, sz_size_t b_length, //
sz_size_t bound) {
return sz_hamming_distance_serial(a, a_length, b, b_length, bound);
}
SZ_DYNAMIC sz_size_t sz_hamming_distance_utf8( //
sz_cptr_t a, sz_size_t a_length, //
sz_cptr_t b, sz_size_t b_length, //
sz_size_t bound) {
return sz_hamming_distance_utf8_serial(a, a_length, b, b_length, bound);
}
SZ_DYNAMIC sz_size_t sz_edit_distance( //
sz_cptr_t a, sz_size_t a_length, //
sz_cptr_t b, sz_size_t b_length, //
sz_size_t bound, sz_memory_allocator_t *alloc) {
#if SZ_USE_X86_AVX512
return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc);
#else
return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc);
#endif
}
SZ_DYNAMIC sz_size_t sz_edit_distance_utf8( //
sz_cptr_t a, sz_size_t a_length, //
sz_cptr_t b, sz_size_t b_length, //
sz_size_t bound, sz_memory_allocator_t *alloc) {
return _sz_edit_distance_wagner_fisher_serial(a, a_length, b, b_length, bound, sz_true_k, alloc);
}
SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length,
sz_error_cost_t const *subs, sz_error_cost_t gap,
sz_memory_allocator_t *alloc) {
#if SZ_USE_X86_AVX512
return sz_alignment_score_avx512(a, a_length, b, b_length, subs, gap, alloc);
#else
return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc);
#endif
}
SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, //
sz_hash_callback_t callback, void *callback_handle) {
#if SZ_USE_X86_AVX512
sz_hashes_avx512(text, length, window_length, window_step, callback, callback_handle);
#elif SZ_USE_X86_AVX2
sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle);
#else
sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle);
#endif
}
SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
sz_charset_t set;
sz_charset_init(&set);
for (; n_length; ++n, --n_length) sz_charset_add(&set, *n);
return sz_find_charset(h, h_length, &set);
}
SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
sz_charset_t set;
sz_charset_init(&set);
for (; n_length; ++n, --n_length) sz_charset_add(&set, *n);
sz_charset_invert(&set);
return sz_find_charset(h, h_length, &set);
}
SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
sz_charset_t set;
sz_charset_init(&set);
for (; n_length; ++n, --n_length) sz_charset_add(&set, *n);
return sz_rfind_charset(h, h_length, &set);
}
SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
sz_charset_t set;
sz_charset_init(&set);
for (; n_length; ++n, --n_length) sz_charset_add(&set, *n);
sz_charset_invert(&set);
return sz_rfind_charset(h, h_length, &set);
}
SZ_DYNAMIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length,
sz_random_generator_t generator, void *generator_user_data) {
sz_generate_serial(alphabet, alphabet_size, result, result_length, generator, generator_user_data);
}
#endif
#pragma endregion
#ifdef __cplusplus
#pragma GCC diagnostic pop
}
#endif // __cplusplus
#endif // STRINGZILLA_H_