From: Vsevolod Stakhov Date: Fri, 9 Feb 2024 12:29:12 +0000 (+0000) Subject: [Feature] Add stringzilla library for faster strings operations X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bcd768142090471eabbc254f469bba711bc11dde;p=thirdparty%2Frspamd.git [Feature] Add stringzilla library for faster strings operations --- diff --git a/CMakeLists.txt b/CMakeLists.txt index 75455aae86..955871c01c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -125,6 +125,7 @@ INCLUDE_DIRECTORIES("${CMAKE_SOURCE_DIR}/" "${CMAKE_SOURCE_DIR}/contrib/lua-lpeg" "${CMAKE_SOURCE_DIR}/contrib/frozen/include" "${CMAKE_SOURCE_DIR}/contrib/fu2/include" + "${CMAKE_SOURCE_DIR}/contrib/stringzilla/include" "${CMAKE_BINARY_DIR}/src" #Stored in the binary dir "${CMAKE_BINARY_DIR}/src/libcryptobox") @@ -663,6 +664,13 @@ IF (ENABLE_LUA_REPL MATCHES "ON") LIST(APPEND RSPAMD_REQUIRED_LIBRARIES rspamd-replxx) ENDIF () +ADD_SUBDIRECTORY(contrib/stringzilla) +LIST(APPEND RSPAMD_REQUIRED_LIBRARIES rspamd-stringzilla) +# Propagate to all targets, as we use those in the includes +FOREACH (DEFINITION ${SZ_DEFINITIONS}) + ADD_DEFINITIONS(${DEFINITION}) +ENDFOREACH () + IF (ENABLE_SNOWBALL MATCHES "ON") LIST(APPEND RSPAMD_REQUIRED_LIBRARIES stemmer) ENDIF () diff --git a/contrib/DEPENDENCY_INFO.md b/contrib/DEPENDENCY_INFO.md index 300a38cee7..bd4e6742b2 100644 --- a/contrib/DEPENDENCY_INFO.md +++ b/contrib/DEPENDENCY_INFO.md @@ -38,4 +38,5 @@ | ankerl/svector | 1.0.2 | MIT | NO | | | ankerl/unordered_dense | 4.4.0 | MIT | NO | | | backward-cpp | 1.6 | MIT | NO | | +| stringzilla | 3.0.0 | Apache2 | NO | | diff --git a/contrib/stringzilla/CMakeLists.txt b/contrib/stringzilla/CMakeLists.txt new file mode 100644 index 0000000000..779b6faf5d --- /dev/null +++ b/contrib/stringzilla/CMakeLists.txt @@ -0,0 +1,17 @@ +SET(STRINGZILLASRC lib.c) + +SET(SZ_DEFINITIONS + "-DSZ_DYNAMIC_DISPATCH=1" + "-DSZ_USE_MISALIGNED_LOADS=1" + "-DSZ_USE_X86_AVX512=1" + "-DSZ_USE_X86_AVX2=1" + "-DSZ_USE_ARM_NEON=1" + "-DSZ_USE_ARM_SVE=1" + PARENT_SCOPE) + +FOREACH (DEFINITION ${SZ_DEFINITIONS}) + ADD_DEFINITIONS(${DEFINITION}) +ENDFOREACH () + +ADD_LIBRARY(rspamd-stringzilla STATIC ${STRINGZILLASRC}) +SET_TARGET_PROPERTIES(rspamd-stringzilla PROPERTIES VERSION ${RSPAMD_VERSION}) \ No newline at end of file diff --git a/contrib/stringzilla/LICENSE b/contrib/stringzilla/LICENSE new file mode 100644 index 0000000000..261eeb9e9f --- /dev/null +++ b/contrib/stringzilla/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/contrib/stringzilla/include/stringzilla/stringzilla.h b/contrib/stringzilla/include/stringzilla/stringzilla.h new file mode 100644 index 0000000000..283b0ca76d --- /dev/null +++ b/contrib/stringzilla/include/stringzilla/stringzilla.h @@ -0,0 +1,4778 @@ +/** + * @brief StringZilla is a collection of simple string algorithms, designed to be used in Big Data applications. + * It may be slower than LibC, but has a broader & cleaner interface, and a very short implementation + * targeting modern x86 CPUs with AVX-512 and Arm NEON and older CPUs with SWAR and auto-vectorization. + * + * Consider overriding the following macros to customize the library: + * + * - `SZ_DEBUG=0` - whether to enable debug assertions and logging. + * - `SZ_DYNAMIC_DISPATCH=0` - whether to use runtime dispatching of the most advanced SIMD backend. + * - `SZ_USE_MISALIGNED_LOADS=0` - whether to use misaligned loads on platforms that support them. + * - `SZ_SWAR_THRESHOLD=24` - threshold for switching to SWAR backend over serial byte-level for-loops. + * - `SZ_USE_X86_AVX512=?` - whether to use AVX-512 instructions on x86_64. + * - `SZ_USE_X86_AVX2=?` - whether to use AVX2 instructions on x86_64. + * - `SZ_USE_ARM_NEON=?` - whether to use NEON instructions on ARM. + * - `SZ_USE_ARM_SVE=?` - whether to use SVE instructions on ARM. + * + * @see StringZilla: https://github.com/ashvardanian/StringZilla/blob/main/README.md + * @see LibC String: https://pubs.opengroup.org/onlinepubs/009695399/basedefs/string.h.html + * + * @file stringzilla.h + * @author Ash Vardanian + */ +#ifndef STRINGZILLA_H_ +#define STRINGZILLA_H_ + +#define STRINGZILLA_VERSION_MAJOR 3 +#define STRINGZILLA_VERSION_MINOR 0 +#define STRINGZILLA_VERSION_PATCH 0 + +/** + * @brief When set to 1, the library will include the following LibC headers: and . + * In debug builds (SZ_DEBUG=1), the library will also include and . + * + * You may want to disable this compiling for use in the kernel, or in embedded systems. + * You may also avoid them, if you are very sensitive to compilation time and avoid pre-compiled headers. + * https://artificial-mind.net/projects/compile-health/ + */ +#ifndef SZ_AVOID_LIBC +#define SZ_AVOID_LIBC (0) // true or false +#endif + +/** + * @brief A misaligned load can be - trying to fetch eight consecutive bytes from an address + * that is not divisible by eight. + * + * Most platforms support it, but there is no industry standard way to check for those. + * This value will mostly affect the performance of the serial (SWAR) backend. + */ +#ifndef SZ_USE_MISALIGNED_LOADS +#define SZ_USE_MISALIGNED_LOADS (0) // true or false +#endif + +/** + * @brief Removes compile-time dispatching, and replaces it with runtime dispatching. + * So the `sz_find` function will invoke the most advanced backend supported by the CPU, + * that runs the program, rather than the most advanced backend supported by the CPU + * used to compile the library or the downstream application. + */ +#ifndef SZ_DYNAMIC_DISPATCH +#define SZ_DYNAMIC_DISPATCH (0) // true or false +#endif + +/** + * @brief Analogous to `size_t` and `std::size_t`, unsigned integer, identical to pointer size. + * 64-bit on most platforms where pointers are 64-bit. + * 32-bit on platforms where pointers are 32-bit. + */ +#if defined(__LP64__) || defined(_LP64) || defined(__x86_64__) || defined(_WIN64) +#define SZ_DETECT_64_BIT (1) +#define SZ_SIZE_MAX (0xFFFFFFFFFFFFFFFFull) // Largest unsigned integer that fits into 64 bits. +#define SZ_SSIZE_MAX (0x7FFFFFFFFFFFFFFFull) // Largest signed integer that fits into 64 bits. +#else +#define SZ_DETECT_64_BIT (0) +#define SZ_SIZE_MAX (0xFFFFFFFFu) // Largest unsigned integer that fits into 32 bits. +#define SZ_SSIZE_MAX (0x7FFFFFFFu) // Largest signed integer that fits into 32 bits. +#endif + +/* + * Debugging and testing. + */ +#ifndef SZ_DEBUG +#if defined(DEBUG) || defined(_DEBUG) // This means "Not using DEBUG information". +#define SZ_DEBUG (1) +#else +#define SZ_DEBUG (0) +#endif +#endif + +/* Annotation for the public API symbols: + * + * - `SZ_PUBLIC` is used for functions that are part of the public API. + * - `SZ_INTERNAL` is used for internal helper functions with unstable APIs. + * - `SZ_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime. + */ +#ifndef SZ_DYNAMIC +#if SZ_DYNAMIC_DISPATCH +#if defined(_WIN32) || defined(__CYGWIN__) +#define SZ_DYNAMIC __declspec(dllexport) +#define SZ_PUBLIC inline static +#define SZ_INTERNAL inline static +#else +#define SZ_DYNAMIC __attribute__((visibility("default"))) +#define SZ_PUBLIC __attribute__((unused)) inline static +#define SZ_INTERNAL __attribute__((always_inline)) inline static +#endif // _WIN32 || __CYGWIN__ +#else +#define SZ_DYNAMIC inline static +#define SZ_PUBLIC inline static +#define SZ_INTERNAL inline static +#endif // SZ_DYNAMIC_DISPATCH +#endif // SZ_DYNAMIC + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Let's infer the integer types or pull them from LibC, + * if that is allowed by the user. + */ +#if !SZ_AVOID_LIBC +#include // `size_t` +#include // `uint8_t` +typedef int8_t sz_i8_t; // Always 8 bits +typedef uint8_t sz_u8_t; // Always 8 bits +typedef uint16_t sz_u16_t; // Always 16 bits +typedef int32_t sz_i32_t; // Always 32 bits +typedef uint32_t sz_u32_t; // Always 32 bits +typedef uint64_t sz_u64_t; // Always 64 bits +typedef int64_t sz_i64_t; // Always 64 bits +typedef size_t sz_size_t; // Pointer-sized unsigned integer, 32 or 64 bits +typedef ptrdiff_t sz_ssize_t; // Signed version of `sz_size_t`, 32 or 64 bits + +#else // if SZ_AVOID_LIBC: + +typedef signed char sz_i8_t; // Always 8 bits +typedef unsigned char sz_u8_t; // Always 8 bits +typedef unsigned short sz_u16_t; // Always 16 bits +typedef int sz_i32_t; // Always 32 bits +typedef unsigned int sz_u32_t; // Always 32 bits +typedef long long sz_i64_t; // Always 64 bits +typedef unsigned long long sz_u64_t; // Always 64 bits + +#if SZ_DETECT_64_BIT +typedef unsigned long long sz_size_t; // 64-bit. +typedef long long sz_ssize_t; // 64-bit. +#else +typedef unsigned sz_size_t; // 32-bit. +typedef unsigned sz_ssize_t; // 32-bit. +#endif // SZ_DETECT_64_BIT + +#endif // SZ_AVOID_LIBC + +/** + * @brief Compile-time assert macro similar to `static_assert` in C++. + */ +#define sz_static_assert(condition, name) \ + typedef struct { \ + int static_assert_##name : (condition) ? 1 : -1; \ + } sz_static_assert_##name##_t + +sz_static_assert(sizeof(sz_size_t) == sizeof(void *), sz_size_t_must_be_pointer_size); +sz_static_assert(sizeof(sz_ssize_t) == sizeof(void *), sz_ssize_t_must_be_pointer_size); + +#pragma region Public API + +typedef char *sz_ptr_t; // A type alias for `char *` +typedef char const *sz_cptr_t; // A type alias for `char const *` +typedef sz_i8_t sz_error_cost_t; // Character mismatch cost for fuzzy matching functions + +typedef enum { sz_false_k = 0, sz_true_k = 1 } sz_bool_t; // Only one relevant bit +typedef enum { sz_less_k = -1, sz_equal_k = 0, sz_greater_k = 1 } sz_ordering_t; // Only three possible states: <=> + +/** + * @brief Tiny string-view structure. It's POD type, unlike the `std::string_view`. + */ +typedef struct sz_string_view_t { + sz_cptr_t start; + sz_size_t length; +} sz_string_view_t; + +/** + * @brief Enumeration of SIMD capabilities of the target architecture. + * Used to introspect the supported functionality of the dynamic library. + */ +typedef enum sz_capability_t { + sz_cap_serial_k = 1, /// Serial (non-SIMD) capability + sz_cap_any_k = 0x7FFFFFFF, /// Mask representing any capability + + sz_cap_arm_neon_k = 1 << 10, /// ARM NEON capability + sz_cap_arm_sve_k = 1 << 11, /// ARM SVE capability TODO: Not yet supported or used + + sz_cap_x86_avx2_k = 1 << 20, /// x86 AVX2 capability + sz_cap_x86_avx512f_k = 1 << 21, /// x86 AVX512 F capability + sz_cap_x86_avx512bw_k = 1 << 22, /// x86 AVX512 BW instruction capability + sz_cap_x86_avx512vl_k = 1 << 23, /// x86 AVX512 VL instruction capability + sz_cap_x86_avx512vbmi_k = 1 << 24, /// x86 AVX512 VBMI instruction capability + sz_cap_x86_gfni_k = 1 << 25, /// x86 AVX512 GFNI instruction capability + +} sz_capability_t; + +/** + * @brief Function to determine the SIMD capabilities of the current machine @b only at @b runtime. + * @return A bitmask of the SIMD capabilities represented as a `sz_capability_t` enum value. + */ +SZ_DYNAMIC sz_capability_t sz_capabilities(void); + +/** + * @brief Bit-set structure for 256 possible byte values. Useful for filtering and search. + * @see sz_charset_init, sz_charset_add, sz_charset_contains, sz_charset_invert + */ +typedef union sz_charset_t { + sz_u64_t _u64s[4]; + sz_u32_t _u32s[8]; + sz_u16_t _u16s[16]; + sz_u8_t _u8s[32]; +} sz_charset_t; + +/** @brief Initializes a bit-set to an empty collection, meaning - all characters are banned. */ +SZ_PUBLIC void sz_charset_init(sz_charset_t *s) { s->_u64s[0] = s->_u64s[1] = s->_u64s[2] = s->_u64s[3] = 0; } + +/** @brief Adds a character to the set and accepts @b unsigned integers. */ +SZ_PUBLIC void sz_charset_add_u8(sz_charset_t *s, sz_u8_t c) { s->_u64s[c >> 6] |= (1ull << (c & 63u)); } + +/** @brief Adds a character to the set. Consider @b sz_charset_add_u8. */ +SZ_PUBLIC void sz_charset_add(sz_charset_t *s, char c) { sz_charset_add_u8(s, *(sz_u8_t *)(&c)); } // bitcast + +/** @brief Checks if the set contains a given character and accepts @b unsigned integers. */ +SZ_PUBLIC sz_bool_t sz_charset_contains_u8(sz_charset_t const *s, sz_u8_t c) { + // Checking the bit can be done in different ways: + // - (s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0 + // - (s->_u32s[c >> 5] & (1u << (c & 31u))) != 0 + // - (s->_u16s[c >> 4] & (1u << (c & 15u))) != 0 + // - (s->_u8s[c >> 3] & (1u << (c & 7u))) != 0 + return (sz_bool_t)((s->_u64s[c >> 6] & (1ull << (c & 63u))) != 0); +} + +/** @brief Checks if the set contains a given character. Consider @b sz_charset_contains_u8. */ +SZ_PUBLIC sz_bool_t sz_charset_contains(sz_charset_t const *s, char c) { + return sz_charset_contains_u8(s, *(sz_u8_t *)(&c)); // bitcast +} + +/** @brief Inverts the contents of the set, so allowed character get disallowed, and vice versa. */ +SZ_PUBLIC void sz_charset_invert(sz_charset_t *s) { + s->_u64s[0] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[1] ^= 0xFFFFFFFFFFFFFFFFull, // + s->_u64s[2] ^= 0xFFFFFFFFFFFFFFFFull, s->_u64s[3] ^= 0xFFFFFFFFFFFFFFFFull; +} + +typedef void *(*sz_memory_allocate_t)(sz_size_t, void *); +typedef void (*sz_memory_free_t)(void *, sz_size_t, void *); +typedef sz_u64_t (*sz_random_generator_t)(void *); + +/** + * @brief Some complex pattern matching algorithms may require memory allocations. + * This structure is used to pass the memory allocator to those functions. + * @see sz_memory_allocator_init_fixed + */ +typedef struct sz_memory_allocator_t { + sz_memory_allocate_t allocate; + sz_memory_free_t free; + void *handle; +} sz_memory_allocator_t; + +/** + * @brief Initializes a memory allocator to use the system default `malloc` and `free`. + * @param alloc Memory allocator to initialize. + */ +SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc); + +/** + * @brief Initializes a memory allocator to use a static-capacity buffer. + * No dynamic allocations will be performed. + * + * @param alloc Memory allocator to initialize. + * @param buffer Buffer to use for allocations. + * @param length Length of the buffer. @b Must be greater than 8 bytes. Different values would be optimal for + * different algorithms and input lengths, but 4096 bytes (one RAM page) is a good default. + */ +SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length); + +/** + * @brief The number of bytes a stack-allocated string can hold, including the SZ_NULL termination character. + * ! This can't be changed from outside. Don't use the `#error` as it may already be included and set. + */ +#ifdef SZ_STRING_INTERNAL_SPACE +#undef SZ_STRING_INTERNAL_SPACE +#endif +#define SZ_STRING_INTERNAL_SPACE (23) + +/** + * @brief Tiny memory-owning string structure with a Small String Optimization (SSO). + * Differs in layout from Folly, Clang, GCC, and probably most other implementations. + * It's designed to avoid any branches on read-only operations, and can store up + * to 22 characters on stack, followed by the SZ_NULL-termination character. + * + * @section Changing Length + * + * One nice thing about this design, is that you can, in many cases, change the length of the string + * without any branches, invoking a `+=` or `-=` on the 64-bit `length` field. If the string is on heap, + * the solution is obvious. If it's on stack, inplace decrement wouldn't affect the top bytes of the string, + * only changing the last byte containing the length. + */ +typedef union sz_string_t { + + struct internal { + sz_ptr_t start; + sz_u8_t length; + char chars[SZ_STRING_INTERNAL_SPACE]; + } internal; + + struct external { + sz_ptr_t start; + sz_size_t length; + /// @brief Number of bytes, that have been allocated for this string, equals to (capacity + 1). + sz_size_t space; + sz_size_t padding; + } external; + + sz_u64_t u64s[4]; + +} sz_string_t; + +typedef sz_u64_t (*sz_hash_t)(sz_cptr_t, sz_size_t); +typedef sz_bool_t (*sz_equal_t)(sz_cptr_t, sz_cptr_t, sz_size_t); +typedef sz_ordering_t (*sz_order_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); +typedef void (*sz_to_converter_t)(sz_cptr_t, sz_size_t, sz_ptr_t); + +/** + * @brief Computes the 64-bit unsigned hash of a string. Fairly fast for short strings, + * simple implementation, and supports rolling computation, reused in other APIs. + * Similar to `std::hash` in C++. + * + * @param text String to hash. + * @param length Number of bytes in the text. + * @return 64-bit hash value. + * + * @see sz_hashes, sz_hashes_fingerprint, sz_hashes_intersection + */ +SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t text, sz_size_t length); + +/** @copydoc sz_hash */ +SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t text, sz_size_t length); + +/** + * @brief Checks if two string are equal. + * Similar to `memcmp(a, b, length) == 0` in LibC and `a == b` in STL. + * + * The implementation of this function is very similar to `sz_order`, but the usage patterns are different. + * This function is more often used in parsing, while `sz_order` is often used in sorting. + * It works best on platforms with cheap + * + * @param a First string to compare. + * @param b Second string to compare. + * @param length Number of bytes in both strings. + * @return 1 if strings match, 0 otherwise. + */ +SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length); + +/** @copydoc sz_equal */ +SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length); + +/** + * @brief Estimates the relative order of two strings. Equivalent to `memcmp(a, b, length)` in LibC. + * Can be used on different length strings. + * + * @param a First string to compare. + * @param a_length Number of bytes in the first string. + * @param b Second string to compare. + * @param b_length Number of bytes in the second string. + * @return Negative if (a < b), positive if (a > b), zero if they are equal. + */ +SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); + +/** @copydoc sz_order */ +SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); + +/** + * @brief Equivalent to `for (char & c : text) c = tolower(c)`. + * + * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. + * So there are 26 english letters, shifted by 32 values, meaning that a conversion + * can be done by flipping the 5th bit each inappropriate character byte. This, however, + * breaks for extended ASCII, so a different solution is needed. + * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html + * + * @param text String to be normalized. + * @param length Number of bytes in the string. + * @param result Output string, can point to the same address as ::text. + */ +SZ_PUBLIC void sz_tolower(sz_cptr_t text, sz_size_t length, sz_ptr_t result); + +/** + * @brief Equivalent to `for (char & c : text) c = toupper(c)`. + * + * ASCII characters [A, Z] map to decimals [65, 90], and [a, z] map to [97, 122]. + * So there are 26 english letters, shifted by 32 values, meaning that a conversion + * can be done by flipping the 5th bit each inappropriate character byte. This, however, + * breaks for extended ASCII, so a different solution is needed. + * http://0x80.pl/notesen/2016-01-06-swar-swap-case.html + * + * @param text String to be normalized. + * @param length Number of bytes in the string. + * @param result Output string, can point to the same address as ::text. + */ +SZ_PUBLIC void sz_toupper(sz_cptr_t text, sz_size_t length, sz_ptr_t result); + +/** + * @brief Equivalent to `for (char & c : text) c = toascii(c)`. + * + * @param text String to be normalized. + * @param length Number of bytes in the string. + * @param result Output string, can point to the same address as ::text. + */ +SZ_PUBLIC void sz_toascii(sz_cptr_t text, sz_size_t length, sz_ptr_t result); + +/** + * @brief Generates a random string for a given alphabet, avoiding integer division and modulo operations. + * Similar to `text[i] = alphabet[rand() % cardinality]`. + * + * The modulo operation is expensive, and should be avoided in performance-critical code. + * We avoid it using small lookup tables and replacing it with a multiplication and shifts, similar to `libdivide`. + * Alternative algorithms would include: + * - Montgomery form: https://en.algorithmica.org/hpc/number-theory/montgomery/ + * - Barret reduction: https://www.nayuki.io/page/barrett-reduction-algorithm + * - Lemire's trick: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ + * + * @param alphabet Set of characters to sample from. + * @param cardinality Number of characters to sample from. + * @param text Output string, can point to the same address as ::text. + * @param generate Callback producing random numbers given the generator state. + * @param generator Generator state, can be a pointer to a seed, or a pointer to a random number generator. + */ +SZ_PUBLIC void sz_generate(sz_cptr_t alphabet, sz_size_t cardinality, sz_ptr_t text, sz_size_t length, + sz_random_generator_t generate, void *generator); + +/** + * @brief Similar to `memcpy`, copies contents of one string into another. + * The behavior is undefined if the strings overlap. + * + * @param target String to copy into. + * @param length Number of bytes to copy. + * @param source String to copy from. + */ +SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length); + +/** @copydoc sz_copy */ +SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); + +/** + * @brief Similar to `memmove`, copies (moves) contents of one string into another. + * Unlike `sz_copy`, allows overlapping strings as arguments. + * + * @param target String to copy into. + * @param length Number of bytes to copy. + * @param source String to copy from. + */ +SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length); + +/** @copydoc sz_move */ +SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length); + +typedef void (*sz_move_t)(sz_ptr_t, sz_cptr_t, sz_size_t); + +/** + * @brief Similar to `memset`, fills a string with a given value. + * + * @param target String to fill. + * @param length Number of bytes to fill. + * @param value Value to fill with. + */ +SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value); + +/** @copydoc sz_fill */ +SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value); + +typedef void (*sz_fill_t)(sz_ptr_t, sz_size_t, sz_u8_t); + +/** + * @brief Initializes a string class instance to an empty value. + */ +SZ_PUBLIC void sz_string_init(sz_string_t *string); + +/** + * @brief Convenience function checking if the provided string is stored inside of the ::string instance itself, + * alternative being - allocated in a remote region of the heap. + */ +SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string); + +/** + * @brief Unpacks the opaque instance of a string class into its components. + * Recommended to use only in read-only operations. + * + * @param string String to unpack. + * @param start Pointer to the start of the string. + * @param length Number of bytes in the string, before the SZ_NULL character. + * @param space Number of bytes allocated for the string (heap or stack), including the SZ_NULL character. + * @param is_external Whether the string is allocated on the heap externally, or fits withing ::string instance. + */ +SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, + sz_bool_t *is_external); + +/** + * @brief Unpacks only the start and length of the string. + * Recommended to use only in read-only operations. + * + * @param string String to unpack. + * @param start Pointer to the start of the string. + * @param length Number of bytes in the string, before the SZ_NULL character. + */ +SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length); + +/** + * @brief Constructs a string of a given ::length with noisy contents. + * Use the returned character pointer to populate the string. + * + * @param string String to initialize. + * @param length Number of bytes in the string, before the SZ_NULL character. + * @param allocator Memory allocator to use for the allocation. + * @return SZ_NULL if the operation failed, pointer to the start of the string otherwise. + */ +SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator); + +/** + * @brief Doesn't change the contents or the length of the string, but grows the available memory capacity. + * This is beneficial, if several insertions are expected, and we want to minimize allocations. + * + * @param string String to grow. + * @param new_capacity The number of characters to reserve space for, including existing ones. + * @param allocator Memory allocator to use for the allocation. + * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. + */ +SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator); + +/** + * @brief Grows the string by adding an uninitialized region of ::added_length at the given ::offset. + * Would often be used in conjunction with one or more `sz_copy` calls to populate the allocated region. + * Similar to `sz_string_reserve`, but changes the length of the ::string. + * + * @param string String to grow. + * @param offset Offset of the first byte to reserve space for. + * If provided offset is larger than the length, it will be capped. + * @param added_length The number of new characters to reserve space for. + * @param allocator Memory allocator to use for the allocation. + * @return SZ_NULL if the operation failed, pointer to the new start of the string otherwise. + */ +SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, + sz_memory_allocator_t *allocator); + +/** + * @brief Removes a range from a string. Changes the length, but not the capacity. + * Performs no allocations or deallocations and can't fail. + * + * @param string String to clean. + * @param offset Offset of the first byte to remove. + * @param length Number of bytes to remove. Out-of-bound ranges will be capped. + * @return Number of bytes removed. + */ +SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length); + +/** + * @brief Shrinks the string to fit the current length, if it's allocated on the heap. + * Teh reverse operation of ::sz_string_reserve. + * + * @param string String to shrink. + * @param allocator Memory allocator to use for the allocation. + * @return Whether the operation was successful. The only failures can come from the allocator. + */ +SZ_PUBLIC sz_ptr_t sz_string_shrink_to_fit(sz_string_t *string, sz_memory_allocator_t *allocator); + +/** + * @brief Frees the string, if it's allocated on the heap. + * If the string is on the stack, the function clears/resets the state. + */ +SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator); + +#pragma endregion + +#pragma region Fast Substring Search API + +typedef sz_cptr_t (*sz_find_byte_t)(sz_cptr_t, sz_size_t, sz_cptr_t); +typedef sz_cptr_t (*sz_find_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); +typedef sz_cptr_t (*sz_find_set_t)(sz_cptr_t, sz_size_t, sz_charset_t const *); + +/** + * @brief Locates first matching byte in a string. Equivalent to `memchr(haystack, *needle, h_length)` in LibC. + * + * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memchr.S + * Aarch64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/aarch64/memchr.S + * + * @param haystack Haystack - the string to search in. + * @param h_length Number of bytes in the haystack. + * @param needle Needle - single-byte substring to find. + * @return Address of the first match. + */ +SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); + +/** @copydoc sz_find_byte */ +SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); + +/** + * @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC. + * + * X86_64 implementation: https://github.com/lattera/glibc/blob/master/sysdeps/x86_64/memrchr.S + * Aarch64 implementation: missing + * + * @param haystack Haystack - the string to search in. + * @param h_length Number of bytes in the haystack. + * @param needle Needle - single-byte substring to find. + * @return Address of the last match. + */ +SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); + +/** @copydoc sz_rfind_byte */ +SZ_PUBLIC sz_cptr_t sz_rfind_byte_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); + +/** + * @brief Locates first matching substring. + * Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC. + * Similar to `strstr(haystack, needle)` in LibC, but requires known length. + * + * @param haystack Haystack - the string to search in. + * @param h_length Number of bytes in the haystack. + * @param needle Needle - substring to find. + * @param n_length Number of bytes in the needle. + * @return Address of the first match. + */ +SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); + +/** @copydoc sz_find */ +SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); + +/** + * @brief Locates the last matching substring. + * + * @param haystack Haystack - the string to search in. + * @param h_length Number of bytes in the haystack. + * @param needle Needle - substring to find. + * @param n_length Number of bytes in the needle. + * @return Address of the last match. + */ +SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); + +/** @copydoc sz_rfind */ +SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); + +/** + * @brief Finds the first character present from the ::set, present in ::text. + * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. + * May have identical implementation and performance to ::sz_rfind_charset. + * + * @param text String to be trimmed. + * @param accepted Set of accepted characters. + * @return Number of bytes forming the prefix. + */ +SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); + +/** @copydoc sz_find_charset */ +SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); + +/** + * @brief Finds the last character present from the ::set, present in ::text. + * Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC. + * May have identical implementation and performance to ::sz_find_charset. + * + * Useful for parsing, when we want to skip a set of characters. Examples: + * * 6 whitespaces: " \t\n\r\v\f". + * * 16 digits forming a float number: "0123456789,.eE+-". + * * 5 HTML reserved characters: "\"'&<>", of which "<>" can be useful for parsing. + * * 2 JSON string special characters useful to locate the end of the string: "\"\\". + * + * @param text String to be trimmed. + * @param rejected Set of rejected characters. + * @return Number of bytes forming the prefix. + */ +SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); + +/** @copydoc sz_rfind_charset */ +SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); + +#pragma endregion + +#pragma region String Similarity Measures API + +/** + * @brief Computes the Levenshtein edit-distance between two strings using the Wagner-Fisher algorithm. + * Similar to the Needleman-Wunsch alignment algorithm. Often used in fuzzy string matching. + * + * @param a First string to compare. + * @param a_length Number of bytes in the first string. + * @param b Second string to compare. + * @param b_length Number of bytes in the second string. + * + * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, + * so the memory usage is linear in relation to ::a_length and ::b_length. + * If SZ_NULL is passed, will initialize to the systems default `malloc`. + * @param bound Upper bound on the distance, that allows us to exit early. + * If zero is passed, the maximum possible distance will be equal to the length of the longer input. + * @return Unsigned integer for edit distance, the `bound` if was exceeded or `SZ_SIZE_MAX` + * if the memory allocation failed. + * + * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default + * @see https://en.wikipedia.org/wiki/Levenshtein_distance + */ +SZ_DYNAMIC sz_size_t sz_edit_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc); + +/** @copydoc sz_edit_distance */ +SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc); + +typedef sz_size_t (*sz_edit_distance_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_size_t, sz_memory_allocator_t *); + +/** + * @brief Computes Needleman–Wunsch alignment score for two string. Often used in bioinformatics and cheminformatics. + * Similar to the Levenshtein edit-distance, parameterized for gap and substitution penalties. + * + * Not commutative in the general case, as the order of the strings matters, as `sz_alignment_score(a, b)` may + * not be equal to `sz_alignment_score(b, a)`. Becomes @b commutative, if the substitution costs are symmetric. + * Equivalent to the negative Levenshtein distance, if: `gap == -1` and `subs[i][j] == (i == j ? 0: -1)`. + * + * @param a First string to compare. + * @param a_length Number of bytes in the first string. + * @param b Second string to compare. + * @param b_length Number of bytes in the second string. + * @param gap Penalty cost for gaps - insertions and removals. + * @param subs Substitution costs matrix with 256 x 256 values for all pairs of characters. + * + * @param alloc Temporary memory allocator. Only some of the rows of the matrix will be allocated, + * so the memory usage is linear in relation to ::a_length and ::b_length. + * If SZ_NULL is passed, will initialize to the systems default `malloc`. + * @return Signed similarity score. Can be negative, depending on the substitution costs. + * If the memory allocation fails, the function returns `SZ_SSIZE_MAX`. + * + * @see sz_memory_allocator_init_fixed, sz_memory_allocator_init_default + * @see https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm + */ +SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, // + sz_memory_allocator_t *alloc); + +/** @copydoc sz_alignment_score */ +SZ_PUBLIC sz_ssize_t sz_alignment_score_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, // + sz_memory_allocator_t *alloc); + +typedef sz_ssize_t (*sz_alignment_score_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t, sz_error_cost_t const *, + sz_error_cost_t, sz_memory_allocator_t *); + +typedef void (*sz_hash_callback_t)(sz_cptr_t, sz_size_t, sz_u64_t, void *user); + +/** + * @brief Computes the Karp-Rabin rolling hashes of a string supplying them to the provided `callback`. + * Can be used for similarity scores, search, ranking, etc. + * + * Rabin-Karp-like rolling hashes can have very high-level of collisions and depend + * on the choice of bases and the prime number. That's why, often two hashes from the same + * family are used with different bases. + * + * 1. Kernighan and Ritchie's function uses 31, a prime close to the size of English alphabet. + * 2. To be friendlier to byte-arrays and UTF8, we use 257 for the second function. + * + * Choosing the right ::window_length is task- and domain-dependant. For example, most English words are + * between 3 and 7 characters long, so a window of 4 bytes would be a good choice. For DNA sequences, + * the ::window_length might be a multiple of 3, as the codons are 3 (nucleotides) bytes long. + * With such minimalistic alphabets of just four characters (AGCT) longer windows might be needed. + * For protein sequences the alphabet is 20 characters long, so the window can be shorter, than for DNAs. + * + * @param text String to hash. + * @param length Number of bytes in the string. + * @param window_length Length of the rolling window in bytes. + * @param window_step Step of reported hashes. @b Must be power of two. Should be smaller than `window_length`. + * @param callback Function receiving the start & length of a substring, the hash, and the `callback_handle`. + * @param callback_handle Optional user-provided pointer to be passed to the `callback`. + * @see sz_hashes_fingerprint, sz_hashes_intersection + */ +SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // + sz_hash_callback_t callback, void *callback_handle); + +/** @copydoc sz_hashes */ +SZ_PUBLIC void sz_hashes_serial(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // + sz_hash_callback_t callback, void *callback_handle); + +typedef void (*sz_hashes_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_size_t, sz_hash_callback_t, void *); + +/** + * @brief Computes the Karp-Rabin rolling hashes of a string outputting a binary fingerprint. + * Such fingerprints can be compared with Hamming or Jaccard (Tanimoto) distance for similarity. + * + * The algorithm doesn't clear the fingerprint buffer on start, so it can be invoked multiple times + * to produce a fingerprint of a longer string, by passing the previous fingerprint as the ::fingerprint. + * It can also be reused to produce multi-resolution fingerprints by changing the ::window_length + * and calling the same function multiple times for the same input ::text. + * + * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, + * avoiding cache-coherency penalties of remote on-heap buffers. + * + * @param text String to hash. + * @param length Number of bytes in the string. + * @param fingerprint Output fingerprint buffer. + * @param fingerprint_bytes Number of bytes in the fingerprint buffer. + * @param window_length Length of the rolling window in bytes. + * @see sz_hashes, sz_hashes_intersection + */ +SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t text, sz_size_t length, sz_size_t window_length, // + sz_ptr_t fingerprint, sz_size_t fingerprint_bytes); + +typedef void (*sz_hashes_fingerprint_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_ptr_t, sz_size_t); + +/** + * @brief Given a hash-fingerprint of a textual document, computes the number of intersecting hashes + * of the incoming document. Can be used for document scoring and search. + * + * Processes large strings in parts to maximize the cache utilization, using a small on-stack buffer, + * avoiding cache-coherency penalties of remote on-heap buffers. + * + * @param text Input document. + * @param length Number of bytes in the input document. + * @param fingerprint Reference document fingerprint. + * @param fingerprint_bytes Number of bytes in the reference documents fingerprint. + * @param window_length Length of the rolling window in bytes. + * @see sz_hashes, sz_hashes_fingerprint + */ +SZ_PUBLIC sz_size_t sz_hashes_intersection(sz_cptr_t text, sz_size_t length, sz_size_t window_length, // + sz_cptr_t fingerprint, sz_size_t fingerprint_bytes); + +typedef sz_size_t (*sz_hashes_intersection_t)(sz_cptr_t, sz_size_t, sz_size_t, sz_cptr_t, sz_size_t); + +#pragma endregion + +#pragma region Convenience API + +/** + * @brief Finds the first character in the haystack, that is present in the needle. + * Convenience function, reused across different language bindings. + * @see sz_find_charset + */ +SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); + +/** + * @brief Finds the first character in the haystack, that is @b not present in the needle. + * Convenience function, reused across different language bindings. + * @see sz_find_charset + */ +SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); + +/** + * @brief Finds the last character in the haystack, that is present in the needle. + * Convenience function, reused across different language bindings. + * @see sz_find_charset + */ +SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); + +/** + * @brief Finds the last character in the haystack, that is @b not present in the needle. + * Convenience function, reused across different language bindings. + * @see sz_find_charset + */ +SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length); + +#pragma endregion + +#pragma region String Sequences API + +struct sz_sequence_t; + +typedef sz_cptr_t (*sz_sequence_member_start_t)(struct sz_sequence_t const *, sz_size_t); +typedef sz_size_t (*sz_sequence_member_length_t)(struct sz_sequence_t const *, sz_size_t); +typedef sz_bool_t (*sz_sequence_predicate_t)(struct sz_sequence_t const *, sz_size_t); +typedef sz_bool_t (*sz_sequence_comparator_t)(struct sz_sequence_t const *, sz_size_t, sz_size_t); +typedef sz_bool_t (*sz_string_is_less_t)(sz_cptr_t, sz_size_t, sz_cptr_t, sz_size_t); + +typedef struct sz_sequence_t { + sz_u64_t *order; + sz_size_t count; + sz_sequence_member_start_t get_start; + sz_sequence_member_length_t get_length; + void const *handle; +} sz_sequence_t; + +/** + * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. + * Expects ::offsets to contains `count + 1` entries, the last pointing at the end + * of the last string, indicating the total length of the ::tape. + */ +SZ_PUBLIC void sz_sequence_from_u32tape(sz_cptr_t *start, sz_u32_t const *offsets, sz_size_t count, + sz_sequence_t *sequence); + +/** + * @brief Initiates the sequence structure from a tape layout, used by Apache Arrow. + * Expects ::offsets to contains `count + 1` entries, the last pointing at the end + * of the last string, indicating the total length of the ::tape. + */ +SZ_PUBLIC void sz_sequence_from_u64tape(sz_cptr_t *start, sz_u64_t const *offsets, sz_size_t count, + sz_sequence_t *sequence); + +/** + * @brief Similar to `std::partition`, given a predicate splits the sequence into two parts. + * The algorithm is unstable, meaning that elements may change relative order, as long + * as they are in the right partition. This is the simpler algorithm for partitioning. + */ +SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate); + +/** + * @brief Inplace `std::set_union` for two consecutive chunks forming the same continuous `sequence`. + * + * @param partition The number of elements in the first sub-sequence in `sequence`. + * @param less Comparison function, to determine the lexicographic ordering. + */ +SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less); + +/** + * @brief Sorting algorithm, combining Radix Sort for the first 32 bits of every word + * and a follow-up by a more conventional sorting procedure on equally prefixed parts. + */ +SZ_PUBLIC void sz_sort(sz_sequence_t *sequence); + +/** + * @brief Partial sorting algorithm, combining Radix Sort for the first 32 bits of every word + * and a follow-up by a more conventional sorting procedure on equally prefixed parts. + */ +SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t n); + +/** + * @brief Intro-Sort algorithm that supports custom comparators. + */ +SZ_PUBLIC void sz_sort_intro(sz_sequence_t *sequence, sz_sequence_comparator_t less); + +#pragma endregion + +/* + * Hardware feature detection. + * All of those can be controlled by the user. + */ +#ifndef SZ_USE_X86_AVX512 +#ifdef __AVX512BW__ +#define SZ_USE_X86_AVX512 1 +#else +#define SZ_USE_X86_AVX512 0 +#endif +#endif + +#ifndef SZ_USE_X86_AVX2 +#ifdef __AVX2__ +#define SZ_USE_X86_AVX2 1 +#else +#define SZ_USE_X86_AVX2 0 +#endif +#endif + +#ifndef SZ_USE_ARM_NEON +#ifdef __ARM_NEON +#define SZ_USE_ARM_NEON 1 +#else +#define SZ_USE_ARM_NEON 0 +#endif +#endif + +#ifndef SZ_USE_ARM_SVE +#ifdef __ARM_FEATURE_SVE +#define SZ_USE_ARM_SVE 1 +#else +#define SZ_USE_ARM_SVE 0 +#endif +#endif + +/* + * Include hardware-specific headers. + */ +#if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 +#include +#endif // SZ_USE_X86... +#if SZ_USE_ARM_NEON +#include +#include +#endif // SZ_USE_ARM_NEON +#if SZ_USE_ARM_SVE +#include +#endif // SZ_USE_ARM_SVE + +#pragma region Hardware-Specific API + +#if SZ_USE_X86_AVX512 + +/** @copydoc sz_equal_serial */ +SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length); +/** @copydoc sz_order_serial */ +SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length); +/** @copydoc sz_copy_serial */ +SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_move_serial */ +SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_fill_serial */ +SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value); +/** @copydoc sz_find_byte */ +SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_rfind_byte */ +SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_find */ +SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +/** @copydoc sz_rfind */ +SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +/** @copydoc sz_find_charset */ +SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +/** @copydoc sz_rfind_charset */ +SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +/** @copydoc sz_edit_distance */ +SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc); +/** @copydoc sz_alignment_score */ +SZ_PUBLIC sz_ssize_t sz_alignment_score_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, // + sz_memory_allocator_t *alloc); +/** @copydoc sz_hashes */ +SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle); +#endif + +#if SZ_USE_X86_AVX2 +/** @copydoc sz_equal */ +SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_move */ +SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length); +/** @copydoc sz_fill */ +SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value); +/** @copydoc sz_find_byte */ +SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_rfind_byte */ +SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_find */ +SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +/** @copydoc sz_rfind */ +SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +/** @copydoc sz_hashes */ +SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle); +#endif + +#if SZ_USE_ARM_NEON +/** @copydoc sz_equal */ +SZ_PUBLIC sz_bool_t sz_equal_neon(sz_cptr_t a, sz_cptr_t b, sz_size_t length); +/** @copydoc sz_find_byte */ +SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_rfind_byte */ +SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle); +/** @copydoc sz_find */ +SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +/** @copydoc sz_rfind */ +SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length); +/** @copydoc sz_find_charset */ +SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +/** @copydoc sz_rfind_charset */ +SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t text, sz_size_t length, sz_charset_t const *set); +#endif + +#pragma endregion + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" + +/* + ********************************************************************************************************************** + ********************************************************************************************************************** + ********************************************************************************************************************** + * + * This is where we the actual implementation begins. + * The rest of the file is hidden from the public API. + * + ********************************************************************************************************************** + ********************************************************************************************************************** + ********************************************************************************************************************** + */ + +#pragma region Compiler Extensions and Helper Functions + +#pragma GCC visibility push(hidden) + +/** + * @brief Helper-macro to mark potentially unused variables. + */ +#define sz_unused(x) ((void)(x)) + +/** + * @brief Helper-macro casting a variable to another type of the same size. + */ +#define sz_bitcast(type, value) (*((type *)&(value))) + +/** + * @brief Defines `SZ_NULL`, analogous to `NULL`. + * The default often comes from locale.h, stddef.h, + * stdio.h, stdlib.h, string.h, time.h, or wchar.h. + */ +#ifdef __GNUG__ +#define SZ_NULL __null +#else +#define SZ_NULL ((void *)0) +#endif + +/** + * @brief Cache-line width, that will affect the execution of some algorithms, + * like equality checks and relative order computing. + */ +#define SZ_CACHE_LINE_WIDTH (64) // bytes + +/** + * @brief Similar to `assert`, the `sz_assert` is used in the SZ_DEBUG mode + * to check the invariants of the library. It's a no-op in the SZ_RELEASE mode. + * @note If you want to catch it, put a breakpoint at @b `__GI_exit` + */ +#if SZ_DEBUG +#include // `fprintf` +#include // `EXIT_FAILURE` +#define sz_assert(condition) \ + do { \ + if (!(condition)) { \ + fprintf(stderr, "Assertion failed: %s, in file %s, line %d\n", #condition, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) +#else +#define sz_assert(condition) ((void)0) +#endif + +/* + * Intrinsics aliases for MSVC, GCC, and Clang. + */ +#if defined(_MSC_VER) +#include +SZ_INTERNAL sz_size_t sz_u64_popcount(sz_u64_t x) { return __popcnt64(x); } +SZ_INTERNAL sz_size_t sz_u64_ctz(sz_u64_t x) { return _tzcnt_u64(x); } +SZ_INTERNAL sz_size_t sz_u64_clz(sz_u64_t x) { return _lzcnt_u64(x); } +SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return _byteswap_uint64(val); } +SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return __popcnt(x); } +SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return _tzcnt_u32(x); } +SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return _lzcnt_u32(x); } +SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return _byteswap_ulong(val); } +#else +SZ_INTERNAL int sz_u64_popcount(sz_u64_t x) { return __builtin_popcountll(x); } +SZ_INTERNAL int sz_u64_ctz(sz_u64_t x) { return __builtin_ctzll(x); } +SZ_INTERNAL int sz_u64_clz(sz_u64_t x) { return __builtin_clzll(x); } +SZ_INTERNAL sz_u64_t sz_u64_bytes_reverse(sz_u64_t val) { return __builtin_bswap64(val); } +SZ_INTERNAL int sz_u32_popcount(sz_u32_t x) { return __builtin_popcount(x); } +SZ_INTERNAL int sz_u32_ctz(sz_u32_t x) { return __builtin_ctz(x); } // ! Undefined if `x == 0` +SZ_INTERNAL int sz_u32_clz(sz_u32_t x) { return __builtin_clz(x); } // ! Undefined if `x == 0` +SZ_INTERNAL sz_u32_t sz_u32_bytes_reverse(sz_u32_t val) { return __builtin_bswap32(val); } +#endif + +SZ_INTERNAL sz_u64_t sz_u64_rotl(sz_u64_t x, sz_u64_t r) { return (x << r) | (x >> (64 - r)); } + +/** + * @brief Select bits from either ::a or ::b depending on the value of ::mask bits. + * + * Similar to `_mm_blend_epi16` intrinsic on x86. + * Described in the "Bit Twiddling Hacks" by Sean Eron Anderson. + * https://graphics.stanford.edu/~seander/bithacks.html#ConditionalSetOrClearBitsWithoutBranching + */ +SZ_INTERNAL sz_u64_t sz_u64_blend(sz_u64_t a, sz_u64_t b, sz_u64_t mask) { return a ^ ((a ^ b) & mask); } + +/* + * Efficiently computing the minimum and maximum of two or three values can be tricky. + * The simple branching baseline would be: + * + * x < y ? x : y // can replace with 1 conditional move + * + * Branchless approach is well known for signed integers, but it doesn't apply to unsigned ones. + * https://stackoverflow.com/questions/514435/templatized-branchless-int-max-min-function + * https://graphics.stanford.edu/~seander/bithacks.html#IntegerMinOrMax + * Using only bit-shifts for singed integers it would be: + * + * y + ((x - y) & (x - y) >> 31) // 4 unique operations + * + * Alternatively, for any integers using multiplication: + * + * (x > y) * y + (x <= y) * x // 5 operations + * + * Alternatively, to avoid multiplication: + * + * x & ~((x < y) - 1) + y & ((x < y) - 1) // 6 unique operations + */ +#define sz_min_of_two(x, y) (x < y ? x : y) +#define sz_max_of_two(x, y) (x < y ? y : x) +#define sz_min_of_three(x, y, z) sz_min_of_two(x, sz_min_of_two(y, z)) +#define sz_max_of_three(x, y, z) sz_max_of_two(x, sz_max_of_two(y, z)) + +/** @brief Branchless minimum function for two signed 32-bit integers. */ +SZ_INTERNAL sz_i32_t sz_i32_min_of_two(sz_i32_t x, sz_i32_t y) { return y + ((x - y) & (x - y) >> 31); } + +/** @brief Branchless minimum function for two signed 32-bit integers. */ +SZ_INTERNAL sz_i32_t sz_i32_max_of_two(sz_i32_t x, sz_i32_t y) { return x - ((x - y) & (x - y) >> 31); } + +/** + * @brief Clamps signed offsets in a string to a valid range. Used for Pythonic-style slicing. + */ +SZ_INTERNAL void sz_ssize_clamp_interval(sz_size_t length, sz_ssize_t start, sz_ssize_t end, + sz_size_t *normalized_offset, sz_size_t *normalized_length) { + // TODO: Remove branches. + // Normalize negative indices + if (start < 0) start += length; + if (end < 0) end += length; + + // Clamp indices to a valid range + if (start < 0) start = 0; + if (end < 0) end = 0; + if (start > (sz_ssize_t)length) start = length; + if (end > (sz_ssize_t)length) end = length; + + // Ensure start <= end + if (start > end) start = end; + + *normalized_offset = start; + *normalized_length = end - start; +} + +/** + * @brief Compute the logarithm base 2 of a positive integer, rounding down. + */ +SZ_INTERNAL sz_size_t sz_size_log2i_nonzero(sz_size_t x) { + sz_assert(x > 0 && "Non-positive numbers have no defined logarithm"); + sz_size_t leading_zeros = sz_u64_clz(x); + return 63 - leading_zeros; +} + +/** + * @brief Compute the smallest power of two greater than or equal to ::x. + */ +SZ_INTERNAL sz_size_t sz_size_bit_ceil(sz_size_t x) { + // Unlike the commonly used trick with `clz` intrinsics, is valid across the whole range of `x`. + // https://stackoverflow.com/a/10143264 + x--; + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; + x |= x >> 32; + x++; + return x; +} + +/** + * @brief Transposes an 8x8 bit matrix packed in a `sz_u64_t`. + * + * There is a well known SWAR sequence for that known to chess programmers, + * willing to flip a bit-matrix of pieces along the main A1-H8 diagonal. + * https://www.chessprogramming.org/Flipping_Mirroring_and_Rotating + * https://lukas-prokop.at/articles/2021-07-23-transpose + */ +SZ_INTERNAL sz_u64_t sz_u64_transpose(sz_u64_t x) { + sz_u64_t t; + t = x ^ (x << 36); + x ^= 0xf0f0f0f00f0f0f0full & (t ^ (x >> 36)); + t = 0xcccc0000cccc0000ull & (x ^ (x << 18)); + x ^= t ^ (t >> 18); + t = 0xaa00aa00aa00aa00ull & (x ^ (x << 9)); + x ^= t ^ (t >> 9); + return x; +} + +/** + * @brief Helper, that swaps two 64-bit integers representing the order of elements in the sequence. + */ +SZ_INTERNAL void sz_u64_swap(sz_u64_t *a, sz_u64_t *b) { + sz_u64_t t = *a; + *a = *b; + *b = t; +} + +/** + * @brief Helper structure to simplify work with 16-bit words. + * @see sz_u16_load + */ +typedef union sz_u16_vec_t { + sz_u16_t u16; + sz_u8_t u8s[2]; +} sz_u16_vec_t; + +/** + * @brief Load a 16-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. + */ +SZ_INTERNAL sz_u16_vec_t sz_u16_load(sz_cptr_t ptr) { +#if !SZ_USE_MISALIGNED_LOADS + sz_u16_vec_t result; + result.u8s[0] = ptr[0]; + result.u8s[1] = ptr[1]; + return result; +#elif defined(_MSC_VER) + return *((__unaligned sz_u16_vec_t *)ptr); +#else + __attribute__((aligned(1))) sz_u16_vec_t const *result = (sz_u16_vec_t const *)ptr; + return *result; +#endif +} + +/** + * @brief Helper structure to simplify work with 32-bit words. + * @see sz_u32_load + */ +typedef union sz_u32_vec_t { + sz_u32_t u32; + sz_u16_t u16s[2]; + sz_u8_t u8s[4]; +} sz_u32_vec_t; + +/** + * @brief Load a 32-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. + */ +SZ_INTERNAL sz_u32_vec_t sz_u32_load(sz_cptr_t ptr) { +#if !SZ_USE_MISALIGNED_LOADS + sz_u32_vec_t result; + result.u8s[0] = ptr[0]; + result.u8s[1] = ptr[1]; + result.u8s[2] = ptr[2]; + result.u8s[3] = ptr[3]; + return result; +#elif defined(_MSC_VER) + return *((__unaligned sz_u32_vec_t *)ptr); +#else + __attribute__((aligned(1))) sz_u32_vec_t const *result = (sz_u32_vec_t const *)ptr; + return *result; +#endif +} + +/** + * @brief Helper structure to simplify work with 64-bit words. + * @see sz_u64_load + */ +typedef union sz_u64_vec_t { + sz_u64_t u64; + sz_u32_t u32s[2]; + sz_u16_t u16s[4]; + sz_u8_t u8s[8]; +} sz_u64_vec_t; + +/** + * @brief Load a 64-bit unsigned integer from a potentially unaligned pointer, can be expensive on some platforms. + */ +SZ_INTERNAL sz_u64_vec_t sz_u64_load(sz_cptr_t ptr) { +#if !SZ_USE_MISALIGNED_LOADS + sz_u64_vec_t result; + result.u8s[0] = ptr[0]; + result.u8s[1] = ptr[1]; + result.u8s[2] = ptr[2]; + result.u8s[3] = ptr[3]; + result.u8s[4] = ptr[4]; + result.u8s[5] = ptr[5]; + result.u8s[6] = ptr[6]; + result.u8s[7] = ptr[7]; + return result; +#elif defined(_MSC_VER) + return *((__unaligned sz_u64_vec_t *)ptr); +#else + __attribute__((aligned(1))) sz_u64_vec_t const *result = (sz_u64_vec_t const *)ptr; + return *result; +#endif +} + +/** @brief Helper function, using the supplied fixed-capacity buffer to allocate memory. */ +SZ_INTERNAL sz_ptr_t _sz_memory_allocate_fixed(sz_size_t length, void *handle) { + sz_size_t capacity; + sz_copy((sz_ptr_t)&capacity, (sz_cptr_t)handle, sizeof(sz_size_t)); + sz_size_t consumed_capacity = sizeof(sz_size_t); + if (consumed_capacity + length > capacity) return SZ_NULL; + return (sz_ptr_t)handle + consumed_capacity; +} + +/** @brief Helper "no-op" function, simulating memory deallocation when we use a "static" memory buffer. */ +SZ_INTERNAL void _sz_memory_free_fixed(sz_ptr_t start, sz_size_t length, void *handle) { + sz_unused(start && length && handle); +} + +/** @brief An internal callback used to set a bit in a power-of-two length binary fingerprint of a string. */ +SZ_INTERNAL void _sz_hashes_fingerprint_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, void *handle) { + sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; + sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; + sz_size_t fingerprint_bytes = fingerprint_buffer->length; + fingerprint_u8s[(hash / 8) & (fingerprint_bytes - 1)] |= (1 << (hash & 7)); + sz_unused(start && length); +} + +/** @brief An internal callback used to set a bit in a @b non power-of-two length binary fingerprint of a string. */ +SZ_INTERNAL void _sz_hashes_fingerprint_non_pow2_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, + void *handle) { + sz_string_view_t *fingerprint_buffer = (sz_string_view_t *)handle; + sz_u8_t *fingerprint_u8s = (sz_u8_t *)fingerprint_buffer->start; + sz_size_t fingerprint_bytes = fingerprint_buffer->length; + fingerprint_u8s[(hash / 8) % fingerprint_bytes] |= (1 << (hash & 7)); + sz_unused(start && length); +} + +/** @brief An internal callback, used to mix all the running hashes into one pointer-size value. */ +SZ_INTERNAL void _sz_hashes_fingerprint_scalar_callback(sz_cptr_t start, sz_size_t length, sz_u64_t hash, + void *scalar_handle) { + sz_unused(start && length && hash && scalar_handle); + sz_size_t *scalar_ptr = (sz_size_t *)scalar_handle; + *scalar_ptr ^= hash; +} + +/** + * @brief Chooses the offsets of the most interesting characters in a search needle. + * + * Search throughput can significantly deteriorate if we are matching the wrong characters. + * Say the needle is "aXaYa", and we are comparing the first, second, and last character. + * If we use SIMD and compare many offsets at a time, comparing against "a" in every register is a waste. + * + * Similarly, dealing with UTF8 inputs, we know that the lower bits of each character code carry more information. + * Cyrillic alphabet, for example, falls into [0x0410, 0x042F] code range for uppercase [А, Я], and + * into [0x0430, 0x044F] for lowercase [а, я]. Scanning through a text written in Russian, half of the + * bytes will carry absolutely no value and will be equal to 0x04. + */ +SZ_INTERNAL void _sz_locate_needle_anomalies(sz_cptr_t start, sz_size_t length, // + sz_size_t *first, sz_size_t *second, sz_size_t *third) { + *first = 0; + *second = length / 2; + *third = length - 1; + + // + int has_duplicates = // + start[*first] == start[*second] || // + start[*first] == start[*third] || // + start[*second] == start[*third]; + + // Loop through letters to find non-colliding variants. + if (length > 3 && has_duplicates) { + // Pivot the middle point left, until we find a character different from the first one. + for (; start[*second] == start[*first] && *second; --(*second)) {} + // Pivot the middle point right, until we find a character different from the first one. + for (; start[*second] == start[*first] && *second + 1 < *third; ++(*second)) {} + // Pivot the third (last) point left, until we find a different character. + for (; (start[*third] == start[*second] || start[*third] == start[*first]) && *third > (*second + 1); + --(*third)) {} + } +} + +#pragma GCC visibility pop +#pragma endregion + +#pragma region Serial Implementation + +#if !SZ_AVOID_LIBC +#include // `fprintf` +#include // `malloc`, `EXIT_FAILURE` +#else +extern void *malloc(size_t); +extern void free(void *); +#endif + +SZ_PUBLIC void sz_memory_allocator_init_default(sz_memory_allocator_t *alloc) { + alloc->allocate = (sz_memory_allocate_t)malloc; + alloc->free = (sz_memory_free_t)free; + alloc->handle = SZ_NULL; +} + +SZ_PUBLIC void sz_memory_allocator_init_fixed(sz_memory_allocator_t *alloc, void *buffer, sz_size_t length) { + // The logic here is simple - put the buffer length in the first slots of the buffer. + // Later use it for bounds checking. + alloc->allocate = (sz_memory_allocate_t)_sz_memory_allocate_fixed; + alloc->free = (sz_memory_free_t)_sz_memory_free_fixed; + alloc->handle = &buffer; + sz_copy((sz_ptr_t)buffer, (sz_cptr_t)&length, sizeof(sz_size_t)); +} + +/** + * @brief Byte-level equality comparison between two strings. + * If unaligned loads are allowed, uses a switch-table to avoid loops on short strings. + */ +SZ_PUBLIC sz_bool_t sz_equal_serial(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { + sz_cptr_t const a_end = a + length; + while (a != a_end && *a == *b) a++, b++; + return (sz_bool_t)(a_end == a); +} + +SZ_PUBLIC sz_cptr_t sz_find_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { + for (sz_cptr_t const end = text + length; text != end; ++text) + if (sz_charset_contains(set, *text)) return text; + return SZ_NULL; +} + +SZ_PUBLIC sz_cptr_t sz_rfind_charset_serial(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Warray-bounds" + sz_cptr_t const end = text; + for (text += length; text != end;) + if (sz_charset_contains(set, *(text -= 1))) return text; + return SZ_NULL; +#pragma GCC diagnostic pop +} + +SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { + sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; + sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length); + sz_size_t min_length = a_shorter ? a_length : b_length; + sz_cptr_t min_end = a + min_length; +#if SZ_USE_MISALIGNED_LOADS + for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) { + a_vec.u64 = sz_u64_bytes_reverse(sz_u64_load(a).u64); + b_vec.u64 = sz_u64_bytes_reverse(sz_u64_load(b).u64); + if (a_vec.u64 != b_vec.u64) return ordering_lookup[a_vec.u64 < b_vec.u64]; + } +#endif + for (; a != min_end; ++a, ++b) + if (*a != *b) return ordering_lookup[*a < *b]; + return a_length != b_length ? ordering_lookup[a_shorter] : sz_equal_k; +} + +/** + * @brief Byte-level equality comparison between two 64-bit integers. + * @return 64-bit integer, where every top bit in each byte signifies a match. + */ +SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { + sz_u64_vec_t vec; + vec.u64 = ~(a.u64 ^ b.u64); + // The match is valid, if every bit within each byte is set. + // For that take the bottom 7 bits of each byte, add one to them, + // and if this sets the top bit to one, then all the 7 bits are ones as well. + vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull)); + return vec; +} + +/** + * @brief Find the first occurrence of a @b single-character needle in an arbitrary length haystack. + * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. + * Identical to `memchr(haystack, needle[0], haystack_length)`. + */ +SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + + if (!h_length) return SZ_NULL; + sz_cptr_t const h_end = h + h_length; + +#if !SZ_USE_MISALIGNED_LOADS + // Process the misaligned head, to void UB on unaligned 64-bit loads. + for (; ((sz_size_t)h & 7ull) && h < h_end; ++h) + if (*h == *n) return h; +#endif + + // Broadcast the n into every byte of a 64-bit integer to use SWAR + // techniques and process eight characters at a time. + sz_u64_vec_t h_vec, n_vec, match_vec; + match_vec.u64 = 0; + n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; + for (; h + 8 <= h_end; h += 8) { + h_vec.u64 = *(sz_u64_t const *)h; + match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); + if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8; + } + + // Handle the misaligned tail. + for (; h < h_end; ++h) + if (*h == *n) return h; + return SZ_NULL; +} + +/** + * @brief Find the last occurrence of a @b single-character needle in an arbitrary length haystack. + * This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time. + * Identical to `memrchr(haystack, needle[0], haystack_length)`. + */ +sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + + if (!h_length) return SZ_NULL; + sz_cptr_t const h_start = h; + + // Reposition the `h` pointer to the end, as we will be walking backwards. + h = h + h_length - 1; + +#if !SZ_USE_MISALIGNED_LOADS + // Process the misaligned head, to void UB on unaligned 64-bit loads. + for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h) + if (*h == *n) return h; +#endif + + // Broadcast the n into every byte of a 64-bit integer to use SWAR + // techniques and process eight characters at a time. + sz_u64_vec_t h_vec, n_vec, match_vec; + n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull; + for (; h >= h_start + 7; h -= 8) { + h_vec.u64 = *(sz_u64_t const *)(h - 7); + match_vec = _sz_u64_each_byte_equal(h_vec, n_vec); + if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8; + } + + for (; h >= h_start; --h) + if (*h == *n) return h; + return SZ_NULL; +} + +/** + * @brief 2Byte-level equality comparison between two 64-bit integers. + * @return 64-bit integer, where every top bit in each 2byte signifies a match. + */ +SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { + sz_u64_vec_t vec; + vec.u64 = ~(a.u64 ^ b.u64); + // The match is valid, if every bit within each 2byte is set. + // For that take the bottom 15 bits of each 2byte, add one to them, + // and if this sets the top bit to one, then all the 15 bits are ones as well. + vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull)); + return vec; +} + +/** + * @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack. + * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. + */ +SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + + // This is an internal method, and the haystack is guaranteed to be at least 2 bytes long. + sz_assert(h_length >= 2 && "The haystack is too short."); + sz_cptr_t const h_end = h + h_length; + +#if !SZ_USE_MISALIGNED_LOADS + // Process the misaligned head, to void UB on unaligned 64-bit loads. + for (; ((sz_size_t)h & 7ull) && h + 2 <= h_end; ++h) + if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; +#endif + + sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec; + n_vec.u64 = 0; + n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1]; + n_vec.u64 *= 0x0001000100010001ull; // broadcast + + // This code simulates hyper-scalar execution, analyzing 8 offsets at a time. + for (; h + 9 <= h_end; h += 8) { + h_even_vec.u64 = *(sz_u64_t *)h; + h_odd_vec.u64 = (h_even_vec.u64 >> 8) | ((sz_u64_t)h[8] << 56); + matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec); + matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec); + + matches_even_vec.u64 >>= 8; + if (matches_even_vec.u64 + matches_odd_vec.u64) { + sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64; + return h + sz_u64_ctz(match_indicators) / 8; + } + } + + for (; h + 2 <= h_end; ++h) + if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h; + return SZ_NULL; +} + +/** + * @brief 4Byte-level equality comparison between two 64-bit integers. + * @return 64-bit integer, where every top bit in each 4byte signifies a match. + */ +SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { + sz_u64_vec_t vec; + vec.u64 = ~(a.u64 ^ b.u64); + // The match is valid, if every bit within each 4byte is set. + // For that take the bottom 31 bits of each 4byte, add one to them, + // and if this sets the top bit to one, then all the 31 bits are ones as well. + vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull)); + return vec; +} + +/** + * @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack. + * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. + */ +SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + + // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. + sz_assert(h_length >= 4 && "The haystack is too short."); + sz_cptr_t const h_end = h + h_length; + +#if !SZ_USE_MISALIGNED_LOADS + // Process the misaligned head, to void UB on unaligned 64-bit loads. + for (; ((sz_size_t)h & 7ull) && h + 4 <= h_end; ++h) + if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; +#endif + + sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec; + n_vec.u64 = 0; + n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; + n_vec.u64 *= 0x0000000100000001ull; // broadcast + + // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words. + // We load the subsequent four-byte word as well, taking its first bytes. Think of it as a glorified prefetch :) + sz_u64_t h_page_current, h_page_next; + for (; h + sizeof(sz_u64_t) + sizeof(sz_u32_t) <= h_end; h += sizeof(sz_u64_t)) { + h_page_current = *(sz_u64_t *)h; + h_page_next = *(sz_u32_t *)(h + 8); + h0_vec.u64 = (h_page_current); + h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); + h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); + h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); + matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec); + matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec); + matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec); + matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec); + + if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64) { + matches0_vec.u64 >>= 24; + matches1_vec.u64 >>= 16; + matches2_vec.u64 >>= 8; + sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64; + return h + sz_u64_ctz(match_indicators) / 8; + } + } + + for (; h + 4 <= h_end; ++h) + if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h; + return SZ_NULL; +} + +/** + * @brief 3Byte-level equality comparison between two 64-bit integers. + * @return 64-bit integer, where every top bit in each 3byte signifies a match. + */ +SZ_INTERNAL sz_u64_vec_t _sz_u64_each_3byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) { + sz_u64_vec_t vec; + vec.u64 = ~(a.u64 ^ b.u64); + // The match is valid, if every bit within each 4byte is set. + // For that take the bottom 31 bits of each 4byte, add one to them, + // and if this sets the top bit to one, then all the 31 bits are ones as well. + vec.u64 = ((vec.u64 & 0xFFFF7FFFFF7FFFFFull) + 0x0000000001000001ull) & ((vec.u64 & 0x0000800000800000ull)); + return vec; +} + +/** + * @brief Find the first occurrence of a @b three-character needle in an arbitrary length haystack. + * This implementation uses hardware-agnostic SWAR technique, to process 8 possible offsets at a time. + */ +SZ_INTERNAL sz_cptr_t _sz_find_3byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + + // This is an internal method, and the haystack is guaranteed to be at least 4 bytes long. + sz_assert(h_length >= 3 && "The haystack is too short."); + sz_cptr_t const h_end = h + h_length; + +#if !SZ_USE_MISALIGNED_LOADS + // Process the misaligned head, to void UB on unaligned 64-bit loads. + for (; ((sz_size_t)h & 7ull) && h + 3 <= h_end; ++h) + if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; +#endif + + // We fetch 12 + sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, h4_vec; + sz_u64_vec_t matches0_vec, matches1_vec, matches2_vec, matches3_vec, matches4_vec; + sz_u64_vec_t n_vec; + n_vec.u64 = 0; + n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3]; + n_vec.u64 *= 0x0000000001000001ull; // broadcast + + // This code simulates hyper-scalar execution, analyzing 8 offsets at a time using three 64-bit words. + // We load the subsequent two-byte word as well. + sz_u64_t h_page_current, h_page_next; + for (; h + sizeof(sz_u64_t) + sizeof(sz_u16_t) <= h_end; h += sizeof(sz_u64_t)) { + h_page_current = *(sz_u64_t *)h; + h_page_next = *(sz_u16_t *)(h + 8); + h0_vec.u64 = (h_page_current); + h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56); + h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48); + h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40); + h4_vec.u64 = (h_page_current >> 32) | (h_page_next << 32); + matches0_vec = _sz_u64_each_3byte_equal(h0_vec, n_vec); + matches1_vec = _sz_u64_each_3byte_equal(h1_vec, n_vec); + matches2_vec = _sz_u64_each_3byte_equal(h2_vec, n_vec); + matches3_vec = _sz_u64_each_3byte_equal(h3_vec, n_vec); + matches4_vec = _sz_u64_each_3byte_equal(h4_vec, n_vec); + + if (matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64) { + matches0_vec.u64 >>= 16; + matches1_vec.u64 >>= 8; + matches3_vec.u64 <<= 8; + matches4_vec.u64 <<= 16; + sz_u64_t match_indicators = + matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64 | matches4_vec.u64; + return h + sz_u64_ctz(match_indicators) / 8; + } + } + + for (; h + 3 <= h_end; ++h) + if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) == 3) return h; + return SZ_NULL; +} + +/** + * @brief Boyer-Moore-Horspool algorithm for exact matching of patterns up to @b 256-bytes long. + * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. + */ +SZ_INTERNAL sz_cptr_t _sz_find_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // + sz_cptr_t n_chars, sz_size_t n_length) { + sz_assert(n_length <= 256 && "The pattern is too long."); + // Several popular string matching algorithms are using a bad-character shift table. + // Boyer Moore: https://www-igm.univ-mlv.fr/~lecroq/string/node14.html + // Quick Search: https://www-igm.univ-mlv.fr/~lecroq/string/node19.html + // Smith: https://www-igm.univ-mlv.fr/~lecroq/string/node21.html + union { + sz_u8_t jumps[256]; + sz_u64_vec_t vecs[64]; + } bad_shift_table; + + // Let's initialize the table using SWAR to the total length of the string. + sz_u8_t const *h = (sz_u8_t const *)h_chars; + sz_u8_t const *n = (sz_u8_t const *)n_chars; + { + sz_u64_vec_t n_length_vec; + n_length_vec.u64 = n_length; + n_length_vec.u64 *= 0x0101010101010101ull; // broadcast + for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; + for (sz_size_t i = 0; i + 1 < n_length; ++i) bad_shift_table.jumps[n[i]] = (sz_u8_t)(n_length - i - 1); + } + + // Another common heuristic is to match a few characters from different parts of a string. + // Raita suggests to use the first two, the last, and the middle character of the pattern. + sz_u32_vec_t h_vec, n_vec; + + // Pick the parts of the needle that are worth comparing. + sz_size_t offset_first, offset_mid, offset_last; + _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); + + // Broadcast those characters into an unsigned integer. + n_vec.u8s[0] = n[offset_first]; + n_vec.u8s[1] = n[offset_first + 1]; + n_vec.u8s[2] = n[offset_mid]; + n_vec.u8s[3] = n[offset_last]; + + // Scan through the whole haystack, skipping the last `n_length - 1` bytes. + for (sz_size_t i = 0; i <= h_length - n_length;) { + h_vec.u8s[0] = h[i + offset_first]; + h_vec.u8s[1] = h[i + offset_first + 1]; + h_vec.u8s[2] = h[i + offset_mid]; + h_vec.u8s[3] = h[i + offset_last]; + if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; + i += bad_shift_table.jumps[h[i + n_length - 1]]; + } + return SZ_NULL; +} + +/** + * @brief Boyer-Moore-Horspool algorithm for @b reverse-order exact matching of patterns up to @b 256-bytes long. + * Uses the Raita heuristic to match the first two, the last, and the middle character of the pattern. + */ +SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_upto_256bytes_serial(sz_cptr_t h_chars, sz_size_t h_length, // + sz_cptr_t n_chars, sz_size_t n_length) { + sz_assert(n_length <= 256 && "The pattern is too long."); + union { + sz_u8_t jumps[256]; + sz_u64_vec_t vecs[64]; + } bad_shift_table; + + // Let's initialize the table using SWAR to the total length of the string. + sz_u8_t const *h = (sz_u8_t const *)h_chars; + sz_u8_t const *n = (sz_u8_t const *)n_chars; + { + sz_u64_vec_t n_length_vec; + n_length_vec.u64 = n_length; + n_length_vec.u64 *= 0x0101010101010101ull; // broadcast + for (sz_size_t i = 0; i != 64; ++i) bad_shift_table.vecs[i].u64 = n_length_vec.u64; + for (sz_size_t i = 0; i + 1 < n_length; ++i) + bad_shift_table.jumps[n[n_length - i - 1]] = (sz_u8_t)(n_length - i - 1); + } + + // Another common heuristic is to match a few characters from different parts of a string. + // Raita suggests to use the first two, the last, and the middle character of the pattern. + sz_u32_vec_t h_vec, n_vec; + + // Pick the parts of the needle that are worth comparing. + sz_size_t offset_first, offset_mid, offset_last; + _sz_locate_needle_anomalies(n_chars, n_length, &offset_first, &offset_mid, &offset_last); + + // Broadcast those characters into an unsigned integer. + n_vec.u8s[0] = n[offset_first]; + n_vec.u8s[1] = n[offset_first + 1]; + n_vec.u8s[2] = n[offset_mid]; + n_vec.u8s[3] = n[offset_last]; + + // Scan through the whole haystack, skipping the first `n_length - 1` bytes. + for (sz_size_t j = 0; j <= h_length - n_length;) { + sz_size_t i = h_length - n_length - j; + h_vec.u8s[0] = h[i + offset_first]; + h_vec.u8s[1] = h[i + offset_first + 1]; + h_vec.u8s[2] = h[i + offset_mid]; + h_vec.u8s[3] = h[i + offset_last]; + if (h_vec.u32 == n_vec.u32 && sz_equal((sz_cptr_t)h + i, n_chars, n_length)) return (sz_cptr_t)h + i; + j += bad_shift_table.jumps[h[i]]; + } + return SZ_NULL; +} + +/** + * @brief Exact substring search helper function, that finds the first occurrence of a prefix of the needle + * using a given search function, and then verifies the remaining part of the needle. + */ +SZ_INTERNAL sz_cptr_t _sz_find_with_prefix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, + sz_find_t find_prefix, sz_size_t prefix_length) { + + sz_size_t suffix_length = n_length - prefix_length; + while (1) { + sz_cptr_t found = find_prefix(h, h_length, n, prefix_length); + if (!found) return SZ_NULL; + + // Verify the remaining part of the needle + sz_size_t remaining = h_length - (found - h); + if (remaining < suffix_length) return SZ_NULL; + if (sz_equal(found + prefix_length, n + prefix_length, suffix_length)) return found; + + // Adjust the position. + h = found + 1; + h_length = remaining - 1; + } + + // Unreachable, but helps silence compiler warnings: + return SZ_NULL; +} + +/** + * @brief Exact reverse-order substring search helper function, that finds the last occurrence of a suffix of the + * needle using a given search function, and then verifies the remaining part of the needle. + */ +SZ_INTERNAL sz_cptr_t _sz_rfind_with_suffix(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length, + sz_find_t find_suffix, sz_size_t suffix_length) { + + sz_size_t prefix_length = n_length - suffix_length; + while (1) { + sz_cptr_t found = find_suffix(h, h_length, n + prefix_length, suffix_length); + if (!found) return SZ_NULL; + + // Verify the remaining part of the needle + sz_size_t remaining = found - h; + if (remaining < prefix_length) return SZ_NULL; + if (sz_equal(found - prefix_length, n, prefix_length)) return found - prefix_length; + + // Adjust the position. + h_length = remaining - 1; + } + + // Unreachable, but helps silence compiler warnings: + return SZ_NULL; +} + +SZ_INTERNAL sz_cptr_t _sz_find_over_4bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + return _sz_find_with_prefix(h, h_length, n, n_length, (sz_find_t)_sz_find_4byte_serial, 4); +} + +SZ_INTERNAL sz_cptr_t _sz_find_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, + sz_size_t n_length) { + return _sz_find_with_prefix(h, h_length, n, n_length, _sz_find_horspool_upto_256bytes_serial, 256); +} + +SZ_INTERNAL sz_cptr_t _sz_rfind_horspool_over_256bytes_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, + sz_size_t n_length) { + return _sz_rfind_with_suffix(h, h_length, n, n_length, _sz_rfind_horspool_upto_256bytes_serial, 256); +} + +SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + + // This almost never fires, but it's better to be safe than sorry. + if (h_length < n_length || !n_length) return SZ_NULL; + + sz_find_t backends[] = { + // For very short strings brute-force SWAR makes sense. + (sz_find_t)sz_find_byte_serial, + (sz_find_t)_sz_find_2byte_serial, + (sz_find_t)_sz_find_3byte_serial, + (sz_find_t)_sz_find_4byte_serial, + // To avoid constructing the skip-table, let's use the prefixed approach. + (sz_find_t)_sz_find_over_4bytes_serial, + // For longer needles - use skip tables. + (sz_find_t)_sz_find_horspool_upto_256bytes_serial, + (sz_find_t)_sz_find_horspool_over_256bytes_serial, + }; + + return backends[ + // For very short strings brute-force SWAR makes sense. + (n_length > 1) + (n_length > 2) + (n_length > 3) + + // To avoid constructing the skip-table, let's use the prefixed approach. + (n_length > 4) + + // For longer needles - use skip tables. + (n_length > 8) + (n_length > 256)](h, h_length, n, n_length); +} + +SZ_PUBLIC sz_cptr_t sz_rfind_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + + // This almost never fires, but it's better to be safe than sorry. + if (h_length < n_length || !n_length) return SZ_NULL; + + sz_find_t backends[] = { + // For very short strings brute-force SWAR makes sense. + (sz_find_t)sz_rfind_byte_serial, + // TODO: implement reverse-order SWAR for 2/3/4 byte variants. + // TODO: (sz_find_t)_sz_rfind_2byte_serial, + // TODO: (sz_find_t)_sz_rfind_3byte_serial, + // TODO: (sz_find_t)_sz_rfind_4byte_serial, + // To avoid constructing the skip-table, let's use the prefixed approach. + // (sz_find_t)_sz_rfind_over_4bytes_serial, + // For longer needles - use skip tables. + (sz_find_t)_sz_rfind_horspool_upto_256bytes_serial, + (sz_find_t)_sz_rfind_horspool_over_256bytes_serial, + }; + + return backends[ + // For very short strings brute-force SWAR makes sense. + 0 + + // To avoid constructing the skip-table, let's use the prefixed approach. + (n_length > 1) + + // For longer needles - use skip tables. + (n_length > 256)](h, h_length, n, n_length); +} + +SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_serial( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound, sz_memory_allocator_t *alloc) { + + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. + sz_memory_allocator_t global_alloc; + if (!alloc) { + sz_memory_allocator_init_default(&global_alloc); + alloc = &global_alloc; + } + + // TODO: Generalize to remove the following asserts! + sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); + sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); + sz_unused(longer_length && bound); + + // We are going to store 3 diagonals of the matrix. + // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. + sz_size_t n = shorter_length + 1; + sz_size_t buffer_length = sizeof(sz_size_t) * n * 3; + sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); + if (!distances) return SZ_SIZE_MAX; + + sz_size_t *previous_distances = distances; + sz_size_t *current_distances = previous_distances + n; + sz_size_t *next_distances = previous_distances + n * 2; + + // Initialize the first two diagonals: + previous_distances[0] = 0; + current_distances[0] = current_distances[1] = 1; + + // Progress through the upper triangle of the Levenshtein matrix. + sz_size_t next_skew_diagonal_index = 2; + for (; next_skew_diagonal_index != n; ++next_skew_diagonal_index) { + sz_size_t const next_skew_diagonal_length = next_skew_diagonal_index + 1; + for (sz_size_t i = 0; i + 2 < next_skew_diagonal_length; ++i) { + sz_size_t cost_of_substitution = shorter[next_skew_diagonal_index - i - 2] != longer[i]; + sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; + sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; + next_distances[i + 1] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); + } + // Don't forget to populate the first row and the fiest column of the Levenshtein matrix. + next_distances[0] = next_distances[next_skew_diagonal_length - 1] = next_skew_diagonal_index; + // Perform a circular rotarion of those buffers, to reuse the memory. + sz_size_t *temporary = previous_distances; + previous_distances = current_distances; + current_distances = next_distances; + next_distances = temporary; + } + + // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a + // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal + // index on either side, we will be cropping those values out. + sz_size_t total_diagonals = n + n - 1; + for (; next_skew_diagonal_index != total_diagonals; ++next_skew_diagonal_index) { + sz_size_t const next_skew_diagonal_length = total_diagonals - next_skew_diagonal_index; + for (sz_size_t i = 0; i != next_skew_diagonal_length; ++i) { + sz_size_t cost_of_substitution = + shorter[shorter_length - 1 - i] != longer[next_skew_diagonal_index - n + i]; + sz_size_t cost_if_substitution = previous_distances[i] + cost_of_substitution; + sz_size_t cost_if_deletion_or_insertion = sz_min_of_two(current_distances[i], current_distances[i + 1]) + 1; + next_distances[i] = sz_min_of_two(cost_if_deletion_or_insertion, cost_if_substitution); + } + // Perform a circular rotarion of those buffers, to reuse the memory, this time, with a shift, + // dropping the first element in the current array. + sz_size_t *temporary = previous_distances; + previous_distances = current_distances + 1; + current_distances = next_distances; + next_distances = temporary; + } + + // Cache scalar before `free` call. + sz_size_t result = current_distances[0]; + alloc->free(distances, buffer_length, alloc->handle); + return result; +} + +SZ_INTERNAL sz_size_t _sz_edit_distance_wagner_fisher_serial( // + sz_cptr_t longer, sz_size_t longer_length, // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_size_t bound, sz_memory_allocator_t *alloc) { + + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. + sz_memory_allocator_t global_alloc; + if (!alloc) { + sz_memory_allocator_init_default(&global_alloc); + alloc = &global_alloc; + } + + // If a buffering memory-allocator is provided, this operation is practically free, + // and cheaper than allocating even 512 bytes (for small distance matrices) on stack. + sz_size_t buffer_length = sizeof(sz_size_t) * ((shorter_length + 1) * 2); + sz_size_t *distances = (sz_size_t *)alloc->allocate(buffer_length, alloc->handle); + if (!distances) return SZ_SIZE_MAX; + + sz_size_t *previous_distances = distances; + sz_size_t *current_distances = previous_distances + shorter_length + 1; + + for (sz_size_t idx_shorter = 0; idx_shorter != (shorter_length + 1); ++idx_shorter) + previous_distances[idx_shorter] = idx_shorter; + + // Keeping track of the bound parameter introduces a very noticeable performance penalty. + // So if it's not provided, we can skip the check altogether. + if (!bound) { + for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { + current_distances[0] = idx_longer + 1; + for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { + sz_size_t cost_deletion = previous_distances[idx_shorter + 1] + 1; + sz_size_t cost_insertion = current_distances[idx_shorter] + 1; + sz_size_t cost_substitution = + previous_distances[idx_shorter] + (longer[idx_longer] != shorter[idx_shorter]); + // ? It might be a good idea to enforce branchless execution here. + // ? The caveat being that the benchmarks on longer sequences backfire and more research is needed. + current_distances[idx_shorter + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); + } + sz_u64_swap((sz_u64_t *)&previous_distances, (sz_u64_t *)¤t_distances); + } + // Cache scalar before `free` call. + sz_size_t result = previous_distances[shorter_length]; + alloc->free(distances, buffer_length, alloc->handle); + return result; + } + // + else { + for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { + current_distances[0] = idx_longer + 1; + + // Initialize min_distance with a value greater than bound + sz_size_t min_distance = bound - 1; + + for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { + sz_size_t cost_deletion = previous_distances[idx_shorter + 1] + 1; + sz_size_t cost_insertion = current_distances[idx_shorter] + 1; + sz_size_t cost_substitution = + previous_distances[idx_shorter] + (longer[idx_longer] != shorter[idx_shorter]); + current_distances[idx_shorter + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); + + // Keep track of the minimum distance seen so far in this row + min_distance = sz_min_of_two(current_distances[idx_shorter + 1], min_distance); + } + + // If the minimum distance in this row exceeded the bound, return early + if (min_distance >= bound) { + alloc->free(distances, buffer_length, alloc->handle); + return bound; + } + + // Swap previous_distances and current_distances pointers + sz_u64_swap((sz_u64_t *)&previous_distances, (sz_u64_t *)¤t_distances); + } + // Cache scalar before `free` call. + sz_size_t result = previous_distances[shorter_length] < bound ? previous_distances[shorter_length] : bound; + alloc->free(distances, buffer_length, alloc->handle); + return result; + } +} + +SZ_PUBLIC sz_size_t sz_edit_distance_serial( // + sz_cptr_t longer, sz_size_t longer_length, // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_size_t bound, sz_memory_allocator_t *alloc) { + + // Let's make sure that we use the amount proportional to the + // number of elements in the shorter string, not the larger. + if (shorter_length > longer_length) { + sz_u64_swap((sz_u64_t *)&longer_length, (sz_u64_t *)&shorter_length); + sz_u64_swap((sz_u64_t *)&longer, (sz_u64_t *)&shorter); + } + + // Skip the matching prefixes and suffixes, they won't affect the distance. + for (sz_cptr_t a_end = longer + longer_length, b_end = shorter + shorter_length; + longer != a_end && shorter != b_end && *longer == *shorter; + ++longer, ++shorter, --longer_length, --shorter_length) + ; + for (; longer_length && shorter_length && longer[longer_length - 1] == shorter[shorter_length - 1]; + --longer_length, --shorter_length) + ; + + // Bounded computations may exit early. + if (bound) { + // If one of the strings is empty - the edit distance is equal to the length of the other one. + if (longer_length == 0) return shorter_length <= bound ? shorter_length : bound; + if (shorter_length == 0) return longer_length <= bound ? longer_length : bound; + // If the difference in length is beyond the `bound`, there is no need to check at all. + if (longer_length - shorter_length > bound) return bound; + } + + if (shorter_length == 0) return longer_length; // If no mismatches were found - the distance is zero. + if (shorter_length == longer_length && !bound) + return _sz_edit_distance_skewed_diagonals_serial(longer, longer_length, shorter, shorter_length, bound, alloc); + return _sz_edit_distance_wagner_fisher_serial(longer, longer_length, shorter, shorter_length, bound, alloc); +} + +SZ_PUBLIC sz_ssize_t sz_alignment_score_serial( // + sz_cptr_t longer, sz_size_t longer_length, // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, // + sz_memory_allocator_t *alloc) { + + // If one of the strings is empty - the edit distance is equal to the length of the other one + if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; + if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; + + // Let's make sure that we use the amount proportional to the + // number of elements in the shorter string, not the larger. + if (shorter_length > longer_length) { + sz_u64_swap((sz_u64_t *)&longer_length, (sz_u64_t *)&shorter_length); + sz_u64_swap((sz_u64_t *)&longer, (sz_u64_t *)&shorter); + } + + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. + sz_memory_allocator_t global_alloc; + if (!alloc) { + sz_memory_allocator_init_default(&global_alloc); + alloc = &global_alloc; + } + + sz_size_t n = shorter_length + 1; + sz_size_t buffer_length = sizeof(sz_ssize_t) * n * 2; + sz_ssize_t *distances = (sz_ssize_t *)alloc->allocate(buffer_length, alloc->handle); + sz_ssize_t *previous_distances = distances; + sz_ssize_t *current_distances = previous_distances + n; + + for (sz_size_t idx_shorter = 0; idx_shorter != n; ++idx_shorter) + previous_distances[idx_shorter] = (sz_ssize_t)idx_shorter * gap; + + sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; + sz_u8_t const *longer_unsigned = (sz_u8_t const *)longer; + for (sz_size_t idx_longer = 0; idx_longer != longer_length; ++idx_longer) { + current_distances[0] = ((sz_ssize_t)idx_longer + 1) * gap; + + // Initialize min_distance with a value greater than bound + sz_error_cost_t const *a_subs = subs + longer_unsigned[idx_longer] * 256ul; + for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { + sz_ssize_t cost_deletion = previous_distances[idx_shorter + 1] + gap; + sz_ssize_t cost_insertion = current_distances[idx_shorter] + gap; + sz_ssize_t cost_substitution = previous_distances[idx_shorter] + a_subs[shorter_unsigned[idx_shorter]]; + current_distances[idx_shorter + 1] = sz_max_of_three(cost_deletion, cost_insertion, cost_substitution); + } + + // Swap previous_distances and current_distances pointers + sz_u64_swap((sz_u64_t *)&previous_distances, (sz_u64_t *)¤t_distances); + } + + // Cache scalar before `free` call. + sz_ssize_t result = previous_distances[shorter_length]; + alloc->free(distances, buffer_length, alloc->handle); + return result; +} + +/** + * @brief Largest prime number that fits into 31 bits. + * @see https://mersenneforum.org/showthread.php?t=3471 + */ +#define SZ_U32_MAX_PRIME (2147483647u) + +/** + * @brief Largest prime number that fits into 64 bits. + * @see https://mersenneforum.org/showthread.php?t=3471 + * + * 2^64 = 18,446,744,073,709,551,616 + * this = 18,446,744,073,709,551,557 + * diff = 59 + */ +#define SZ_U64_MAX_PRIME (18446744073709551557ull) + +/* + * One hardware-accelerated way of mixing hashes can be CRC, but it's only implemented for 32-bit values. + * Using a Boost-like mixer works very poorly in such case: + * + * hash_first ^ (hash_second + 0x517cc1b727220a95 + (hash_first << 6) + (hash_first >> 2)); + * + * Let's stick to the Fibonacci hash trick using the golden ratio. + * https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ + */ +#define _sz_hash_mix(first, second) ((first * 11400714819323198485ull) ^ (second * 11400714819323198485ull)) +#define _sz_shift_low(x) (x) +#define _sz_shift_high(x) ((x + 77ull) & 0xFFull) +#define _sz_prime_mod(x) (x % SZ_U64_MAX_PRIME) + +SZ_PUBLIC sz_u64_t sz_hash_serial(sz_cptr_t start, sz_size_t length) { + + sz_u64_t hash_low = 0; + sz_u64_t hash_high = 0; + sz_u8_t const *text = (sz_u8_t const *)start; + sz_u8_t const *text_end = text + length; + + switch (length) { + case 0: return 0; + + // Texts under 7 bytes long are definitely below the largest prime. + case 1: + hash_low = _sz_shift_low(text[0]); + hash_high = _sz_shift_high(text[0]); + break; + case 2: + hash_low = _sz_shift_low(text[0]) * 31ull + _sz_shift_low(text[1]); + hash_high = _sz_shift_high(text[0]) * 257ull + _sz_shift_high(text[1]); + break; + case 3: + hash_low = _sz_shift_low(text[0]) * 31ull * 31ull + // + _sz_shift_low(text[1]) * 31ull + // + _sz_shift_low(text[2]); + hash_high = _sz_shift_high(text[0]) * 257ull * 257ull + // + _sz_shift_high(text[1]) * 257ull + // + _sz_shift_high(text[2]); + break; + case 4: + hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull + // + _sz_shift_low(text[1]) * 31ull * 31ull + // + _sz_shift_low(text[2]) * 31ull + // + _sz_shift_low(text[3]); + hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull + // + _sz_shift_high(text[1]) * 257ull * 257ull + // + _sz_shift_high(text[2]) * 257ull + // + _sz_shift_high(text[3]); + break; + case 5: + hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull + // + _sz_shift_low(text[1]) * 31ull * 31ull * 31ull + // + _sz_shift_low(text[2]) * 31ull * 31ull + // + _sz_shift_low(text[3]) * 31ull + // + _sz_shift_low(text[4]); + hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull + // + _sz_shift_high(text[1]) * 257ull * 257ull * 257ull + // + _sz_shift_high(text[2]) * 257ull * 257ull + // + _sz_shift_high(text[3]) * 257ull + // + _sz_shift_high(text[4]); + break; + case 6: + hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull + // + _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull + // + _sz_shift_low(text[2]) * 31ull * 31ull * 31ull + // + _sz_shift_low(text[3]) * 31ull * 31ull + // + _sz_shift_low(text[4]) * 31ull + // + _sz_shift_low(text[5]); + hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull + // + _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull + // + _sz_shift_high(text[2]) * 257ull * 257ull * 257ull + // + _sz_shift_high(text[3]) * 257ull * 257ull + // + _sz_shift_high(text[4]) * 257ull + // + _sz_shift_high(text[5]); + break; + case 7: + hash_low = _sz_shift_low(text[0]) * 31ull * 31ull * 31ull * 31ull * 31ull * 31ull + // + _sz_shift_low(text[1]) * 31ull * 31ull * 31ull * 31ull * 31ull + // + _sz_shift_low(text[2]) * 31ull * 31ull * 31ull * 31ull + // + _sz_shift_low(text[3]) * 31ull * 31ull * 31ull + // + _sz_shift_low(text[4]) * 31ull * 31ull + // + _sz_shift_low(text[5]) * 31ull + // + _sz_shift_low(text[6]); + hash_high = _sz_shift_high(text[0]) * 257ull * 257ull * 257ull * 257ull * 257ull * 257ull + // + _sz_shift_high(text[1]) * 257ull * 257ull * 257ull * 257ull * 257ull + // + _sz_shift_high(text[2]) * 257ull * 257ull * 257ull * 257ull + // + _sz_shift_high(text[3]) * 257ull * 257ull * 257ull + // + _sz_shift_high(text[4]) * 257ull * 257ull + // + _sz_shift_high(text[5]) * 257ull + // + _sz_shift_high(text[6]); + break; + default: + // Unroll the first seven cycles: + hash_low = hash_low * 31ull + _sz_shift_low(text[0]); + hash_high = hash_high * 257ull + _sz_shift_high(text[0]); + hash_low = hash_low * 31ull + _sz_shift_low(text[1]); + hash_high = hash_high * 257ull + _sz_shift_high(text[1]); + hash_low = hash_low * 31ull + _sz_shift_low(text[2]); + hash_high = hash_high * 257ull + _sz_shift_high(text[2]); + hash_low = hash_low * 31ull + _sz_shift_low(text[3]); + hash_high = hash_high * 257ull + _sz_shift_high(text[3]); + hash_low = hash_low * 31ull + _sz_shift_low(text[4]); + hash_high = hash_high * 257ull + _sz_shift_high(text[4]); + hash_low = hash_low * 31ull + _sz_shift_low(text[5]); + hash_high = hash_high * 257ull + _sz_shift_high(text[5]); + hash_low = hash_low * 31ull + _sz_shift_low(text[6]); + hash_high = hash_high * 257ull + _sz_shift_high(text[6]); + text += 7; + + // Iterate throw the rest with the modulus: + for (; text != text_end; ++text) { + hash_low = hash_low * 31ull + _sz_shift_low(text[0]); + hash_high = hash_high * 257ull + _sz_shift_high(text[0]); + // Wrap the hashes around: + hash_low = _sz_prime_mod(hash_low); + hash_high = _sz_prime_mod(hash_high); + } + break; + } + + return _sz_hash_mix(hash_low, hash_high); +} + +SZ_PUBLIC void sz_hashes_serial(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle) { + + if (length < window_length || !window_length) return; + sz_u8_t const *text = (sz_u8_t const *)start; + sz_u8_t const *text_end = text + length; + + // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. + sz_u64_t prime_power_low = 1, prime_power_high = 1; + for (sz_size_t i = 0; i + 1 < window_length; ++i) + prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, + prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; + + // Compute the initial hash value for the first window. + sz_u64_t hash_low = 0, hash_high = 0, hash_mix; + for (sz_u8_t const *first_end = text + window_length; text < first_end; ++text) + hash_low = (hash_low * 31ull + _sz_shift_low(*text)) % SZ_U64_MAX_PRIME, + hash_high = (hash_high * 257ull + _sz_shift_high(*text)) % SZ_U64_MAX_PRIME; + + // In most cases the fingerprint length will be a power of two. + hash_mix = _sz_hash_mix(hash_low, hash_high); + callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); + + // Compute the hash value for every window, exporting into the fingerprint, + // using the expensive modulo operation. + sz_size_t cycles = 1; + sz_size_t const step_mask = step - 1; + for (; text < text_end; ++text, ++cycles) { + // Discard one character: + hash_low -= _sz_shift_low(*(text - window_length)) * prime_power_low; + hash_high -= _sz_shift_high(*(text - window_length)) * prime_power_high; + // And add a new one: + hash_low = 31ull * hash_low + _sz_shift_low(*text); + hash_high = 257ull * hash_high + _sz_shift_high(*text); + // Wrap the hashes around: + hash_low = _sz_prime_mod(hash_low); + hash_high = _sz_prime_mod(hash_high); + // Mix only if we've skipped enough hashes. + if ((cycles & step_mask) == 0) { + hash_mix = _sz_hash_mix(hash_low, hash_high); + callback((sz_cptr_t)text, window_length, hash_mix, callback_handle); + } + } +} + +#undef _sz_shift_low +#undef _sz_shift_high +#undef _sz_hash_mix +#undef _sz_prime_mod + +/** + * @brief Uses a small lookup-table to convert a lowercase character to uppercase. + */ +SZ_INTERNAL sz_u8_t sz_u8_tolower(sz_u8_t c) { + static sz_u8_t const lowered[256] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // + 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, // + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // + 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // + 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // + 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // + 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // + 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // + 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // + 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // + }; + return lowered[c]; +} + +/** + * @brief Uses a small lookup-table to convert an uppercase character to lowercase. + */ +SZ_INTERNAL sz_u8_t sz_u8_toupper(sz_u8_t c) { + static sz_u8_t const upped[256] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, // + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, // + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // + 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 91, 92, 93, 94, 95, // + 96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, // + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, // + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, // + 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, // + 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, // + 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, // + 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // + 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, // + 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, // + 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, // + }; + return upped[c]; +} + +/** + * @brief Uses two small lookup tables (768 bytes total) to accelerate division by a small + * unsigned integer. Performs two lookups, one multiplication, two shifts, and two accumulations. + * + * @param divisor Integral value larger than one. + * @param number Integral value to divide. + */ +SZ_INTERNAL sz_u8_t sz_u8_divide(sz_u8_t number, sz_u8_t divisor) { + static sz_u16_t const multipliers[256] = { + 0, 0, 0, 21846, 0, 39322, 21846, 9363, 0, 50973, 39322, 29790, 21846, 15124, 9363, 4370, + 0, 57826, 50973, 44841, 39322, 34329, 29790, 25645, 21846, 18351, 15124, 12137, 9363, 6780, 4370, 2115, + 0, 61565, 57826, 54302, 50973, 47824, 44841, 42011, 39322, 36765, 34329, 32006, 29790, 27671, 25645, 23705, + 21846, 20063, 18351, 16706, 15124, 13602, 12137, 10725, 9363, 8049, 6780, 5554, 4370, 3224, 2115, 1041, + 0, 63520, 61565, 59668, 57826, 56039, 54302, 52614, 50973, 49377, 47824, 46313, 44841, 43407, 42011, 40649, + 39322, 38028, 36765, 35532, 34329, 33154, 32006, 30885, 29790, 28719, 27671, 26647, 25645, 24665, 23705, 22766, + 21846, 20945, 20063, 19198, 18351, 17520, 16706, 15907, 15124, 14356, 13602, 12863, 12137, 11424, 10725, 10038, + 9363, 8700, 8049, 7409, 6780, 6162, 5554, 4957, 4370, 3792, 3224, 2665, 2115, 1573, 1041, 517, + 0, 64520, 63520, 62535, 61565, 60609, 59668, 58740, 57826, 56926, 56039, 55164, 54302, 53452, 52614, 51788, + 50973, 50169, 49377, 48595, 47824, 47063, 46313, 45572, 44841, 44120, 43407, 42705, 42011, 41326, 40649, 39982, + 39322, 38671, 38028, 37392, 36765, 36145, 35532, 34927, 34329, 33738, 33154, 32577, 32006, 31443, 30885, 30334, + 29790, 29251, 28719, 28192, 27671, 27156, 26647, 26143, 25645, 25152, 24665, 24182, 23705, 23233, 22766, 22303, + 21846, 21393, 20945, 20502, 20063, 19628, 19198, 18772, 18351, 17933, 17520, 17111, 16706, 16305, 15907, 15514, + 15124, 14738, 14356, 13977, 13602, 13231, 12863, 12498, 12137, 11779, 11424, 11073, 10725, 10380, 10038, 9699, + 9363, 9030, 8700, 8373, 8049, 7727, 7409, 7093, 6780, 6470, 6162, 5857, 5554, 5254, 4957, 4662, + 4370, 4080, 3792, 3507, 3224, 2943, 2665, 2388, 2115, 1843, 1573, 1306, 1041, 778, 517, 258, + }; + // This table can be avoided using a single addition and counting trailing zeros. + static sz_u8_t const shifts[256] = { + 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, // + 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, // + 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, // + 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // + }; + sz_u32_t multiplier = multipliers[divisor]; + sz_u8_t shift = shifts[divisor]; + + sz_u16_t q = (sz_u16_t)((multiplier * number) >> 16); + sz_u16_t t = ((number - q) >> 1) + q; + return (sz_u8_t)(t >> shift); +} + +SZ_PUBLIC void sz_tolower_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { + sz_u8_t *unsigned_result = (sz_u8_t *)result; + sz_u8_t const *unsigned_text = (sz_u8_t const *)text; + sz_u8_t const *end = unsigned_text + length; + for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_tolower(*unsigned_text); +} + +SZ_PUBLIC void sz_toupper_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { + sz_u8_t *unsigned_result = (sz_u8_t *)result; + sz_u8_t const *unsigned_text = (sz_u8_t const *)text; + sz_u8_t const *end = unsigned_text + length; + for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = sz_u8_toupper(*unsigned_text); +} + +SZ_PUBLIC void sz_toascii_serial(sz_cptr_t text, sz_size_t length, sz_ptr_t result) { + sz_u8_t *unsigned_result = (sz_u8_t *)result; + sz_u8_t const *unsigned_text = (sz_u8_t const *)text; + sz_u8_t const *end = unsigned_text + length; + for (; unsigned_text != end; ++unsigned_text, ++unsigned_result) *unsigned_result = *unsigned_text & 0x7F; +} + +SZ_PUBLIC void sz_generate(sz_cptr_t alphabet, sz_size_t alphabet_size, sz_ptr_t result, sz_size_t result_length, + sz_random_generator_t generator, void *generator_user_data) { + + sz_assert(alphabet_size > 0 && alphabet_size <= 256 && "Inadequate alphabet size"); + + if (alphabet_size == 1) + for (sz_cptr_t end = result + result_length; result != end; ++result) *result = *alphabet; + + else { + sz_assert(generator && "Expects a valid random generator"); + for (sz_cptr_t end = result + result_length; result != end; ++result) + *result = alphabet[sz_u8_divide(generator(generator_user_data) & 0xFF, (sz_u8_t)alphabet_size)]; + } +} + +#pragma endregion + +/* + * Serial implementation of string class operations. + */ +#pragma region Serial Implementation for the String Class + +/** + * @brief Threshold for switching to SWAR (8-bytes at a time) backend over serial byte-level for-loops. + * On very short strings, under 16 bytes long, at most a single word will be processed with SWAR. + * Assuming potentially misaligned loads, SWAR makes sense only after ~24 bytes. + */ +#ifndef SZ_SWAR_THRESHOLD +#define SZ_SWAR_THRESHOLD (24) // bytes +#endif + +SZ_PUBLIC sz_bool_t sz_string_is_on_stack(sz_string_t const *string) { + // It doesn't matter if it's on stack or heap, the pointer location is the same. + return (sz_bool_t)((sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]); +} + +SZ_PUBLIC void sz_string_range(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length) { + sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; + sz_size_t is_big_mask = is_small - 1ull; + *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. + // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. + *length = string->external.length & (0x00000000000000FFull | is_big_mask); +} + +SZ_PUBLIC void sz_string_unpack(sz_string_t const *string, sz_ptr_t *start, sz_size_t *length, sz_size_t *space, + sz_bool_t *is_external) { + sz_size_t is_small = (sz_cptr_t)string->internal.start == (sz_cptr_t)&string->internal.chars[0]; + sz_size_t is_big_mask = is_small - 1ull; + *start = string->external.start; // It doesn't matter if it's on stack or heap, the pointer location is the same. + // If the string is small, use branch-less approach to mask-out the top 7 bytes of the length. + *length = string->external.length & (0x00000000000000FFull | is_big_mask); + // In case the string is small, the `is_small - 1ull` will become 0xFFFFFFFFFFFFFFFFull. + *space = sz_u64_blend(SZ_STRING_INTERNAL_SPACE, string->external.space, is_big_mask); + *is_external = (sz_bool_t)!is_small; +} + +SZ_PUBLIC sz_bool_t sz_string_equal(sz_string_t const *a, sz_string_t const *b) { + // Tempting to say that the external.length is bitwise the same even if it includes + // some bytes of the on-stack payload, but we don't at this writing maintain that invariant. + // (An on-stack string includes noise bytes in the high-order bits of external.length. So do this + // the hard/correct way. + +#if SZ_USE_MISALIGNED_LOADS + // Dealing with StringZilla strings, we know that the `start` pointer always points + // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. + +#endif + // Alternatively, fall back to byte-by-byte comparison. + sz_ptr_t a_start, b_start; + sz_size_t a_length, b_length; + sz_string_range(a, &a_start, &a_length); + sz_string_range(b, &b_start, &b_length); + return (sz_bool_t)(a_length == b_length && sz_equal(a_start, b_start, b_length)); +} + +SZ_PUBLIC sz_ordering_t sz_string_order(sz_string_t const *a, sz_string_t const *b) { +#if SZ_USE_MISALIGNED_LOADS + // Dealing with StringZilla strings, we know that the `start` pointer always points + // to a word at least 8 bytes long. Therefore, we can compare the first 8 bytes at once. + +#endif + // Alternatively, fall back to byte-by-byte comparison. + sz_ptr_t a_start, b_start; + sz_size_t a_length, b_length; + sz_string_range(a, &a_start, &a_length); + sz_string_range(b, &b_start, &b_length); + return sz_order(a_start, a_length, b_start, b_length); +} + +SZ_PUBLIC void sz_string_init(sz_string_t *string) { + sz_assert(string && "String can't be SZ_NULL."); + + // Only 8 + 1 + 1 need to be initialized. + string->internal.start = &string->internal.chars[0]; + // But for safety let's initialize the entire structure to zeros. + // string->internal.chars[0] = 0; + // string->internal.length = 0; + string->u64s[1] = 0; + string->u64s[2] = 0; + string->u64s[3] = 0; +} + +SZ_PUBLIC sz_ptr_t sz_string_init_length(sz_string_t *string, sz_size_t length, sz_memory_allocator_t *allocator) { + sz_size_t space_needed = length + 1; // space for trailing \0 + sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); + // Initialize the string to zeros for safety. + string->u64s[1] = 0; + string->u64s[2] = 0; + string->u64s[3] = 0; + // If we are lucky, no memory allocations will be needed. + if (space_needed <= SZ_STRING_INTERNAL_SPACE) { + string->internal.start = &string->internal.chars[0]; + string->internal.length = (sz_u8_t)length; + } + else { + // If we are not lucky, we need to allocate memory. + string->external.start = (sz_ptr_t)allocator->allocate(space_needed, allocator->handle); + if (!string->external.start) return SZ_NULL; + string->external.length = length; + string->external.space = space_needed; + } + sz_assert(&string->internal.start == &string->external.start && "Alignment confusion"); + string->external.start[length] = 0; + return string->external.start; +} + +SZ_PUBLIC sz_ptr_t sz_string_reserve(sz_string_t *string, sz_size_t new_capacity, sz_memory_allocator_t *allocator) { + + sz_assert(string && "String can't be SZ_NULL."); + + sz_size_t new_space = new_capacity + 1; + if (new_space <= SZ_STRING_INTERNAL_SPACE) return string->external.start; + + sz_ptr_t string_start; + sz_size_t string_length; + sz_size_t string_space; + sz_bool_t string_is_external; + sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); + sz_assert(new_space > string_space && "New space must be larger than current."); + + sz_ptr_t new_start = (sz_ptr_t)allocator->allocate(new_space, allocator->handle); + if (!new_start) return SZ_NULL; + + sz_copy(new_start, string_start, string_length); + string->external.start = new_start; + string->external.space = new_space; + string->external.padding = 0; + string->external.length = string_length; + + // Deallocate the old string. + if (string_is_external) allocator->free(string_start, string_space, allocator->handle); + return string->external.start; +} + +SZ_PUBLIC sz_ptr_t sz_string_expand(sz_string_t *string, sz_size_t offset, sz_size_t added_length, + sz_memory_allocator_t *allocator) { + + sz_assert(string && allocator && "String and allocator can't be SZ_NULL."); + + sz_ptr_t string_start; + sz_size_t string_length; + sz_size_t string_space; + sz_bool_t string_is_external; + sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); + + // The user intended to extend the string. + offset = sz_min_of_two(offset, string_length); + + // If we are lucky, no memory allocations will be needed. + if (offset + string_length + added_length < string_space) { + sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); + string_start[string_length + added_length] = 0; + // Even if the string is on the stack, the `+=` won't affect the tail of the string. + string->external.length += added_length; + } + // If we are not lucky, we need to allocate more memory. + else { + sz_size_t next_planned_size = sz_max_of_two(SZ_CACHE_LINE_WIDTH, string_space * 2ull); + sz_size_t min_needed_space = sz_size_bit_ceil(offset + string_length + added_length + 1); + sz_size_t new_space = sz_max_of_two(min_needed_space, next_planned_size); + string_start = sz_string_reserve(string, new_space - 1, allocator); + if (!string_start) return SZ_NULL; + + // Copy into the new buffer. + sz_move(string_start + offset + added_length, string_start + offset, string_length - offset); + string_start[string_length + added_length] = 0; + string->external.length = string_length + added_length; + } + + return string_start; +} + +SZ_PUBLIC sz_size_t sz_string_erase(sz_string_t *string, sz_size_t offset, sz_size_t length) { + + sz_assert(string && "String can't be SZ_NULL."); + + sz_ptr_t string_start; + sz_size_t string_length; + sz_size_t string_space; + sz_bool_t string_is_external; + sz_string_unpack(string, &string_start, &string_length, &string_space, &string_is_external); + + // Normalize the offset, it can't be larger than the length. + offset = sz_min_of_two(offset, string_length); + + // We shouldn't normalize the length, to avoid overflowing on `offset + length >= string_length`, + // if receiving `length == SZ_SIZE_MAX`. After following expression the `length` will contain + // exactly the delta between original and final length of this `string`. + length = sz_min_of_two(length, string_length - offset); + + // There are 2 common cases, that wouldn't even require a `memmove`: + // 1. Erasing the entire contents of the string. + // In that case `length` argument will be equal or greater than `length` member. + // 2. Removing the tail of the string with something like `string.pop_back()` in C++. + // + // In both of those, regardless of the location of the string - stack or heap, + // the erasing is as easy as setting the length to the offset. + // In every other case, we must `memmove` the tail of the string to the left. + if (offset + length < string_length) + sz_move(string_start + offset, string_start + offset + length, string_length - offset - length); + + // The `string->external.length = offset` assignment would discard last characters + // of the on-the-stack string, but inplace subtraction would work. + string->external.length -= length; + string_start[string_length - length] = 0; + return length; +} + +SZ_PUBLIC void sz_string_free(sz_string_t *string, sz_memory_allocator_t *allocator) { + if (!sz_string_is_on_stack(string)) + allocator->free(string->external.start, string->external.space, allocator->handle); + sz_string_init(string); +} + +SZ_PUBLIC void sz_fill_serial(sz_ptr_t target, sz_size_t length, sz_u8_t value) { + sz_ptr_t end = target + length; + // Dealing with short strings, a single sequential pass would be faster. + // If the size is larger than 2 words, then at least 1 of them will be aligned. + // But just one aligned word may not be worth SWAR. + if (length < SZ_SWAR_THRESHOLD) + while (target != end) *(target++) = value; + + // In case of long strings, skip unaligned bytes, and then fill the rest in 64-bit chunks. + else { + sz_u64_t value64 = (sz_u64_t)(value) * 0x0101010101010101ull; + while ((sz_size_t)target & 7ull) *(target++) = value; + while (target + 8 <= end) *(sz_u64_t *)target = value64, target += 8; + while (target != end) *(target++) = value; + } +} + +SZ_PUBLIC void sz_copy_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { +#if SZ_USE_MISALIGNED_LOADS + while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, length -= 8; +#endif + while (length--) *(target++) = *(source++); +} + +SZ_PUBLIC void sz_move_serial(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { + // Implementing `memmove` is trickier, than `memcpy`, as the ranges may overlap. + // Existing implementations often have two passes, in normal and reversed order, + // depending on the relation of `target` and `source` addresses. + // https://student.cs.uwaterloo.ca/~cs350/common/os161-src-html/doxygen/html/memmove_8c_source.html + // https://marmota.medium.com/c-language-making-memmove-def8792bb8d5 + // + // We can use the `memcpy` like left-to-right pass if we know that the `target` is before `source`. + // Or if we know that they don't intersect! In that case the traversal order is irrelevant, + // but older CPUs may predict and fetch forward-passes better. + if (target < source || target >= source + length) { +#if SZ_USE_MISALIGNED_LOADS + while (length >= 8) *(sz_u64_t *)target = *(sz_u64_t *)source, target += 8, source += 8, length -= 8; +#endif + while (length--) *(target++) = *(source++); + } + else { + // Jump to the end and walk backwards. + target += length, source += length; +#if SZ_USE_MISALIGNED_LOADS + while (length >= 8) *(sz_u64_t *)(target -= 8) = *(sz_u64_t *)(source -= 8), length -= 8; +#endif + while (length--) *(--target) = *(--source); + } +} + +#pragma endregion + +/* + * @brief Serial implementation for strings sequence processing. + */ +#pragma region Serial Implementation for Sequences + +SZ_PUBLIC sz_size_t sz_partition(sz_sequence_t *sequence, sz_sequence_predicate_t predicate) { + + sz_size_t matches = 0; + while (matches != sequence->count && predicate(sequence, sequence->order[matches])) ++matches; + + for (sz_size_t i = matches + 1; i < sequence->count; ++i) + if (predicate(sequence, sequence->order[i])) + sz_u64_swap(sequence->order + i, sequence->order + matches), ++matches; + + return matches; +} + +SZ_PUBLIC void sz_merge(sz_sequence_t *sequence, sz_size_t partition, sz_sequence_comparator_t less) { + + sz_size_t start_b = partition + 1; + + // If the direct merge is already sorted + if (!less(sequence, sequence->order[start_b], sequence->order[partition])) return; + + sz_size_t start_a = 0; + while (start_a <= partition && start_b <= sequence->count) { + + // If element 1 is in right place + if (!less(sequence, sequence->order[start_b], sequence->order[start_a])) { start_a++; } + else { + sz_size_t value = sequence->order[start_b]; + sz_size_t index = start_b; + + // Shift all the elements between element 1 + // element 2, right by 1. + while (index != start_a) { sequence->order[index] = sequence->order[index - 1], index--; } + sequence->order[start_a] = value; + + // Update all the pointers + start_a++; + partition++; + start_b++; + } + } +} + +SZ_PUBLIC void sz_sort_insertion(sz_sequence_t *sequence, sz_sequence_comparator_t less) { + sz_u64_t *keys = sequence->order; + sz_size_t keys_count = sequence->count; + for (sz_size_t i = 1; i < keys_count; i++) { + sz_u64_t i_key = keys[i]; + sz_size_t j = i; + for (; j > 0 && less(sequence, i_key, keys[j - 1]); --j) keys[j] = keys[j - 1]; + keys[j] = i_key; + } +} + +SZ_INTERNAL void _sz_sift_down(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t start, + sz_size_t end) { + sz_size_t root = start; + while (2 * root + 1 <= end) { + sz_size_t child = 2 * root + 1; + if (child + 1 <= end && less(sequence, order[child], order[child + 1])) { child++; } + if (!less(sequence, order[root], order[child])) { return; } + sz_u64_swap(order + root, order + child); + root = child; + } +} + +SZ_INTERNAL void _sz_heapify(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_u64_t *order, sz_size_t count) { + sz_size_t start = (count - 2) / 2; + while (1) { + _sz_sift_down(sequence, less, order, start, count - 1); + if (start == 0) return; + start--; + } +} + +SZ_INTERNAL void _sz_heapsort(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, sz_size_t last) { + sz_u64_t *order = sequence->order; + sz_size_t count = last - first; + _sz_heapify(sequence, less, order + first, count); + sz_size_t end = count - 1; + while (end > 0) { + sz_u64_swap(order + first, order + first + end); + end--; + _sz_sift_down(sequence, less, order + first, 0, end); + } +} + +SZ_PUBLIC void sz_sort_introsort_recursion(sz_sequence_t *sequence, sz_sequence_comparator_t less, sz_size_t first, + sz_size_t last, sz_size_t depth) { + + sz_size_t length = last - first; + switch (length) { + case 0: + case 1: return; + case 2: + if (less(sequence, sequence->order[first + 1], sequence->order[first])) + sz_u64_swap(&sequence->order[first], &sequence->order[first + 1]); + return; + case 3: { + sz_u64_t a = sequence->order[first]; + sz_u64_t b = sequence->order[first + 1]; + sz_u64_t c = sequence->order[first + 2]; + if (less(sequence, b, a)) sz_u64_swap(&a, &b); + if (less(sequence, c, b)) sz_u64_swap(&c, &b); + if (less(sequence, b, a)) sz_u64_swap(&a, &b); + sequence->order[first] = a; + sequence->order[first + 1] = b; + sequence->order[first + 2] = c; + return; + } + } + // Until a certain length, the quadratic-complexity insertion-sort is fine + if (length <= 16) { + sz_sequence_t sub_seq = *sequence; + sub_seq.order += first; + sub_seq.count = length; + sz_sort_insertion(&sub_seq, less); + return; + } + + // Fallback to N-logN-complexity heap-sort + if (depth == 0) { + _sz_heapsort(sequence, less, first, last); + return; + } + + --depth; + + // Median-of-three logic to choose pivot + sz_size_t median = first + length / 2; + if (less(sequence, sequence->order[median], sequence->order[first])) + sz_u64_swap(&sequence->order[first], &sequence->order[median]); + if (less(sequence, sequence->order[last - 1], sequence->order[first])) + sz_u64_swap(&sequence->order[first], &sequence->order[last - 1]); + if (less(sequence, sequence->order[median], sequence->order[last - 1])) + sz_u64_swap(&sequence->order[median], &sequence->order[last - 1]); + + // Partition using the median-of-three as the pivot + sz_u64_t pivot = sequence->order[median]; + sz_size_t left = first; + sz_size_t right = last - 1; + while (1) { + while (less(sequence, sequence->order[left], pivot)) left++; + while (less(sequence, pivot, sequence->order[right])) right--; + if (left >= right) break; + sz_u64_swap(&sequence->order[left], &sequence->order[right]); + left++; + right--; + } + + // Recursively sort the partitions + sz_sort_introsort_recursion(sequence, less, first, left, depth); + sz_sort_introsort_recursion(sequence, less, right + 1, last, depth); +} + +SZ_PUBLIC void sz_sort_introsort(sz_sequence_t *sequence, sz_sequence_comparator_t less) { + if (sequence->count == 0) return; + sz_size_t size_is_not_power_of_two = (sequence->count & (sequence->count - 1)) != 0; + sz_size_t depth_limit = sz_size_log2i_nonzero(sequence->count) + size_is_not_power_of_two; + sz_sort_introsort_recursion(sequence, less, 0, sequence->count, depth_limit); +} + +SZ_PUBLIC void sz_sort_recursion( // + sz_sequence_t *sequence, sz_size_t bit_idx, sz_size_t bit_max, sz_sequence_comparator_t comparator, + sz_size_t partial_order_length) { + + if (!sequence->count) return; + + // Array of size one doesn't need sorting - only needs the prefix to be discarded. + if (sequence->count == 1) { + sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; + order_half_words[1] = 0; + return; + } + + // Partition a range of integers according to a specific bit value + sz_size_t split = 0; + sz_u64_t mask = (1ull << 63) >> bit_idx; + + // The clean approach would be to perform a single pass over the sequence. + // + // while (split != sequence->count && !(sequence->order[split] & mask)) ++split; + // for (sz_size_t i = split + 1; i < sequence->count; ++i) + // if (!(sequence->order[i] & mask)) sz_u64_swap(sequence->order + i, sequence->order + split), ++split; + // + // This, however, doesn't take into account the high relative cost of writes and swaps. + // To cercumvent that, we can first count the total number entries to be mapped into either part. + // And then walk through both parts, swapping the entries that are in the wrong part. + // This would often lead to ~15% performance gain. + sz_size_t count_with_bit_set = 0; + for (sz_size_t i = 0; i != sequence->count; ++i) count_with_bit_set += (sequence->order[i] & mask) != 0; + split = sequence->count - count_with_bit_set; + + // It's possible that the sequence is already partitioned. + if (split != 0 && split != sequence->count) { + // Use two pointers to efficiently reposition elements. + // On pointer walks left-to-right from the start, and the other walks right-to-left from the end. + sz_size_t left = 0; + sz_size_t right = sequence->count - 1; + while (1) { + // Find the next element with the bit set on the left side. + while (left < split && !(sequence->order[left] & mask)) ++left; + // Find the next element without the bit set on the right side. + while (right >= split && (sequence->order[right] & mask)) --right; + // Swap the mispositioned elements. + if (left < split && right >= split) { + sz_u64_swap(sequence->order + left, sequence->order + right); + ++left; + --right; + } + else { break; } + } + } + + // Go down recursively. + if (bit_idx < bit_max) { + sz_sequence_t a = *sequence; + a.count = split; + sz_sort_recursion(&a, bit_idx + 1, bit_max, comparator, partial_order_length); + + sz_sequence_t b = *sequence; + b.order += split; + b.count -= split; + sz_sort_recursion(&b, bit_idx + 1, bit_max, comparator, partial_order_length); + } + // Reached the end of recursion. + else { + // Discard the prefixes. + sz_u32_t *order_half_words = (sz_u32_t *)sequence->order; + for (sz_size_t i = 0; i != sequence->count; ++i) { order_half_words[i * 2 + 1] = 0; } + + sz_sequence_t a = *sequence; + a.count = split; + sz_sort_introsort(&a, comparator); + + sz_sequence_t b = *sequence; + b.order += split; + b.count -= split; + sz_sort_introsort(&b, comparator); + } +} + +SZ_INTERNAL sz_bool_t _sz_sort_is_less(sz_sequence_t *sequence, sz_size_t i_key, sz_size_t j_key) { + sz_cptr_t i_str = sequence->get_start(sequence, i_key); + sz_cptr_t j_str = sequence->get_start(sequence, j_key); + sz_size_t i_len = sequence->get_length(sequence, i_key); + sz_size_t j_len = sequence->get_length(sequence, j_key); + return (sz_bool_t)(sz_order_serial(i_str, i_len, j_str, j_len) == sz_less_k); +} + +SZ_PUBLIC void sz_sort_partial(sz_sequence_t *sequence, sz_size_t partial_order_length) { + + // Export up to 4 bytes into the `sequence` bits themselves + for (sz_size_t i = 0; i != sequence->count; ++i) { + sz_cptr_t begin = sequence->get_start(sequence, sequence->order[i]); + sz_size_t length = sequence->get_length(sequence, sequence->order[i]); + length = length > 4ull ? 4ull : length; + sz_ptr_t prefix = (sz_ptr_t)&sequence->order[i]; + for (sz_size_t j = 0; j != length; ++j) prefix[7 - j] = begin[j]; + } + + // Perform optionally-parallel radix sort on them + sz_sort_recursion(sequence, 0, 32, (sz_sequence_comparator_t)_sz_sort_is_less, partial_order_length); +} + +SZ_PUBLIC void sz_sort(sz_sequence_t *sequence) { sz_sort_partial(sequence, sequence->count); } + +#pragma endregion + +/* + * @brief AVX2 implementation of the string search algorithms. + * Very minimalistic, but still faster than the serial implementation. + */ +#pragma region AVX2 Implementation + +#if SZ_USE_X86_AVX2 +#pragma GCC push_options +#pragma GCC target("avx2") +#pragma clang attribute push(__attribute__((target("avx2"))), apply_to = function) +#include + +/** + * @brief Helper structure to simplify work with 256-bit registers. + */ +typedef union sz_u256_vec_t { + __m256i ymm; + __m128i xmms[2]; + sz_u64_t u64s[4]; + sz_u32_t u32s[8]; + sz_u16_t u16s[16]; + sz_u8_t u8s[32]; +} sz_u256_vec_t; + +SZ_PUBLIC void sz_fill_avx2(sz_ptr_t target, sz_size_t length, sz_u8_t value) { + for (; length >= 32; target += 32, length -= 32) _mm256_storeu_si256((__m256i *)target, _mm256_set1_epi8(value)); + sz_fill_serial(target, length, value); +} + +SZ_PUBLIC void sz_copy_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { + for (; length >= 32; target += 32, source += 32, length -= 32) + _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); + sz_copy_serial(target, source, length); +} + +SZ_PUBLIC void sz_move_avx2(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { + if (target < source || target >= source + length) { + for (; length >= 32; target += 32, source += 32, length -= 32) + _mm256_storeu_si256((__m256i *)target, _mm256_lddqu_si256((__m256i const *)source)); + while (length--) *(target++) = *(source++); + } + else { + // Jump to the end and walk backwards. + for (target += length, source += length; length >= 32; length -= 32) + _mm256_storeu_si256((__m256i *)(target -= 32), _mm256_lddqu_si256((__m256i const *)(source -= 32))); + while (length--) *(--target) = *(--source); + } +} + +SZ_PUBLIC sz_cptr_t sz_find_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + int mask; + sz_u256_vec_t h_vec, n_vec; + n_vec.ymm = _mm256_set1_epi8(n[0]); + + while (h_length >= 32) { + h_vec.ymm = _mm256_lddqu_si256((__m256i const *)h); + mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); + if (mask) return h + sz_u32_ctz(mask); + h += 32, h_length -= 32; + } + + return sz_find_byte_serial(h, h_length, n); +} + +SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + int mask; + sz_u256_vec_t h_vec, n_vec; + n_vec.ymm = _mm256_set1_epi8(n[0]); + + while (h_length >= 32) { + h_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + h_length - 32)); + mask = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_vec.ymm, n_vec.ymm)); + if (mask) return h + h_length - 1 - sz_u32_clz(mask); + h_length -= 32; + } + + return sz_rfind_byte_serial(h, h_length, n); +} + +SZ_PUBLIC sz_cptr_t sz_find_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + + // This almost never fires, but it's better to be safe than sorry. + if (h_length < n_length || !n_length) return SZ_NULL; + if (n_length == 1) return sz_find_byte_avx2(h, h_length, n); + + // Pick the parts of the needle that are worth comparing. + sz_size_t offset_first, offset_mid, offset_last; + _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); + + // Broadcast those characters into YMM registers. + int matches; + sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; + n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); + n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); + n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); + + // Scan through the string. + for (; h_length >= n_length + 32; h += 32, h_length -= 32) { + h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_first)); + h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_mid)); + h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h + offset_last)); + matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & + _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & + _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); + while (matches) { + int potential_offset = sz_u32_ctz(matches); + if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; + matches &= matches - 1; + } + } + + return sz_find_serial(h, h_length, n, n_length); +} + +SZ_PUBLIC sz_cptr_t sz_rfind_avx2(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + + // This almost never fires, but it's better to be safe than sorry. + if (h_length < n_length || !n_length) return SZ_NULL; + if (n_length == 1) return sz_rfind_byte_avx2(h, h_length, n); + + // Pick the parts of the needle that are worth comparing. + sz_size_t offset_first, offset_mid, offset_last; + _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); + + // Broadcast those characters into YMM registers. + int matches; + sz_u256_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; + n_first_vec.ymm = _mm256_set1_epi8(n[offset_first]); + n_mid_vec.ymm = _mm256_set1_epi8(n[offset_mid]); + n_last_vec.ymm = _mm256_set1_epi8(n[offset_last]); + + // Scan through the string. + sz_cptr_t h_reversed; + for (; h_length >= n_length + 32; h_length -= 32) { + h_reversed = h + h_length - n_length - 32 + 1; + h_first_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_first)); + h_mid_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_mid)); + h_last_vec.ymm = _mm256_lddqu_si256((__m256i const *)(h_reversed + offset_last)); + matches = _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_first_vec.ymm, n_first_vec.ymm)) & + _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_mid_vec.ymm, n_mid_vec.ymm)) & + _mm256_movemask_epi8(_mm256_cmpeq_epi8(h_last_vec.ymm, n_last_vec.ymm)); + while (matches) { + int potential_offset = sz_u32_clz(matches); + if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) + return h + h_length - n_length - potential_offset; + matches &= ~(1 << (31 - potential_offset)); + } + } + + return sz_rfind_serial(h, h_length, n, n_length); +} + +/** + * @brief There is no AVX2 instruction for fast multiplication of 64-bit integers. + * This implementation is coming from Agner Fog's Vector Class Library. + */ +SZ_INTERNAL __m256i _mm256_mul_epu64(__m256i a, __m256i b) { + __m256i bswap = _mm256_shuffle_epi32(b, 0xB1); + __m256i prodlh = _mm256_mullo_epi32(a, bswap); + __m256i zero = _mm256_setzero_si256(); + __m256i prodlh2 = _mm256_hadd_epi32(prodlh, zero); + __m256i prodlh3 = _mm256_shuffle_epi32(prodlh2, 0x73); + __m256i prodll = _mm256_mul_epu32(a, b); + __m256i prod = _mm256_add_epi64(prodll, prodlh3); + return prod; +} + +SZ_PUBLIC void sz_hashes_avx2(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle) { + + if (length < window_length || !window_length) return; + if (length < 4 * window_length) { + sz_hashes_serial(start, length, window_length, step, callback, callback_handle); + return; + } + + // Using AVX2, we can perform 4 long integer multiplications and additions within one register. + // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. + sz_size_t const max_hashes = length - window_length + 1; + sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. + sz_u8_t const *text_first = (sz_u8_t const *)start; + sz_u8_t const *text_second = text_first + min_hashes_per_thread; + sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; + sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; + sz_u8_t const *text_end = text_first + length; + + // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. + sz_u64_t prime_power_low = 1, prime_power_high = 1; + for (sz_size_t i = 0; i + 1 < window_length; ++i) + prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, + prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; + + // Broadcast the constants into the registers. + sz_u256_vec_t prime_vec, golden_ratio_vec; + sz_u256_vec_t base_low_vec, base_high_vec, prime_power_low_vec, prime_power_high_vec, shift_high_vec; + base_low_vec.ymm = _mm256_set1_epi64x(31ull); + base_high_vec.ymm = _mm256_set1_epi64x(257ull); + shift_high_vec.ymm = _mm256_set1_epi64x(77ull); + prime_vec.ymm = _mm256_set1_epi64x(SZ_U64_MAX_PRIME); + golden_ratio_vec.ymm = _mm256_set1_epi64x(11400714819323198485ull); + prime_power_low_vec.ymm = _mm256_set1_epi64x(prime_power_low); + prime_power_high_vec.ymm = _mm256_set1_epi64x(prime_power_high); + + // Compute the initial hash values for every one of the four windows. + sz_u256_vec_t hash_low_vec, hash_high_vec, hash_mix_vec, chars_low_vec, chars_high_vec; + hash_low_vec.ymm = _mm256_setzero_si256(); + hash_high_vec.ymm = _mm256_setzero_si256(); + for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; + ++text_first, ++text_second, ++text_third, ++text_fourth) { + + // 1. Multiply the hashes by the base. + hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); + hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); + + // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, + // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. + chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); + chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); + + // 3. Add the incoming characters. + hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); + hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); + + // 4. Compute the modulo. Assuming there are only 59 values between our prime + // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. + hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), + _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); + hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), + _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); + } + + // 5. Compute the hash mix, that will be used to index into the fingerprint. + // This includes a serial step at the end. + hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); + hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); + hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); + callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); + callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); + callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); + callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); + + // Now repeat that operation for the remaining characters, discarding older characters. + sz_size_t cycle = 1; + sz_size_t const step_mask = step - 1; + for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { + // 0. Load again the four characters we are dropping, shift them, and subtract. + chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[-window_length], text_third[-window_length], + text_second[-window_length], text_first[-window_length]); + chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); + hash_low_vec.ymm = + _mm256_sub_epi64(hash_low_vec.ymm, _mm256_mul_epu64(chars_low_vec.ymm, prime_power_low_vec.ymm)); + hash_high_vec.ymm = + _mm256_sub_epi64(hash_high_vec.ymm, _mm256_mul_epu64(chars_high_vec.ymm, prime_power_high_vec.ymm)); + + // 1. Multiply the hashes by the base. + hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, base_low_vec.ymm); + hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, base_high_vec.ymm); + + // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, + // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. + chars_low_vec.ymm = _mm256_set_epi64x(text_fourth[0], text_third[0], text_second[0], text_first[0]); + chars_high_vec.ymm = _mm256_add_epi8(chars_low_vec.ymm, shift_high_vec.ymm); + + // 3. Add the incoming characters. + hash_low_vec.ymm = _mm256_add_epi64(hash_low_vec.ymm, chars_low_vec.ymm); + hash_high_vec.ymm = _mm256_add_epi64(hash_high_vec.ymm, chars_high_vec.ymm); + + // 4. Compute the modulo. Assuming there are only 59 values between our prime + // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. + hash_low_vec.ymm = _mm256_blendv_epi8(hash_low_vec.ymm, _mm256_sub_epi64(hash_low_vec.ymm, prime_vec.ymm), + _mm256_cmpgt_epi64(hash_low_vec.ymm, prime_vec.ymm)); + hash_high_vec.ymm = _mm256_blendv_epi8(hash_high_vec.ymm, _mm256_sub_epi64(hash_high_vec.ymm, prime_vec.ymm), + _mm256_cmpgt_epi64(hash_high_vec.ymm, prime_vec.ymm)); + + // 5. Compute the hash mix, that will be used to index into the fingerprint. + // This includes a serial step at the end. + hash_low_vec.ymm = _mm256_mul_epu64(hash_low_vec.ymm, golden_ratio_vec.ymm); + hash_high_vec.ymm = _mm256_mul_epu64(hash_high_vec.ymm, golden_ratio_vec.ymm); + hash_mix_vec.ymm = _mm256_xor_si256(hash_low_vec.ymm, hash_high_vec.ymm); + if ((cycle & step_mask) == 0) { + callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); + callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); + callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); + callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); + } + } +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif +#pragma endregion + +/* + * @brief AVX-512 implementation of the string search algorithms. + * + * Different subsets of AVX-512 were introduced in different years: + * * 2017 SkyLake: F, CD, ER, PF, VL, DQ, BW + * * 2018 CannonLake: IFMA, VBMI + * * 2019 IceLake: VPOPCNTDQ, VNNI, VBMI2, BITALG, GFNI, VPCLMULQDQ, VAES + * * 2020 TigerLake: VP2INTERSECT + */ +#pragma region AVX-512 Implementation + +#if SZ_USE_X86_AVX512 +#pragma GCC push_options +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "bmi", "bmi2") +#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,bmi,bmi2"))), apply_to = function) +#include + +/** + * @brief Helper structure to simplify work with 512-bit registers. + */ +typedef union sz_u512_vec_t { + __m512i zmm; + __m256i ymms[2]; + __m128i xmms[4]; + sz_u64_t u64s[8]; + sz_u32_t u32s[16]; + sz_u16_t u16s[32]; + sz_u8_t u8s[64]; + sz_i64_t i64s[8]; + sz_i32_t i32s[16]; +} sz_u512_vec_t; + +SZ_INTERNAL __mmask64 _sz_u64_clamp_mask_until(sz_size_t n) { + // The simplest approach to compute this if we know that `n` is blow or equal 64: + // return (1ull << n) - 1; + // A slightly more complex approach, if we don't know that `n` is under 64: + return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n < 64 ? n : 64); +} + +SZ_INTERNAL __mmask32 _sz_u32_clamp_mask_until(sz_size_t n) { + // The simplest approach to compute this if we know that `n` is blow or equal 32: + // return (1ull << n) - 1; + // A slightly more complex approach, if we don't know that `n` is under 32: + return _bzhi_u32(0xFFFFFFFF, n < 32 ? n : 32); +} + +SZ_INTERNAL __mmask16 _sz_u16_clamp_mask_until(sz_size_t n) { + // The simplest approach to compute this if we know that `n` is blow or equal 16: + // return (1ull << n) - 1; + // A slightly more complex approach, if we don't know that `n` is under 16: + return _bzhi_u32(0xFFFFFFFF, n < 16 ? n : 16); +} + +SZ_INTERNAL __mmask64 _sz_u64_mask_until(sz_size_t n) { + // The simplest approach to compute this if we know that `n` is blow or equal 64: + // return (1ull << n) - 1; + // A slightly more complex approach, if we don't know that `n` is under 64: + return _bzhi_u64(0xFFFFFFFFFFFFFFFF, n); +} + +SZ_PUBLIC sz_ordering_t sz_order_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { + sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k}; + sz_u512_vec_t a_vec, b_vec; + __mmask64 a_mask, b_mask, mask_not_equal; + + // The rare case, when both string are very long. + while ((a_length >= 64) & (b_length >= 64)) { + a_vec.zmm = _mm512_loadu_epi8(a); + b_vec.zmm = _mm512_loadu_epi8(b); + mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); + if (mask_not_equal != 0) { + int first_diff = _tzcnt_u64(mask_not_equal); + char a_char = a[first_diff]; + char b_char = b[first_diff]; + return ordering_lookup[a_char < b_char]; + } + a += 64, b += 64, a_length -= 64, b_length -= 64; + } + + // In most common scenarios at least one of the strings is under 64 bytes. + if (a_length | b_length) { + a_mask = _sz_u64_clamp_mask_until(a_length); + b_mask = _sz_u64_clamp_mask_until(b_length); + a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a); + b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b); + // The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments. + // They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have + // been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards. + mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); + if (mask_not_equal != 0) { + int first_diff = _tzcnt_u64(mask_not_equal); + char a_char = a[first_diff]; + char b_char = b[first_diff]; + return ordering_lookup[a_char < b_char]; + } + else + // From logic perspective, the hardest cases are "abc\0" and "abc". + // The result must be `sz_greater_k`, as the latter is shorter. + return a_length != b_length ? ordering_lookup[a_length < b_length] : sz_equal_k; + } + else + return sz_equal_k; +} + +SZ_PUBLIC sz_bool_t sz_equal_avx512(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { + __mmask64 mask; + sz_u512_vec_t a_vec, b_vec; + + while (length >= 64) { + a_vec.zmm = _mm512_loadu_epi8(a); + b_vec.zmm = _mm512_loadu_epi8(b); + mask = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm); + if (mask != 0) return sz_false_k; + a += 64, b += 64, length -= 64; + } + + if (length) { + mask = _sz_u64_mask_until(length); + a_vec.zmm = _mm512_maskz_loadu_epi8(mask, a); + b_vec.zmm = _mm512_maskz_loadu_epi8(mask, b); + // Reuse the same `mask` variable to find the bit that doesn't match + mask = _mm512_mask_cmpneq_epi8_mask(mask, a_vec.zmm, b_vec.zmm); + return (sz_bool_t)(mask == 0); + } + else + return sz_true_k; +} + +SZ_PUBLIC void sz_fill_avx512(sz_ptr_t target, sz_size_t length, sz_u8_t value) { + for (; length >= 64; target += 64, length -= 64) _mm512_storeu_epi8(target, _mm512_set1_epi8(value)); + // At this point the length is guaranteed to be under 64. + _mm512_mask_storeu_epi8(target, _sz_u64_mask_until(length), _mm512_set1_epi8(value)); +} + +SZ_PUBLIC void sz_copy_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { + for (; length >= 64; target += 64, source += 64, length -= 64) + _mm512_storeu_epi8(target, _mm512_loadu_epi8(source)); + // At this point the length is guaranteed to be under 64. + __mmask64 mask = _sz_u64_mask_until(length); + _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); +} + +SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { + if (target < source || target >= source + length) { + for (; length >= 64; target += 64, source += 64, length -= 64) + _mm512_storeu_epi8(target, _mm512_loadu_epi8(source)); + // At this point the length is guaranteed to be under 64. + __mmask64 mask = _sz_u64_mask_until(length); + _mm512_mask_storeu_epi8(target, mask, _mm512_maskz_loadu_epi8(mask, source)); + } + else { + // Jump to the end and walk backwards. + for (target += length, source += length; length >= 64; length -= 64) + _mm512_storeu_epi8(target -= 64, _mm512_loadu_epi8(source -= 64)); + // At this point the length is guaranteed to be under 64. + __mmask64 mask = _sz_u64_mask_until(length); + _mm512_mask_storeu_epi8(target - length, mask, _mm512_maskz_loadu_epi8(mask, source - length)); + } +} + +SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + __mmask64 mask; + sz_u512_vec_t h_vec, n_vec; + n_vec.zmm = _mm512_set1_epi8(n[0]); + + while (h_length >= 64) { + h_vec.zmm = _mm512_loadu_epi8(h); + mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); + if (mask) return h + sz_u64_ctz(mask); + h += 64, h_length -= 64; + } + + if (h_length) { + mask = _sz_u64_mask_until(h_length); + h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); + // Reuse the same `mask` variable to find the bit that doesn't match + mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); + if (mask) return h + sz_u64_ctz(mask); + } + + return SZ_NULL; +} + +SZ_PUBLIC sz_cptr_t sz_find_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + + // This almost never fires, but it's better to be safe than sorry. + if (h_length < n_length || !n_length) return SZ_NULL; + if (n_length == 1) return sz_find_byte_avx512(h, h_length, n); + + // Pick the parts of the needle that are worth comparing. + sz_size_t offset_first, offset_mid, offset_last; + _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); + + // Broadcast those characters into ZMM registers. + __mmask64 matches; + __mmask64 mask; + sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; + n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); + n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); + n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); + + // Scan through the string. + for (; h_length >= n_length + 64; h += 64, h_length -= 64) { + h_first_vec.zmm = _mm512_loadu_epi8(h + offset_first); + h_mid_vec.zmm = _mm512_loadu_epi8(h + offset_mid); + h_last_vec.zmm = _mm512_loadu_epi8(h + offset_last); + matches = _kand_mask64(_kand_mask64( // Intersect the masks + _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), + _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), + _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); + while (matches) { + int potential_offset = sz_u64_ctz(matches); + if (n_length <= 3 || sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; + matches &= matches - 1; + } + + // TODO: If the last character contains a bad byte, we can reposition the start of the next iteration. + // This will be very helpful for very long needles. + } + + // The "tail" of the function uses masked loads to process the remaining bytes. + { + mask = _sz_u64_mask_until(h_length - n_length + 1); + h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); + h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); + h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); + matches = _kand_mask64(_kand_mask64( // Intersect the masks + _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), + _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), + _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); + while (matches) { + int potential_offset = sz_u64_ctz(matches); + if (n_length <= 3 || sz_equal_avx512(h + potential_offset, n, n_length)) return h + potential_offset; + matches &= matches - 1; + } + } + return SZ_NULL; +} + +SZ_PUBLIC sz_cptr_t sz_rfind_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + __mmask64 mask; + sz_u512_vec_t h_vec, n_vec; + n_vec.zmm = _mm512_set1_epi8(n[0]); + + while (h_length >= 64) { + h_vec.zmm = _mm512_loadu_epi8(h + h_length - 64); + mask = _mm512_cmpeq_epi8_mask(h_vec.zmm, n_vec.zmm); + if (mask) return h + h_length - 1 - sz_u64_clz(mask); + h_length -= 64; + } + + if (h_length) { + mask = _sz_u64_mask_until(h_length); + h_vec.zmm = _mm512_maskz_loadu_epi8(mask, h); + // Reuse the same `mask` variable to find the bit that doesn't match + mask = _mm512_mask_cmpeq_epu8_mask(mask, h_vec.zmm, n_vec.zmm); + if (mask) return h + 64 - sz_u64_clz(mask) - 1; + } + + return SZ_NULL; +} + +SZ_PUBLIC sz_cptr_t sz_rfind_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + + // This almost never fires, but it's better to be safe than sorry. + if (h_length < n_length || !n_length) return SZ_NULL; + if (n_length == 1) return sz_rfind_byte_avx512(h, h_length, n); + + // Pick the parts of the needle that are worth comparing. + sz_size_t offset_first, offset_mid, offset_last; + _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); + + // Broadcast those characters into ZMM registers. + __mmask64 mask; + __mmask64 matches; + sz_u512_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec; + n_first_vec.zmm = _mm512_set1_epi8(n[offset_first]); + n_mid_vec.zmm = _mm512_set1_epi8(n[offset_mid]); + n_last_vec.zmm = _mm512_set1_epi8(n[offset_last]); + + // Scan through the string. + sz_cptr_t h_reversed; + for (; h_length >= n_length + 64; h_length -= 64) { + h_reversed = h + h_length - n_length - 64 + 1; + h_first_vec.zmm = _mm512_loadu_epi8(h_reversed + offset_first); + h_mid_vec.zmm = _mm512_loadu_epi8(h_reversed + offset_mid); + h_last_vec.zmm = _mm512_loadu_epi8(h_reversed + offset_last); + matches = _kand_mask64(_kand_mask64( // Intersect the masks + _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), + _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), + _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); + while (matches) { + int potential_offset = sz_u64_clz(matches); + if (n_length <= 3 || sz_equal_avx512(h + h_length - n_length - potential_offset, n, n_length)) + return h + h_length - n_length - potential_offset; + sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && + "The bit must be set before we squash it"); + matches &= ~((sz_u64_t)1 << (63 - potential_offset)); + } + } + + // The "tail" of the function uses masked loads to process the remaining bytes. + { + mask = _sz_u64_mask_until(h_length - n_length + 1); + h_first_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_first); + h_mid_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_mid); + h_last_vec.zmm = _mm512_maskz_loadu_epi8(mask, h + offset_last); + matches = _kand_mask64(_kand_mask64( // Intersect the masks + _mm512_cmpeq_epi8_mask(h_first_vec.zmm, n_first_vec.zmm), + _mm512_cmpeq_epi8_mask(h_mid_vec.zmm, n_mid_vec.zmm)), + _mm512_cmpeq_epi8_mask(h_last_vec.zmm, n_last_vec.zmm)); + while (matches) { + int potential_offset = sz_u64_clz(matches); + if (n_length <= 3 || sz_equal_avx512(h + 64 - potential_offset - 1, n, n_length)) + return h + 64 - potential_offset - 1; + sz_assert((matches & ((sz_u64_t)1 << (63 - potential_offset))) != 0 && + "The bit must be set before we squash it"); + matches &= ~((sz_u64_t)1 << (63 - potential_offset)); + } + } + + return SZ_NULL; +} + +SZ_INTERNAL sz_size_t _sz_edit_distance_skewed_diagonals_upto65k_avx512( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound, sz_memory_allocator_t *alloc) { + + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. + sz_memory_allocator_t global_alloc; + if (!alloc) { + sz_memory_allocator_init_default(&global_alloc); + alloc = &global_alloc; + } + + // TODO: Generalize! + sz_size_t max_length = 256u * 256u; + sz_assert(!bound && "For bounded search the method should only evaluate one band of the matrix."); + sz_assert(shorter_length == longer_length && "The method hasn't been generalized to different length inputs yet."); + sz_assert(shorter_length < max_length && "The length must fit into 16-bit integer. Otherwise use serial variant."); + sz_unused(longer_length && bound && max_length); + + // We are going to store 3 diagonals of the matrix. + // The length of the longest (main) diagonal would be `n = (shorter_length + 1)`. + sz_size_t n = shorter_length + 1; + // Unlike the serial version, we also want to avoid reverse-order iteration over teh shorter string. + // So let's allocate a bit more memory and reverse-export our shorter string into that buffer. + sz_size_t buffer_length = sizeof(sz_u16_t) * n * 3 + shorter_length; + sz_u16_t *distances = (sz_u16_t *)alloc->allocate(buffer_length, alloc->handle); + if (!distances) return SZ_SIZE_MAX; + + sz_u16_t *previous_distances = distances; + sz_u16_t *current_distances = previous_distances + n; + sz_u16_t *next_distances = current_distances + n; + sz_ptr_t shorter_reversed = (sz_ptr_t)(next_distances + n); + + // Export the reversed string into the buffer. + for (sz_size_t i = 0; i != shorter_length; ++i) shorter_reversed[i] = shorter[shorter_length - 1 - i]; + + // Initialize the first two diagonals: + previous_distances[0] = 0; + current_distances[0] = current_distances[1] = 1; + + // Using ZMM registers, we can process 32x 16-bit values at once, + // storing 16 bytes of each string in YMM registers. + sz_u512_vec_t insertions_vec, deletions_vec, substitutions_vec, next_vec; + sz_u512_vec_t ones_u16_vec; + ones_u16_vec.zmm = _mm512_set1_epi16(1); + // This is a mixed-precision implementation, using 8-bit representations for part of the operations. + // Even there, in case `SZ_USE_X86_AVX2=0`, let's use the `sz_u512_vec_t` type, addressing the first YMM halfs. + sz_u512_vec_t shorter_vec, longer_vec; + sz_u512_vec_t ones_u8_vec; + ones_u8_vec.ymms[0] = _mm256_set1_epi8(1); + + // Progress through the upper triangle of the Levenshtein matrix. + sz_size_t next_skew_diagonal_index = 2; + for (; next_skew_diagonal_index != n; ++next_skew_diagonal_index) { + sz_size_t const next_skew_diagonal_length = next_skew_diagonal_index + 1; + for (sz_size_t i = 0; i + 2 < next_skew_diagonal_length;) { + sz_size_t remaining_length = next_skew_diagonal_length - i - 2; + sz_size_t register_length = remaining_length < 32 ? remaining_length : 32; + sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); + longer_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, longer + i); + // Our original code addressed the shorter string `[next_skew_diagonal_index - i - 2]` for growing `i`. + // If the `shorter` string was reversed, the `[next_skew_diagonal_index - i - 2]` would + // be equal to `[shorter_length - 1 - next_skew_diagonal_index + i + 2]`. + // Which simplified would be equal to `[shorter_length - next_skew_diagonal_index + i + 1]`. + shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8( + remaining_length_mask, shorter_reversed + shorter_length - next_skew_diagonal_index + i + 1); + // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 + // to get the result as a vector, instead of a bitmask. Adding 1 to every scalar we can overflow + // transforming from {0xFF, 0} values to {0, 1} values - exactly what we need. Then - upcast to 16-bit. + substitutions_vec.zmm = _mm512_cvtepi8_epi16( // + _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); + substitutions_vec.zmm = _mm512_add_epi16( // + substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i)); + // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, + // than rotate the bytes in the ZMM register. + insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i); + deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1); + // First get the minimum of insertions and deletions. + next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); + next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); + _mm512_mask_storeu_epi16(next_distances + i + 1, remaining_length_mask, next_vec.zmm); + i += register_length; + } + // Don't forget to populate the first row and the fiest column of the Levenshtein matrix. + next_distances[0] = next_distances[next_skew_diagonal_length - 1] = next_skew_diagonal_index; + // Perform a circular rotarion of those buffers, to reuse the memory. + sz_u16_t *temporary = previous_distances; + previous_distances = current_distances; + current_distances = next_distances; + next_distances = temporary; + } + + // By now we've scanned through the upper triangle of the matrix, where each subsequent iteration results in a + // larger diagonal. From now onwards, we will be shrinking. Instead of adding value equal to the skewed diagonal + // index on either side, we will be cropping those values out. + sz_size_t total_diagonals = n + n - 1; + for (; next_skew_diagonal_index != total_diagonals; ++next_skew_diagonal_index) { + sz_size_t const next_skew_diagonal_length = total_diagonals - next_skew_diagonal_index; + for (sz_size_t i = 0; i != next_skew_diagonal_length;) { + sz_size_t remaining_length = next_skew_diagonal_length - i; + sz_size_t register_length = remaining_length < 32 ? remaining_length : 32; + sz_u32_t remaining_length_mask = _bzhi_u32(0xFFFFFFFFu, register_length); + longer_vec.ymms[0] = + _mm256_maskz_loadu_epi8(remaining_length_mask, longer + next_skew_diagonal_index - n + i); + // Our original code addressed the shorter string `[shorter_length - 1 - i]` for growing `i`. + // If the `shorter` string was reversed, the `[shorter_length - 1 - i]` would + // be equal to `[shorter_length - 1 - shorter_length + 1 + i]`. + // Which simplified would be equal to just `[i]`. Beautiful! + shorter_vec.ymms[0] = _mm256_maskz_loadu_epi8(remaining_length_mask, shorter_reversed + i); + // For substitutions, perform the equality comparison using AVX2 instead of AVX-512 + // to get the result as a vector, instead of a bitmask. The compare it against the accumulated + // substitution costs. + substitutions_vec.zmm = _mm512_cvtepi8_epi16( // + _mm256_add_epi8(_mm256_cmpeq_epi8(longer_vec.ymms[0], shorter_vec.ymms[0]), ones_u8_vec.ymms[0])); + substitutions_vec.zmm = _mm512_add_epi16( // + substitutions_vec.zmm, _mm512_maskz_loadu_epi16(remaining_length_mask, previous_distances + i)); + // For insertions and deletions, on modern hardware, it's faster to issue two separate loads, + // than rotate the bytes in the ZMM register. + insertions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i); + deletions_vec.zmm = _mm512_maskz_loadu_epi16(remaining_length_mask, current_distances + i + 1); + // First get the minimum of insertions and deletions. + next_vec.zmm = _mm512_add_epi16(_mm512_min_epu16(insertions_vec.zmm, deletions_vec.zmm), ones_u16_vec.zmm); + next_vec.zmm = _mm512_min_epu16(next_vec.zmm, substitutions_vec.zmm); + _mm512_mask_storeu_epi16(next_distances + i, remaining_length_mask, next_vec.zmm); + i += register_length; + } + + // Perform a circular rotarion of those buffers, to reuse the memory, this time, with a shift, + // dropping the first element in the current array. + sz_u16_t *temporary = previous_distances; + previous_distances = current_distances + 1; + current_distances = next_distances; + next_distances = temporary; + } + + // Cache scalar before `free` call. + sz_size_t result = current_distances[0]; + alloc->free(distances, buffer_length, alloc->handle); + return result; +} + +SZ_INTERNAL sz_size_t sz_edit_distance_avx512( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_size_t bound, sz_memory_allocator_t *alloc) { + + if (shorter_length == longer_length && !bound && shorter_length && shorter_length < 256u * 256u) + return _sz_edit_distance_skewed_diagonals_upto65k_avx512(shorter, shorter_length, longer, longer_length, bound, + alloc); + else + return sz_edit_distance_serial(shorter, shorter_length, longer, longer_length, bound, alloc); +} + +#pragma clang attribute pop +#pragma GCC pop_options + +#pragma GCC push_options +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512dq", "bmi", "bmi2") +#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512dq,bmi,bmi2"))), \ + apply_to = function) + +SZ_PUBLIC void sz_hashes_avx512(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle) { + + if (length < window_length || !window_length) return; + if (length < 4 * window_length) { + sz_hashes_serial(start, length, window_length, step, callback, callback_handle); + return; + } + + // Using AVX2, we can perform 4 long integer multiplications and additions within one register. + // So let's slice the entire string into 4 overlapping windows, to slide over them in parallel. + sz_size_t const max_hashes = length - window_length + 1; + sz_size_t const min_hashes_per_thread = max_hashes / 4; // At most one sequence can overlap between 2 threads. + sz_u8_t const *text_first = (sz_u8_t const *)start; + sz_u8_t const *text_second = text_first + min_hashes_per_thread; + sz_u8_t const *text_third = text_first + min_hashes_per_thread * 2; + sz_u8_t const *text_fourth = text_first + min_hashes_per_thread * 3; + sz_u8_t const *text_end = text_first + length; + + // Broadcast the global constants into the registers. + // Both high and low hashes will work with the same prime and golden ratio. + sz_u512_vec_t prime_vec, golden_ratio_vec; + prime_vec.zmm = _mm512_set1_epi64(SZ_U64_MAX_PRIME); + golden_ratio_vec.zmm = _mm512_set1_epi64(11400714819323198485ull); + + // Prepare the `prime ^ window_length` values, that we are going to use for modulo arithmetic. + sz_u64_t prime_power_low = 1, prime_power_high = 1; + for (sz_size_t i = 0; i + 1 < window_length; ++i) + prime_power_low = (prime_power_low * 31ull) % SZ_U64_MAX_PRIME, + prime_power_high = (prime_power_high * 257ull) % SZ_U64_MAX_PRIME; + + // We will be evaluating 4 offsets at a time with 2 different hash functions. + // We can fit all those 8 state variables in each of the following ZMM registers. + sz_u512_vec_t base_vec, prime_power_vec, shift_vec; + base_vec.zmm = _mm512_set_epi64(31ull, 31ull, 31ull, 31ull, 257ull, 257ull, 257ull, 257ull); + shift_vec.zmm = _mm512_set_epi64(0ull, 0ull, 0ull, 0ull, 77ull, 77ull, 77ull, 77ull); + prime_power_vec.zmm = _mm512_set_epi64(prime_power_low, prime_power_low, prime_power_low, prime_power_low, + prime_power_high, prime_power_high, prime_power_high, prime_power_high); + + // Compute the initial hash values for every one of the four windows. + sz_u512_vec_t hash_vec, chars_vec; + hash_vec.zmm = _mm512_setzero_si512(); + for (sz_u8_t const *prefix_end = text_first + window_length; text_first < prefix_end; + ++text_first, ++text_second, ++text_third, ++text_fourth) { + + // 1. Multiply the hashes by the base. + hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); + + // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, + // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`... + chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // + text_fourth[0], text_third[0], text_second[0], text_first[0]); + chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); + + // 3. Add the incoming characters. + hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); + + // 4. Compute the modulo. Assuming there are only 59 values between our prime + // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. + hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, + _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); + } + + // 5. Compute the hash mix, that will be used to index into the fingerprint. + // This includes a serial step at the end. + sz_u512_vec_t hash_mix_vec; + hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); + hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // + _mm512_extracti64x4_epi64(hash_mix_vec.zmm, 0)); + + callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); + callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); + callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); + callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); + + // Now repeat that operation for the remaining characters, discarding older characters. + sz_size_t cycle = 1; + sz_size_t step_mask = step - 1; + for (; text_fourth != text_end; ++text_first, ++text_second, ++text_third, ++text_fourth, ++cycle) { + // 0. Load again the four characters we are dropping, shift them, and subtract. + chars_vec.zmm = _mm512_set_epi64(text_fourth[-window_length], text_third[-window_length], + text_second[-window_length], text_first[-window_length], // + text_fourth[-window_length], text_third[-window_length], + text_second[-window_length], text_first[-window_length]); + chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); + hash_vec.zmm = _mm512_sub_epi64(hash_vec.zmm, _mm512_mullo_epi64(chars_vec.zmm, prime_power_vec.zmm)); + + // 1. Multiply the hashes by the base. + hash_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, base_vec.zmm); + + // 2. Load the four characters from `text_first`, `text_first + max_hashes_per_thread`, + // `text_first + max_hashes_per_thread * 2`, `text_first + max_hashes_per_thread * 3`. + chars_vec.zmm = _mm512_set_epi64(text_fourth[0], text_third[0], text_second[0], text_first[0], // + text_fourth[0], text_third[0], text_second[0], text_first[0]); + chars_vec.zmm = _mm512_add_epi8(chars_vec.zmm, shift_vec.zmm); + + // ... and prefetch the next four characters into Level 2 or higher. + _mm_prefetch(text_fourth + 1, _MM_HINT_T1); + _mm_prefetch(text_third + 1, _MM_HINT_T1); + _mm_prefetch(text_second + 1, _MM_HINT_T1); + _mm_prefetch(text_first + 1, _MM_HINT_T1); + + // 3. Add the incoming characters. + hash_vec.zmm = _mm512_add_epi64(hash_vec.zmm, chars_vec.zmm); + + // 4. Compute the modulo. Assuming there are only 59 values between our prime + // and the 2^64 value, we can simply compute the modulo by conditionally subtracting the prime. + hash_vec.zmm = _mm512_mask_blend_epi8(_mm512_cmpgt_epi64_mask(hash_vec.zmm, prime_vec.zmm), hash_vec.zmm, + _mm512_sub_epi64(hash_vec.zmm, prime_vec.zmm)); + + // 5. Compute the hash mix, that will be used to index into the fingerprint. + // This includes a serial step at the end. + hash_mix_vec.zmm = _mm512_mullo_epi64(hash_vec.zmm, golden_ratio_vec.zmm); + hash_mix_vec.ymms[0] = _mm256_xor_si256(_mm512_extracti64x4_epi64(hash_mix_vec.zmm, 1), // + _mm512_castsi512_si256(hash_mix_vec.zmm)); + + if ((cycle & step_mask) == 0) { + callback((sz_cptr_t)text_first, window_length, hash_mix_vec.u64s[0], callback_handle); + callback((sz_cptr_t)text_second, window_length, hash_mix_vec.u64s[1], callback_handle); + callback((sz_cptr_t)text_third, window_length, hash_mix_vec.u64s[2], callback_handle); + callback((sz_cptr_t)text_fourth, window_length, hash_mix_vec.u64s[3], callback_handle); + } + } +} + +#pragma clang attribute pop +#pragma GCC pop_options + +#pragma GCC push_options +#pragma GCC target("avx", "avx512f", "avx512vl", "avx512bw", "avx512vbmi", "bmi", "bmi2", "gfni") +#pragma clang attribute push(__attribute__((target("avx,avx512f,avx512vl,avx512bw,avx512vbmi,bmi,bmi2,gfni"))), \ + apply_to = function) + +SZ_PUBLIC sz_cptr_t sz_find_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { + + sz_size_t load_length; + __mmask32 load_mask, matches_mask; + // To store the set in the register we need just 256 bits, but the `VPERMB` instruction + // we are going to invoke is surprisingly cheaper on ZMM registers. + sz_u512_vec_t text_vec, filter_vec; + filter_vec.ymms[0] = _mm256_loadu_epi64(&filter->_u64s[0]); + + // We are going to view the `filter` at 8-bit word granularity. + sz_u512_vec_t filter_slice_offsets_vec; + sz_u512_vec_t filter_slice_vec; + sz_u512_vec_t offset_within_slice_vec; + sz_u512_vec_t mask_in_filter_slice_vec; + sz_u512_vec_t matches_vec; + + while (length) { + // For every byte: + // 1. Find corresponding word in a set. + // 2. Produce a bitmask to check against that word. + load_length = sz_min_of_two(length, 32); + load_mask = _sz_u64_mask_until(load_length); + text_vec.ymms[0] = _mm256_maskz_loadu_epi8(load_mask, text); + + // To shift right every byte by 3 bits we can use the GF2 affine transformations. + // https://wunkolo.github.io/post/2020/11/gf2p8affineqb-int8-shifting/ + // After next line, all 8-bit offsets in the `filter_slice_offsets_vec` should be under 32. + filter_slice_offsets_vec.ymms[0] = + _mm256_gf2p8affine_epi64_epi8(text_vec.ymms[0], _mm256_set1_epi64x(0x0102040810204080ull << (3 * 8)), 0); + + // After next line, `filter_slice_vec` will contain the right word from the set, + // needed to filter the presence of the byte in the set. + filter_slice_vec.ymms[0] = _mm256_permutexvar_epi8(filter_slice_offsets_vec.ymms[0], filter_vec.ymms[0]); + + // After next line, all 8-bit offsets in the `filter_slice_offsets_vec` should be under 8. + offset_within_slice_vec.ymms[0] = _mm256_and_si256(text_vec.ymms[0], _mm256_set1_epi64x(0x0707070707070707ull)); + + // Instead of performing one more Galois Field operation, we can upcast to 16-bit integers, + // and perform the fift and intersection there. + filter_slice_vec.zmm = _mm512_cvtepi8_epi16(filter_slice_vec.ymms[0]); + offset_within_slice_vec.zmm = _mm512_cvtepi8_epi16(offset_within_slice_vec.ymms[0]); + mask_in_filter_slice_vec.zmm = _mm512_sllv_epi16(_mm512_set1_epi16(1), offset_within_slice_vec.zmm); + matches_vec.zmm = _mm512_and_si512(filter_slice_vec.zmm, mask_in_filter_slice_vec.zmm); + + matches_mask = _mm512_mask_cmpneq_epi16_mask(load_mask, matches_vec.zmm, _mm512_setzero_si512()); + if (matches_mask) { + int offset = sz_u32_ctz(matches_mask); + return text + offset; + } + else { text += load_length, length -= load_length; } + } + + return SZ_NULL; +} + +SZ_PUBLIC sz_cptr_t sz_rfind_charset_avx512(sz_cptr_t text, sz_size_t length, sz_charset_t const *filter) { + + sz_size_t load_length; + __mmask32 load_mask, matches_mask; + // To store the set in the register we need just 256 bits, but the `VPERMB` instruction + // we are going to invoke is surprisingly cheaper on ZMM registers. + sz_u512_vec_t text_vec, filter_vec; + filter_vec.ymms[0] = _mm256_loadu_epi64(&filter->_u64s[0]); + + // We are going to view the `filter` at 8-bit word granularity. + sz_u512_vec_t filter_slice_offsets_vec; + sz_u512_vec_t filter_slice_vec; + sz_u512_vec_t offset_within_slice_vec; + sz_u512_vec_t mask_in_filter_slice_vec; + sz_u512_vec_t matches_vec; + + while (length) { + // For every byte: + // 1. Find corresponding word in a set. + // 2. Produce a bitmask to check against that word. + load_length = sz_min_of_two(length, 32); + load_mask = _sz_u64_mask_until(load_length); + text_vec.ymms[0] = _mm256_maskz_loadu_epi8(load_mask, text + length - load_length); + + // To shift right every byte by 3 bits we can use the GF2 affine transformations. + // https://wunkolo.github.io/post/2020/11/gf2p8affineqb-int8-shifting/ + // After next line, all 8-bit offsets in the `filter_slice_offsets_vec` should be under 32. + filter_slice_offsets_vec.ymms[0] = + _mm256_gf2p8affine_epi64_epi8(text_vec.ymms[0], _mm256_set1_epi64x(0x0102040810204080ull << (3 * 8)), 0); + + // After next line, `filter_slice_vec` will contain the right word from the set, + // needed to filter the presence of the byte in the set. + filter_slice_vec.ymms[0] = _mm256_permutexvar_epi8(filter_slice_offsets_vec.ymms[0], filter_vec.ymms[0]); + + // After next line, all 8-bit offsets in the `filter_slice_offsets_vec` should be under 8. + offset_within_slice_vec.ymms[0] = _mm256_and_si256(text_vec.ymms[0], _mm256_set1_epi64x(0x0707070707070707ull)); + + // Instead of performing one more Galois Field operation, we can upcast to 16-bit integers, + // and perform the fift and intersection there. + filter_slice_vec.zmm = _mm512_cvtepi8_epi16(filter_slice_vec.ymms[0]); + offset_within_slice_vec.zmm = _mm512_cvtepi8_epi16(offset_within_slice_vec.ymms[0]); + mask_in_filter_slice_vec.zmm = _mm512_sllv_epi16(_mm512_set1_epi16(1), offset_within_slice_vec.zmm); + matches_vec.zmm = _mm512_and_si512(filter_slice_vec.zmm, mask_in_filter_slice_vec.zmm); + + matches_mask = _mm512_mask_cmpneq_epi16_mask(load_mask, matches_vec.zmm, _mm512_setzero_si512()); + if (matches_mask) { + int offset = sz_u32_clz(matches_mask); + return text + length - load_length + 32 - offset - 1; + } + else { length -= load_length; } + } + + return SZ_NULL; +} + +/** + * Computes the Needleman Wunsch alignment score between two strings. + * The method uses 32-bit integers to accumulate the running score for every cell in the matrix. + * Assuming the costs of substitutions can be arbitrary signed 8-bit integers, the method is expected to be used + * on strings not exceeding 2^24 length or 16.7 million characters. + * + * Unlike the `_sz_edit_distance_skewed_diagonals_upto65k_avx512` method, this one uses signed integers to store + * the accumulated score. Moreover, it's primary bottleneck is the latency of gathering the substitution costs + * from the substitution matrix. If we use the diagonal order, we will be comparing a slice of the first string with + * a slice of the second. If we stick to the conventional horizontal order, we will be comparing one character against + * a slice, which is much easier to optimize. In that case we are sampling costs not from arbitrary parts of + * a 256 x 256 matrix, but from a single row! + */ +SZ_INTERNAL sz_ssize_t _sz_alignment_score_wagner_fisher_upto17m_avx512( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { + + // If one of the strings is empty - the edit distance is equal to the length of the other one + if (longer_length == 0) return (sz_ssize_t)shorter_length * gap; + if (shorter_length == 0) return (sz_ssize_t)longer_length * gap; + + // Let's make sure that we use the amount proportional to the + // number of elements in the shorter string, not the larger. + if (shorter_length > longer_length) { + sz_u64_swap((sz_u64_t *)&longer_length, (sz_u64_t *)&shorter_length); + sz_u64_swap((sz_u64_t *)&longer, (sz_u64_t *)&shorter); + } + + // Simplify usage in higher-level libraries, where wrapping custom allocators may be troublesome. + sz_memory_allocator_t global_alloc; + if (!alloc) { + sz_memory_allocator_init_default(&global_alloc); + alloc = &global_alloc; + } + + sz_size_t const max_length = 256ull * 256ull * 256ull; + sz_size_t const n = longer_length + 1; + sz_assert(n < max_length && "The length must fit into 24-bit integer. Otherwise use serial variant."); + sz_unused(longer_length && max_length); + + sz_size_t buffer_length = sizeof(sz_i32_t) * n * 2; + sz_i32_t *distances = (sz_i32_t *)alloc->allocate(buffer_length, alloc->handle); + sz_i32_t *previous_distances = distances; + sz_i32_t *current_distances = previous_distances + n; + + // Intialize the first row of the Levenshtein matrix with `iota`. + for (sz_size_t idx_longer = 0; idx_longer != n; ++idx_longer) + previous_distances[idx_longer] = (sz_ssize_t)idx_longer * gap; + + /// Contains up to 16 consecutive characters from the longer string. + sz_u512_vec_t longer_vec; + sz_u512_vec_t cost_deletion_vec, cost_substitution_vec, lookup_substitution_vec, current_vec; + sz_u512_vec_t row_first_subs_vec, row_second_subs_vec, row_third_subs_vec, row_fourth_subs_vec; + sz_u512_vec_t shuffled_first_subs_vec, shuffled_second_subs_vec, shuffled_third_subs_vec, shuffled_fourth_subs_vec; + + // Prepare constants and masks. + sz_u512_vec_t is_third_or_fourth_vec, is_second_or_fourth_vec, gap_vec; + { + char is_third_or_fourth_check, is_second_or_fourth_check; + *(sz_u8_t *)&is_third_or_fourth_check = 0x80, *(sz_u8_t *)&is_second_or_fourth_check = 0x40; + is_third_or_fourth_vec.zmm = _mm512_set1_epi8(is_third_or_fourth_check); + is_second_or_fourth_vec.zmm = _mm512_set1_epi8(is_second_or_fourth_check); + gap_vec.zmm = _mm512_set1_epi32(gap); + } + + sz_u8_t const *shorter_unsigned = (sz_u8_t const *)shorter; + for (sz_size_t idx_shorter = 0; idx_shorter != shorter_length; ++idx_shorter) { + sz_i32_t last_in_row = current_distances[0] = (sz_ssize_t)(idx_shorter + 1) * gap; + + // Load one row of the substitution matrix into four ZMM registers. + sz_error_cost_t const *row_subs = subs + shorter_unsigned[idx_shorter] * 256u; + row_first_subs_vec.zmm = _mm512_loadu_epi8(row_subs + 64 * 0); + row_second_subs_vec.zmm = _mm512_loadu_epi8(row_subs + 64 * 1); + row_third_subs_vec.zmm = _mm512_loadu_epi8(row_subs + 64 * 2); + row_fourth_subs_vec.zmm = _mm512_loadu_epi8(row_subs + 64 * 3); + + // In the serial version we have one forward pass, that computes the deletion, + // insertion, and substitution costs at once. + // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { + // sz_ssize_t cost_deletion = previous_distances[idx_longer + 1] + gap; + // sz_ssize_t cost_insertion = current_distances[idx_longer] + gap; + // sz_ssize_t cost_substitution = previous_distances[idx_longer] + row_subs[longer_unsigned[idx_longer]]; + // current_distances[idx_longer + 1] = sz_min_of_three(cost_deletion, cost_insertion, cost_substitution); + // } + // + // Given the complexity of handling the data-dependency between consecutive insertion cost computations + // within a Levenshtein matrix, the simplest design would be to vectorize every kind of cost computation + // separately. + // 1. Compute substitution costs for up to 64 characters at once, upcasting from 8-bit integers to 32. + // 2. Compute the pairwise minimum with deletion costs. + // 3. Inclusive prefix minimum computation to combine with addition costs. + // Proceeding with substitutions: + for (sz_size_t idx_longer = 0; idx_longer < longer_length; idx_longer += 64) { + sz_size_t register_length = sz_min_of_two(longer_length - idx_longer, 64); + __mmask64 mask = _sz_u64_mask_until(register_length); + longer_vec.zmm = _mm512_maskz_loadu_epi8(mask, longer + idx_longer); + + // Blend the `row_(first|second|third|fourth)_subs_vec` into `current_vec`, picking the right source + // for every character in `longer_vec`. Before that, we need to permute the subsititution vectors. + // Only the bottom 6 bits of a byte are used in VPERB, so we don't even need to mask. + shuffled_first_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_first_subs_vec.zmm); + shuffled_second_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_second_subs_vec.zmm); + shuffled_third_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_third_subs_vec.zmm); + shuffled_fourth_subs_vec.zmm = _mm512_maskz_permutexvar_epi8(mask, longer_vec.zmm, row_fourth_subs_vec.zmm); + + // To blend we can invoke three `_mm512_cmplt_epu8_mask`, but we can also achieve the same using + // the AND logical operation, checking the top two bits of every byte. + // Continuing this thought, we can use the VPTESTMB instruction to output the mask after the AND. + __mmask64 is_third_or_fourth = _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_third_or_fourth_vec.zmm); + __mmask64 is_second_or_fourth = + _mm512_mask_test_epi8_mask(mask, longer_vec.zmm, is_second_or_fourth_vec.zmm); + lookup_substitution_vec.zmm = _mm512_mask_blend_epi8( + is_third_or_fourth, + // Choose between the first and the second. + _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_first_subs_vec.zmm, shuffled_second_subs_vec.zmm), + // Choose between the third and the fourth. + _mm512_mask_blend_epi8(is_second_or_fourth, shuffled_third_subs_vec.zmm, shuffled_fourth_subs_vec.zmm)); + + // First, sign-extend lower and upper 16 bytes to 16-bit integers. + __m512i current_0_31_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 0)); + __m512i current_32_63_vec = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(lookup_substitution_vec.zmm, 1)); + + // Now extend those 16-bit integers to 32-bit. + // This isn't free, same as the subsequent store, so we only want to do that for the populated lanes. + // To minimize the number of loads and stores, we can combine our substitution costs with the previous + // distances, containing the deletion costs. + { + cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + idx_longer); + cost_substitution_vec.zmm = _mm512_add_epi32( + cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 0))); + cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + 1 + idx_longer); + cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); + current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); + + // Inclusive prefix minimum computation to combine with insertion costs. + // Simply disabling this operation results in 5x performance improvement, meaning + // that this operation is responsible for 80% of the total runtime. + // for (sz_size_t idx_longer = 0; idx_longer < longer_length; ++idx_longer) { + // current_distances[idx_longer + 1] = + // sz_max_of_two(current_distances[idx_longer] + gap, current_distances[idx_longer + 1]); + // } + // + // To perform the same operation in vectorized form, we need to perform a tree-like reduction, + // that will involve multiple steps. It's quite expensive and should be first tested in the + // "experimental" section. + // + // Another approach might be loop unrolling: + // current_vec.i32s[0] = last_in_row = sz_i32_max_of_two(current_vec.i32s[0], last_in_row + gap); + // current_vec.i32s[1] = last_in_row = sz_i32_max_of_two(current_vec.i32s[1], last_in_row + gap); + // current_vec.i32s[2] = last_in_row = sz_i32_max_of_two(current_vec.i32s[2], last_in_row + gap); + // ... yet this approach is also quite expensive. + for (int i = 0; i != 16; ++i) + current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); + _mm512_mask_storeu_epi32(current_distances + idx_longer + 1, mask, current_vec.zmm); + } + + // Export the values from 16 to 31. + if (register_length > 16) { + mask = _kshiftri_mask64(mask, 16); + cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + idx_longer + 16); + cost_substitution_vec.zmm = _mm512_add_epi32( + cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_0_31_vec, 1))); + cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + 1 + idx_longer + 16); + cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); + current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); + + // Aggregate running insertion costs within the register. + for (int i = 0; i != 16; ++i) + current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); + _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 16, mask, current_vec.zmm); + } + + // Export the values from 32 to 47. + if (register_length > 32) { + mask = _kshiftri_mask64(mask, 16); + cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + idx_longer + 32); + cost_substitution_vec.zmm = _mm512_add_epi32( + cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 0))); + cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + 1 + idx_longer + 32); + cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); + current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); + + // Aggregate running insertion costs within the register. + for (int i = 0; i != 16; ++i) + current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); + _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 32, mask, current_vec.zmm); + } + + // Export the values from 32 to 47. + if (register_length > 48) { + mask = _kshiftri_mask64(mask, 16); + cost_substitution_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + idx_longer + 48); + cost_substitution_vec.zmm = _mm512_add_epi32( + cost_substitution_vec.zmm, _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(current_32_63_vec, 1))); + cost_deletion_vec.zmm = _mm512_maskz_loadu_epi32(mask, previous_distances + 1 + idx_longer + 48); + cost_deletion_vec.zmm = _mm512_add_epi32(cost_deletion_vec.zmm, gap_vec.zmm); + current_vec.zmm = _mm512_max_epi32(cost_substitution_vec.zmm, cost_deletion_vec.zmm); + + // Aggregate running insertion costs within the register. + for (int i = 0; i != 16; ++i) + current_vec.i32s[i] = last_in_row = sz_max_of_two(current_vec.i32s[i], last_in_row + gap); + _mm512_mask_storeu_epi32(current_distances + idx_longer + 1 + 48, mask, current_vec.zmm); + } + } + + // Swap previous_distances and current_distances pointers + sz_u64_swap((sz_u64_t *)&previous_distances, (sz_u64_t *)¤t_distances); + } + + // Cache scalar before `free` call. + sz_ssize_t result = previous_distances[longer_length]; + alloc->free(distances, buffer_length, alloc->handle); + return result; +} + +SZ_INTERNAL sz_ssize_t sz_alignment_score_avx512( // + sz_cptr_t shorter, sz_size_t shorter_length, // + sz_cptr_t longer, sz_size_t longer_length, // + sz_error_cost_t const *subs, sz_error_cost_t gap, sz_memory_allocator_t *alloc) { + + if (sz_max_of_two(shorter_length, longer_length) < (256ull * 256ull * 256ull)) + return _sz_alignment_score_wagner_fisher_upto17m_avx512(shorter, shorter_length, longer, longer_length, subs, + gap, alloc); + else + return sz_alignment_score_serial(shorter, shorter_length, longer, longer_length, subs, gap, alloc); +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif + +#pragma endregion + +/* @brief Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit + * Arm processors. Implements: {substring search, character search, character set search} x {forward, reverse}. + */ +#pragma region ARM NEON + +#if SZ_USE_ARM_NEON + +/** + * @brief Helper structure to simplify work with 64-bit words. + */ +typedef union sz_u128_vec_t { + uint8x16_t u8x16; + uint16x8_t u16x8; + uint32x4_t u32x4; + uint64x2_t u64x2; + sz_u64_t u64s[2]; + sz_u32_t u32s[4]; + sz_u16_t u16s[8]; + sz_u8_t u8s[16]; +} sz_u128_vec_t; + +SZ_INTERNAL sz_u64_t vreinterpretq_u8_u4(uint8x16_t vec) { + // Use `vshrn` to produce a bitmask, similar to `movemask` in SSE. + // https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon + return vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(vec), 4)), 0) & 0x8888888888888888ull; +} + +SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + sz_u64_t matches; + sz_u128_vec_t h_vec, n_vec, matches_vec; + n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); + + while (h_length >= 16) { + h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h); + matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); + // In Arm NEON we don't have a `movemask` to combine it with `ctz` and get the offset of the match. + // But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting) + // the vector with a relative offsets array. + matches = vreinterpretq_u8_u4(matches_vec.u8x16); + if (matches) return h + sz_u64_ctz(matches) / 4; + + h += 16, h_length -= 16; + } + + return sz_find_byte_serial(h, h_length, n); +} + +SZ_PUBLIC sz_cptr_t sz_rfind_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) { + sz_u64_t matches; + sz_u128_vec_t h_vec, n_vec, matches_vec; + n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n); + + while (h_length >= 16) { + h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h + h_length - 16); + matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16); + matches = vreinterpretq_u8_u4(matches_vec.u8x16); + if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; + h_length -= 16; + } + + return sz_rfind_byte_serial(h, h_length, n); +} + +SZ_PUBLIC sz_u64_t _sz_find_charset_neon_register(sz_u128_vec_t h_vec, uint8x16_t set_top_vec_u8x16, + uint8x16_t set_bottom_vec_u8x16) { + + // Once we've read the characters in the haystack, we want to + // compare them against our bitset. The serial version of that code + // would look like: `(set_->_u8s[c >> 3] & (1u << (c & 7u))) != 0`. + uint8x16_t byte_index_vec = vshrq_n_u8(h_vec.u8x16, 3); + uint8x16_t byte_mask_vec = vshlq_u8(vdupq_n_u8(1), vreinterpretq_s8_u8(vandq_u8(h_vec.u8x16, vdupq_n_u8(7)))); + uint8x16_t matches_top_vec = vqtbl1q_u8(set_top_vec_u8x16, byte_index_vec); + // The table lookup instruction in NEON replies to out-of-bound requests with zeros. + // The values in `byte_index_vec` all fall in [0; 32). So for values under 16, substracting 16 will underflow + // and map into interval [240, 256). Meaning that those will be populated with zeros and we can safely + // merge `matches_top_vec` and `matches_bottom_vec` with a bitwise OR. + uint8x16_t matches_bottom_vec = vqtbl1q_u8(set_bottom_vec_u8x16, vsubq_u8(byte_index_vec, vdupq_n_u8(16))); + uint8x16_t matches_vec = vorrq_u8(matches_top_vec, matches_bottom_vec); + // Istead of pure `vandq_u8`, we can immediately broadcast a match presence across each 8-bit word. + matches_vec = vtstq_u8(matches_vec, byte_mask_vec); + return vreinterpretq_u8_u4(matches_vec); +} + +SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + + // This almost never fires, but it's better to be safe than sorry. + if (h_length < n_length || !n_length) return SZ_NULL; + if (n_length == 1) return sz_find_byte_neon(h, h_length, n); + + // Pick the parts of the needle that are worth comparing. + sz_size_t offset_first, offset_mid, offset_last; + _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); + + // Broadcast those characters into SIMD registers. + sz_u64_t matches; + sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; + n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); + n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); + n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); + + // Scan through the string. + for (; h_length >= n_length + 16; h += 16, h_length -= 16) { + h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_first)); + h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_mid)); + h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + offset_last)); + matches_vec.u8x16 = vandq_u8( // + vandq_u8( // + vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // + vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), + vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); + matches = vreinterpretq_u8_u4(matches_vec.u8x16); + while (matches) { + int potential_offset = sz_u64_ctz(matches) / 4; + if (sz_equal(h + potential_offset, n, n_length)) return h + potential_offset; + matches &= matches - 1; + } + } + + return sz_find_serial(h, h_length, n, n_length); +} + +SZ_PUBLIC sz_cptr_t sz_rfind_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + + // This almost never fires, but it's better to be safe than sorry. + if (h_length < n_length || !n_length) return SZ_NULL; + if (n_length == 1) return sz_rfind_byte_neon(h, h_length, n); + + // Pick the parts of the needle that are worth comparing. + sz_size_t offset_first, offset_mid, offset_last; + _sz_locate_needle_anomalies(n, n_length, &offset_first, &offset_mid, &offset_last); + + // Will contain 4 bits per character. + sz_u64_t matches; + sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec; + n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_first]); + n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_mid]); + n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[offset_last]); + + sz_cptr_t h_reversed; + for (; h_length >= n_length + 16; h_length -= 16) { + h_reversed = h + h_length - n_length - 16 + 1; + h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_first)); + h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_mid)); + h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h_reversed + offset_last)); + matches_vec.u8x16 = vandq_u8( // + vandq_u8( // + vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), // + vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)), + vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16)); + matches = vreinterpretq_u8_u4(matches_vec.u8x16); + while (matches) { + int potential_offset = sz_u64_clz(matches) / 4; + if (sz_equal(h + h_length - n_length - potential_offset, n, n_length)) + return h + h_length - n_length - potential_offset; + sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 && + "The bit must be set before we squash it"); + matches &= ~(1ull << (63 - potential_offset * 4)); + } + } + + return sz_rfind_serial(h, h_length, n, n_length); +} + +SZ_PUBLIC sz_cptr_t sz_find_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { + sz_u64_t matches; + sz_u128_vec_t h_vec; + uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); + uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); + + for (; h_length >= 16; h += 16, h_length -= 16) { + h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h)); + matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); + if (matches) return h + sz_u64_ctz(matches) / 4; + } + + return sz_find_charset_serial(h, h_length, set); +} + +SZ_PUBLIC sz_cptr_t sz_rfind_charset_neon(sz_cptr_t h, sz_size_t h_length, sz_charset_t const *set) { + sz_u64_t matches; + sz_u128_vec_t h_vec; + uint8x16_t set_top_vec_u8x16 = vld1q_u8(&set->_u8s[0]); + uint8x16_t set_bottom_vec_u8x16 = vld1q_u8(&set->_u8s[16]); + + // Check `sz_find_charset_neon` for explanations. + for (; h_length >= 16; h_length -= 16) { + h_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h) + h_length - 16); + matches = _sz_find_charset_neon_register(h_vec, set_top_vec_u8x16, set_bottom_vec_u8x16); + if (matches) return h + h_length - 1 - sz_u64_clz(matches) / 4; + } + + return sz_rfind_charset_serial(h, h_length, set); +} + +#endif // Arm Neon + +#pragma endregion + +/* + * @brief Pick the right implementation for the string search algorithms. + */ +#pragma region Compile-Time Dispatching + +SZ_PUBLIC sz_u64_t sz_hash(sz_cptr_t ins, sz_size_t length) { return sz_hash_serial(ins, length); } +SZ_PUBLIC void sz_tolower(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_tolower_serial(ins, length, outs); } +SZ_PUBLIC void sz_toupper(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toupper_serial(ins, length, outs); } +SZ_PUBLIC void sz_toascii(sz_cptr_t ins, sz_size_t length, sz_ptr_t outs) { sz_toascii_serial(ins, length, outs); } + +SZ_PUBLIC void sz_hashes_fingerprint(sz_cptr_t start, sz_size_t length, sz_size_t window_length, sz_ptr_t fingerprint, + sz_size_t fingerprint_bytes) { + + sz_bool_t fingerprint_length_is_power_of_two = (sz_bool_t)((fingerprint_bytes & (fingerprint_bytes - 1)) == 0); + sz_string_view_t fingerprint_buffer = {fingerprint, fingerprint_bytes}; + + // https://blog.stuffedcow.net/2015/08/pagewalk-coherence/ + + // In most cases the fingerprint length will be a power of two. + if (fingerprint_length_is_power_of_two == sz_false_k) + sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_non_pow2_callback, &fingerprint_buffer); + else + sz_hashes(start, length, window_length, 1, _sz_hashes_fingerprint_pow2_callback, &fingerprint_buffer); +} + +#if !SZ_DYNAMIC_DISPATCH + +SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { +#if SZ_USE_X86_AVX512 + return sz_equal_avx512(a, b, length); +#else + return sz_equal_serial(a, b, length); +#endif +} + +SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { +#if SZ_USE_X86_AVX512 + return sz_order_avx512(a, a_length, b, b_length); +#else + return sz_order_serial(a, a_length, b, b_length); +#endif +} + +SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { +#if SZ_USE_X86_AVX512 + sz_copy_avx512(target, source, length); +#elif SZ_USE_X86_AVX2 + sz_copy_avx2(target, source, length); +#else + sz_copy_serial(target, source, length); +#endif +} + +SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { +#if SZ_USE_X86_AVX512 + sz_move_avx512(target, source, length); +#elif SZ_USE_X86_AVX2 + sz_move_avx2(target, source, length); +#else + sz_move_serial(target, source, length); +#endif +} + +SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { +#if SZ_USE_X86_AVX512 + sz_fill_avx512(target, length, value); +#elif SZ_USE_X86_AVX2 + sz_fill_avx2(target, length, value); +#else + sz_fill_serial(target, length, value); +#endif +} + +SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { +#if SZ_USE_X86_AVX512 + return sz_find_byte_avx512(haystack, h_length, needle); +#elif SZ_USE_X86_AVX2 + return sz_find_byte_avx2(haystack, h_length, needle); +#elif SZ_USE_ARM_NEON + return sz_find_byte_neon(haystack, h_length, needle); +#else + return sz_find_byte_serial(haystack, h_length, needle); +#endif +} + +SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { +#if SZ_USE_X86_AVX512 + return sz_rfind_byte_avx512(haystack, h_length, needle); +#elif SZ_USE_X86_AVX2 + return sz_rfind_byte_avx2(haystack, h_length, needle); +#elif SZ_USE_ARM_NEON + return sz_rfind_byte_neon(haystack, h_length, needle); +#else + return sz_rfind_byte_serial(haystack, h_length, needle); +#endif +} + +SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { +#if SZ_USE_X86_AVX512 + return sz_find_avx512(haystack, h_length, needle, n_length); +#elif SZ_USE_X86_AVX2 + return sz_find_avx2(haystack, h_length, needle, n_length); +#elif SZ_USE_ARM_NEON + return sz_find_neon(haystack, h_length, needle, n_length); +#else + return sz_find_serial(haystack, h_length, needle, n_length); +#endif +} + +SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { +#if SZ_USE_X86_AVX512 + return sz_rfind_avx512(haystack, h_length, needle, n_length); +#elif SZ_USE_X86_AVX2 + return sz_rfind_avx2(haystack, h_length, needle, n_length); +#elif SZ_USE_ARM_NEON + return sz_rfind_neon(haystack, h_length, needle, n_length); +#else + return sz_rfind_serial(haystack, h_length, needle, n_length); +#endif +} + +SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { +#if SZ_USE_X86_AVX512 + return sz_find_charset_avx512(text, length, set); +#elif SZ_USE_ARM_NEON + return sz_find_charset_neon(text, length, set); +#else + return sz_find_charset_serial(text, length, set); +#endif +} + +SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { +#if SZ_USE_X86_AVX512 + return sz_rfind_charset_avx512(text, length, set); +#elif SZ_USE_ARM_NEON + return sz_rfind_charset_neon(text, length, set); +#else + return sz_rfind_charset_serial(text, length, set); +#endif +} + +SZ_DYNAMIC sz_size_t sz_edit_distance( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc) { +#if SZ_USE_X86_AVX512 + return sz_edit_distance_avx512(a, a_length, b, b_length, bound, alloc); +#else + return sz_edit_distance_serial(a, a_length, b, b_length, bound, alloc); +#endif +} + +SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, + sz_error_cost_t const *subs, sz_error_cost_t gap, + sz_memory_allocator_t *alloc) { +#if SZ_USE_X86_AVX512 + return sz_alignment_score_avx512(a, a_length, b, b_length, subs, gap, alloc); +#else + return sz_alignment_score_serial(a, a_length, b, b_length, subs, gap, alloc); +#endif +} + +SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t window_step, // + sz_hash_callback_t callback, void *callback_handle) { +#if SZ_USE_X86_AVX512 + sz_hashes_avx512(text, length, window_length, window_step, callback, callback_handle); +#elif SZ_USE_X86_AVX2 + sz_hashes_avx2(text, length, window_length, window_step, callback, callback_handle); +#else + sz_hashes_serial(text, length, window_length, window_step, callback, callback_handle); +#endif +} + +SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_charset_t set; + sz_charset_init(&set); + for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); + return sz_find_charset(h, h_length, &set); +} + +SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_charset_t set; + sz_charset_init(&set); + for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); + sz_charset_invert(&set); + return sz_find_charset(h, h_length, &set); +} + +SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_charset_t set; + sz_charset_init(&set); + for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); + return sz_rfind_charset(h, h_length, &set); +} + +SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_charset_t set; + sz_charset_init(&set); + for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); + sz_charset_invert(&set); + return sz_rfind_charset(h, h_length, &set); +} + +#endif +#pragma endregion + +#ifdef __cplusplus +#pragma GCC diagnostic pop +} +#endif // __cplusplus + +#endif // STRINGZILLA_H_ diff --git a/contrib/stringzilla/include/stringzilla/stringzilla.hpp b/contrib/stringzilla/include/stringzilla/stringzilla.hpp new file mode 100644 index 0000000000..47c0e8ec30 --- /dev/null +++ b/contrib/stringzilla/include/stringzilla/stringzilla.hpp @@ -0,0 +1,3755 @@ +/** + * @brief StringZilla C++ wrapper improving over the performance of `std::string_view` and `std::string`, + * mostly for substring search, adding approximate matching functionality, and C++23 functionality + * to a C++11 compatible implementation. + * + * This implementation is aiming to be compatible with C++11, while implementing the C++23 functionality. + * By default, it includes C++ STL headers, but that can be avoided to minimize compilation overhead. + * https://artificial-mind.net/projects/compile-health/ + * + * @see StringZilla: https://github.com/ashvardanian/StringZilla/blob/main/README.md + * @see C++ Standard String: https://en.cppreference.com/w/cpp/header/string + * + * @file stringzilla.hpp + * @author Ash Vardanian + */ +#ifndef STRINGZILLA_HPP_ +#define STRINGZILLA_HPP_ + +/** + * @brief When set to 1, the library will include the C++ STL headers and implement + * automatic conversion from and to `std::stirng_view` and `std::basic_string`. + */ +#ifndef SZ_AVOID_STL +#define SZ_AVOID_STL (0) // true or false +#endif + +/* We need to detect the version of the C++ language we are compiled with. + * This will affect recent features like `operator<=>` and tests against STL. + */ +#define SZ_DETECT_CPP_23 (__cplusplus >= 202101L) +#define SZ_DETECT_CPP20 (__cplusplus >= 202002L) +#define SZ_DETECT_CPP_17 (__cplusplus >= 201703L) +#define SZ_DETECT_CPP14 (__cplusplus >= 201402L) +#define SZ_DETECT_CPP_11 (__cplusplus >= 201103L) +#define SZ_DETECT_CPP_98 (__cplusplus >= 199711L) + +/** + * @brief The `constexpr` keyword has different applicability scope in different C++ versions. + * Useful for STL conversion operators, as several `std::string` members are `constexpr` in C++20. + */ +#if SZ_DETECT_CPP20 +#define sz_constexpr_if_cpp20 constexpr +#else +#define sz_constexpr_if_cpp20 +#endif + +#if !SZ_AVOID_STL +#include +#include +#if SZ_DETECT_CPP_17 && __cpp_lib_string_view +#include +#endif +#endif + +#include // `assert` +#include // `std::size_t` +#include // `std::basic_ostream` +#include // `std::out_of_range` +#include // `std::swap` + +#include + +namespace ashvardanian { +namespace stringzilla { + +template +class basic_charset; +template +class basic_string_slice; +template +class basic_string; + +using string_span = basic_string_slice; +using string_view = basic_string_slice; + +template +using carray = char[count_characters]; + +#pragma region Character Sets + +/** + * @brief The concatenation of the `ascii_lowercase` and `ascii_uppercase`. This value is not locale-dependent. + * https://docs.python.org/3/library/string.html#string.ascii_letters + */ +inline carray<52> const &ascii_letters() noexcept { + static carray<52> const all = { + // + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', + 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', + 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + }; + return all; +} + +/** + * @brief The lowercase letters "abcdefghijklmnopqrstuvwxyz". This value is not locale-dependent. + * https://docs.python.org/3/library/string.html#string.ascii_lowercase + */ +inline carray<26> const &ascii_lowercase() noexcept { + static carray<26> const all = { + // + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', + 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + }; + return all; +} + +/** + * @brief The uppercase letters "ABCDEFGHIJKLMNOPQRSTUVWXYZ". This value is not locale-dependent. + * https://docs.python.org/3/library/string.html#string.ascii_uppercase + */ +inline carray<26> const &ascii_uppercase() noexcept { + static carray<26> const all = { + // + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + }; + return all; +} + +/** + * @brief ASCII characters which are considered printable. + * A combination of `digits`, `ascii_letters`, `punctuation`, and `whitespace`. + * https://docs.python.org/3/library/string.html#string.printable + */ +inline carray<100> const &ascii_printables() noexcept { + static carray<100> const all = { + // + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', + 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', + 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', + 'Y', 'Z', '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', + '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', ' ', '\t', '\n', '\r', '\f', '\v', + }; + return all; +} + +/** + * @brief Non-printable ASCII control characters. + * Includes all codes from 0 to 31 and 127. + */ +inline carray<33> const &ascii_controls() noexcept { + static carray<33> const all = { + // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 127, + }; + return all; +} + +/** + * @brief The digits "0123456789". + * https://docs.python.org/3/library/string.html#string.digits + */ +inline carray<10> const &digits() noexcept { + static carray<10> const all = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}; + return all; +} + +/** + * @brief The letters "0123456789abcdefABCDEF". + * https://docs.python.org/3/library/string.html#string.hexdigits + */ +inline carray<22> const &hexdigits() noexcept { + static carray<22> const all = { + // + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', // + 'a', 'b', 'c', 'd', 'e', 'f', 'A', 'B', 'C', 'D', 'E', 'F', + }; + return all; +} + +/** + * @brief The letters "01234567". + * https://docs.python.org/3/library/string.html#string.octdigits + */ +inline carray<8> const &octdigits() noexcept { + static carray<8> const all = {'0', '1', '2', '3', '4', '5', '6', '7'}; + return all; +} + +/** + * @brief ASCII characters considered punctuation characters in the C locale: + * !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~. + * https://docs.python.org/3/library/string.html#string.punctuation + */ +inline carray<32> const &punctuation() noexcept { + static carray<32> const all = { + // + '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/', ':', + ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', + }; + return all; +} + +/** + * @brief ASCII characters that are considered whitespace. + * This includes space, tab, linefeed, return, formfeed, and vertical tab. + * https://docs.python.org/3/library/string.html#string.whitespace + */ +inline carray<6> const &whitespaces() noexcept { + static carray<6> const all = {' ', '\t', '\n', '\r', '\f', '\v'}; + return all; +} + +/** + * @brief ASCII characters that are considered line delimiters. + * https://docs.python.org/3/library/stdtypes.html#str.splitlines + */ +inline carray<8> const &newlines() noexcept { + static carray<8> const all = {'\n', '\r', '\f', '\v', '\x1C', '\x1D', '\x1E', '\x85'}; + return all; +} + +/** + * @brief ASCII characters forming the BASE64 encoding alphabet. + */ +inline carray<64> const &base64() noexcept { + static carray<64> const all = { + // + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', + 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', + 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/', + }; + return all; +} + +/** + * @brief A set of characters represented as a bitset with 256 slots. + */ +template +class basic_charset { + sz_charset_t bitset_; + + public: + using char_type = char_type_; + + basic_charset() noexcept { + // ! Instead of relying on the `sz_charset_init`, we have to reimplement it to support `constexpr`. + bitset_._u64s[0] = 0, bitset_._u64s[1] = 0, bitset_._u64s[2] = 0, bitset_._u64s[3] = 0; + } + explicit basic_charset(std::initializer_list chars) noexcept : basic_charset() { + // ! Instead of relying on the `sz_charset_add(&bitset_, c)`, we have to reimplement it to support `constexpr`. + for (auto c : chars) bitset_._u64s[sz_bitcast(sz_u8_t, c) >> 6] |= (1ull << (sz_bitcast(sz_u8_t, c) & 63u)); + } + template + explicit basic_charset(char_type const (&chars)[count_characters]) noexcept : basic_charset() { + static_assert(count_characters > 0, "Character array cannot be empty"); + for (std::size_t i = 0; i < count_characters - 1; ++i) { // count_characters - 1 to exclude the null terminator + char_type c = chars[i]; + bitset_._u64s[sz_bitcast(sz_u8_t, c) >> 6] |= (1ull << (sz_bitcast(sz_u8_t, c) & 63u)); + } + } + + template + explicit basic_charset(std::array const &chars) noexcept : basic_charset() { + static_assert(count_characters > 0, "Character array cannot be empty"); + for (std::size_t i = 0; i < count_characters - 1; ++i) { // count_characters - 1 to exclude the null terminator + char_type c = chars[i]; + bitset_._u64s[sz_bitcast(sz_u8_t, c) >> 6] |= (1ull << (sz_bitcast(sz_u8_t, c) & 63u)); + } + } + + basic_charset(basic_charset const &other) noexcept : bitset_(other.bitset_) {} + basic_charset &operator=(basic_charset const &other) noexcept { + bitset_ = other.bitset_; + return *this; + } + + basic_charset operator|(basic_charset other) const noexcept { + basic_charset result = *this; + result.bitset_._u64s[0] |= other.bitset_._u64s[0], result.bitset_._u64s[1] |= other.bitset_._u64s[1], + result.bitset_._u64s[2] |= other.bitset_._u64s[2], result.bitset_._u64s[3] |= other.bitset_._u64s[3]; + return *this; + } + + inline basic_charset &add(char_type c) noexcept { + sz_charset_add(&bitset_, sz_bitcast(sz_u8_t, c)); + return *this; + } + inline sz_charset_t &raw() noexcept { return bitset_; } + inline sz_charset_t const &raw() const noexcept { return bitset_; } + inline bool contains(char_type c) const noexcept { return sz_charset_contains(&bitset_, sz_bitcast(sz_u8_t, c)); } + inline basic_charset inverted() const noexcept { + basic_charset result = *this; + sz_charset_invert(&result.bitset_); + return result; + } +}; + +using char_set = basic_charset; + +inline char_set ascii_letters_set() { return char_set {ascii_letters()}; } +inline char_set ascii_lowercase_set() { return char_set {ascii_lowercase()}; } +inline char_set ascii_uppercase_set() { return char_set {ascii_uppercase()}; } +inline char_set ascii_printables_set() { return char_set {ascii_printables()}; } +inline char_set ascii_controls_set() { return char_set {ascii_controls()}; } +inline char_set digits_set() { return char_set {digits()}; } +inline char_set hexdigits_set() { return char_set {hexdigits()}; } +inline char_set octdigits_set() { return char_set {octdigits()}; } +inline char_set punctuation_set() { return char_set {punctuation()}; } +inline char_set whitespaces_set() { return char_set {whitespaces()}; } +inline char_set newlines_set() { return char_set {newlines()}; } +inline char_set base64_set() { return char_set {base64()}; } + +#pragma endregion + +#pragma region Ranges of Search Matches + +struct end_sentinel_type {}; +struct include_overlaps_type {}; +struct exclude_overlaps_type {}; + +#if SZ_DETECT_CPP_17 +inline static constexpr end_sentinel_type end_sentinel; +inline static constexpr include_overlaps_type include_overlaps; +inline static constexpr exclude_overlaps_type exclude_overlaps; +#endif + +/** + * @brief Zero-cost wrapper around the `.find` member function of string-like classes. + */ +template +struct matcher_find { + using size_type = typename string_type_::size_type; + string_type_ needle_; + + matcher_find(string_type_ needle = {}) noexcept : needle_(needle) {} + size_type needle_length() const noexcept { return needle_.length(); } + size_type operator()(string_type_ haystack) const noexcept { return haystack.find(needle_); } + size_type skip_length() const noexcept { + // TODO: Apply Galil rule to match repetitive patterns in strictly linear time. + return std::is_same() ? 1 : needle_.length(); + } +}; + +/** + * @brief Zero-cost wrapper around the `.rfind` member function of string-like classes. + */ +template +struct matcher_rfind { + using size_type = typename string_type_::size_type; + string_type_ needle_; + + matcher_rfind(string_type_ needle = {}) noexcept : needle_(needle) {} + size_type needle_length() const noexcept { return needle_.length(); } + size_type operator()(string_type_ haystack) const noexcept { return haystack.rfind(needle_); } + size_type skip_length() const noexcept { + // TODO: Apply Galil rule to match repetitive patterns in strictly linear time. + return std::is_same() ? 1 : needle_.length(); + } +}; + +/** + * @brief Zero-cost wrapper around the `.find_first_of` member function of string-like classes. + */ +template +struct matcher_find_first_of { + using size_type = typename haystack_type::size_type; + needles_type needles_; + constexpr size_type needle_length() const noexcept { return 1; } + constexpr size_type skip_length() const noexcept { return 1; } + size_type operator()(haystack_type haystack) const noexcept { return haystack.find_first_of(needles_); } +}; + +/** + * @brief Zero-cost wrapper around the `.find_last_of` member function of string-like classes. + */ +template +struct matcher_find_last_of { + using size_type = typename haystack_type::size_type; + needles_type needles_; + constexpr size_type needle_length() const noexcept { return 1; } + constexpr size_type skip_length() const noexcept { return 1; } + size_type operator()(haystack_type haystack) const noexcept { return haystack.find_last_of(needles_); } +}; + +/** + * @brief Zero-cost wrapper around the `.find_first_not_of` member function of string-like classes. + */ +template +struct matcher_find_first_not_of { + using size_type = typename haystack_type::size_type; + needles_type needles_; + constexpr size_type needle_length() const noexcept { return 1; } + constexpr size_type skip_length() const noexcept { return 1; } + size_type operator()(haystack_type haystack) const noexcept { return haystack.find_first_not_of(needles_); } +}; + +/** + * @brief Zero-cost wrapper around the `.find_last_not_of` member function of string-like classes. + */ +template +struct matcher_find_last_not_of { + using size_type = typename haystack_type::size_type; + needles_type needles_; + constexpr size_type needle_length() const noexcept { return 1; } + constexpr size_type skip_length() const noexcept { return 1; } + size_type operator()(haystack_type haystack) const noexcept { return haystack.find_last_not_of(needles_); } +}; + +/** + * @brief A range of string slices representing the matches of a substring search. + * Compatible with C++23 ranges, C++11 string views, and of course, StringZilla. + * Similar to a pair of `boost::algorithm::find_iterator`. + */ +template +class range_matches { + public: + using string_type = string_type_; + using matcher_type = matcher_type_; + + private: + matcher_type matcher_; + string_type haystack_; + + public: + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using value_type = string_type; + using pointer = string_type; // Needed for compatibility with STL container constructors. + using reference = string_type; // Needed for compatibility with STL container constructors. + + range_matches(string_type haystack, matcher_type needle) noexcept : matcher_(needle), haystack_(haystack) {} + + class iterator { + matcher_type matcher_; + string_type remaining_; + + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = string_type; + using pointer = string_type; // Needed for compatibility with STL container constructors. + using reference = string_type; // Needed for compatibility with STL container constructors. + + iterator(string_type haystack, matcher_type matcher) noexcept : matcher_(matcher), remaining_(haystack) { + auto position = matcher_(remaining_); + remaining_.remove_prefix(position != string_type::npos ? position : remaining_.size()); + } + + pointer operator->() const noexcept = delete; + value_type operator*() const noexcept { return remaining_.substr(0, matcher_.needle_length()); } + + iterator &operator++() noexcept { + remaining_.remove_prefix(matcher_.skip_length()); + auto position = matcher_(remaining_); + remaining_.remove_prefix(position != string_type::npos ? position : remaining_.size()); + return *this; + } + + iterator operator++(int) noexcept { + iterator temp = *this; + ++(*this); + return temp; + } + + bool operator!=(iterator const &other) const noexcept { return remaining_.begin() != other.remaining_.begin(); } + bool operator==(iterator const &other) const noexcept { return remaining_.begin() == other.remaining_.begin(); } + bool operator!=(end_sentinel_type) const noexcept { return !remaining_.empty(); } + bool operator==(end_sentinel_type) const noexcept { return remaining_.empty(); } + }; + + iterator begin() const noexcept { return {haystack_, matcher_}; } + iterator end() const noexcept { return {string_type {haystack_.data() + haystack_.size(), 0ull}, matcher_}; } + size_type size() const noexcept { return static_cast(ssize()); } + difference_type ssize() const noexcept { return std::distance(begin(), end()); } + bool empty() const noexcept { return begin() == end_sentinel_type {}; } + bool include_overlaps() const noexcept { return matcher_.skip_length() < matcher_.needle_length(); } + + /** + * @brief Copies the matches into a container. + */ + template + void to(container_ &container) { + for (auto match : *this) { container.push_back(match); } + } + + /** + * @brief Copies the matches into a consumed container, returning it at the end. + */ + template + container_ to() { + return container_ {begin(), end()}; + } +}; + +/** + * @brief A range of string slices representing the matches of a @b reverse-order substring search. + * Compatible with C++23 ranges, C++11 string views, and of course, StringZilla. + * Similar to a pair of `boost::algorithm::find_iterator`. + */ +template +class range_rmatches { + public: + using string_type = string_type_; + using matcher_type = matcher_type_; + + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using value_type = string_type; + using pointer = string_type; // Needed for compatibility with STL container constructors. + using reference = string_type; // Needed for compatibility with STL container constructors. + + private: + matcher_type matcher_; + string_type haystack_; + + public: + range_rmatches(string_type haystack, matcher_type needle) : matcher_(needle), haystack_(haystack) {} + + class iterator { + matcher_type matcher_; + string_type remaining_; + + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = string_type; + using pointer = string_type; // Needed for compatibility with STL container constructors. + using reference = string_type; // Needed for compatibility with STL container constructors. + + iterator(string_type haystack, matcher_type matcher) noexcept : matcher_(matcher), remaining_(haystack) { + auto position = matcher_(remaining_); + remaining_.remove_suffix(position != string_type::npos + ? remaining_.size() - position - matcher_.needle_length() + : remaining_.size()); + } + + pointer operator->() const noexcept = delete; + value_type operator*() const noexcept { + return remaining_.substr(remaining_.size() - matcher_.needle_length()); + } + + iterator &operator++() noexcept { + remaining_.remove_suffix(matcher_.skip_length()); + auto position = matcher_(remaining_); + remaining_.remove_suffix(position != string_type::npos + ? remaining_.size() - position - matcher_.needle_length() + : remaining_.size()); + return *this; + } + + iterator operator++(int) noexcept { + iterator temp = *this; + ++(*this); + return temp; + } + + bool operator!=(iterator const &other) const noexcept { return remaining_.end() != other.remaining_.end(); } + bool operator==(iterator const &other) const noexcept { return remaining_.end() == other.remaining_.end(); } + bool operator!=(end_sentinel_type) const noexcept { return !remaining_.empty(); } + bool operator==(end_sentinel_type) const noexcept { return remaining_.empty(); } + }; + + iterator begin() const noexcept { return {haystack_, matcher_}; } + iterator end() const noexcept { return {string_type {haystack_.data(), 0ull}, matcher_}; } + size_type size() const noexcept { return static_cast(ssize()); } + difference_type ssize() const noexcept { return std::distance(begin(), end()); } + bool empty() const noexcept { return begin() == end_sentinel_type {}; } + bool include_overlaps() const noexcept { return matcher_.skip_length() < matcher_.needle_length(); } + + /** + * @brief Copies the matches into a container. + */ + template + void to(container_ &container) { + for (auto match : *this) { container.push_back(match); } + } + + /** + * @brief Copies the matches into a consumed container, returning it at the end. + */ + template + container_ to() { + return container_ {begin(), end()}; + } +}; + +/** + * @brief A range of string slices for different splits of the data. + * Compatible with C++23 ranges, C++11 string views, and of course, StringZilla. + * Similar to a pair of `boost::algorithm::split_iterator`. + * + * In some sense, represents the inverse operation to `range_matches`, as it reports not the search matches + * but the data between them. Meaning that for `N` search matches, there will be `N+1` elements in the range. + * Unlike ::range_matches, this range can't be empty. It also can't report overlapping intervals. + */ +template +class range_splits { + public: + using string_type = string_type_; + using matcher_type = matcher_type_; + + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using value_type = string_type; + using pointer = string_type; // Needed for compatibility with STL container constructors. + using reference = string_type; // Needed for compatibility with STL container constructors. + + private: + matcher_type matcher_; + string_type haystack_; + + public: + range_splits(string_type haystack, matcher_type needle) noexcept : matcher_(needle), haystack_(haystack) {} + + class iterator { + matcher_type matcher_; + string_type remaining_; + std::size_t length_within_remaining_; + bool reached_tail_; + + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = string_type; + using pointer = string_type; // Needed for compatibility with STL container constructors. + using reference = string_type; // Needed for compatibility with STL container constructors. + + iterator(string_type haystack, matcher_type matcher) noexcept : matcher_(matcher), remaining_(haystack) { + auto position = matcher_(remaining_); + length_within_remaining_ = position != string_type::npos ? position : remaining_.size(); + reached_tail_ = false; + } + + iterator(string_type haystack, matcher_type matcher, end_sentinel_type) noexcept + : matcher_(matcher), remaining_(haystack), length_within_remaining_(0), reached_tail_(true) {} + + pointer operator->() const noexcept = delete; + value_type operator*() const noexcept { return remaining_.substr(0, length_within_remaining_); } + + iterator &operator++() noexcept { + remaining_.remove_prefix(length_within_remaining_); + reached_tail_ = remaining_.empty(); + remaining_.remove_prefix(matcher_.needle_length() * !reached_tail_); + auto position = matcher_(remaining_); + length_within_remaining_ = position != string_type::npos ? position : remaining_.size(); + return *this; + } + + iterator operator++(int) noexcept { + iterator temp = *this; + ++(*this); + return temp; + } + + bool operator!=(iterator const &other) const noexcept { + return (remaining_.begin() != other.remaining_.begin()) || (reached_tail_ != other.reached_tail_); + } + bool operator==(iterator const &other) const noexcept { + return (remaining_.begin() == other.remaining_.begin()) && (reached_tail_ == other.reached_tail_); + } + bool operator!=(end_sentinel_type) const noexcept { return !remaining_.empty() || !reached_tail_; } + bool operator==(end_sentinel_type) const noexcept { return remaining_.empty() && reached_tail_; } + bool is_last() const noexcept { return remaining_.size() == length_within_remaining_; } + }; + + iterator begin() const noexcept { return {haystack_, matcher_}; } + iterator end() const noexcept { return {string_type {haystack_.end(), 0}, matcher_, end_sentinel_type {}}; } + size_type size() const noexcept { return static_cast(ssize()); } + difference_type ssize() const noexcept { return std::distance(begin(), end()); } + constexpr bool empty() const noexcept { return false; } + + /** + * @brief Copies the matches into a container. + */ + template + void to(container_ &container) { + for (auto match : *this) { container.push_back(match); } + } + + /** + * @brief Copies the matches into a consumed container, returning it at the end. + */ + template + container_ to(container_ &&container = {}) { + for (auto match : *this) { container.push_back(match); } + return std::move(container); + } +}; + +/** + * @brief A range of string slices for different splits of the data in @b reverse-order. + * Compatible with C++23 ranges, C++11 string views, and of course, StringZilla. + * Similar to a pair of `boost::algorithm::split_iterator`. + * + * In some sense, represents the inverse operation to `range_matches`, as it reports not the search matches + * but the data between them. Meaning that for `N` search matches, there will be `N+1` elements in the range. + * Unlike ::range_matches, this range can't be empty. It also can't report overlapping intervals. + */ +template +class range_rsplits { + public: + using string_type = string_type_; + using matcher_type = matcher_type_; + + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using value_type = string_type; + using pointer = string_type; // Needed for compatibility with STL container constructors. + using reference = string_type; // Needed for compatibility with STL container constructors. + + private: + matcher_type matcher_; + string_type haystack_; + + public: + range_rsplits(string_type haystack, matcher_type needle) noexcept : matcher_(needle), haystack_(haystack) {} + + class iterator { + matcher_type matcher_; + string_type remaining_; + std::size_t length_within_remaining_; + bool reached_tail_; + + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = string_type; + using pointer = string_type; // Needed for compatibility with STL container constructors. + using reference = string_type; // Needed for compatibility with STL container constructors. + + iterator(string_type haystack, matcher_type matcher) noexcept : matcher_(matcher), remaining_(haystack) { + auto position = matcher_(remaining_); + length_within_remaining_ = position != string_type::npos + ? remaining_.size() - position - matcher_.needle_length() + : remaining_.size(); + reached_tail_ = false; + } + + iterator(string_type haystack, matcher_type matcher, end_sentinel_type) noexcept + : matcher_(matcher), remaining_(haystack), length_within_remaining_(0), reached_tail_(true) {} + + pointer operator->() const noexcept = delete; + value_type operator*() const noexcept { + return remaining_.substr(remaining_.size() - length_within_remaining_); + } + + iterator &operator++() noexcept { + remaining_.remove_suffix(length_within_remaining_); + reached_tail_ = remaining_.empty(); + remaining_.remove_suffix(matcher_.needle_length() * !reached_tail_); + auto position = matcher_(remaining_); + length_within_remaining_ = position != string_type::npos + ? remaining_.size() - position - matcher_.needle_length() + : remaining_.size(); + return *this; + } + + iterator operator++(int) noexcept { + iterator temp = *this; + ++(*this); + return temp; + } + + bool operator!=(iterator const &other) const noexcept { + return (remaining_.end() != other.remaining_.end()) || (reached_tail_ != other.reached_tail_); + } + bool operator==(iterator const &other) const noexcept { + return (remaining_.end() == other.remaining_.end()) && (reached_tail_ == other.reached_tail_); + } + bool operator!=(end_sentinel_type) const noexcept { return !remaining_.empty() || !reached_tail_; } + bool operator==(end_sentinel_type) const noexcept { return remaining_.empty() && reached_tail_; } + bool is_last() const noexcept { return remaining_.size() == length_within_remaining_; } + }; + + iterator begin() const noexcept { return {haystack_, matcher_}; } + iterator end() const noexcept { return {{haystack_.data(), 0ull}, matcher_, end_sentinel_type {}}; } + size_type size() const noexcept { return static_cast(ssize()); } + difference_type ssize() const noexcept { return std::distance(begin(), end()); } + constexpr bool empty() const noexcept { return false; } + + /** + * @brief Copies the matches into a container. + */ + template + void to(container_ &container) { + for (auto match : *this) { container.push_back(match); } + } + + /** + * @brief Copies the matches into a consumed container, returning it at the end. + */ + template + container_ to(container_ &&container = {}) { + for (auto match : *this) { container.push_back(match); } + return std::move(container); + } +}; + +/** + * @brief Find all potentially @b overlapping inclusions of a needle substring. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_matches> find_all(string const &h, string const &n, + include_overlaps_type = {}) noexcept { + return {h, n}; +} + +/** + * @brief Find all potentially @b overlapping inclusions of a needle substring in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_rmatches> rfind_all(string const &h, string const &n, + include_overlaps_type = {}) noexcept { + return {h, n}; +} + +/** + * @brief Find all @b non-overlapping inclusions of a needle substring. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_matches> find_all(string const &h, string const &n, + exclude_overlaps_type) noexcept { + return {h, n}; +} + +/** + * @brief Find all @b non-overlapping inclusions of a needle substring in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_rmatches> rfind_all(string const &h, string const &n, + exclude_overlaps_type) noexcept { + return {h, n}; +} + +/** + * @brief Find all inclusions of characters from the second string. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_matches> find_all_characters(string const &h, string const &n) noexcept { + return {h, n}; +} + +/** + * @brief Find all inclusions of characters from the second string in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_rmatches> rfind_all_characters(string const &h, string const &n) noexcept { + return {h, n}; +} + +/** + * @brief Find all characters except the ones in the second string. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_matches> find_all_other_characters(string const &h, + string const &n) noexcept { + return {h, n}; +} + +/** + * @brief Find all characters except the ones in the second string in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_rmatches> rfind_all_other_characters(string const &h, + string const &n) noexcept { + return {h, n}; +} + +/** + * @brief Splits a string around every @b non-overlapping inclusion of the second string. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_splits> split(string const &h, string const &n) noexcept { + return {h, n}; +} + +/** + * @brief Splits a string around every @b non-overlapping inclusion of the second string in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_rsplits> rsplit(string const &h, string const &n) noexcept { + return {h, n}; +} + +/** + * @brief Splits a string around every character from the second string. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_splits> split_characters(string const &h, string const &n) noexcept { + return {h, n}; +} + +/** + * @brief Splits a string around every character from the second string in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_rsplits> rsplit_characters(string const &h, string const &n) noexcept { + return {h, n}; +} + +/** + * @brief Splits a string around every character except the ones from the second string. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_splits> split_other_characters(string const &h, + string const &n) noexcept { + return {h, n}; +} + +/** + * @brief Splits a string around every character except the ones from the second string in @b reverse order. + * @tparam string A string-like type, ideally a view, like StringZilla or STL `string_view`. + */ +template +range_rsplits> rsplit_other_characters(string const &h, + string const &n) noexcept { + return {h, n}; +} + +/** @brief Helper function using `std::advance` iterator and return it back. */ +template +iterator_type advanced(iterator_type &&it, distance_type n) { + std::advance(it, n); + return it; +} + +/** @brief Helper function using `range_length` to compute the unsigned distance. */ +template +std::size_t range_length(iterator_type first, iterator_type last) { + return static_cast(std::distance(first, last)); +} + +#pragma endregion + +#pragma region Global Operations with Dynamic Memory + +template +static void *_call_allocate(sz_size_t n, void *allocator_state) noexcept { + return reinterpret_cast(allocator_state)->allocate(n); +} + +template +static void _call_free(void *ptr, sz_size_t n, void *allocator_state) noexcept { + return reinterpret_cast(allocator_state)->deallocate(reinterpret_cast(ptr), n); +} + +template +static sz_u64_t _call_random_generator(void *state) noexcept { + generator_type_ &generator = *reinterpret_cast(state); + return generator(); +} + +template +static bool _with_alloc(allocator_type_ &allocator, allocator_callback_ &&callback) noexcept { + sz_memory_allocator_t alloc; + alloc.allocate = &_call_allocate; + alloc.free = &_call_free; + alloc.handle = &allocator; + return callback(alloc); +} + +template +static bool _with_alloc(allocator_callback_ &&callback) noexcept { + allocator_type_ allocator; + return _with_alloc(allocator, std::forward(callback)); +} + +#pragma endregion + +#pragma region Helper Template Classes + +/** + * @brief A result of split a string once, containing the string slice ::before, + * the ::match itself, and the slice ::after. + */ +template +struct string_partition_result { + string_ before; + string_ match; + string_ after; +}; + +/** + * @brief A reverse iterator for mutable and immutable character buffers. + * Replaces `std::reverse_iterator` to avoid including ``. + */ +template +class reversed_iterator_for { + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = value_type_; + using difference_type = std::ptrdiff_t; + using pointer = value_type_ *; + using reference = value_type_ &; + + reversed_iterator_for(pointer ptr) noexcept : ptr_(ptr) {} + reference operator*() const noexcept { return *ptr_; } + + bool operator==(reversed_iterator_for const &other) const noexcept { return ptr_ == other.ptr_; } + bool operator!=(reversed_iterator_for const &other) const noexcept { return ptr_ != other.ptr_; } + reference operator[](difference_type n) const noexcept { return *(*this + n); } + reversed_iterator_for operator+(difference_type n) const noexcept { return reversed_iterator_for(ptr_ - n); } + reversed_iterator_for operator-(difference_type n) const noexcept { return reversed_iterator_for(ptr_ + n); } + difference_type operator-(reversed_iterator_for const &other) const noexcept { return other.ptr_ - ptr_; } + + reversed_iterator_for &operator++() noexcept { + --ptr_; + return *this; + } + + reversed_iterator_for operator++(int) const noexcept { + reversed_iterator_for temp = *this; + --ptr_; + return temp; + } + + reversed_iterator_for &operator--() const noexcept { + ++ptr_; + return *this; + } + + reversed_iterator_for operator--(int) const noexcept { + reversed_iterator_for temp = *this; + ++ptr_; + return temp; + } + + private: + value_type_ *ptr_; +}; + +/** + * @brief An "expression template" for lazy concatenation of strings using the `operator|`. + * + * TODO: Ensure eqnership passing and move semantics are preserved. + */ +template +struct concatenation { + + using value_type = typename first_type::value_type; + using pointer = value_type *; + using const_pointer = value_type const *; + using size_type = typename first_type::size_type; + using difference_type = typename first_type::difference_type; + + first_type first; + second_type second; + + std::size_t size() const noexcept { return first.size() + second.size(); } + std::size_t length() const noexcept { return first.size() + second.size(); } + + size_type copy(pointer destination) const noexcept { + first.copy(destination); + second.copy(destination + first.length()); + return first.length() + second.length(); + } + + size_type copy(pointer destination, size_type length) const noexcept { + auto first_length = std::min(first.length(), length); + auto second_length = std::min(second.length(), length - first_length); + first.copy(destination, first_length); + second.copy(destination + first_length, second_length); + return first_length + second_length; + } + + template + concatenation, last_type> operator|(last_type &&last) const { + return {*this, last}; + } +}; + +#pragma endregion + +#pragma region String Views/Spans + +/** + * @brief A string slice (view/span) class implementing a superset of C++23 functionality + * with much faster SIMD-accelerated substring search and approximate matching. + * Constructors are `constexpr` enabling `_sz` literals. + * + * @tparam char_type_ The character type, usually `char const` or `char`. Must be a single byte long. + */ +template +class basic_string_slice { + + static_assert(sizeof(char_type_) == 1, "Characters must be a single byte long"); + static_assert(std::is_reference::value == false, "Characters can't be references"); + + using char_type = char_type_; + using mutable_char_type = typename std::remove_const::type; + using immutable_char_type = typename std::add_const::type; + + char_type *start_; + std::size_t length_; + + public: + // STL compatibility + using traits_type = std::char_traits; + using value_type = mutable_char_type; + using pointer = char_type *; + using const_pointer = immutable_char_type *; + using reference = char_type &; + using const_reference = immutable_char_type &; + using const_iterator = const_pointer; + using iterator = pointer; + using reverse_iterator = reversed_iterator_for; + using const_reverse_iterator = reversed_iterator_for; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + + // Non-STL type definitions + using string_slice = basic_string_slice; + using string_span = basic_string_slice; + using string_view = basic_string_slice; + using partition_type = string_partition_result; + + /** @brief Special value for missing matches. + * We take the largest 63-bit unsigned integer. + */ + static constexpr size_type npos = 0x7FFFFFFFFFFFFFFFull; + +#pragma region Constructors and STL Utilities + + constexpr basic_string_slice() noexcept : start_(nullptr), length_(0) {} + constexpr basic_string_slice(pointer c_string) noexcept + : start_(c_string), length_(null_terminated_length(c_string)) {} + constexpr basic_string_slice(pointer c_string, size_type length) noexcept : start_(c_string), length_(length) {} + + sz_constexpr_if_cpp20 basic_string_slice(basic_string_slice const &other) noexcept = default; + sz_constexpr_if_cpp20 basic_string_slice &operator=(basic_string_slice const &other) noexcept = default; + basic_string_slice(std::nullptr_t) = delete; + + /** @brief Exchanges the view with that of the `other`. */ + void swap(string_slice &other) noexcept { std::swap(start_, other.start_), std::swap(length_, other.length_); } + +#if !SZ_AVOID_STL + + template ::value, int>::type = 0> + sz_constexpr_if_cpp20 basic_string_slice(std::string const &other) noexcept + : basic_string_slice(other.data(), other.size()) {} + + template ::value, int>::type = 0> + sz_constexpr_if_cpp20 basic_string_slice(std::string &other) noexcept + : basic_string_slice(&other[0], other.size()) {} // The `.data()` has mutable variant only since C++17 + + template ::value, int>::type = 0> + sz_constexpr_if_cpp20 string_slice &operator=(std::string const &other) noexcept { + return assign({other.data(), other.size()}); + } + + template ::value, int>::type = 0> + sz_constexpr_if_cpp20 string_slice &operator=(std::string &other) noexcept { + return assign({other.data(), other.size()}); + } + + operator std::string() const { return {data(), size()}; } + + /** + * @brief Formatted output function for compatibility with STL's `std::basic_ostream`. + * @throw `std::ios_base::failure` if an exception occurred during output. + */ + template + friend std::basic_ostream &operator<<(std::basic_ostream &os, + string_slice const &str) noexcept(false) { + return os.write(str.data(), str.size()); + } + +#if SZ_DETECT_CPP_17 && __cpp_lib_string_view + + template ::value, int>::type = 0> + sz_constexpr_if_cpp20 basic_string_slice(std::string_view const &other) noexcept + : basic_string_slice(other.data(), other.size()) {} + + template ::value, int>::type = 0> + sz_constexpr_if_cpp20 string_slice &operator=(std::string_view const &other) noexcept { + return assign({other.data(), other.size()}); + } + operator std::string_view() const noexcept { return {data(), size()}; } + +#endif + +#endif + +#pragma endregion + +#pragma region Iterators and Element Access + + iterator begin() const noexcept { return iterator(start_); } + iterator end() const noexcept { return iterator(start_ + length_); } + const_iterator cbegin() const noexcept { return const_iterator(start_); } + const_iterator cend() const noexcept { return const_iterator(start_ + length_); } + reverse_iterator rbegin() const noexcept { return reverse_iterator(start_ + length_ - 1); } + reverse_iterator rend() const noexcept { return reverse_iterator(start_ - 1); } + const_reverse_iterator crbegin() const noexcept { return const_reverse_iterator(start_ + length_ - 1); } + const_reverse_iterator crend() const noexcept { return const_reverse_iterator(start_ - 1); } + + reference operator[](size_type pos) const noexcept { return start_[pos]; } + reference at(size_type pos) const noexcept { return start_[pos]; } + reference front() const noexcept { return start_[0]; } + reference back() const noexcept { return start_[length_ - 1]; } + pointer data() const noexcept { return start_; } + + difference_type ssize() const noexcept { return static_cast(length_); } + size_type size() const noexcept { return length_; } + size_type length() const noexcept { return length_; } + size_type max_size() const noexcept { return npos - 1; } + bool empty() const noexcept { return length_ == 0; } + +#pragma endregion + +#pragma region Slicing + +#pragma region Safe and Signed Extensions + + /** + * @brief Equivalent to Python's `"abc"[-3:-1]`. Exception-safe, unlike STL's `substr`. + * Supports signed and unsigned intervals. + */ + string_slice operator[](std::initializer_list signed_offsets) const noexcept { + assert(signed_offsets.size() == 2 && "operator[] can't take more than 2 offsets"); + return sub(signed_offsets.begin()[0], signed_offsets.begin()[1]); + } + + /** + * @brief Signed alternative to `at()`. Handy if you often write `str[str.size() - 2]`. + * @warning The behavior is @b undefined if the position is beyond bounds. + */ + reference sat(difference_type signed_offset) const noexcept { + size_type pos = (signed_offset < 0) ? size() + signed_offset : signed_offset; + assert(pos < size() && "string_slice::sat(i) out of bounds"); + return start_[pos]; + } + + /** + * @brief The opposite operation to `remove_prefix`, that does no bounds checking. + * @warning The behavior is @b undefined if `n > size()`. + */ + string_slice front(size_type n) const noexcept { + assert(n <= size() && "string_slice::front(n) out of bounds"); + return {start_, n}; + } + + /** + * @brief The opposite operation to `remove_prefix`, that does no bounds checking. + * @warning The behavior is @b undefined if `n > size()`. + */ + string_slice back(size_type n) const noexcept { + assert(n <= size() && "string_slice::back(n) out of bounds"); + return {start_ + length_ - n, n}; + } + + /** + * @brief Equivalent to Python's `"abc"[-3:-1]`. Exception-safe, unlike STL's `substr`. + * Supports signed and unsigned intervals. + */ + string_slice sub(difference_type signed_start_offset, difference_type signed_end_offset = npos) const noexcept { + sz_size_t normalized_offset, normalized_length; + sz_ssize_clamp_interval(length_, signed_start_offset, signed_end_offset, &normalized_offset, + &normalized_length); + return string_slice(start_ + normalized_offset, normalized_length); + } + + /** + * @brief Exports this entire view. Not an STL function, but useful for concatenations. + * The STL variant expects at least two arguments. + */ + size_type copy(value_type *destination) const noexcept { + sz_copy((sz_ptr_t)destination, start_, length_); + return length_; + } + +#pragma endregion + +#pragma region STL Style + + /** + * @brief Removes the first `n` characters from the view. + * @warning The behavior is @b undefined if `n > size()`. + */ + void remove_prefix(size_type n) noexcept { assert(n <= size()), start_ += n, length_ -= n; } + + /** + * @brief Removes the last `n` characters from the view. + * @warning The behavior is @b undefined if `n > size()`. + */ + void remove_suffix(size_type n) noexcept { assert(n <= size()), length_ -= n; } + + /** @brief Added for STL compatibility. */ + string_slice substr() const noexcept { return *this; } + + /** + * @brief Return a slice of this view after first `skip` bytes. + * @throws `std::out_of_range` if `skip > size()`. + * @see `sub` for a cleaner exception-less alternative. + */ + string_slice substr(size_type skip) const noexcept(false) { + if (skip > size()) throw std::out_of_range("string_slice::substr"); + return string_slice(start_ + skip, length_ - skip); + } + + /** + * @brief Return a slice of this view after first `skip` bytes, taking at most `count` bytes. + * @throws `std::out_of_range` if `skip > size()`. + * @see `sub` for a cleaner exception-less alternative. + */ + string_slice substr(size_type skip, size_type count) const noexcept(false) { + if (skip > size()) throw std::out_of_range("string_slice::substr"); + return string_slice(start_ + skip, sz_min_of_two(count, length_ - skip)); + } + + /** + * @brief Exports a slice of this view after first `skip` bytes, taking at most `count` bytes. + * @throws `std::out_of_range` if `skip > size()`. + * @see `sub` for a cleaner exception-less alternative. + */ + size_type copy(value_type *destination, size_type count, size_type skip = 0) const noexcept(false) { + if (skip > size()) throw std::out_of_range("string_slice::copy"); + count = sz_min_of_two(count, length_ - skip); + sz_copy((sz_ptr_t)destination, start_ + skip, count); + return count; + } + +#pragma endregion + +#pragma endregion + +#pragma region Comparisons + +#pragma region Whole String Comparisons + + /** + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + */ + int compare(string_view other) const noexcept { + return (int)sz_order(start_, length_, other.start_, other.length_); + } + + /** + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * Equivalent to `substr(pos1, count1).compare(other)`. + * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + * @throw `std::out_of_range` if `pos1 > size()`. + */ + int compare(size_type pos1, size_type count1, string_view other) const noexcept(false) { + return substr(pos1, count1).compare(other); + } + + /** + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * Equivalent to `substr(pos1, count1).compare(other.substr(pos2, count2))`. + * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + * @throw `std::out_of_range` if `pos1 > size()` or if `pos2 > other.size()`. + */ + int compare(size_type pos1, size_type count1, string_view other, size_type pos2, size_type count2) const + noexcept(false) { + return substr(pos1, count1).compare(other.substr(pos2, count2)); + } + + /** + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + */ + int compare(const_pointer other) const noexcept { return compare(string_view(other)); } + + /** + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * Equivalent to substr(pos1, count1).compare(other). + * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + * @throw `std::out_of_range` if `pos1 > size()`. + */ + int compare(size_type pos1, size_type count1, const_pointer other) const noexcept(false) { + return substr(pos1, count1).compare(string_view(other)); + } + + /** + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * Equivalent to `substr(pos1, count1).compare({s, count2})`. + * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + * @throw `std::out_of_range` if `pos1 > size()`. + */ + int compare(size_type pos1, size_type count1, const_pointer other, size_type count2) const noexcept(false) { + return substr(pos1, count1).compare(string_view(other, count2)); + } + + /** @brief Checks if the string is equal to the other string. */ + bool operator==(string_view other) const noexcept { + return size() == other.size() && sz_equal(data(), other.data(), other.size()) == sz_true_k; + } + + /** @brief Checks if the string is equal to a concatenation of two strings. */ + bool operator==(concatenation const &other) const noexcept { + return size() == other.size() && sz_equal(data(), other.first.data(), other.first.size()) == sz_true_k && + sz_equal(data() + other.first.size(), other.second.data(), other.second.size()) == sz_true_k; + } + +#if SZ_DETECT_CPP20 + + /** @brief Computes the lexicographic ordering between this and the ::other string. */ + std::strong_ordering operator<=>(string_view other) const noexcept { + std::strong_ordering orders[3] {std::strong_ordering::less, std::strong_ordering::equal, + std::strong_ordering::greater}; + return orders[compare(other) + 1]; + } + +#else + + /** @brief Checks if the string is not equal to the other string. */ + bool operator!=(string_view other) const noexcept { return !operator==(other); } + + /** @brief Checks if the string is lexicographically smaller than the other string. */ + bool operator<(string_view other) const noexcept { return compare(other) == sz_less_k; } + + /** @brief Checks if the string is lexicographically equal or smaller than the other string. */ + bool operator<=(string_view other) const noexcept { return compare(other) != sz_greater_k; } + + /** @brief Checks if the string is lexicographically greater than the other string. */ + bool operator>(string_view other) const noexcept { return compare(other) == sz_greater_k; } + + /** @brief Checks if the string is lexicographically equal or greater than the other string. */ + bool operator>=(string_view other) const noexcept { return compare(other) != sz_less_k; } + +#endif + +#pragma endregion +#pragma region Prefix and Suffix Comparisons + + /** @brief Checks if the string starts with the other string. */ + bool starts_with(string_view other) const noexcept { + return length_ >= other.length_ && sz_equal(start_, other.start_, other.length_) == sz_true_k; + } + + /** @brief Checks if the string starts with the other string. */ + bool starts_with(const_pointer other) const noexcept { + auto other_length = null_terminated_length(other); + return length_ >= other_length && sz_equal(start_, other, other_length) == sz_true_k; + } + + /** @brief Checks if the string starts with the other character. */ + bool starts_with(value_type other) const noexcept { return length_ && start_[0] == other; } + + /** @brief Checks if the string ends with the other string. */ + bool ends_with(string_view other) const noexcept { + return length_ >= other.length_ && + sz_equal(start_ + length_ - other.length_, other.start_, other.length_) == sz_true_k; + } + + /** @brief Checks if the string ends with the other string. */ + bool ends_with(const_pointer other) const noexcept { + auto other_length = null_terminated_length(other); + return length_ >= other_length && sz_equal(start_ + length_ - other_length, other, other_length) == sz_true_k; + } + + /** @brief Checks if the string ends with the other character. */ + bool ends_with(value_type other) const noexcept { return length_ && start_[length_ - 1] == other; } + + /** @brief Python-like convenience function, dropping the matching prefix. */ + string_slice remove_prefix(string_view other) const noexcept { + return starts_with(other) ? string_slice {start_ + other.length_, length_ - other.length_} : *this; + } + + /** @brief Python-like convenience function, dropping the matching suffix. */ + string_slice remove_suffix(string_view other) const noexcept { + return ends_with(other) ? string_slice {start_, length_ - other.length_} : *this; + } + +#pragma endregion +#pragma endregion + +#pragma region Matching Substrings + + bool contains(string_view other) const noexcept { return find(other) != npos; } + bool contains(value_type character) const noexcept { return find(character) != npos; } + bool contains(const_pointer other) const noexcept { return find(other) != npos; } + +#pragma region Returning offsets + + /** + * @brief Find the first occurrence of a substring, skipping the first `skip` characters. + * The behavior is @b undefined if `skip > size()`. + * @return The offset of the first character of the match, or `npos` if not found. + */ + size_type find(string_view other, size_type skip = 0) const noexcept { + auto ptr = sz_find(start_ + skip, length_ - skip, other.start_, other.length_); + return ptr ? ptr - start_ : npos; + } + + /** + * @brief Find the first occurrence of a character, skipping the first `skip` characters. + * The behavior is @b undefined if `skip > size()`. + * @return The offset of the match, or `npos` if not found. + */ + size_type find(value_type character, size_type skip = 0) const noexcept { + auto ptr = sz_find_byte(start_ + skip, length_ - skip, &character); + return ptr ? ptr - start_ : npos; + } + + /** + * @brief Find the first occurrence of a substring, skipping the first `skip` characters. + * The behavior is @b undefined if `skip > size()`. + * @return The offset of the first character of the match, or `npos` if not found. + */ + size_type find(const_pointer other, size_type pos, size_type count) const noexcept { + return find(string_view(other, count), pos); + } + + /** + * @brief Find the last occurrence of a substring. + * @return The offset of the first character of the match, or `npos` if not found. + */ + size_type rfind(string_view other) const noexcept { + auto ptr = sz_rfind(start_, length_, other.start_, other.length_); + return ptr ? ptr - start_ : npos; + } + + /** + * @brief Find the last occurrence of a substring, within first `until` characters. + * @return The offset of the first character of the match, or `npos` if not found. + */ + size_type rfind(string_view other, size_type until) const noexcept(false) { + return until + other.size() < length_ ? substr(0, until + other.size()).rfind(other) : rfind(other); + } + + /** + * @brief Find the last occurrence of a character. + * @return The offset of the match, or `npos` if not found. + */ + size_type rfind(value_type character) const noexcept { + auto ptr = sz_rfind_byte(start_, length_, &character); + return ptr ? ptr - start_ : npos; + } + + /** + * @brief Find the last occurrence of a character, within first `until` characters. + * @return The offset of the match, or `npos` if not found. + */ + size_type rfind(value_type character, size_type until) const noexcept { + return until < length_ ? substr(0, until + 1).rfind(character) : rfind(character); + } + + /** + * @brief Find the last occurrence of a substring, within first `until` characters. + * @return The offset of the first character of the match, or `npos` if not found. + */ + size_type rfind(const_pointer other, size_type until, size_type count) const noexcept { + return rfind(string_view(other, count), until); + } + + /** @brief Find the first occurrence of a character from a set. */ + size_type find(char_set set) const noexcept { return find_first_of(set); } + + /** @brief Find the last occurrence of a character from a set. */ + size_type rfind(char_set set) const noexcept { return find_last_of(set); } + +#pragma endregion +#pragma region Returning Partitions + + /** @brief Split the string into three parts, before the match, the match itself, and after it. */ + partition_type partition(string_view pattern) const noexcept { return partition_(pattern, pattern.length()); } + + /** @brief Split the string into three parts, before the match, the match itself, and after it. */ + partition_type partition(char_set pattern) const noexcept { return partition_(pattern, 1); } + + /** @brief Split the string into three parts, before the @b last match, the last match itself, and after it. */ + partition_type rpartition(string_view pattern) const noexcept { return rpartition_(pattern, pattern.length()); } + + /** @brief Split the string into three parts, before the @b last match, the last match itself, and after it. */ + partition_type rpartition(char_set pattern) const noexcept { return rpartition_(pattern, 1); } + +#pragma endregion +#pragma endregion + +#pragma region Matching Character Sets + + // `isascii` is a macro in MSVC headers + bool contains_only(char_set set) const noexcept { return find_first_not_of(set) == npos; } + bool is_alpha() const noexcept { return !empty() && contains_only(ascii_letters_set()); } + bool is_alnum() const noexcept { return !empty() && contains_only(ascii_letters_set() | digits_set()); } + bool is_ascii() const noexcept { return empty() || contains_only(ascii_controls_set() | ascii_printables_set()); } + bool is_digit() const noexcept { return !empty() && contains_only(digits_set()); } + bool is_lower() const noexcept { return !empty() && contains_only(ascii_lowercase_set()); } + bool is_space() const noexcept { return !empty() && contains_only(whitespaces_set()); } + bool is_upper() const noexcept { return !empty() && contains_only(ascii_uppercase_set()); } + bool is_printable() const noexcept { return empty() || contains_only(ascii_printables_set()); } + +#pragma region Character Set Arguments + /** + * @brief Find the first occurrence of a character from a set. + * @param skip Number of characters to skip before the search. + * @warning The behavior is @b undefined if `skip > size()`. + */ + size_type find_first_of(char_set set, size_type skip = 0) const noexcept { + auto ptr = sz_find_charset(start_ + skip, length_ - skip, &set.raw()); + return ptr ? ptr - start_ : npos; + } + + /** + * @brief Find the first occurrence of a character outside a set. + * @param skip The number of first characters to be skipped. + * @warning The behavior is @b undefined if `skip > size()`. + */ + size_type find_first_not_of(char_set set, size_type skip = 0) const noexcept { + return find_first_of(set.inverted(), skip); + } + + /** + * @brief Find the last occurrence of a character from a set. + */ + size_type find_last_of(char_set set) const noexcept { + auto ptr = sz_rfind_charset(start_, length_, &set.raw()); + return ptr ? ptr - start_ : npos; + } + + /** + * @brief Find the last occurrence of a character outside a set. + */ + size_type find_last_not_of(char_set set) const noexcept { return find_last_of(set.inverted()); } + + /** + * @brief Find the last occurrence of a character from a set. + * @param until The offset of the last character to be considered. + */ + size_type find_last_of(char_set set, size_type until) const noexcept { + auto len = sz_min_of_two(until + 1, length_); + auto ptr = sz_rfind_charset(start_, len, &set.raw()); + return ptr ? ptr - start_ : npos; + } + + /** + * @brief Find the last occurrence of a character outside a set. + * @param until The offset of the last character to be considered. + */ + size_type find_last_not_of(char_set set, size_type until) const noexcept { + return find_last_of(set.inverted(), until); + } + +#pragma endregion +#pragma region String Arguments + + /** + * @brief Find the first occurrence of a character from a ::set. + * @param skip The number of first characters to be skipped. + */ + size_type find_first_of(string_view other, size_type skip = 0) const noexcept { + return find_first_of(other.as_set(), skip); + } + + /** + * @brief Find the first occurrence of a character outside a ::set. + * @param skip The number of first characters to be skipped. + */ + size_type find_first_not_of(string_view other, size_type skip = 0) const noexcept { + return find_first_not_of(other.as_set(), skip); + } + + /** + * @brief Find the last occurrence of a character from a ::set. + * @param until The offset of the last character to be considered. + */ + size_type find_last_of(string_view other, size_type until = npos) const noexcept { + return find_last_of(other.as_set(), until); + } + + /** + * @brief Find the last occurrence of a character outside a ::set. + * @param until The offset of the last character to be considered. + */ + size_type find_last_not_of(string_view other, size_type until = npos) const noexcept { + return find_last_not_of(other.as_set(), until); + } + +#pragma endregion +#pragma region C-Style Arguments + + /** + * @brief Find the first occurrence of a character from a set. + * @param skip The number of first characters to be skipped. + * @warning The behavior is @b undefined if `skip > size()`. + */ + size_type find_first_of(const_pointer other, size_type skip, size_type count) const noexcept { + return find_first_of(string_view(other, count), skip); + } + + /** + * @brief Find the first occurrence of a character outside a set. + * @param skip The number of first characters to be skipped. + * @warning The behavior is @b undefined if `skip > size()`. + */ + size_type find_first_not_of(const_pointer other, size_type skip, size_type count) const noexcept { + return find_first_not_of(string_view(other, count), skip); + } + + /** + * @brief Find the last occurrence of a character from a set. + * @param until The number of first characters to be considered. + */ + size_type find_last_of(const_pointer other, size_type until, size_type count) const noexcept { + return find_last_of(string_view(other, count), until); + } + + /** + * @brief Find the last occurrence of a character outside a set. + * @param until The number of first characters to be considered. + */ + size_type find_last_not_of(const_pointer other, size_type until, size_type count) const noexcept { + return find_last_not_of(string_view(other, count), until); + } + +#pragma endregion +#pragma region Slicing + + /** + * @brief Python-like convenience function, dropping prefix formed of given characters. + * Similar to `boost::algorithm::trim_left_if(str, is_any_of(set))`. + */ + string_slice lstrip(char_set set) const noexcept { + set = set.inverted(); + auto new_start = sz_find_charset(start_, length_, &set.raw()); + return new_start ? string_slice {new_start, length_ - static_cast(new_start - start_)} + : string_slice(); + } + + /** + * @brief Python-like convenience function, dropping suffix formed of given characters. + * Similar to `boost::algorithm::trim_right_if(str, is_any_of(set))`. + */ + string_slice rstrip(char_set set) const noexcept { + set = set.inverted(); + auto new_end = sz_rfind_charset(start_, length_, &set.raw()); + return new_end ? string_slice {start_, static_cast(new_end - start_ + 1)} : string_slice(); + } + + /** + * @brief Python-like convenience function, dropping both the prefix & the suffix formed of given characters. + * Similar to `boost::algorithm::trim_if(str, is_any_of(set))`. + */ + string_slice strip(char_set set) const noexcept { + set = set.inverted(); + auto new_start = sz_find_charset(start_, length_, &set.raw()); + return new_start ? string_slice {new_start, + static_cast( + sz_rfind_charset(new_start, length_ - (new_start - start_), &set.raw()) - + new_start + 1)} + : string_slice(); + } + +#pragma endregion +#pragma endregion + +#pragma region Search Ranges + + using find_all_type = range_matches>; + using rfind_all_type = range_rmatches>; + + using find_disjoint_type = range_matches>; + using rfind_disjoint_type = range_rmatches>; + + using find_all_chars_type = range_matches>; + using rfind_all_chars_type = range_rmatches>; + + /** @brief Find all potentially @b overlapping occurrences of a given string. */ + find_all_type find_all(string_view needle, include_overlaps_type = {}) const noexcept { return {*this, needle}; } + + /** @brief Find all potentially @b overlapping occurrences of a given string in @b reverse order. */ + rfind_all_type rfind_all(string_view needle, include_overlaps_type = {}) const noexcept { return {*this, needle}; } + + /** @brief Find all @b non-overlapping occurrences of a given string. */ + find_disjoint_type find_all(string_view needle, exclude_overlaps_type) const noexcept { return {*this, needle}; } + + /** @brief Find all @b non-overlapping occurrences of a given string in @b reverse order. */ + rfind_disjoint_type rfind_all(string_view needle, exclude_overlaps_type) const noexcept { return {*this, needle}; } + + /** @brief Find all occurrences of given characters. */ + find_all_chars_type find_all(char_set set) const noexcept { return {*this, {set}}; } + + /** @brief Find all occurrences of given characters in @b reverse order. */ + rfind_all_chars_type rfind_all(char_set set) const noexcept { return {*this, {set}}; } + + using split_type = range_splits>; + using rsplit_type = range_rsplits>; + + using split_chars_type = range_splits>; + using rsplit_chars_type = range_rsplits>; + + /** @brief Split around occurrences of a given string. */ + split_type split(string_view delimiter) const noexcept { return {*this, delimiter}; } + + /** @brief Split around occurrences of a given string in @b reverse order. */ + rsplit_type rsplit(string_view delimiter) const noexcept { return {*this, delimiter}; } + + /** @brief Split around occurrences of given characters. */ + split_chars_type split(char_set set = whitespaces_set()) const noexcept { return {*this, {set}}; } + + /** @brief Split around occurrences of given characters in @b reverse order. */ + rsplit_chars_type rsplit(char_set set = whitespaces_set()) const noexcept { return {*this, {set}}; } + + /** @brief Split around the occurrences of all newline characters. */ + split_chars_type splitlines() const noexcept { return split(newlines_set); } + +#pragma endregion + + /** @brief Hashes the string, equivalent to `std::hash{}(str)`. */ + size_type hash() const noexcept { return static_cast(sz_hash(start_, length_)); } + + /** @brief Populate a character set with characters present in this string. */ + char_set as_set() const noexcept { + char_set set; + for (auto c : *this) set.add(c); + return set; + } + + private: + sz_constexpr_if_cpp20 string_view &assign(string_view const &other) noexcept { + start_ = other.start_; + length_ = other.length_; + return *this; + } + + sz_constexpr_if_cpp20 static size_type null_terminated_length(const_pointer s) noexcept { + const_pointer p = s; + while (*p) ++p; + return p - s; + } + + template + partition_type partition_(pattern_ &&pattern, std::size_t pattern_length) const noexcept { + size_type pos = find(pattern); + if (pos == npos) return {*this, string_view(), string_view()}; + return {string_view(start_, pos), string_view(start_ + pos, pattern_length), + string_view(start_ + pos + pattern_length, length_ - pos - pattern_length)}; + } + + template + partition_type rpartition_(pattern_ &&pattern, std::size_t pattern_length) const noexcept { + size_type pos = rfind(pattern); + if (pos == npos) return {*this, string_view(), string_view()}; + return {string_view(start_, pos), string_view(start_ + pos, pattern_length), + string_view(start_ + pos + pattern_length, length_ - pos - pattern_length)}; + } +}; + +#pragma endregion + +/** + * @brief Memory-owning string class with a Small String Optimization. + * + * @section API + * + * Some APIs are different from `basic_string_slice`: + * * `lstrip`, `rstrip`, `strip` modify the string in-place, instead of returning a new view. + * * `sat`, `sub`, and element access has non-const overloads returning references to mutable objects. + * + * Functions defined for `basic_string`, but not present in `basic_string_slice`: + * * `replace`, `insert`, `erase`, `append`, `push_back`, `pop_back`, `resize`, `shrink_to_fit`... from STL, + * * `try_` exception-free "try" operations that returning non-zero values on succces, + * * `replace_all` and `erase_all` similar to Boost, + * * `edit_distance` - Levenshtein distance computation reusing the allocator, + * * `randomize`, `random` - for fast random string generation. + * + * Functions defined for `basic_string_slice`, but not present in `basic_string`: + * * `[r]partition`, `[r]split`, `[r]find_all` missing to enforce lifetime on long operations. + * * `remove_prefix`, `remove_suffix` for now. + * + * @section Exceptions + * + * Default constructor is `constexpr`. Move constructor and move assignment operator are `noexcept`. + * Copy constructor and copy assignment operator are not! They may throw `std::bad_alloc` if the memory + * allocation fails. Similar to STL `std::out_of_range` if the position argument to some of the functions + * is out of bounds. Same as with STL, the bound checks are often assymetric, so pay attention to docs. + * If exceptions are disabled, on failure, `std::terminate` is called. + */ +template > +class basic_string { + + static_assert(sizeof(char_type_) == 1, "Characters must be a single byte long"); + static_assert(std::is_reference::value == false, "Characters can't be references"); + static_assert(std::is_const::value == false, "Characters must be mutable"); + + using char_type = char_type_; + using sz_alloc_type = sz_memory_allocator_t; + + sz_string_t string_; + + /** + * Stateful allocators and their support in C++ strings is extremely error-prone by design. + * Depending on traits like `propagate_on_container_copy_assignment` and `propagate_on_container_move_assignment`, + * its state will be copied from one string to another. It goes against the design of most string constructors, + * as they also receive allocator as the last argument! + */ + static_assert(std::is_empty::value, "We currently only support stateless allocators"); + + template + static bool _with_alloc(allocator_callback &&callback) noexcept { + return ashvardanian::stringzilla::_with_alloc(callback); + } + + bool is_internal() const noexcept { return sz_string_is_on_stack(&string_); } + + void init(std::size_t length, char_type value) noexcept(false) { + sz_ptr_t start; + if (!_with_alloc( + [&](sz_alloc_type &alloc) { return (start = sz_string_init_length(&string_, length, &alloc)); })) + throw std::bad_alloc(); + sz_fill(start, length, *(sz_u8_t *)&value); + } + + void init(string_view other) noexcept(false) { + sz_ptr_t start; + if (!_with_alloc( + [&](sz_alloc_type &alloc) { return (start = sz_string_init_length(&string_, other.size(), &alloc)); })) + throw std::bad_alloc(); + sz_copy(start, (sz_cptr_t)other.data(), other.size()); + } + + void move(basic_string &other) noexcept { + // We can't just assign the other string state, as its start address may be somewhere else on the stack. + sz_ptr_t string_start; + sz_size_t string_length; + sz_size_t string_space; + sz_bool_t string_is_external; + sz_string_unpack(&other.string_, &string_start, &string_length, &string_space, &string_is_external); + + // Acquire the old string's value bitwise + *(&string_) = *(&other.string_); + // Reposition the string start pointer to the stack if it fits. + // Ternary condition may be optimized to a branchless version. + string_.internal.start = string_is_external ? string_.internal.start : &string_.internal.chars[0]; + sz_string_init(&other.string_); // Discard the other string. + } + + public: + // STL compatibility + using traits_type = std::char_traits; + using value_type = char_type; + using pointer = char_type *; + using const_pointer = char_type const *; + using reference = char_type &; + using const_reference = char_type const &; + using const_iterator = const_pointer; + using iterator = pointer; + using const_reverse_iterator = reversed_iterator_for; + using reverse_iterator = reversed_iterator_for; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + + // Non-STL type definitions + using allocator_type = allocator_type_; + using string_span = basic_string_slice; + using string_view = basic_string_slice::type>; + using partition_type = string_partition_result; + + /** @brief Special value for missing matches. + * We take the largest 63-bit unsigned integer. + */ + static constexpr size_type npos = 0x7FFFFFFFFFFFFFFFull; + +#pragma region Constructors and STL Utilities + + sz_constexpr_if_cpp20 basic_string() noexcept { + // ! Instead of relying on the `sz_string_init`, we have to reimplement it to support `constexpr`. + string_.internal.start = &string_.internal.chars[0]; + string_.u64s[1] = 0; + string_.u64s[2] = 0; + string_.u64s[3] = 0; + } + + ~basic_string() noexcept { + _with_alloc([&](sz_alloc_type &alloc) { + sz_string_free(&string_, &alloc); + return true; + }); + } + + basic_string(basic_string &&other) noexcept { move(other); } + basic_string &operator=(basic_string &&other) noexcept { + if (!is_internal()) { + _with_alloc([&](sz_alloc_type &alloc) { + sz_string_free(&string_, &alloc); + return true; + }); + } + move(other); + return *this; + } + + basic_string(basic_string const &other) noexcept(false) { init(other); } + basic_string &operator=(basic_string const &other) noexcept(false) { return assign(other); } + basic_string(string_view view) noexcept(false) { init(view); } + basic_string &operator=(string_view view) noexcept(false) { return assign(view); } + + basic_string(const_pointer c_string) noexcept(false) : basic_string(string_view(c_string)) {} + basic_string(const_pointer c_string, size_type length) noexcept(false) + : basic_string(string_view(c_string, length)) {} + basic_string &operator=(const_pointer other) noexcept(false) { return assign(string_view(other)); } + + basic_string(std::nullptr_t) = delete; + + /** @brief Construct a string by repeating a certain ::character ::count times. */ + basic_string(size_type count, value_type character) noexcept(false) { init(count, character); } + + basic_string(basic_string const &other, size_type pos) noexcept(false) { init(string_view(other).substr(pos)); } + basic_string(basic_string const &other, size_type pos, size_type count) noexcept(false) { + init(string_view(other).substr(pos, count)); + } + + basic_string(std::initializer_list list) noexcept(false) { + init(string_view(list.begin(), list.size())); + } + + operator string_view() const noexcept { return view(); } + string_view view() const noexcept { + sz_ptr_t string_start; + sz_size_t string_length; + sz_string_range(&string_, &string_start, &string_length); + return {string_start, string_length}; + } + + operator string_span() noexcept { return span(); } + string_span span() noexcept { + sz_ptr_t string_start; + sz_size_t string_length; + sz_string_range(&string_, &string_start, &string_length); + return {string_start, string_length}; + } + + /** @brief Exchanges the string contents witt the `other` string. */ + void swap(basic_string &other) noexcept { + // If at least one of the strings is on the stack, a basic `std::swap(string_, other.string_)` won't work, + // as the pointer to the stack-allocated memory will be swapped, instead of the contents. + sz_ptr_t first_start, second_start; + sz_size_t first_length, second_length; + sz_size_t first_space, second_space; + sz_bool_t first_is_external, second_is_external; + sz_string_unpack(&string_, &first_start, &first_length, &first_space, &first_is_external); + sz_string_unpack(&other.string_, &second_start, &second_length, &second_space, &second_is_external); + std::swap(string_, other.string_); + if (!first_is_external) other.string_.internal.start = &other.string_.internal.chars[0]; + if (!second_is_external) string_.internal.start = &string_.internal.chars[0]; + } + +#if !SZ_AVOID_STL + + basic_string(std::string const &other) noexcept(false) : basic_string(other.data(), other.size()) {} + basic_string &operator=(std::string const &other) noexcept(false) { return assign({other.data(), other.size()}); } + + // As we are need both `data()` and `size()`, going through `operator string_view()` + // and `sz_string_unpack` is faster than separate invocations. + operator std::string() const { return view(); } + + /** + * @brief Formatted output function for compatibility with STL's `std::basic_ostream`. + * @throw `std::ios_base::failure` if an exception occurred during output. + */ + template + friend std::basic_ostream &operator<<(std::basic_ostream &os, + basic_string const &str) noexcept(false) { + return os.write(str.data(), str.size()); + } + +#if SZ_DETECT_CPP_17 && __cpp_lib_string_view + + basic_string(std::string_view other) noexcept(false) : basic_string(other.data(), other.size()) {} + basic_string &operator=(std::string_view other) noexcept(false) { return assign({other.data(), other.size()}); } + operator std::string_view() const noexcept { return view(); } + +#endif + +#endif + + template + explicit basic_string(concatenation const &expression) noexcept(false) { + _with_alloc([&](sz_alloc_type &alloc) { + sz_ptr_t ptr = sz_string_init_length(&string_, expression.length(), &alloc); + if (!ptr) return false; + expression.copy(ptr); + return true; + }); + } + + template + basic_string &operator=(concatenation const &expression) noexcept(false) { + if (!try_assign(expression)) throw std::bad_alloc(); + return *this; + } + +#pragma endregion + +#pragma region Iterators and Accessors + + iterator begin() noexcept { return iterator(data()); } + const_iterator begin() const noexcept { return const_iterator(data()); } + const_iterator cbegin() const noexcept { return const_iterator(data()); } + + // As we are need both `data()` and `size()`, going through `operator string_view()` + // and `sz_string_unpack` is faster than separate invocations. + iterator end() noexcept { return span().end(); } + const_iterator end() const noexcept { return view().end(); } + const_iterator cend() const noexcept { return view().end(); } + + reverse_iterator rbegin() noexcept { return span().rbegin(); } + const_reverse_iterator rbegin() const noexcept { return view().rbegin(); } + const_reverse_iterator crbegin() const noexcept { return view().crbegin(); } + + reverse_iterator rend() noexcept { return span().rend(); } + const_reverse_iterator rend() const noexcept { return view().rend(); } + const_reverse_iterator crend() const noexcept { return view().crend(); } + + reference operator[](size_type pos) noexcept { return string_.internal.start[pos]; } + const_reference operator[](size_type pos) const noexcept { return string_.internal.start[pos]; } + + reference front() noexcept { return string_.internal.start[0]; } + const_reference front() const noexcept { return string_.internal.start[0]; } + reference back() noexcept { return string_.internal.start[size() - 1]; } + const_reference back() const noexcept { return string_.internal.start[size() - 1]; } + pointer data() noexcept { return string_.internal.start; } + const_pointer data() const noexcept { return string_.internal.start; } + pointer c_str() noexcept { return string_.internal.start; } + const_pointer c_str() const noexcept { return string_.internal.start; } + + reference at(size_type pos) noexcept(false) { + if (pos >= size()) throw std::out_of_range("sz::basic_string::at"); + return string_.internal.start[pos]; + } + const_reference at(size_type pos) const noexcept(false) { + if (pos >= size()) throw std::out_of_range("sz::basic_string::at"); + return string_.internal.start[pos]; + } + + difference_type ssize() const noexcept { return static_cast(size()); } + size_type size() const noexcept { return view().size(); } + size_type length() const noexcept { return size(); } + size_type max_size() const noexcept { return npos - 1; } + bool empty() const noexcept { return string_.external.length == 0; } + size_type capacity() const noexcept { + sz_ptr_t string_start; + sz_size_t string_length; + sz_size_t string_space; + sz_bool_t string_is_external; + sz_string_unpack(&string_, &string_start, &string_length, &string_space, &string_is_external); + return string_space - 1; + } + + allocator_type get_allocator() const noexcept { return {}; } + +#pragma endregion + +#pragma region Slicing + +#pragma region Safe and Signed Extensions + + /** + * @brief Equivalent to Python's `"abc"[-3:-1]`. Exception-safe, unlike STL's `substr`. + * Supports signed and unsigned intervals. + */ + string_view operator[](std::initializer_list offsets) const noexcept { return view()[offsets]; } + string_span operator[](std::initializer_list offsets) noexcept { return span()[offsets]; } + + /** + * @brief Signed alternative to `at()`. Handy if you often write `str[str.size() - 2]`. + * @warning The behavior is @b undefined if the position is beyond bounds. + */ + value_type sat(difference_type offset) const noexcept { return view().sat(offset); } + reference sat(difference_type offset) noexcept { return span().sat(offset); } + + /** + * @brief The opposite operation to `remove_prefix`, that does no bounds checking. + * @warning The behavior is @b undefined if `n > size()`. + */ + string_view front(size_type n) const noexcept { return view().front(n); } + string_span front(size_type n) noexcept { return span().front(n); } + + /** + * @brief The opposite operation to `remove_prefix`, that does no bounds checking. + * @warning The behavior is @b undefined if `n > size()`. + */ + string_view back(size_type n) const noexcept { return view().back(n); } + string_span back(size_type n) noexcept { return span().back(n); } + + /** + * @brief Equivalent to Python's `"abc"[-3:-1]`. Exception-safe, unlike STL's `substr`. + * Supports signed and unsigned intervals. @b Doesn't copy or allocate memory! + */ + string_view sub(difference_type start, difference_type end = npos) const noexcept { return view().sub(start, end); } + string_span sub(difference_type start, difference_type end = npos) noexcept { return span().sub(start, end); } + + /** + * @brief Exports this entire view. Not an STL function, but useful for concatenations. + * The STL variant expects at least two arguments. + */ + size_type copy(value_type *destination) const noexcept { return view().copy(destination); } + +#pragma endregion + +#pragma region STL Style + + /** + * @brief Removes the first `n` characters from the view. + * @warning The behavior is @b undefined if `n > size()`. + */ + void remove_prefix(size_type n) noexcept { + assert(n <= size()); + sz_string_erase(&string_, 0, n); + } + + /** + * @brief Removes the last `n` characters from the view. + * @warning The behavior is @b undefined if `n > size()`. + */ + void remove_suffix(size_type n) noexcept { + assert(n <= size()); + sz_string_erase(&string_, size() - n, n); + } + + /** @brief Added for STL compatibility. */ + basic_string substr() const noexcept { return *this; } + + /** + * @brief Return a slice of this view after first `skip` bytes. + * @throws `std::out_of_range` if `skip > size()`. + * @see `sub` for a cleaner exception-less alternative. + */ + basic_string substr(size_type skip) const noexcept(false) { return view().substr(skip); } + + /** + * @brief Return a slice of this view after first `skip` bytes, taking at most `count` bytes. + * @throws `std::out_of_range` if `skip > size()`. + * @see `sub` for a cleaner exception-less alternative. + */ + basic_string substr(size_type skip, size_type count) const noexcept(false) { return view().substr(skip, count); } + + /** + * @brief Exports a slice of this view after first `skip` bytes, taking at most `count` bytes. + * @throws `std::out_of_range` if `skip > size()`. + * @see `sub` for a cleaner exception-less alternative. + */ + size_type copy(value_type *destination, size_type count, size_type skip = 0) const noexcept(false) { + return view().copy(destination, count, skip); + } + +#pragma endregion + +#pragma endregion + +#pragma region Comparisons + +#pragma region Whole String Comparisons + + /** + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + */ + int compare(string_view other) const noexcept { return view().compare(other); } + + /** + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * Equivalent to `substr(pos1, count1).compare(other)`. + * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + * @throw `std::out_of_range` if `pos1 > size()`. + */ + int compare(size_type pos1, size_type count1, string_view other) const noexcept(false) { + return view().compare(pos1, count1, other); + } + + /** + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * Equivalent to `substr(pos1, count1).compare(other.substr(pos2, count2))`. + * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + * @throw `std::out_of_range` if `pos1 > size()` or if `pos2 > other.size()`. + */ + int compare(size_type pos1, size_type count1, string_view other, size_type pos2, size_type count2) const + noexcept(false) { + return view().compare(pos1, count1, other, pos2, count2); + } + + /** + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + */ + int compare(const_pointer other) const noexcept { return view().compare(other); } + + /** + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * Equivalent to substr(pos1, count1).compare(other). + * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + * @throw `std::out_of_range` if `pos1 > size()`. + */ + int compare(size_type pos1, size_type count1, const_pointer other) const noexcept(false) { + return view().compare(pos1, count1, other); + } + + /** + * @brief Compares two strings lexicographically. If prefix matches, lengths are compared. + * Equivalent to `substr(pos1, count1).compare({s, count2})`. + * @return 0 if equal, negative if `*this` is less than `other`, positive if `*this` is greater than `other`. + * @throw `std::out_of_range` if `pos1 > size()`. + */ + int compare(size_type pos1, size_type count1, const_pointer other, size_type count2) const noexcept(false) { + return view().compare(pos1, count1, other, count2); + } + + /** @brief Checks if the string is equal to the other string. */ + bool operator==(basic_string const &other) const noexcept { return view() == other.view(); } + bool operator==(string_view other) const noexcept { return view() == other; } + bool operator==(const_pointer other) const noexcept { return view() == string_view(other); } + +#if SZ_DETECT_CPP20 + + /** @brief Computes the lexicographic ordering between this and the ::other string. */ + std::strong_ordering operator<=>(basic_string const &other) const noexcept { return view() <=> other.view(); } + std::strong_ordering operator<=>(string_view other) const noexcept { return view() <=> other; } + std::strong_ordering operator<=>(const_pointer other) const noexcept { return view() <=> string_view(other); } + +#else + + /** @brief Checks if the string is not equal to the other string. */ + bool operator!=(string_view other) const noexcept { return !operator==(other); } + + /** @brief Checks if the string is lexicographically smaller than the other string. */ + bool operator<(string_view other) const noexcept { return compare(other) == sz_less_k; } + + /** @brief Checks if the string is lexicographically equal or smaller than the other string. */ + bool operator<=(string_view other) const noexcept { return compare(other) != sz_greater_k; } + + /** @brief Checks if the string is lexicographically greater than the other string. */ + bool operator>(string_view other) const noexcept { return compare(other) == sz_greater_k; } + + /** @brief Checks if the string is lexicographically equal or greater than the other string. */ + bool operator>=(string_view other) const noexcept { return compare(other) != sz_less_k; } + +#endif + +#pragma endregion +#pragma region Prefix and Suffix Comparisons + + /** @brief Checks if the string starts with the other string. */ + bool starts_with(string_view other) const noexcept { return view().starts_with(other); } + + /** @brief Checks if the string starts with the other string. */ + bool starts_with(const_pointer other) const noexcept { return view().starts_with(other); } + + /** @brief Checks if the string starts with the other character. */ + bool starts_with(value_type other) const noexcept { return view().starts_with(other); } + + /** @brief Checks if the string ends with the other string. */ + bool ends_with(string_view other) const noexcept { return view().ends_with(other); } + + /** @brief Checks if the string ends with the other string. */ + bool ends_with(const_pointer other) const noexcept { return view().ends_with(other); } + + /** @brief Checks if the string ends with the other character. */ + bool ends_with(value_type other) const noexcept { return view().ends_with(other); } + +#pragma endregion +#pragma endregion + +#pragma region Matching Substrings + + bool contains(string_view other) const noexcept { return view().contains(other); } + bool contains(value_type character) const noexcept { return view().contains(character); } + bool contains(const_pointer other) const noexcept { return view().contains(other); } + +#pragma region Returning offsets + + /** + * @brief Find the first occurrence of a substring, skipping the first `skip` characters. + * The behavior is @b undefined if `skip > size()`. + * @return The offset of the first character of the match, or `npos` if not found. + */ + size_type find(string_view other, size_type skip = 0) const noexcept { return view().find(other, skip); } + + /** + * @brief Find the first occurrence of a character, skipping the first `skip` characters. + * The behavior is @b undefined if `skip > size()`. + * @return The offset of the match, or `npos` if not found. + */ + size_type find(value_type character, size_type skip = 0) const noexcept { return view().find(character, skip); } + + /** + * @brief Find the first occurrence of a substring, skipping the first `skip` characters. + * The behavior is @b undefined if `skip > size()`. + * @return The offset of the first character of the match, or `npos` if not found. + */ + size_type find(const_pointer other, size_type pos, size_type count) const noexcept { + return view().find(other, pos, count); + } + + /** + * @brief Find the last occurrence of a substring. + * @return The offset of the first character of the match, or `npos` if not found. + */ + size_type rfind(string_view other) const noexcept { return view().rfind(other); } + + /** + * @brief Find the last occurrence of a substring, within first `until` characters. + * @return The offset of the first character of the match, or `npos` if not found. + */ + size_type rfind(string_view other, size_type until) const noexcept { return view().rfind(other, until); } + + /** + * @brief Find the last occurrence of a character. + * @return The offset of the match, or `npos` if not found. + */ + size_type rfind(value_type character) const noexcept { return view().rfind(character); } + + /** + * @brief Find the last occurrence of a character, within first `until` characters. + * @return The offset of the match, or `npos` if not found. + */ + size_type rfind(value_type character, size_type until) const noexcept { return view().rfind(character, until); } + + /** + * @brief Find the last occurrence of a substring, within first `until` characters. + * @return The offset of the first character of the match, or `npos` if not found. + */ + size_type rfind(const_pointer other, size_type until, size_type count) const noexcept { + return view().rfind(other, until, count); + } + + /** @brief Find the first occurrence of a character from a set. */ + size_type find(char_set set) const noexcept { return view().find(set); } + + /** @brief Find the last occurrence of a character from a set. */ + size_type rfind(char_set set) const noexcept { return view().rfind(set); } + +#pragma endregion +#pragma endregion + +#pragma region Matching Character Sets + + bool contains_only(char_set set) const noexcept { return find_first_not_of(set) == npos; } + bool is_alpha() const noexcept { return !empty() && contains_only(ascii_letters_set()); } + bool is_alnum() const noexcept { return !empty() && contains_only(ascii_letters_set() | digits_set()); } + bool is_ascii() const noexcept { return empty() || contains_only(ascii_controls_set() | ascii_printables_set()); } + bool is_digit() const noexcept { return !empty() && contains_only(digits_set()); } + bool is_lower() const noexcept { return !empty() && contains_only(ascii_lowercase_set()); } + bool is_space() const noexcept { return !empty() && contains_only(whitespaces_set()); } + bool is_upper() const noexcept { return !empty() && contains_only(ascii_uppercase_set()); } + bool is_printable() const noexcept { return empty() || contains_only(ascii_printables_set()); } + +#pragma region Character Set Arguments + + /** + * @brief Find the first occurrence of a character from a set. + * @param skip Number of characters to skip before the search. + * @warning The behavior is @b undefined if `skip > size()`. + */ + size_type find_first_of(char_set set, size_type skip = 0) const noexcept { return view().find_first_of(set, skip); } + + /** + * @brief Find the first occurrence of a character outside a set. + * @param skip The number of first characters to be skipped. + * @warning The behavior is @b undefined if `skip > size()`. + */ + size_type find_first_not_of(char_set set, size_type skip = 0) const noexcept { + return view().find_first_not_of(set, skip); + } + + /** + * @brief Find the last occurrence of a character from a set. + */ + size_type find_last_of(char_set set) const noexcept { return view().find_last_of(set); } + + /** + * @brief Find the last occurrence of a character outside a set. + */ + size_type find_last_not_of(char_set set) const noexcept { return view().find_last_not_of(set); } + + /** + * @brief Find the last occurrence of a character from a set. + * @param until The offset of the last character to be considered. + */ + size_type find_last_of(char_set set, size_type until) const noexcept { return view().find_last_of(set, until); } + + /** + * @brief Find the last occurrence of a character outside a set. + * @param until The offset of the last character to be considered. + */ + size_type find_last_not_of(char_set set, size_type until) const noexcept { + return view().find_last_not_of(set, until); + } + +#pragma endregion +#pragma region String Arguments + + /** + * @brief Find the first occurrence of a character from a ::set. + * @param skip The number of first characters to be skipped. + */ + size_type find_first_of(string_view other, size_type skip = 0) const noexcept { + return view().find_first_of(other, skip); + } + + /** + * @brief Find the first occurrence of a character outside a ::set. + * @param skip The number of first characters to be skipped. + */ + size_type find_first_not_of(string_view other, size_type skip = 0) const noexcept { + return view().find_first_not_of(other, skip); + } + + /** + * @brief Find the last occurrence of a character from a ::set. + * @param until The offset of the last character to be considered. + */ + size_type find_last_of(string_view other, size_type until = npos) const noexcept { + return view().find_last_of(other, until); + } + + /** + * @brief Find the last occurrence of a character outside a ::set. + * @param until The offset of the last character to be considered. + */ + size_type find_last_not_of(string_view other, size_type until = npos) const noexcept { + return view().find_last_not_of(other, until); + } + +#pragma endregion +#pragma region C-Style Arguments + + /** + * @brief Find the first occurrence of a character from a set. + * @param skip The number of first characters to be skipped. + * @warning The behavior is @b undefined if `skip > size()`. + */ + size_type find_first_of(const_pointer other, size_type skip, size_type count) const noexcept { + return view().find_first_of(other, skip, count); + } + + /** + * @brief Find the first occurrence of a character outside a set. + * @param skip The number of first characters to be skipped. + * @warning The behavior is @b undefined if `skip > size()`. + */ + size_type find_first_not_of(const_pointer other, size_type skip, size_type count) const noexcept { + return view().find_first_not_of(other, skip, count); + } + + /** + * @brief Find the last occurrence of a character from a set. + * @param until The number of first characters to be considered. + */ + size_type find_last_of(const_pointer other, size_type until, size_type count) const noexcept { + return view().find_last_of(other, until, count); + } + + /** + * @brief Find the last occurrence of a character outside a set. + * @param until The number of first characters to be considered. + */ + size_type find_last_not_of(const_pointer other, size_type until, size_type count) const noexcept { + return view().find_last_not_of(other, until, count); + } + +#pragma endregion +#pragma region Slicing + + /** + * @brief Python-like convenience function, dropping prefix formed of given characters. + * Similar to `boost::algorithm::trim_left_if(str, is_any_of(set))`. + */ + basic_string &lstrip(char_set set) noexcept { + auto remaining = view().lstrip(set); + remove_prefix(size() - remaining.size()); + return *this; + } + + /** + * @brief Python-like convenience function, dropping suffix formed of given characters. + * Similar to `boost::algorithm::trim_right_if(str, is_any_of(set))`. + */ + basic_string &rstrip(char_set set) noexcept { + auto remaining = view().rstrip(set); + remove_suffix(size() - remaining.size()); + return *this; + } + + /** + * @brief Python-like convenience function, dropping both the prefix & the suffix formed of given characters. + * Similar to `boost::algorithm::trim_if(str, is_any_of(set))`. + */ + basic_string &strip(char_set set) noexcept { return lstrip(set).rstrip(set); } + +#pragma endregion +#pragma endregion + +#pragma region Modifiers +#pragma region Non-STL API + + bool try_resize(size_type count, value_type character = '\0') noexcept; + + bool try_reserve(size_type capacity) noexcept { + return _with_alloc([&](sz_alloc_type &alloc) { return sz_string_reserve(&string_, capacity, &alloc); }); + } + + bool try_assign(string_view other) noexcept; + + template + bool try_assign(concatenation const &other) noexcept; + + bool try_push_back(char_type c) noexcept; + + bool try_append(const_pointer str, size_type length) noexcept; + + bool try_append(string_view str) noexcept { return try_append(str.data(), str.size()); } + + /** + * @brief Erases ( @b in-place ) a range of characters defined with signed offsets. + * @return Number of characters removed. + */ + size_type try_erase(difference_type signed_start_offset = 0, difference_type signed_end_offset = npos) noexcept { + sz_size_t normalized_offset, normalized_length; + sz_ssize_clamp_interval(size(), signed_start_offset, signed_end_offset, &normalized_offset, &normalized_length); + if (!normalized_length) return false; + sz_string_erase(&string_, normalized_offset, normalized_length); + return normalized_length; + } + + /** + * @brief Inserts ( @b in-place ) a range of characters at a given signed offset. + * @return `true` if the insertion was successful, `false` otherwise. + */ + bool try_insert(difference_type signed_offset, string_view string) noexcept { + sz_size_t normalized_offset, normalized_length; + sz_ssize_clamp_interval(size(), signed_offset, 0, &normalized_offset, &normalized_length); + if (!_with_alloc([&](sz_alloc_type &alloc) { + return sz_string_expand(&string_, normalized_offset, string.size(), &alloc); + })) + return false; + + sz_copy(data() + normalized_offset, string.data(), string.size()); + return true; + } + + /** + * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @return `true` if the replacement was successful, `false` otherwise. + */ + bool try_replace(difference_type signed_start_offset, difference_type signed_end_offset, + string_view replacement) noexcept { + + sz_size_t normalized_offset, normalized_length; + sz_ssize_clamp_interval(size(), signed_start_offset, signed_end_offset, &normalized_offset, &normalized_length); + if (!try_preparing_replacement(normalized_offset, normalized_length, replacement)) return false; + sz_copy(data() + normalized_offset, replacement.data(), replacement.size()); + return true; + } + +#pragma endregion + +#pragma region STL Interfaces + + /** + * @brief Clears the string contents, but @b no deallocations happen. + */ + void clear() noexcept { sz_string_erase(&string_, 0, SZ_SIZE_MAX); } + + /** + * @brief Resizes the string to the given size, filling the new space with the given character, + * or NULL-character if nothing is provided. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + */ + void resize(size_type count, value_type character = '\0') noexcept(false) { + if (count > max_size()) throw std::length_error("sz::basic_string::resize"); + if (!try_resize(count, character)) throw std::bad_alloc(); + } + + /** + * @brief Informs the string object of a planned change in size, so that it pre-allocate once. + * @throw `std::length_error` if the string is too long. + */ + void reserve(size_type capacity) noexcept(false) { + if (capacity > max_size()) throw std::length_error("sz::basic_string::reserve"); + if (!try_reserve(capacity)) throw std::bad_alloc(); + } + + /** + * @brief Inserts ( @b in-place ) a ::character multiple times at the given offset. + * @throw `std::out_of_range` if `offset > size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + */ + basic_string &insert(size_type offset, size_type repeats, char_type character) noexcept(false) { + if (offset > size()) throw std::out_of_range("sz::basic_string::insert"); + if (size() + repeats > max_size()) throw std::length_error("sz::basic_string::insert"); + if (!_with_alloc([&](sz_alloc_type &alloc) { return sz_string_expand(&string_, offset, repeats, &alloc); })) + throw std::bad_alloc(); + + sz_fill(data() + offset, repeats, character); + return *this; + } + + /** + * @brief Inserts ( @b in-place ) a range of characters at the given offset. + * @throw `std::out_of_range` if `offset > size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + */ + basic_string &insert(size_type offset, string_view other) noexcept(false) { + if (offset > size()) throw std::out_of_range("sz::basic_string::insert"); + if (size() + other.size() > max_size()) throw std::length_error("sz::basic_string::insert"); + if (!_with_alloc( + [&](sz_alloc_type &alloc) { return sz_string_expand(&string_, offset, other.size(), &alloc); })) + throw std::bad_alloc(); + + sz_copy(data() + offset, other.data(), other.size()); + return *this; + } + + /** + * @brief Inserts ( @b in-place ) a range of characters at the given offset. + * @throw `std::out_of_range` if `offset > size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + */ + basic_string &insert(size_type offset, const_pointer start, size_type length) noexcept(false) { + return insert(offset, string_view(start, length)); + } + + /** + * @brief Inserts ( @b in-place ) a slice of another string at the given offset. + * @throw `std::out_of_range` if `offset > size()` or `other_index > other.size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + */ + basic_string &insert(size_type offset, string_view other, size_type other_index, + size_type count = npos) noexcept(false) { + return insert(offset, other.substr(other_index, count)); + } + + /** + * @brief Inserts ( @b in-place ) one ::character at the given iterator position. + * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + */ + iterator insert(const_iterator it, char_type character) noexcept(false) { + auto pos = range_length(cbegin(), it); + insert(pos, string_view(&character, 1)); + return begin() + pos; + } + + /** + * @brief Inserts ( @b in-place ) a ::character multiple times at the given iterator position. + * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + */ + iterator insert(const_iterator it, size_type repeats, char_type character) noexcept(false) { + auto pos = range_length(cbegin(), it); + insert(pos, repeats, character); + return begin() + pos; + } + + /** + * @brief Inserts ( @b in-place ) a range at the given iterator position. + * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + */ + template + iterator insert(const_iterator it, input_iterator first, input_iterator last) noexcept(false) { + + auto pos = range_length(cbegin(), it); + if (pos > size()) throw std::out_of_range("sz::basic_string::insert"); + + auto added_length = range_length(first, last); + if (size() + added_length > max_size()) throw std::length_error("sz::basic_string::insert"); + + if (!_with_alloc([&](sz_alloc_type &alloc) { return sz_string_expand(&string_, pos, added_length, &alloc); })) + throw std::bad_alloc(); + + iterator result = begin() + pos; + for (iterator output = result; first != last; ++first, ++output) *output = *first; + return result; + } + + /** + * @brief Inserts ( @b in-place ) an initializer list of characters. + * @throw `std::out_of_range` if `pos > size()` or `other_index > other.size()`. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + */ + iterator insert(const_iterator it, std::initializer_list ilist) noexcept(false) { + return insert(it, ilist.begin(), ilist.end()); + } + + /** + * @brief Erases ( @b in-place ) the given range of characters. + * @throws `std::out_of_range` if `pos > size()`. + * @see `try_erase_slice` for a cleaner exception-less alternative. + */ + basic_string &erase(size_type pos = 0, size_type count = npos) noexcept(false) { + if (!count || empty()) return *this; + if (pos >= size()) throw std::out_of_range("sz::basic_string::erase"); + sz_string_erase(&string_, pos, count); + return *this; + } + + /** + * @brief Erases ( @b in-place ) the given range of characters. + * @return Iterator pointing following the erased character, or end() if no such character exists. + */ + iterator erase(const_iterator first, const_iterator last) noexcept { + auto start = begin(); + auto offset = first - start; + sz_string_erase(&string_, offset, last - first); + return start + offset; + } + + /** + * @brief Erases ( @b in-place ) the one character at a given postion. + * @return Iterator pointing following the erased character, or end() if no such character exists. + */ + iterator erase(const_iterator pos) noexcept { return erase(pos, pos + 1); } + + /** + * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @throws `std::out_of_range` if `pos > size()`. + * @throws `std::length_error` if the string is too long. + * @see `try_replace` for a cleaner exception-less alternative. + */ + basic_string &replace(size_type pos, size_type count, string_view const &str) noexcept(false) { + if (pos > size()) throw std::out_of_range("sz::basic_string::replace"); + if (size() - count + str.size() > max_size()) throw std::length_error("sz::basic_string::replace"); + if (!try_preparing_replacement(pos, count, str.size())) throw std::bad_alloc(); + sz_copy(data() + pos, str.data(), str.size()); + return *this; + } + + /** + * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @throws `std::out_of_range` if `pos > size()`. + * @throws `std::length_error` if the string is too long. + * @see `try_replace` for a cleaner exception-less alternative. + */ + basic_string &replace(const_iterator first, const_iterator last, string_view const &str) noexcept(false) { + return replace(range_length(cbegin(), first), last - first, str); + } + + /** + * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @throws `std::out_of_range` if `pos > size()` or `pos2 > str.size()`. + * @throws `std::length_error` if the string is too long. + * @see `try_replace` for a cleaner exception-less alternative. + */ + basic_string &replace(size_type pos, size_type count, string_view const &str, size_type pos2, + size_type count2 = npos) noexcept(false) { + return replace(pos, count, str.substr(pos2, count2)); + } + + /** + * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @throws `std::out_of_range` if `pos > size()`. + * @throws `std::length_error` if the string is too long. + * @see `try_replace` for a cleaner exception-less alternative. + */ + basic_string &replace(size_type pos, size_type count, const_pointer cstr, size_type count2) noexcept(false) { + return replace(pos, count, string_view(cstr, count2)); + } + + /** + * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @throws `std::out_of_range` if `pos > size()`. + * @throws `std::length_error` if the string is too long. + * @see `try_replace` for a cleaner exception-less alternative. + */ + basic_string &replace(const_iterator first, const_iterator last, const_pointer cstr, + size_type count2) noexcept(false) { + return replace(range_length(cbegin(), first), last - first, string_view(cstr, count2)); + } + + /** + * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @throws `std::out_of_range` if `pos > size()`. + * @throws `std::length_error` if the string is too long. + * @see `try_replace` for a cleaner exception-less alternative. + */ + basic_string &replace(size_type pos, size_type count, const_pointer cstr) noexcept(false) { + return replace(pos, count, string_view(cstr)); + } + + /** + * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @throws `std::out_of_range` if `pos > size()`. + * @throws `std::length_error` if the string is too long. + * @see `try_replace` for a cleaner exception-less alternative. + */ + basic_string &replace(const_iterator first, const_iterator last, const_pointer cstr) noexcept(false) { + return replace(range_length(cbegin(), first), last - first, string_view(cstr)); + } + + /** + * @brief Replaces ( @b in-place ) a range of characters with a repetition of given characters. + * @throws `std::out_of_range` if `pos > size()`. + * @throws `std::length_error` if the string is too long. + * @see `try_replace` for a cleaner exception-less alternative. + */ + basic_string &replace(size_type pos, size_type count, size_type count2, char_type character) noexcept(false) { + if (pos > size()) throw std::out_of_range("sz::basic_string::replace"); + if (size() - count + count2 > max_size()) throw std::length_error("sz::basic_string::replace"); + if (!try_preparing_replacement(pos, count, count2)) throw std::bad_alloc(); + sz_fill(data() + pos, count2, character); + return *this; + } + + /** + * @brief Replaces ( @b in-place ) a range of characters with a repetition of given characters. + * @throws `std::out_of_range` if `pos > size()`. + * @throws `std::length_error` if the string is too long. + * @see `try_replace` for a cleaner exception-less alternative. + */ + basic_string &replace(const_iterator first, const_iterator last, size_type count2, + char_type character) noexcept(false) { + return replace(range_length(cbegin(), first), last - first, count2, character); + } + + /** + * @brief Replaces ( @b in-place ) a range of characters with a given string. + * @throws `std::out_of_range` if `pos > size()`. + * @throws `std::length_error` if the string is too long. + * @see `try_replace` for a cleaner exception-less alternative. + */ + template + basic_string &replace(const_iterator first, const_iterator last, input_iterator first2, + input_iterator last2) noexcept(false) { + auto pos = range_length(cbegin(), first); + auto count = range_length(first, last); + auto count2 = range_length(first2, last2); + if (pos > size()) throw std::out_of_range("sz::basic_string::replace"); + if (size() - count + count2 > max_size()) throw std::length_error("sz::basic_string::replace"); + if (!try_preparing_replacement(pos, count, count2)) throw std::bad_alloc(); + for (iterator output = begin() + pos; first2 != last2; ++first2, ++output) *output = *first2; + return *this; + } + + /** + * @brief Replaces ( @b in-place ) a range of characters with a given initializer list. + * @throws `std::out_of_range` if `pos > size()`. + * @throws `std::length_error` if the string is too long. + * @see `try_replace` for a cleaner exception-less alternative. + */ + basic_string &replace(const_iterator first, const_iterator last, + std::initializer_list ilist) noexcept(false) { + return replace(first, last, ilist.begin(), ilist.end()); + } + + /** + * @brief Appends the given character at the end. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + */ + void push_back(char_type ch) noexcept(false) { + if (size() == max_size()) throw std::length_error("string::push_back"); + if (!try_push_back(ch)) throw std::bad_alloc(); + } + + /** + * @brief Removes the last character from the string. + * @warning The behavior is @b undefined if the string is empty. + */ + void pop_back() noexcept { sz_string_erase(&string_, size() - 1, 1); } + + /** + * @brief Overwrites the string with the given string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_assign` for a cleaner exception-less alternative. + */ + basic_string &assign(string_view other) noexcept(false) { + if (!try_assign(other)) throw std::bad_alloc(); + return *this; + } + + /** + * @brief Overwrites the string with the given repeated character. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_assign` for a cleaner exception-less alternative. + */ + basic_string &assign(size_type repeats, char_type character) noexcept(false) { + resize(repeats, character); + sz_fill(data(), repeats, *(sz_u8_t *)&character); + return *this; + } + + /** + * @brief Overwrites the string with the given string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_assign` for a cleaner exception-less alternative. + */ + basic_string &assign(const_pointer other, size_type length) noexcept(false) { return assign({other, length}); } + + /** + * @brief Overwrites the string with the given string. + * @throw `std::length_error` if the string is too long or `pos > str.size()`. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_assign` for a cleaner exception-less alternative. + */ + basic_string &assign(string_view str, size_type pos, size_type count = npos) noexcept(false) { + return assign(str.substr(pos, count)); + } + + /** + * @brief Overwrites the string with the given iterator range. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_assign` for a cleaner exception-less alternative. + */ + template + basic_string &assign(input_iterator first, input_iterator last) noexcept(false) { + resize(range_length(first, last)); + for (iterator output = begin(); first != last; ++first, ++output) *output = *first; + return *this; + } + + /** + * @brief Overwrites the string with the given initializer list. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_assign` for a cleaner exception-less alternative. + */ + basic_string &assign(std::initializer_list ilist) noexcept(false) { + return assign(ilist.begin(), ilist.end()); + } + + /** + * @brief Appends to the end of the current string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_append` for a cleaner exception-less alternative. + */ + basic_string &append(string_view str) noexcept(false) { + if (!try_append(str)) throw std::bad_alloc(); + return *this; + } + + /** + * @brief Appends to the end of the current string. + * @throw `std::length_error` if the string is too long or `pos > str.size()`. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_append` for a cleaner exception-less alternative. + */ + basic_string &append(string_view str, size_type pos, size_type length = npos) noexcept(false) { + return append(str.substr(pos, length)); + } + + /** + * @brief Appends to the end of the current string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_append` for a cleaner exception-less alternative. + */ + basic_string &append(const_pointer str, size_type length) noexcept(false) { return append({str, length}); } + + /** + * @brief Appends to the end of the current string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_append` for a cleaner exception-less alternative. + */ + basic_string &append(const_pointer str) noexcept(false) { return append(string_view(str)); } + + /** + * @brief Appends a repeated character to the end of the current string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_append` for a cleaner exception-less alternative. + */ + basic_string &append(size_type repeats, char_type ch) noexcept(false) { + resize(size() + repeats, ch); + return *this; + } + + /** + * @brief Appends to the end of the current string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_append` for a cleaner exception-less alternative. + */ + basic_string &append(std::initializer_list other) noexcept(false) { + return append(other.begin(), other.end()); + } + + /** + * @brief Appends to the end of the current string. + * @throw `std::length_error` if the string is too long. + * @throw `std::bad_alloc` if the allocation fails. + * @see `try_append` for a cleaner exception-less alternative. + */ + template + basic_string &append(input_iterator first, input_iterator last) noexcept(false) { + insert(cend(), first, last); + return *this; + } + + basic_string &operator+=(string_view other) noexcept(false) { return append(other); } + basic_string &operator+=(std::initializer_list other) noexcept(false) { return append(other); } + basic_string &operator+=(char_type character) noexcept(false) { return operator+=(string_view(&character, 1)); } + basic_string &operator+=(const_pointer other) noexcept(false) { return operator+=(string_view(other)); } + + basic_string operator+(char_type character) const noexcept(false) { return operator+(string_view(&character, 1)); } + basic_string operator+(const_pointer other) const noexcept(false) { return operator+(string_view(other)); } + basic_string operator+(string_view other) const noexcept(false) { + return basic_string {concatenation {view(), other}}; + } + basic_string operator+(std::initializer_list other) const noexcept(false) { + return basic_string {concatenation {view(), other}}; + } + +#pragma endregion +#pragma endregion + + concatenation operator|(string_view other) const noexcept { return {view(), other}; } + + size_type edit_distance(string_view other, size_type bound = 0) const noexcept { + size_type distance; + _with_alloc([&](sz_alloc_type &alloc) { + distance = sz_edit_distance(data(), size(), other.data(), other.size(), bound, &alloc); + return true; + }); + return distance; + } + + /** @brief Hashes the string, equivalent to `std::hash{}(str)`. */ + size_type hash() const noexcept { return view().hash(); } + + /** + * @brief Overwrites the string with random characters from the given alphabet using the random generator. + * + * @param generator A random generator function object that returns a random number in the range [0, 2^64). + * @param alphabet A string of characters to choose from. + */ + template + basic_string &randomize(generator_type &generator, string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept { + sz_ptr_t start; + sz_size_t length; + sz_string_range(&string_, &start, &length); + sz_random_generator_t generator_callback = &_call_random_generator; + sz_generate(alphabet.data(), alphabet.size(), start, length, generator_callback, &generator); + return *this; + } + + /** + * @brief Overwrites the string with random characters from the given alphabet + * using `std::rand` as the random generator. + * + * @param alphabet A string of characters to choose from. + */ + basic_string &randomize(string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept { + return randomize(&std::rand, alphabet); + } + + /** + * @brief Generate a new random string of given length using `std::rand` as the random generator. + * May throw exceptions if the memory allocation fails. + * + * @param length The length of the generated string. + * @param alphabet A string of characters to choose from. + */ + static basic_string random(size_type length, string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept(false) { + return basic_string(length, '\0').randomize(alphabet); + } + + /** + * @brief Generate a new random string of given length using the provided random number generator. + * May throw exceptions if the memory allocation fails. + * + * @param generator A random generator function object that returns a random number in the range [0, 2^64). + * @param length The length of the generated string. + * @param alphabet A string of characters to choose from. + */ + template + static basic_string random(generator_type &generator, size_type length, + string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept(false) { + return basic_string(length, '\0').randomize(generator, alphabet); + } + + /** + * @brief Replaces ( @b in-place ) all occurrences of a given string with the ::replacement string. + * Similar to `boost::algorithm::replace_all` and Python's `str.replace`. + * + * The implementation is not as composable, as using search ranges combined with a replacing mapping for matches, + * and might be suboptimal, if you are exporting the cleaned-up string to another buffer. + * The algorithm is suboptimal when this string is made exclusively of the pattern. + */ + basic_string &replace_all(string_view pattern, string_view replacement) noexcept(false) { + if (!try_replace_all(pattern, replacement)) throw std::bad_alloc(); + return *this; + } + + /** + * @brief Replaces ( @b in-place ) all occurrences of a given character set with the ::replacement string. + * Similar to `boost::algorithm::replace_all` and Python's `str.replace`. + * + * The implementation is not as composable, as using search ranges combined with a replacing mapping for matches, + * and might be suboptimal, if you are exporting the cleaned-up string to another buffer. + * The algorithm is suboptimal when this string is made exclusively of the pattern. + */ + basic_string &replace_all(char_set pattern, string_view replacement) noexcept(false) { + if (!try_replace_all(pattern, replacement)) throw std::bad_alloc(); + return *this; + } + + /** + * @brief Replaces ( @b in-place ) all occurrences of a given string with the ::replacement string. + * Similar to `boost::algorithm::replace_all` and Python's `str.replace`. + * + * The implementation is not as composable, as using search ranges combined with a replacing mapping for matches, + * and might be suboptimal, if you are exporting the cleaned-up string to another buffer. + * The algorithm is suboptimal when this string is made exclusively of the pattern. + */ + bool try_replace_all(string_view pattern, string_view replacement) noexcept { + return try_replace_all_(pattern, replacement); + } + + /** + * @brief Replaces ( @b in-place ) all occurrences of a given character set with the ::replacement string. + * Similar to `boost::algorithm::replace_all` and Python's `str.replace`. + * + * The implementation is not as composable, as using search ranges combined with a replacing mapping for matches, + * and might be suboptimal, if you are exporting the cleaned-up string to another buffer. + * The algorithm is suboptimal when this string is made exclusively of the pattern. + */ + bool try_replace_all(char_set pattern, string_view replacement) noexcept { + return try_replace_all_(pattern, replacement); + } + + private: + template + bool try_replace_all_(pattern_type pattern, string_view replacement) noexcept; + + /** + * @brief Tries to prepare the string for a replacement of a given range with a new string. + * The allocation may occur, if the replacement is longer than the replaced range. + */ + bool try_preparing_replacement(size_type offset, size_type length, size_type new_length) noexcept; +}; + +using string = basic_string>; + +static_assert(sizeof(string) == 4 * sizeof(void *), "String size must be 4 pointers."); + +namespace literals { +constexpr string_view operator""_sz(char const *str, std::size_t length) noexcept { return {str, length}; } +} // namespace literals + +template +bool basic_string::try_resize(size_type count, value_type character) noexcept { + sz_ptr_t string_start; + sz_size_t string_length; + sz_size_t string_space; + sz_bool_t string_is_external; + sz_string_unpack(&string_, &string_start, &string_length, &string_space, &string_is_external); + + // Allocate more space if needed. + if (count >= string_space) { + if (!_with_alloc( + [&](sz_alloc_type &alloc) { return sz_string_expand(&string_, SZ_SIZE_MAX, count, &alloc) != NULL; })) + return false; + sz_string_unpack(&string_, &string_start, &string_length, &string_space, &string_is_external); + } + + // Fill the trailing characters. + if (count > string_length) { + sz_fill(string_start + string_length, count - string_length, character); + string_start[count] = '\0'; + // Knowing the layout of the string, we can perform this operation safely, + // even if its located on stack. + string_.external.length += count - string_length; + } + else { sz_string_erase(&string_, count, SZ_SIZE_MAX); } + return true; +} + +template +bool basic_string::try_assign(string_view other) noexcept { + // We can't just assign the other string state, as its start address may be somewhere else on the stack. + sz_ptr_t string_start; + sz_size_t string_length; + sz_string_range(&string_, &string_start, &string_length); + + if (string_length >= other.length()) { + other.copy(string_start, other.length()); + sz_string_erase(&string_, other.length(), SZ_SIZE_MAX); + } + else { + if (!_with_alloc([&](sz_alloc_type &alloc) { + string_start = sz_string_expand(&string_, SZ_SIZE_MAX, other.length(), &alloc); + if (!string_start) return false; + other.copy(string_start, other.length()); + return true; + })) + return false; + } + return true; +} + +template +bool basic_string::try_push_back(char_type c) noexcept { + return _with_alloc([&](sz_alloc_type &alloc) { + auto old_size = size(); + sz_ptr_t start = sz_string_expand(&string_, SZ_SIZE_MAX, 1, &alloc); + if (!start) return false; + start[old_size] = c; + return true; + }); +} + +template +bool basic_string::try_append(const_pointer str, size_type length) noexcept { + return _with_alloc([&](sz_alloc_type &alloc) { + auto old_size = size(); + sz_ptr_t start = sz_string_expand(&string_, SZ_SIZE_MAX, length, &alloc); + if (!start) return false; + sz_copy(start + old_size, str, length); + return true; + }); +} + +template +template +bool basic_string::try_replace_all_(pattern_type pattern, string_view replacement) noexcept { + // Depending on the size of the pattern and the replacement, we may need to allocate more space. + // There are 3 cases to consider: + // 1. The pattern and the replacement are of the same length. Piece of cake! + // 2. The pattern is longer than the replacement. We need to compact the strings. + // 3. The pattern is shorter than the replacement. We may have to allocate more memory. + using matcher_type = typename std::conditional::value, + matcher_find_first_of, + matcher_find>::type; + matcher_type matcher({pattern}); + string_view this_view = view(); + + // 1. The pattern and the replacement are of the same length. + if (matcher.needle_length() == replacement.length()) { + using matches_type = range_matches; + // Instead of iterating with `begin()` and `end()`, we could use the cheaper sentinel-based approach. + // for (string_view match : matches) { ... } + matches_type matches = matches_type(this_view, {pattern}); + for (auto matches_iterator = matches.begin(); matches_iterator != end_sentinel_type {}; ++matches_iterator) { + replacement.copy(const_cast((*matches_iterator).data())); + } + return true; + } + + // 2. The pattern is longer than the replacement. We need to compact the strings. + else if (matcher.needle_length() > replacement.length()) { + // Dealing with shorter replacements, we will avoid memory allocations, but we can also mimnimize the number + // of `memmove`-s, by keeping one more iterator, pointing to the end of the last compacted area. + // Having the split-ranges, however, we reuse their logic. + using splits_type = range_splits; + splits_type splits = splits_type(this_view, {pattern}); + auto matches_iterator = splits.begin(); + auto compacted_end = (*matches_iterator).end(); + if (compacted_end == end()) return true; // No matches. + + ++matches_iterator; // Skip the first match. + do { + string_view match_view = *matches_iterator; + replacement.copy(const_cast(compacted_end)); + compacted_end += replacement.length(); + sz_move((sz_ptr_t)compacted_end, match_view.begin(), match_view.length()); + compacted_end += match_view.length(); + ++matches_iterator; + } while (matches_iterator != end_sentinel_type {}); + + // Can't fail, so let's just return true :) + try_resize(compacted_end - begin()); + return true; + } + + // 3. The pattern is shorter than the replacement. We may have to allocate more memory. + else { + using rmatcher_type = typename std::conditional::value, + matcher_find_last_of, + matcher_rfind>::type; + using rmatches_type = range_rmatches; + rmatches_type rmatches = rmatches_type(this_view, {pattern}); + + // It's cheaper to iterate through the whole string once, counting the number of matches, + // reserving memory once, than re-allocating and copying the string multiple times. + auto matches_count = rmatches.size(); + if (matches_count == 0) return true; // No matches. + + // TODO: Resize without initializing the memory. + auto replacement_delta_length = replacement.length() - matcher.needle_length(); + auto added_length = matches_count * replacement_delta_length; + auto old_length = size(); + auto new_length = old_length + added_length; + if (!try_resize(new_length)) return false; + this_view = view().front(old_length); + + // Now iterate through splits similarly to the 2nd case, but in reverse order. + using rsplits_type = range_rsplits; + rsplits_type splits = rsplits_type(this_view, {pattern}); + auto splits_iterator = splits.begin(); + + // Put the compacted pointer to the end of the new string, and walk left. + auto compacted_begin = this_view.data() + new_length; + + // By now we know that at least one match exists, which means the splits . + do { + string_view slice_view = *splits_iterator; + compacted_begin -= slice_view.length(); + sz_move((sz_ptr_t)compacted_begin, slice_view.begin(), slice_view.length()); + compacted_begin -= replacement.length(); + replacement.copy(const_cast(compacted_begin)); + ++splits_iterator; + } while (!splits_iterator.is_last()); + + return true; + } +} + +template +template +bool basic_string::try_assign(concatenation const &other) noexcept { + // We can't just assign the other string state, as its start address may be somewhere else on the stack. + sz_ptr_t string_start; + sz_size_t string_length; + sz_string_range(&string_, &string_start, &string_length); + + if (string_length >= other.length()) { + sz_string_erase(&string_, other.length(), SZ_SIZE_MAX); + other.copy(string_start, other.length()); + } + else { + if (!_with_alloc([&](sz_alloc_type &alloc) { + string_start = sz_string_expand(&string_, SZ_SIZE_MAX, other.length(), &alloc); + if (!string_start) return false; + other.copy(string_start, other.length()); + return true; + })) + return false; + } + return true; +} + +template +bool basic_string::try_preparing_replacement(size_type offset, size_type length, + size_type replacement_length) noexcept { + // There are three cases: + // 1. The replacement is the same length as the replaced range. + // 2. The replacement is shorter than the replaced range. + // 3. The replacement is longer than the replaced range. An allocation may occur. + assert(offset + length <= size()); + + // 1. The replacement is the same length as the replaced range. + if (replacement_length == length) { return true; } + + // 2. The replacement is shorter than the replaced range. + else if (replacement_length < length) { + sz_string_erase(&string_, offset + replacement_length, length - replacement_length); + return true; + } + // 3. The replacement is longer than the replaced range. An allocation may occur. + else { + return _with_alloc([&](sz_alloc_type &alloc) { + return sz_string_expand(&string_, offset + length, replacement_length - length, &alloc); + }); + } +} + +/** @brief SFINAE-type used to infer the resulting type of concatenating multiple string together. */ +template +struct concatenation_result {}; + +template +struct concatenation_result { + using type = concatenation; +}; + +template +struct concatenation_result { + using type = concatenation::type>; +}; + +/** + * @brief Concatenates two strings into a template expression. + * @see `concatenation` class for more details. + */ +template +concatenation concatenate(first_type &&first, second_type &&second) noexcept(false) { + return {first, second}; +} + +/** + * @brief Concatenates two or more strings into a template expression. + * @see `concatenation` class for more details. + */ +template +typename concatenation_result::type concatenate( + first_type &&first, second_type &&second, following_types &&...following) noexcept(false) { + // Fold expression like the one below would result in faster compile times, + // but would incur the penalty of additional `if`-statements in every `append` call. + // Moreover, those are only supported in C++17 and later. + // std::size_t total_size = (strings.size() + ... + 0); + // std::string result; + // result.reserve(total_size); + // (result.append(strings), ...); + return ashvardanian::stringzilla::concatenate( + std::forward(first), + ashvardanian::stringzilla::concatenate(std::forward(second), + std::forward(following)...)); +} + +/** + * @brief Calculates the Levenshtein edit distance between two strings. + * @see sz_edit_distance + */ +template ::type>> +std::size_t edit_distance(basic_string_slice const &a, basic_string_slice const &b, + allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { + std::size_t result; + if (!_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { + result = sz_edit_distance(a.data(), a.size(), b.data(), b.size(), SZ_SIZE_MAX, &alloc); + return result != SZ_SIZE_MAX; + })) + throw std::bad_alloc(); + return result; +} + +/** + * @brief Calculates the Levenshtein edit distance between two strings. + * @see sz_edit_distance + */ +template > +std::size_t edit_distance(basic_string const &a, + basic_string const &b) noexcept(false) { + return ashvardanian::stringzilla::edit_distance(a.view(), b.view(), a.get_allocator()); +} + +/** + * @brief Calculates the Needleman-Wunsch alignment score between two strings. + * @see sz_alignment_score + */ +template ::type>> +std::ptrdiff_t alignment_score(basic_string_slice const &a, basic_string_slice const &b, + std::int8_t const (&subs)[256][256], std::int8_t gap = -1, + allocator_type_ &&allocator = allocator_type_ {}) noexcept(false) { + + static_assert(sizeof(sz_error_cost_t) == sizeof(std::int8_t), "sz_error_cost_t must be 8-bit."); + static_assert(std::is_signed() == std::is_signed(), + "sz_error_cost_t must be signed."); + + std::ptrdiff_t result; + if (!_with_alloc(allocator, [&](sz_memory_allocator_t &alloc) { + result = sz_alignment_score(a.data(), a.size(), b.data(), b.size(), &subs[0][0], gap, &alloc); + return result != SZ_SSIZE_MAX; + })) + throw std::bad_alloc(); + return result; +} + +/** + * @brief Calculates the Needleman-Wunsch alignment score between two strings. + * @see sz_alignment_score + */ +template > +std::ptrdiff_t alignment_score(basic_string const &a, + basic_string const &b, // + std::int8_t const (&subs)[256][256], std::int8_t gap = -1) noexcept(false) { + return ashvardanian::stringzilla::alignment_score(a.view(), b.view(), subs, gap, a.get_allocator()); +} + +/** + * @brief Overwrites the string slice with random characters from the given alphabet using the random generator. + * + * @param string The string to overwrite. + * @param generator A random generator function object that returns a random number in the range [0, 2^64). + * @param alphabet A string of characters to choose from. + */ +template +void randomize(basic_string_slice string, generator_type_ &generator, + string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept { + static_assert(!std::is_const::value, "The string must be mutable."); + sz_random_generator_t generator_callback = &_call_random_generator; + sz_generate(alphabet.data(), alphabet.size(), string.data(), string.size(), generator_callback, &generator); +} + +/** + * @brief Overwrites the string slice with random characters from the given alphabet + * using `std::rand` as the random generator. + * + * @param string The string to overwrite. + * @param alphabet A string of characters to choose from. + */ +template +void randomize(basic_string_slice string, string_view alphabet = "abcdefghijklmnopqrstuvwxyz") noexcept { + randomize(string, &std::rand, alphabet); +} + +/** + * @brief Internal data-structure used to forward the arguments to the `sz_sort` function. + * @see sorted_order + */ +template +struct _sequence_args { + objects_type_ const *begin; + std::size_t count; + std::size_t *order; + string_extractor_ extractor; +}; + +template +sz_cptr_t _call_sequence_member_start(struct sz_sequence_t const *sequence, sz_size_t i) { + using handle_type = _sequence_args; + handle_type const *args = reinterpret_cast(sequence->handle); + string_view member = args->extractor(args->begin[i]); + return member.data(); +} + +template +sz_size_t _call_sequence_member_length(struct sz_sequence_t const *sequence, sz_size_t i) { + using handle_type = _sequence_args; + handle_type const *args = reinterpret_cast(sequence->handle); + string_view member = args->extractor(args->begin[i]); + return static_cast(member.size()); +} + +/** + * @brief Computes the permutation of an array, that would lead to sorted order. + * The elements of the array must be convertible to a `string_view` with the given extractor. + * Unlike the `sz_sort` C interface, overwrites the output array. + * + * @param[in] begin The pointer to the first element of the array. + * @param[in] end The pointer to the element after the last element of the array. + * @param[out] order The pointer to the output array of indices, that will be populated with the permutation. + * @param[in] extractor The function object that extracts the string from the object. + * + * @see sz_sort + */ +template +void sorted_order(objects_type_ const *begin, objects_type_ const *end, std::size_t *order, + string_extractor_ &&extractor) noexcept { + + // Pack the arguments into a single structure to reference it from the callback. + _sequence_args args = {begin, static_cast(end - begin), order, + std::forward(extractor)}; + // Populate the array with `iota`-style order. + for (std::size_t i = 0; i != args.count; ++i) order[i] = i; + + sz_sequence_t array; + array.order = reinterpret_cast(order); + array.count = args.count; + array.handle = &args; + array.get_start = _call_sequence_member_start; + array.get_length = _call_sequence_member_length; + sz_sort(&array); +} + +#if !SZ_AVOID_STL + +/** + * @brief Computes the Rabin-Karp-like rolling binary fingerprint of a string. + * @see sz_hashes + */ +template +void hashes_fingerprint(basic_string_slice const &str, std::size_t window_length, + std::bitset &fingerprint) noexcept { + constexpr std::size_t fingerprint_bytes = sizeof(std::bitset); + return sz_hashes_fingerprint(str.data(), str.size(), window_length, (sz_ptr_t)&fingerprint, fingerprint_bytes); +} + +/** + * @brief Computes the Rabin-Karp-like rolling binary fingerprint of a string. + * @see sz_hashes + */ +template +std::bitset hashes_fingerprint(basic_string_slice const &str, + std::size_t window_length) noexcept { + std::bitset fingerprint; + ashvardanian::stringzilla::hashes_fingerprint(str, window_length, fingerprint); + return fingerprint; +} + +/** + * @brief Computes the Rabin-Karp-like rolling binary fingerprint of a string. + * @see sz_hashes + */ +template +std::bitset hashes_fingerprint(basic_string const &str, std::size_t window_length) noexcept { + return ashvardanian::stringzilla::hashes_fingerprint(str.view(), window_length); +} + +/** + * @brief Computes the permutation of an array, that would lead to sorted order. + * @return The array of indices, that will be populated with the permutation. + * @throw `std::bad_alloc` if the allocation fails. + */ +template +std::vector sorted_order(objects_type_ const *begin, objects_type_ const *end, + string_extractor_ &&extractor) noexcept(false) { + std::vector order(end - begin); + sorted_order(begin, end, order.data(), std::forward(extractor)); + return order; +} + +/** + * @brief Computes the permutation of an array, that would lead to sorted order. + * @return The array of indices, that will be populated with the permutation. + * @throw `std::bad_alloc` if the allocation fails. + */ +template +std::vector sorted_order(string_like_type_ const *begin, string_like_type_ const *end) noexcept(false) { + static_assert(std::is_convertible::value, + "The type must be convertible to string_view."); + return sorted_order(begin, end, [](string_like_type_ const &s) -> string_view { return s; }); +} + +/** + * @brief Computes the permutation of an array, that would lead to sorted order. + * @return The array of indices, that will be populated with the permutation. + * @throw `std::bad_alloc` if the allocation fails. + */ +template +std::vector sorted_order(std::vector const &array) noexcept(false) { + static_assert(std::is_convertible::value, + "The type must be convertible to string_view."); + return sorted_order(array.data(), array.data() + array.size(), + [](string_like_type_ const &s) -> string_view { return s; }); +} + +#endif + +} // namespace stringzilla +} // namespace ashvardanian + +#pragma region STL Specializations + +namespace std { + +template <> +struct hash { + size_t operator()(ashvardanian::stringzilla::string_view str) const noexcept { return str.hash(); } +}; + +template <> +struct hash { + size_t operator()(ashvardanian::stringzilla::string const &str) const noexcept { return str.hash(); } +}; + +} // namespace std + +#pragma endregion + +#endif // STRINGZILLA_HPP_ diff --git a/contrib/stringzilla/lib.c b/contrib/stringzilla/lib.c new file mode 100644 index 0000000000..61cfed97c0 --- /dev/null +++ b/contrib/stringzilla/lib.c @@ -0,0 +1,304 @@ +/** + * @file lib.c + * @brief StringZilla C library with dynamic backed dispatch for the most appropriate implementation. + * @author Ash Vardanian + * @date January 16, 2024 + * @copyright Copyright (c) 2024 + */ +#if defined(_WIN32) || defined(__CYGWIN__) +#include // `DllMain` +#endif + +// If we don't have the LibC, the `malloc` definition in `stringzilla.h` will be illformed. +#if SZ_AVOID_LIBC +typedef __SIZE_TYPE__ size_t; +#endif + +// Overwrite `SZ_DYNAMIC_DISPATCH` before including StringZilla. +#ifdef SZ_DYNAMIC_DISPATCH +#undef SZ_DYNAMIC_DISPATCH +#endif +#define SZ_DYNAMIC_DISPATCH 1 +#include + +#if SZ_AVOID_LIBC +void free(void *start) { sz_unused(start); } +void *malloc(size_t length) { + sz_unused(length); + return SZ_NULL; +} +#endif + +SZ_DYNAMIC sz_capability_t sz_capabilities(void) { + +#if SZ_USE_X86_AVX512 || SZ_USE_X86_AVX2 + + /// The states of 4 registers populated for a specific "cpuid" assembly call + union four_registers_t { + int array[4]; + struct separate_t { + unsigned eax, ebx, ecx, edx; + } named; + } info1, info7; + +#ifdef _MSC_VER + __cpuidex(info1.array, 1, 0); + __cpuidex(info7.array, 7, 0); +#else + __asm__ __volatile__("cpuid" + : "=a"(info1.named.eax), "=b"(info1.named.ebx), "=c"(info1.named.ecx), "=d"(info1.named.edx) + : "a"(1), "c"(0)); + __asm__ __volatile__("cpuid" + : "=a"(info7.named.eax), "=b"(info7.named.ebx), "=c"(info7.named.ecx), "=d"(info7.named.edx) + : "a"(7), "c"(0)); +#endif + + // Check for AVX2 (Function ID 7, EBX register) + // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L148 + unsigned supports_avx2 = (info7.named.ebx & 0x00000020) != 0; + // Check for AVX512F (Function ID 7, EBX register) + // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L155 + unsigned supports_avx512f = (info7.named.ebx & 0x00010000) != 0; + // Check for AVX512BW (Function ID 7, EBX register) + // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L166 + unsigned supports_avx512bw = (info7.named.ebx & 0x40000000) != 0; + // Check for AVX512VL (Function ID 7, EBX register) + // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L167C25-L167C35 + unsigned supports_avx512vl = (info7.named.ebx & 0x80000000) != 0; + // Check for GFNI (Function ID 1, ECX register) + // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L171C30-L171C40 + unsigned supports_avx512vbmi = (info1.named.ecx & 0x00000002) != 0; + // Check for GFNI (Function ID 1, ECX register) + // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L177C30-L177C40 + unsigned supports_gfni = (info1.named.ecx & 0x00000100) != 0; + + return (sz_capability_t)( // + (sz_cap_x86_avx2_k * supports_avx2) | // + (sz_cap_x86_avx512f_k * supports_avx512f) | // + (sz_cap_x86_avx512vl_k * supports_avx512vl) | // + (sz_cap_x86_avx512bw_k * supports_avx512bw) | // + (sz_cap_x86_avx512vbmi_k * supports_avx512vbmi) | // + (sz_cap_x86_gfni_k * (supports_gfni)) | // + (sz_cap_serial_k)); + +#endif // SIMSIMD_TARGET_X86 + +#if SZ_USE_ARM_NEON || SZ_USE_ARM_SVE + + // Every 64-bit Arm CPU supports NEON + unsigned supports_neon = 1; + unsigned supports_sve = 0; + unsigned supports_sve2 = 0; + sz_unused(supports_sve); + sz_unused(supports_sve2); + + return (sz_capability_t)( // + (sz_cap_arm_neon_k * supports_neon) | // + (sz_cap_serial_k)); + +#endif // SIMSIMD_TARGET_ARM + + return sz_cap_serial_k; +} + +typedef struct sz_implementations_t { + sz_equal_t equal; + sz_order_t order; + + sz_move_t copy; + sz_move_t move; + sz_fill_t fill; + + sz_find_byte_t find_byte; + sz_find_byte_t rfind_byte; + sz_find_t find; + sz_find_t rfind; + sz_find_set_t find_from_set; + sz_find_set_t rfind_from_set; + + // TODO: Upcoming vectorization + sz_edit_distance_t edit_distance; + sz_alignment_score_t alignment_score; + sz_hashes_t hashes; + +} sz_implementations_t; +static sz_implementations_t sz_dispatch_table; + +/** + * @brief Initializes a global static "virtual table" of supported backends + * Run it just once to avoiding unnecessary `if`-s. + */ +static void sz_dispatch_table_init(void) { + sz_implementations_t *impl = &sz_dispatch_table; + sz_capability_t caps = sz_capabilities(); + sz_unused(caps); //< Unused when compiling on pre-SIMD machines. + + impl->equal = sz_equal_serial; + impl->order = sz_order_serial; + impl->copy = sz_copy_serial; + impl->move = sz_move_serial; + impl->fill = sz_fill_serial; + + impl->find = sz_find_serial; + impl->rfind = sz_rfind_serial; + impl->find_byte = sz_find_byte_serial; + impl->rfind_byte = sz_rfind_byte_serial; + impl->find_from_set = sz_find_charset_serial; + impl->rfind_from_set = sz_rfind_charset_serial; + + impl->edit_distance = sz_edit_distance_serial; + impl->alignment_score = sz_alignment_score_serial; + impl->hashes = sz_hashes_serial; + +#if SZ_USE_X86_AVX2 + if (caps & sz_cap_x86_avx2_k) { + impl->copy = sz_copy_avx2; + impl->move = sz_move_avx2; + impl->fill = sz_fill_avx2; + impl->find_byte = sz_find_byte_avx2; + impl->rfind_byte = sz_rfind_byte_avx2; + impl->find = sz_find_avx2; + impl->rfind = sz_rfind_avx2; + } +#endif + +#if SZ_USE_X86_AVX512 + if (caps & sz_cap_x86_avx512f_k) { + impl->equal = sz_equal_avx512; + impl->order = sz_order_avx512; + impl->copy = sz_copy_avx512; + impl->move = sz_move_avx512; + impl->fill = sz_fill_avx512; + + impl->find = sz_find_avx512; + impl->rfind = sz_rfind_avx512; + impl->find_byte = sz_find_byte_avx512; + impl->rfind_byte = sz_rfind_byte_avx512; + + impl->edit_distance = sz_edit_distance_avx512; + } + + if ((caps & sz_cap_x86_avx512f_k) && (caps & sz_cap_x86_avx512vl_k) && (caps & sz_cap_x86_gfni_k) && + (caps & sz_cap_x86_avx512bw_k) && (caps & sz_cap_x86_avx512vbmi_k)) { + impl->find_from_set = sz_find_charset_avx512; + impl->rfind_from_set = sz_rfind_charset_avx512; + impl->alignment_score = sz_alignment_score_avx512; + } +#endif + +#if SZ_USE_ARM_NEON + if (caps & sz_cap_arm_neon_k) { + impl->find = sz_find_neon; + impl->rfind = sz_rfind_neon; + impl->find_byte = sz_find_byte_neon; + impl->rfind_byte = sz_rfind_byte_neon; + impl->find_from_set = sz_find_charset_neon; + impl->rfind_from_set = sz_rfind_charset_neon; + } +#endif +} + +#if defined(_MSC_VER) +BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved) { + switch (fdwReason) { + case DLL_PROCESS_ATTACH: sz_dispatch_table_init(); return TRUE; + case DLL_THREAD_ATTACH: return TRUE; + case DLL_THREAD_DETACH: return TRUE; + case DLL_PROCESS_DETACH: return TRUE; + } +} +#else +__attribute__((constructor)) static void sz_dispatch_table_init_on_gcc_or_clang(void) { sz_dispatch_table_init(); } +#endif + +SZ_DYNAMIC sz_bool_t sz_equal(sz_cptr_t a, sz_cptr_t b, sz_size_t length) { + return sz_dispatch_table.equal(a, b, length); +} + +SZ_DYNAMIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) { + return sz_dispatch_table.order(a, a_length, b, b_length); +} + +SZ_DYNAMIC void sz_copy(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { + sz_dispatch_table.copy(target, source, length); +} + +SZ_DYNAMIC void sz_move(sz_ptr_t target, sz_cptr_t source, sz_size_t length) { + sz_dispatch_table.move(target, source, length); +} + +SZ_DYNAMIC void sz_fill(sz_ptr_t target, sz_size_t length, sz_u8_t value) { + sz_dispatch_table.fill(target, length, value); +} + +SZ_DYNAMIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { + return sz_dispatch_table.find_byte(haystack, h_length, needle); +} + +SZ_DYNAMIC sz_cptr_t sz_rfind_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) { + return sz_dispatch_table.rfind_byte(haystack, h_length, needle); +} + +SZ_DYNAMIC sz_cptr_t sz_find(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { + return sz_dispatch_table.find(haystack, h_length, needle, n_length); +} + +SZ_DYNAMIC sz_cptr_t sz_rfind(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle, sz_size_t n_length) { + return sz_dispatch_table.rfind(haystack, h_length, needle, n_length); +} + +SZ_DYNAMIC sz_cptr_t sz_find_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { + return sz_dispatch_table.find_from_set(text, length, set); +} + +SZ_DYNAMIC sz_cptr_t sz_rfind_charset(sz_cptr_t text, sz_size_t length, sz_charset_t const *set) { + return sz_dispatch_table.rfind_from_set(text, length, set); +} + +SZ_DYNAMIC sz_size_t sz_edit_distance( // + sz_cptr_t a, sz_size_t a_length, // + sz_cptr_t b, sz_size_t b_length, // + sz_size_t bound, sz_memory_allocator_t *alloc) { + return sz_dispatch_table.edit_distance(a, a_length, b, b_length, bound, alloc); +} + +SZ_DYNAMIC sz_ssize_t sz_alignment_score(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, + sz_error_cost_t const *subs, sz_error_cost_t gap, + sz_memory_allocator_t *alloc) { + return sz_dispatch_table.alignment_score(a, a_length, b, b_length, subs, gap, alloc); +} + +SZ_DYNAMIC void sz_hashes(sz_cptr_t text, sz_size_t length, sz_size_t window_length, sz_size_t step, // + sz_hash_callback_t callback, void *callback_handle) { + sz_dispatch_table.hashes(text, length, window_length, step, callback, callback_handle); +} + +SZ_DYNAMIC sz_cptr_t sz_find_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_charset_t set; + sz_charset_init(&set); + for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); + return sz_find_charset(h, h_length, &set); +} + +SZ_DYNAMIC sz_cptr_t sz_find_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_charset_t set; + sz_charset_init(&set); + for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); + sz_charset_invert(&set); + return sz_find_charset(h, h_length, &set); +} + +SZ_DYNAMIC sz_cptr_t sz_rfind_char_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_charset_t set; + sz_charset_init(&set); + for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); + return sz_rfind_charset(h, h_length, &set); +} + +SZ_DYNAMIC sz_cptr_t sz_rfind_char_not_from(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) { + sz_charset_t set; + sz_charset_init(&set); + for (; n_length; ++n, --n_length) sz_charset_add(&set, *n); + sz_charset_invert(&set); + return sz_rfind_charset(h, h_length, &set); +}