]> git.ipfire.org Git - thirdparty/postgresql.git/commitdiff
Optimize hex_encode() and hex_decode() using SIMD.
authorNathan Bossart <nathan@postgresql.org>
Mon, 6 Oct 2025 17:28:50 +0000 (12:28 -0500)
committerNathan Bossart <nathan@postgresql.org>
Mon, 6 Oct 2025 17:28:50 +0000 (12:28 -0500)
The hex_encode() and hex_decode() functions serve as the workhorses
for hexadecimal data for bytea's text format conversion functions,
and some workloads are sensitive to their performance.  This commit
adds new implementations that use routines from port/simd.h, which
testing indicates are much faster for larger inputs.  For small or
invalid inputs, we fall back on the existing scalar versions.
Since we are using port/simd.h, these optimizations apply to both
x86-64 and AArch64.

Author: Nathan Bossart <nathandbossart@gmail.com>
Co-authored-by: Chiranmoy Bhattacharya <chiranmoy.bhattacharya@fujitsu.com>
Co-authored-by: Susmitha Devanga <devanga.susmitha@fujitsu.com>
Reviewed-by: John Naylor <johncnaylorls@gmail.com>
Discussion: https://postgr.es/m/aLhVWTRy0QPbW2tl%40nathan

src/backend/utils/adt/encode.c
src/include/port/simd.h
src/test/regress/expected/strings.out
src/test/regress/sql/strings.sql

index 9a9c7e8da99d4135ee15b74aae197bd964fd4953..aabe9913eee3c00e50366a47965ff0724df0bcad 100644 (file)
@@ -16,6 +16,7 @@
 #include <ctype.h>
 
 #include "mb/pg_wchar.h"
+#include "port/simd.h"
 #include "utils/builtins.h"
 #include "utils/memutils.h"
 #include "varatt.h"
@@ -177,8 +178,8 @@ static const int8 hexlookup[128] = {
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
 };
 
-uint64
-hex_encode(const char *src, size_t len, char *dst)
+static inline uint64
+hex_encode_scalar(const char *src, size_t len, char *dst)
 {
        const char *end = src + len;
 
@@ -193,6 +194,55 @@ hex_encode(const char *src, size_t len, char *dst)
        return (uint64) len * 2;
 }
 
+uint64
+hex_encode(const char *src, size_t len, char *dst)
+{
+#ifdef USE_NO_SIMD
+       return hex_encode_scalar(src, len, dst);
+#else
+       const uint64 tail_idx = len & ~(sizeof(Vector8) - 1);
+       uint64          i;
+
+       /*
+        * This splits the high and low nibbles of each byte into separate
+        * vectors, adds the vectors to a mask that converts the nibbles to their
+        * equivalent ASCII bytes, and interleaves those bytes back together to
+        * form the final hex-encoded string.
+        */
+       for (i = 0; i < tail_idx; i += sizeof(Vector8))
+       {
+               Vector8         srcv;
+               Vector8         lo;
+               Vector8         hi;
+               Vector8         mask;
+
+               vector8_load(&srcv, (const uint8 *) &src[i]);
+
+               lo = vector8_and(srcv, vector8_broadcast(0x0f));
+               mask = vector8_gt(lo, vector8_broadcast(0x9));
+               mask = vector8_and(mask, vector8_broadcast('a' - '0' - 10));
+               mask = vector8_add(mask, vector8_broadcast('0'));
+               lo = vector8_add(lo, mask);
+
+               hi = vector8_and(srcv, vector8_broadcast(0xf0));
+               hi = vector8_shift_right(hi, 4);
+               mask = vector8_gt(hi, vector8_broadcast(0x9));
+               mask = vector8_and(mask, vector8_broadcast('a' - '0' - 10));
+               mask = vector8_add(mask, vector8_broadcast('0'));
+               hi = vector8_add(hi, mask);
+
+               vector8_store((uint8 *) &dst[i * 2],
+                                         vector8_interleave_low(hi, lo));
+               vector8_store((uint8 *) &dst[i * 2 + sizeof(Vector8)],
+                                         vector8_interleave_high(hi, lo));
+       }
+
+       (void) hex_encode_scalar(src + i, len - i, dst + i * 2);
+
+       return (uint64) len * 2;
+#endif
+}
+
 static inline bool
 get_hex(const char *cp, char *out)
 {
@@ -213,8 +263,8 @@ hex_decode(const char *src, size_t len, char *dst)
        return hex_decode_safe(src, len, dst, NULL);
 }
 
-uint64
-hex_decode_safe(const char *src, size_t len, char *dst, Node *escontext)
+static inline uint64
+hex_decode_safe_scalar(const char *src, size_t len, char *dst, Node *escontext)
 {
        const char *s,
                           *srcend;
@@ -254,6 +304,85 @@ hex_decode_safe(const char *src, size_t len, char *dst, Node *escontext)
        return p - dst;
 }
 
+/*
+ * This helper converts each byte to its binary-equivalent nibble by
+ * subtraction and combines them to form the return bytes (separated by zero
+ * bytes).  Returns false if any input bytes are outside the expected ranges of
+ * ASCII values.  Otherwise, returns true.
+ */
+#ifndef USE_NO_SIMD
+static inline bool
+hex_decode_simd_helper(const Vector8 src, Vector8 *dst)
+{
+       Vector8         sub;
+       Vector8         mask_hi = vector8_interleave_low(vector8_broadcast(0), vector8_broadcast(0x0f));
+       Vector8         mask_lo = vector8_interleave_low(vector8_broadcast(0x0f), vector8_broadcast(0));
+       Vector8         tmp;
+       bool            ret;
+
+       tmp = vector8_gt(vector8_broadcast('9' + 1), src);
+       sub = vector8_and(tmp, vector8_broadcast('0'));
+
+       tmp = vector8_gt(src, vector8_broadcast('A' - 1));
+       tmp = vector8_and(tmp, vector8_broadcast('A' - 10));
+       sub = vector8_add(sub, tmp);
+
+       tmp = vector8_gt(src, vector8_broadcast('a' - 1));
+       tmp = vector8_and(tmp, vector8_broadcast('a' - 'A'));
+       sub = vector8_add(sub, tmp);
+
+       *dst = vector8_issub(src, sub);
+       ret = !vector8_has_ge(*dst, 0x10);
+
+       tmp = vector8_and(*dst, mask_hi);
+       tmp = vector8_shift_right(tmp, 8);
+       *dst = vector8_and(*dst, mask_lo);
+       *dst = vector8_shift_left(*dst, 4);
+       *dst = vector8_or(*dst, tmp);
+       return ret;
+}
+#endif                                                 /* ! USE_NO_SIMD */
+
+uint64
+hex_decode_safe(const char *src, size_t len, char *dst, Node *escontext)
+{
+#ifdef USE_NO_SIMD
+       return hex_decode_safe_scalar(src, len, dst, escontext);
+#else
+       const uint64 tail_idx = len & ~(sizeof(Vector8) * 2 - 1);
+       uint64          i;
+       bool            success = true;
+
+       /*
+        * We must process 2 vectors at a time since the output will be half the
+        * length of the input.
+        */
+       for (i = 0; i < tail_idx; i += sizeof(Vector8) * 2)
+       {
+               Vector8         srcv;
+               Vector8         dstv1;
+               Vector8         dstv2;
+
+               vector8_load(&srcv, (const uint8 *) &src[i]);
+               success &= hex_decode_simd_helper(srcv, &dstv1);
+
+               vector8_load(&srcv, (const uint8 *) &src[i + sizeof(Vector8)]);
+               success &= hex_decode_simd_helper(srcv, &dstv2);
+
+               vector8_store((uint8 *) &dst[i / 2], vector8_pack_16(dstv1, dstv2));
+       }
+
+       /*
+        * If something didn't look right in the vector path, try again in the
+        * scalar path so that we can handle it correctly.
+        */
+       if (!success)
+               i = 0;
+
+       return i / 2 + hex_decode_safe_scalar(src + i, len - i, dst + i / 2, escontext);
+#endif
+}
+
 static uint64
 hex_enc_len(const char *src, size_t srclen)
 {
index 5f5737707a89b4a6ce34c1f172b3c51245477d03..b0165b458617b554f01611e4c88a4282738095c9 100644 (file)
@@ -127,6 +127,21 @@ vector32_load(Vector32 *v, const uint32 *s)
 }
 #endif                                                 /* ! USE_NO_SIMD */
 
+/*
+ * Store a vector into the given memory address.
+ */
+#ifndef USE_NO_SIMD
+static inline void
+vector8_store(uint8 *s, Vector8 v)
+{
+#ifdef USE_SSE2
+       _mm_storeu_si128((Vector8 *) s, v);
+#elif defined(USE_NEON)
+       vst1q_u8(s, v);
+#endif
+}
+#endif                                                 /* ! USE_NO_SIMD */
+
 /*
  * Create a vector with all elements set to the same value.
  */
@@ -265,6 +280,25 @@ vector8_has_le(const Vector8 v, const uint8 c)
        return result;
 }
 
+/*
+ * Returns true if any elements in the vector are greater than or equal to the
+ * given scalar.
+ */
+#ifndef USE_NO_SIMD
+static inline bool
+vector8_has_ge(const Vector8 v, const uint8 c)
+{
+#ifdef USE_SSE2
+       Vector8         umax = _mm_max_epu8(v, vector8_broadcast(c));
+       Vector8         cmpe = vector8_eq(umax, v);
+
+       return vector8_is_highbit_set(cmpe);
+#elif defined(USE_NEON)
+       return vmaxvq_u8(v) >= c;
+#endif
+}
+#endif                                                 /* ! USE_NO_SIMD */
+
 /*
  * Return true if the high bit of any element is set
  */
@@ -359,6 +393,55 @@ vector32_or(const Vector32 v1, const Vector32 v2)
 }
 #endif                                                 /* ! USE_NO_SIMD */
 
+/*
+ * Return the bitwise AND of the inputs.
+ */
+#ifndef USE_NO_SIMD
+static inline Vector8
+vector8_and(const Vector8 v1, const Vector8 v2)
+{
+#ifdef USE_SSE2
+       return _mm_and_si128(v1, v2);
+#elif defined(USE_NEON)
+       return vandq_u8(v1, v2);
+#endif
+}
+#endif                                                 /* ! USE_NO_SIMD */
+
+/*
+ * Return the result of adding the respective elements of the input vectors.
+ */
+#ifndef USE_NO_SIMD
+static inline Vector8
+vector8_add(const Vector8 v1, const Vector8 v2)
+{
+#ifdef USE_SSE2
+       return _mm_add_epi8(v1, v2);
+#elif defined(USE_NEON)
+       return vaddq_u8(v1, v2);
+#endif
+}
+#endif                                                 /* ! USE_NO_SIMD */
+
+/*
+ * Return the result of subtracting the respective elements of the input
+ * vectors using signed saturation (i.e., if the operation would yield a value
+ * less than -128, -128 is returned instead).  For more information on
+ * saturation arithmetic, see
+ * https://en.wikipedia.org/wiki/Saturation_arithmetic
+ */
+#ifndef USE_NO_SIMD
+static inline Vector8
+vector8_issub(const Vector8 v1, const Vector8 v2)
+{
+#ifdef USE_SSE2
+       return _mm_subs_epi8(v1, v2);
+#elif defined(USE_NEON)
+       return (Vector8) vqsubq_s8((int8x16_t) v1, (int8x16_t) v2);
+#endif
+}
+#endif                                                 /* ! USE_NO_SIMD */
+
 /*
  * Return a vector with all bits set in each lane where the corresponding
  * lanes in the inputs are equal.
@@ -387,6 +470,23 @@ vector32_eq(const Vector32 v1, const Vector32 v2)
 }
 #endif                                                 /* ! USE_NO_SIMD */
 
+/*
+ * Return a vector with all bits set for each lane of v1 that is greater than
+ * the corresponding lane of v2.  NB: The comparison treats the elements as
+ * signed.
+ */
+#ifndef USE_NO_SIMD
+static inline Vector8
+vector8_gt(const Vector8 v1, const Vector8 v2)
+{
+#ifdef USE_SSE2
+       return _mm_cmpgt_epi8(v1, v2);
+#elif defined(USE_NEON)
+       return vcgtq_s8((int8x16_t) v1, (int8x16_t) v2);
+#endif
+}
+#endif                                                 /* ! USE_NO_SIMD */
+
 /*
  * Given two vectors, return a vector with the minimum element of each.
  */
@@ -402,4 +502,115 @@ vector8_min(const Vector8 v1, const Vector8 v2)
 }
 #endif                                                 /* ! USE_NO_SIMD */
 
+/*
+ * Interleave elements of low halves (e.g., for SSE2, bits 0-63) of given
+ * vectors.  Bytes 0, 2, 4, etc. use v1, and bytes 1, 3, 5, etc. use v2.
+ */
+#ifndef USE_NO_SIMD
+static inline Vector8
+vector8_interleave_low(const Vector8 v1, const Vector8 v2)
+{
+#ifdef USE_SSE2
+       return _mm_unpacklo_epi8(v1, v2);
+#elif defined(USE_NEON)
+       return vzip1q_u8(v1, v2);
+#endif
+}
+#endif                                                 /* ! USE_NO_SIMD */
+
+/*
+ * Interleave elements of high halves (e.g., for SSE2, bits 64-127) of given
+ * vectors.  Bytes 0, 2, 4, etc. use v1, and bytes 1, 3, 5, etc. use v2.
+ */
+#ifndef USE_NO_SIMD
+static inline Vector8
+vector8_interleave_high(const Vector8 v1, const Vector8 v2)
+{
+#ifdef USE_SSE2
+       return _mm_unpackhi_epi8(v1, v2);
+#elif defined(USE_NEON)
+       return vzip2q_u8(v1, v2);
+#endif
+}
+#endif                                                 /* ! USE_NO_SIMD */
+
+/*
+ * Pack 16-bit elements in the given vectors into a single vector of 8-bit
+ * elements.  The first half of the return vector (e.g., for SSE2, bits 0-63)
+ * uses v1, and the second half (e.g., for SSE2, bits 64-127) uses v2.
+ *
+ * NB: The upper 8-bits of each 16-bit element must be zeros, else this will
+ * produce different results on different architectures.
+ */
+#ifndef USE_NO_SIMD
+static inline Vector8
+vector8_pack_16(const Vector8 v1, const Vector8 v2)
+{
+       Vector8         mask PG_USED_FOR_ASSERTS_ONLY;
+
+       mask = vector8_interleave_low(vector8_broadcast(0), vector8_broadcast(0xff));
+       Assert(!vector8_has_ge(vector8_and(v1, mask), 1));
+       Assert(!vector8_has_ge(vector8_and(v2, mask), 1));
+#ifdef USE_SSE2
+       return _mm_packus_epi16(v1, v2);
+#elif defined(USE_NEON)
+       return vuzp1q_u8(v1, v2);
+#endif
+}
+#endif                                                 /* ! USE_NO_SIMD */
+
+/*
+ * Unsigned shift left of each 32-bit element in the vector by "i" bits.
+ *
+ * XXX AArch64 requires an integer literal, so we have to list all expected
+ * values of "i" from all callers in a switch statement.  If you add a new
+ * caller, be sure your expected values of "i" are handled.
+ */
+#ifndef USE_NO_SIMD
+static inline Vector8
+vector8_shift_left(const Vector8 v1, int i)
+{
+#ifdef USE_SSE2
+       return _mm_slli_epi32(v1, i);
+#elif defined(USE_NEON)
+       switch (i)
+       {
+               case 4:
+                       return (Vector8) vshlq_n_u32((Vector32) v1, 4);
+               default:
+                       Assert(false);
+                       return vector8_broadcast(0);
+       }
+#endif
+}
+#endif                                                 /* ! USE_NO_SIMD */
+
+/*
+ * Unsigned shift right of each 32-bit element in the vector by "i" bits.
+ *
+ * XXX AArch64 requires an integer literal, so we have to list all expected
+ * values of "i" from all callers in a switch statement.  If you add a new
+ * caller, be sure your expected values of "i" are handled.
+ */
+#ifndef USE_NO_SIMD
+static inline Vector8
+vector8_shift_right(const Vector8 v1, int i)
+{
+#ifdef USE_SSE2
+       return _mm_srli_epi32(v1, i);
+#elif defined(USE_NEON)
+       switch (i)
+       {
+               case 4:
+                       return (Vector8) vshrq_n_u32((Vector32) v1, 4);
+               case 8:
+                       return (Vector8) vshrq_n_u32((Vector32) v1, 8);
+               default:
+                       Assert(false);
+                       return vector8_broadcast(0);
+       }
+#endif
+}
+#endif                                                 /* ! USE_NO_SIMD */
+
 #endif                                                 /* SIMD_H */
index 691e475bce375af50a119259c8adc74bd079c79c..b9dc08d5f617ea5ce06d0276dc2ec6b150e19f10 100644 (file)
@@ -260,6 +260,64 @@ SELECT reverse('\xabcd'::bytea);
  \xcdab
 (1 row)
 
+SELECT ('\x' || repeat(' ', 32))::bytea;
+ bytea 
+-------
+ \x
+(1 row)
+
+SELECT ('\x' || repeat('!', 32))::bytea;
+ERROR:  invalid hexadecimal digit: "!"
+SELECT ('\x' || repeat('/', 34))::bytea;
+ERROR:  invalid hexadecimal digit: "/"
+SELECT ('\x' || repeat('0', 34))::bytea;
+                bytea                 
+--------------------------------------
+ \x0000000000000000000000000000000000
+(1 row)
+
+SELECT ('\x' || repeat('9', 32))::bytea;
+               bytea                
+------------------------------------
+ \x99999999999999999999999999999999
+(1 row)
+
+SELECT ('\x' || repeat(':', 32))::bytea;
+ERROR:  invalid hexadecimal digit: ":"
+SELECT ('\x' || repeat('@', 34))::bytea;
+ERROR:  invalid hexadecimal digit: "@"
+SELECT ('\x' || repeat('A', 34))::bytea;
+                bytea                 
+--------------------------------------
+ \xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+(1 row)
+
+SELECT ('\x' || repeat('F', 32))::bytea;
+               bytea                
+------------------------------------
+ \xffffffffffffffffffffffffffffffff
+(1 row)
+
+SELECT ('\x' || repeat('G', 32))::bytea;
+ERROR:  invalid hexadecimal digit: "G"
+SELECT ('\x' || repeat('`', 34))::bytea;
+ERROR:  invalid hexadecimal digit: "`"
+SELECT ('\x' || repeat('a', 34))::bytea;
+                bytea                 
+--------------------------------------
+ \xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+(1 row)
+
+SELECT ('\x' || repeat('f', 32))::bytea;
+               bytea                
+------------------------------------
+ \xffffffffffffffffffffffffffffffff
+(1 row)
+
+SELECT ('\x' || repeat('g', 32))::bytea;
+ERROR:  invalid hexadecimal digit: "g"
+SELECT ('\x' || repeat('~', 34))::bytea;
+ERROR:  invalid hexadecimal digit: "~"
 SET bytea_output TO escape;
 SELECT E'\\xDeAdBeEf'::bytea;
       bytea       
index c05f34136990f6b6950a859d13e814e79309c3b6..a2a915234040ce6b16378de6698edaa96d146abc 100644 (file)
@@ -82,6 +82,22 @@ SELECT reverse(''::bytea);
 SELECT reverse('\xaa'::bytea);
 SELECT reverse('\xabcd'::bytea);
 
+SELECT ('\x' || repeat(' ', 32))::bytea;
+SELECT ('\x' || repeat('!', 32))::bytea;
+SELECT ('\x' || repeat('/', 34))::bytea;
+SELECT ('\x' || repeat('0', 34))::bytea;
+SELECT ('\x' || repeat('9', 32))::bytea;
+SELECT ('\x' || repeat(':', 32))::bytea;
+SELECT ('\x' || repeat('@', 34))::bytea;
+SELECT ('\x' || repeat('A', 34))::bytea;
+SELECT ('\x' || repeat('F', 32))::bytea;
+SELECT ('\x' || repeat('G', 32))::bytea;
+SELECT ('\x' || repeat('`', 34))::bytea;
+SELECT ('\x' || repeat('a', 34))::bytea;
+SELECT ('\x' || repeat('f', 32))::bytea;
+SELECT ('\x' || repeat('g', 32))::bytea;
+SELECT ('\x' || repeat('~', 34))::bytea;
+
 SET bytea_output TO escape;
 SELECT E'\\xDeAdBeEf'::bytea;
 SELECT E'\\x De Ad Be Ef '::bytea;