]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
General-purpose unrolled ASCII tolower() loops
authorTony Finch <fanf@isc.org>
Mon, 27 Jun 2022 11:57:28 +0000 (12:57 +0100)
committerTony Finch <fanf@isc.org>
Mon, 12 Sep 2022 11:18:57 +0000 (12:18 +0100)
When converting a string to lower case, the compiler is able to
autovectorize nicely, so a nice simple implementation is also very
fast, comparable to memcpy().

Comparisons are more difficult for the compiler, so we convert eight
bytes at a time using "SIMD within a register" tricks. Experiments
indicate it's best to stick to simple loops for shorter strings and
the remainder of long strings.

.gitlab-ci.yml
lib/dns/compress.c
lib/dns/name.c
lib/dns/rbtdb.c
lib/isc/hash.c
lib/isc/include/isc/ascii.h
tests/dns/name_test.c

index 6ab49f9b2463613ddc9fd6f21396f72f98f00d2d..bdabe4715e2eef4e99708844f5ec46d379782b49 100644 (file)
@@ -1464,9 +1464,11 @@ gcov:
     # source files from lib/dns/rdata/*/, using an even nastier trick.
     - find lib/dns/rdata/* -name "*.c" -execdir cp -f "{}" ../../ \;
     # Help gcovr process inline functions in headers
-    - cp -f lib/isc/include/isc/*.h lib/dns/
     - cp -f lib/dns/include/dns/*.h lib/dns/
     - cp -f lib/dns/include/dns/*.h lib/ns/
+    - cp -f lib/isc/include/isc/*.h lib/isc/
+    - cp -f lib/isc/include/isc/*.h lib/dns/
+    - cp -f lib/isc/include/isc/*.h lib/ns/
     # Generate XML file in the Cobertura XML format suitable for use by GitLab
     # for the purpose of displaying code coverage information in the diff view
     # of a given merge request.
index 08699c6e3c9cdebc587142d027086aeab787b743..58cfdd82ee525aac136684bef09b47bafb04b4d7 100644 (file)
@@ -237,7 +237,6 @@ dns_compress_find(dns_compress_t *cctx, const dns_name_t *name,
                        for (node = cctx->table[i]; node != NULL;
                             node = node->next) {
                                unsigned int l, count;
-                               unsigned char c;
                                unsigned char *p1, *p2;
 
                                if (node->name.length != length) {
@@ -260,39 +259,12 @@ dns_compress_find(dns_compress_t *cctx, const dns_name_t *name,
                                        /* no bitstring support */
                                        INSIST(count <= 63);
 
-                                       /* Loop unrolled for performance */
-                                       while (count > 3) {
-                                               c = isc_ascii_tolower(p1[0]);
-                                               if (c !=
-                                                   isc_ascii_tolower(p2[0])) {
-                                                       goto cont1;
-                                               }
-                                               c = isc_ascii_tolower(p1[1]);
-                                               if (c !=
-                                                   isc_ascii_tolower(p2[1])) {
-                                                       goto cont1;
-                                               }
-                                               c = isc_ascii_tolower(p1[2]);
-                                               if (c !=
-                                                   isc_ascii_tolower(p2[2])) {
-                                                       goto cont1;
-                                               }
-                                               c = isc_ascii_tolower(p1[3]);
-                                               if (c !=
-                                                   isc_ascii_tolower(p2[3])) {
-                                                       goto cont1;
-                                               }
-                                               count -= 4;
-                                               p1 += 4;
-                                               p2 += 4;
-                                       }
-                                       while (count-- > 0) {
-                                               c = isc_ascii_tolower(*p1++);
-                                               if (c !=
-                                                   isc_ascii_tolower(*p2++)) {
-                                                       goto cont1;
-                                               }
+                                       if (!isc_ascii_lowerequal(p1, p2,
+                                                                 count)) {
+                                               goto cont1;
                                        }
+                                       p1 += count;
+                                       p2 += count;
                                }
                                break;
                        cont1:
index 10075a8582890e126a94fa45ac2583237e021312..0ed6be964d273e76045753900ca27ae48c66f29b 100644 (file)
@@ -442,7 +442,7 @@ dns_namereln_t
 dns_name_fullcompare(const dns_name_t *name1, const dns_name_t *name2,
                     int *orderp, unsigned int *nlabelsp) {
        unsigned int l1, l2, l, count1, count2, count, nlabels;
-       int cdiff, ldiff, chdiff;
+       int cdiff, ldiff, diff;
        unsigned char *label1, *label2;
        unsigned char *offsets1, *offsets2;
        dns_offsets_t odata1, odata2;
@@ -492,8 +492,7 @@ dns_name_fullcompare(const dns_name_t *name1, const dns_name_t *name2,
        offsets1 += l1;
        offsets2 += l2;
 
-       while (l > 0) {
-               l--;
+       while (l-- > 0) {
                offsets1--;
                offsets2--;
                label1 = &name1->ndata[*offsets1];
@@ -501,12 +500,6 @@ dns_name_fullcompare(const dns_name_t *name1, const dns_name_t *name2,
                count1 = *label1++;
                count2 = *label2++;
 
-               /*
-                * We dropped bitstring labels, and we don't support any
-                * other extended label types.
-                */
-               INSIST(count1 <= 63 && count2 <= 63);
-
                cdiff = (int)count1 - (int)count2;
                if (cdiff < 0) {
                        count = count1;
@@ -514,44 +507,12 @@ dns_name_fullcompare(const dns_name_t *name1, const dns_name_t *name2,
                        count = count2;
                }
 
-               /* Loop unrolled for performance */
-               while (count > 3) {
-                       chdiff = (int)isc_ascii_tolower(label1[0]) -
-                                (int)isc_ascii_tolower(label2[0]);
-                       if (chdiff != 0) {
-                               *orderp = chdiff;
-                               goto done;
-                       }
-                       chdiff = (int)isc_ascii_tolower(label1[1]) -
-                                (int)isc_ascii_tolower(label2[1]);
-                       if (chdiff != 0) {
-                               *orderp = chdiff;
-                               goto done;
-                       }
-                       chdiff = (int)isc_ascii_tolower(label1[2]) -
-                                (int)isc_ascii_tolower(label2[2]);
-                       if (chdiff != 0) {
-                               *orderp = chdiff;
-                               goto done;
-                       }
-                       chdiff = (int)isc_ascii_tolower(label1[3]) -
-                                (int)isc_ascii_tolower(label2[3]);
-                       if (chdiff != 0) {
-                               *orderp = chdiff;
-                               goto done;
-                       }
-                       count -= 4;
-                       label1 += 4;
-                       label2 += 4;
-               }
-               while (count-- > 0) {
-                       chdiff = (int)isc_ascii_tolower(*label1++) -
-                                (int)isc_ascii_tolower(*label2++);
-                       if (chdiff != 0) {
-                               *orderp = chdiff;
-                               goto done;
-                       }
+               diff = isc_ascii_lowercmp(label1, label2, count);
+               if (diff != 0) {
+                       *orderp = diff;
+                       goto done;
                }
+
                if (cdiff != 0) {
                        *orderp = cdiff;
                        goto done;
@@ -601,9 +562,7 @@ dns_name_compare(const dns_name_t *name1, const dns_name_t *name2) {
 
 bool
 dns_name_equal(const dns_name_t *name1, const dns_name_t *name2) {
-       unsigned int l, count;
-       unsigned char c;
-       unsigned char *label1, *label2;
+       unsigned int length;
 
        /*
         * Are 'name1' and 'name2' equal?
@@ -626,57 +585,13 @@ dns_name_equal(const dns_name_t *name1, const dns_name_t *name2) {
                return (true);
        }
 
-       if (name1->length != name2->length) {
-               return (false);
-       }
-
-       l = name1->labels;
-
-       if (l != name2->labels) {
+       length = name1->length;
+       if (length != name2->length) {
                return (false);
        }
 
-       label1 = name1->ndata;
-       label2 = name2->ndata;
-       while (l-- > 0) {
-               count = *label1++;
-               if (count != *label2++) {
-                       return (false);
-               }
-
-               INSIST(count <= 63); /* no bitstring support */
-
-               /* Loop unrolled for performance */
-               while (count > 3) {
-                       c = isc_ascii_tolower(label1[0]);
-                       if (c != isc_ascii_tolower(label2[0])) {
-                               return (false);
-                       }
-                       c = isc_ascii_tolower(label1[1]);
-                       if (c != isc_ascii_tolower(label2[1])) {
-                               return (false);
-                       }
-                       c = isc_ascii_tolower(label1[2]);
-                       if (c != isc_ascii_tolower(label2[2])) {
-                               return (false);
-                       }
-                       c = isc_ascii_tolower(label1[3]);
-                       if (c != isc_ascii_tolower(label2[3])) {
-                               return (false);
-                       }
-                       count -= 4;
-                       label1 += 4;
-                       label2 += 4;
-               }
-               while (count-- > 0) {
-                       c = isc_ascii_tolower(*label1++);
-                       if (c != isc_ascii_tolower(*label2++)) {
-                               return (false);
-                       }
-               }
-       }
-
-       return (true);
+       /* label lengths are < 64 so tolower() does not affect them */
+       return (isc_ascii_lowerequal(name1->ndata, name2->ndata, length));
 }
 
 bool
@@ -711,10 +626,6 @@ dns_name_caseequal(const dns_name_t *name1, const dns_name_t *name2) {
 
 int
 dns_name_rdatacompare(const dns_name_t *name1, const dns_name_t *name2) {
-       unsigned int l1, l2, l, count1, count2, count;
-       unsigned char c1, c2;
-       unsigned char *label1, *label2;
-
        /*
         * Compare two absolute names as rdata.
         */
@@ -726,47 +637,9 @@ dns_name_rdatacompare(const dns_name_t *name1, const dns_name_t *name2) {
        REQUIRE(name2->labels > 0);
        REQUIRE((name2->attributes & DNS_NAMEATTR_ABSOLUTE) != 0);
 
-       l1 = name1->labels;
-       l2 = name2->labels;
-
-       l = (l1 < l2) ? l1 : l2;
-
-       label1 = name1->ndata;
-       label2 = name2->ndata;
-       while (l > 0) {
-               l--;
-               count1 = *label1++;
-               count2 = *label2++;
-
-               /* no bitstring support */
-               INSIST(count1 <= 63 && count2 <= 63);
-
-               if (count1 != count2) {
-                       return ((count1 < count2) ? -1 : 1);
-               }
-               count = count1;
-               while (count > 0) {
-                       count--;
-                       c1 = isc_ascii_tolower(*label1++);
-                       c2 = isc_ascii_tolower(*label2++);
-                       if (c1 < c2) {
-                               return (-1);
-                       } else if (c1 > c2) {
-                               return (1);
-                       }
-               }
-       }
-
-       /*
-        * If one name had more labels than the other, their common
-        * prefix must have been different because the shorter name
-        * ended with the root label and the longer one can't have
-        * a root label in the middle of it.  Therefore, if we get
-        * to this point, the lengths must be equal.
-        */
-       INSIST(l1 == l2);
-
-       return (0);
+       /* label lengths are < 64 so tolower() does not affect them */
+       return (isc_ascii_lowercmp(name1->ndata, name2->ndata,
+                                  ISC_MIN(name1->length, name2->length)));
 }
 
 bool
@@ -1572,8 +1445,7 @@ dns_name_tofilenametext(const dns_name_t *name, bool omit_final_dot,
 isc_result_t
 dns_name_downcase(const dns_name_t *source, dns_name_t *name,
                  isc_buffer_t *target) {
-       unsigned char *sndata, *ndata;
-       unsigned int nlen, count, labels;
+       unsigned char *ndata;
        isc_buffer_t buffer;
 
        /*
@@ -1599,33 +1471,13 @@ dns_name_downcase(const dns_name_t *source, dns_name_t *name,
                name->ndata = ndata;
        }
 
-       sndata = source->ndata;
-       nlen = source->length;
-       labels = source->labels;
-
-       if (nlen > (target->length - target->used)) {
+       if (source->length > (target->length - target->used)) {
                MAKE_EMPTY(name);
                return (ISC_R_NOSPACE);
        }
 
-       while (labels > 0 && nlen > 0) {
-               labels--;
-               count = *sndata++;
-               *ndata++ = count;
-               nlen--;
-               if (count < 64) {
-                       INSIST(nlen >= count);
-                       while (count > 0) {
-                               *ndata++ = isc_ascii_tolower(*sndata++);
-                               nlen--;
-                               count--;
-                       }
-               } else {
-                       FATAL_ERROR(__FILE__, __LINE__,
-                                   "Unexpected label type %02x", count);
-                       /* Does not return. */
-               }
-       }
+       /* label lengths are < 64 so tolower() does not affect them */
+       isc_ascii_lowercopy(ndata, source->ndata, source->length);
 
        if (source != name) {
                name->labels = source->labels;
index 38220271ff49d339afb5ad5faf529628dac4ee85..48e80d806c253b38828d43449444ff5ccd21fb52 100644 (file)
@@ -9374,9 +9374,7 @@ rdataset_getownercase(const dns_rdataset_t *rdataset, dns_name_t *name) {
        }
 
        if (CASEFULLYLOWER(header)) {
-               for (size_t i = 0; i < name->length; i++) {
-                       name->ndata[i] = isc_ascii_tolower(name->ndata[i]);
-               }
+               isc_ascii_lowercopy(name->ndata, name->ndata, name->length);
        } else {
                uint8_t *nd = name->ndata;
                for (size_t i = 0; i < name->length; i++) {
index 7e2b17c2df6dd5be35c3194bdc8c4fd52e9f1a46..37622ea8c590e23a6cd45f8f97222c75727c3969 100644 (file)
@@ -89,12 +89,9 @@ isc_hash64(const void *data, const size_t length, const bool case_sensitive) {
        if (case_sensitive) {
                isc_siphash24(isc_hash_key, data, length, (uint8_t *)&hval);
        } else {
-               const uint8_t *byte = data;
                uint8_t lower[1024];
-               REQUIRE(length <= 1024);
-               for (unsigned i = 0; i < length; i++) {
-                       lower[i] = isc_ascii_tolower(byte[i]);
-               }
+               REQUIRE(length <= sizeof(lower));
+               isc_ascii_lowercopy(lower, data, length);
                isc_siphash24(isc_hash_key, lower, length, (uint8_t *)&hval);
        }
 
@@ -113,12 +110,9 @@ isc_hash32(const void *data, const size_t length, const bool case_sensitive) {
        if (case_sensitive) {
                isc_halfsiphash24(isc_hash_key, data, length, (uint8_t *)&hval);
        } else {
-               const uint8_t *byte = data;
                uint8_t lower[1024];
-               REQUIRE(length <= 1024);
-               for (unsigned i = 0; i < length; i++) {
-                       lower[i] = isc_ascii_tolower(byte[i]);
-               }
+               REQUIRE(length <= sizeof(lower));
+               isc_ascii_lowercopy(lower, data, length);
                isc_halfsiphash24(isc_hash_key, lower, length,
                                  (uint8_t *)&hval);
        }
index 486e1c540e9d135d62782918f4ee48b96b377bcf..c11641b750482f2889ac031cfd89deea41b6075f 100644 (file)
 
 #pragma once
 
+#include <stdbool.h>
 #include <stdint.h>
+#include <string.h>
+
+#include <isc/endian.h>
 
 /*
  * ASCII case conversion
@@ -27,12 +31,144 @@ extern const uint8_t isc__ascii_toupper[256];
 #define isc_ascii_tolower(c) isc__ascii_tolower[(uint8_t)(c)]
 #define isc_ascii_toupper(c) isc__ascii_toupper[(uint8_t)(c)]
 
+/*
+ * A variant tolower() implementation with no memory accesses,
+ * for use when the compiler is able to autovectorize.
+ */
+static inline uint8_t
+isc__ascii_tolower1(uint8_t c) {
+       return (c + ('a' - 'A') * ('A' <= c && c <= 'Z'));
+}
+
+/*
+ * Copy `len` bytes from `src` to `dst`, converting to lower case.
+ */
+static inline void
+isc_ascii_lowercopy(uint8_t *dst, const uint8_t *src, unsigned len) {
+       while (len-- > 0) {
+               *dst++ = isc__ascii_tolower1(*src++);
+       }
+}
+
 /*
  * Convert a string to lower case in place
  */
 static inline void
 isc_ascii_strtolower(char *str) {
-       for (size_t len = strlen(str); len > 0; len--, str++) {
-               *str = isc_ascii_tolower(*str);
+       isc_ascii_lowercopy((uint8_t *)str, (uint8_t *)str,
+                           (unsigned)strlen(str));
+}
+
+/*
+ * Convert 8 bytes to lower case, using SWAR tricks (SIMD within a register).
+ * Based on "Hacker's Delight" by Henry S. Warren, "searching for a value in a
+ * given range", p. 95. Eight bytes is wider than many labels in DNS names, so
+ * it does not seem worth dealing with the portability issues of wide vector
+ * registers. If there was a vector string load instruction (analogous to
+ * memove() below) the balance might be different.
+ */
+static inline uint64_t
+isc__ascii_tolower8(uint64_t octets) {
+       /*
+        * Multiply a single-byte constant by `all_bytes` to replicate
+        * it to all eight bytes in a word.
+        */
+       uint64_t all_bytes = 0x0101010101010101;
+       /*
+        * Clear the top bit of each byte to make space for a per-byte flag.
+        */
+       uint64_t heptets = octets & (0x7F * all_bytes);
+       /*
+        * We will need to avoid going wrong if our flag bits were originally
+        * set, and clear calculation leftovers in our non-flag bits
+        */
+       uint64_t is_ascii = ~octets & (0x80 * all_bytes);
+       /*
+        * To compare a heptet to `N`, we can add `0x7F - N` so that carry
+        * propagation will set the flag when our heptet is greater than `N`
+        */
+       uint64_t is_gt_Z = heptets + (0x7F - 'Z') * all_bytes;
+       /*
+        * Add one for greater-than-or-equal comparison
+        */
+       uint64_t is_ge_A = heptets + (0x80 - 'A') * all_bytes;
+       /*
+        * Now we have what we need to identify the ascii uppercase bytes
+        */
+       uint64_t is_upper = (is_ge_A ^ is_gt_Z) & is_ascii;
+       /*
+        * Move the is_upper flag bits to bit 0x20 (which is 'a' - 'A')
+        * and use them to adjust each byte as required
+        */
+       return (octets | (is_upper >> 2));
+}
+
+/*
+ * Helper function to do an unaligned load of 8 bytes in host byte order
+ */
+static inline uint64_t
+isc__ascii_load8(const uint8_t *ptr) {
+       uint64_t bytes = 0;
+       memmove(&bytes, ptr, sizeof(bytes));
+       return (bytes);
+}
+
+/*
+ * Compare `len` bytes at `a` and `b` for case-insensitive equality
+ */
+static inline bool
+isc_ascii_lowerequal(const uint8_t *a, const uint8_t *b, unsigned len) {
+       uint64_t a8 = 0, b8 = 0;
+       while (len >= 8) {
+               a8 = isc__ascii_tolower8(isc__ascii_load8(a));
+               b8 = isc__ascii_tolower8(isc__ascii_load8(b));
+               if (a8 != b8) {
+                       return (false);
+               }
+               len -= 8;
+               a += 8;
+               b += 8;
+       }
+       while (len-- > 0) {
+               if (isc_ascii_tolower(*a++) != isc_ascii_tolower(*b++)) {
+                       return (false);
+               }
+       }
+       return (true);
+}
+
+/*
+ * Compare `len` bytes at `a` and `b` for case-insensitive order.
+ * Unlike the previous functions (which do not need to care about byte
+ * order) here we need to ensure the comparisons are lexicographic,
+ * i.e. they treat the strings as big-endian numbers.
+ */
+static inline int
+isc_ascii_lowercmp(const uint8_t *a, const uint8_t *b, unsigned len) {
+       uint64_t a8 = 0, b8 = 0;
+       while (len >= 8) {
+               a8 = isc__ascii_tolower8(htobe64(isc__ascii_load8(a)));
+               b8 = isc__ascii_tolower8(htobe64(isc__ascii_load8(b)));
+               if (a8 != b8) {
+                       goto ret;
+               }
+               len -= 8;
+               a += 8;
+               b += 8;
+       }
+       while (len-- > 0) {
+               a8 = isc_ascii_tolower(*a++);
+               b8 = isc_ascii_tolower(*b++);
+               if (a8 != b8) {
+                       goto ret;
+               }
+       }
+ret:
+       if (a8 < b8) {
+               return (-1);
+       }
+       if (a8 > b8) {
+               return (+1);
        }
+       return (0);
 }
index 42fa631e85faeb8ec9443fa76bbc4fd31a4db407..721544707705d67d7244c436dc58d6e43b45c29d 100644 (file)
@@ -61,21 +61,21 @@ ISC_RUN_TEST_IMPL(fullcompare) {
                { "", "", dns_namereln_equal, 0, 0 },
                { "foo", "", dns_namereln_subdomain, 1, 0 },
                { "", "foo", dns_namereln_contains, -1, 0 },
-               { "foo", "bar", dns_namereln_none, 4, 0 },
-               { "bar", "foo", dns_namereln_none, -4, 0 },
+               { "foo", "bar", dns_namereln_none, 1, 0 },
+               { "bar", "foo", dns_namereln_none, -1, 0 },
                { "bar.foo", "foo", dns_namereln_subdomain, 1, 1 },
                { "foo", "bar.foo", dns_namereln_contains, -1, 1 },
                { "baz.bar.foo", "bar.foo", dns_namereln_subdomain, 1, 2 },
                { "bar.foo", "baz.bar.foo", dns_namereln_contains, -1, 2 },
-               { "foo.example", "bar.example", dns_namereln_commonancestor, 4,
+               { "foo.example", "bar.example", dns_namereln_commonancestor, 1,
                  1 },
 
                /* absolute */
                { ".", ".", dns_namereln_equal, 0, 1 },
-               { "foo.", "bar.", dns_namereln_commonancestor, 4, 1 },
-               { "bar.", "foo.", dns_namereln_commonancestor, -4, 1 },
+               { "foo.", "bar.", dns_namereln_commonancestor, 1, 1 },
+               { "bar.", "foo.", dns_namereln_commonancestor, -1, 1 },
                { "foo.example.", "bar.example.", dns_namereln_commonancestor,
-                 4, 2 },
+                 1, 2 },
                { "bar.foo.", "foo.", dns_namereln_subdomain, 1, 2 },
                { "foo.", "bar.foo.", dns_namereln_contains, -1, 2 },
                { "baz.bar.foo.", "bar.foo.", dns_namereln_subdomain, 1, 3 },