From: Tony Finch Date: Mon, 27 Jun 2022 11:57:28 +0000 (+0100) Subject: General-purpose unrolled ASCII tolower() loops X-Git-Tag: v9.19.6~69^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=21a383a8fd91111004ccaecd562aa1ea944ef161;p=thirdparty%2Fbind9.git General-purpose unrolled ASCII tolower() loops When converting a string to lower case, the compiler is able to autovectorize nicely, so a nice simple implementation is also very fast, comparable to memcpy(). Comparisons are more difficult for the compiler, so we convert eight bytes at a time using "SIMD within a register" tricks. Experiments indicate it's best to stick to simple loops for shorter strings and the remainder of long strings. --- diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6ab49f9b246..bdabe4715e2 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1464,9 +1464,11 @@ gcov: # source files from lib/dns/rdata/*/, using an even nastier trick. - find lib/dns/rdata/* -name "*.c" -execdir cp -f "{}" ../../ \; # Help gcovr process inline functions in headers - - cp -f lib/isc/include/isc/*.h lib/dns/ - cp -f lib/dns/include/dns/*.h lib/dns/ - cp -f lib/dns/include/dns/*.h lib/ns/ + - cp -f lib/isc/include/isc/*.h lib/isc/ + - cp -f lib/isc/include/isc/*.h lib/dns/ + - cp -f lib/isc/include/isc/*.h lib/ns/ # Generate XML file in the Cobertura XML format suitable for use by GitLab # for the purpose of displaying code coverage information in the diff view # of a given merge request. diff --git a/lib/dns/compress.c b/lib/dns/compress.c index 08699c6e3c9..58cfdd82ee5 100644 --- a/lib/dns/compress.c +++ b/lib/dns/compress.c @@ -237,7 +237,6 @@ dns_compress_find(dns_compress_t *cctx, const dns_name_t *name, for (node = cctx->table[i]; node != NULL; node = node->next) { unsigned int l, count; - unsigned char c; unsigned char *p1, *p2; if (node->name.length != length) { @@ -260,39 +259,12 @@ dns_compress_find(dns_compress_t *cctx, const dns_name_t *name, /* no bitstring support */ INSIST(count <= 63); - /* Loop unrolled for performance */ - while (count > 3) { - c = isc_ascii_tolower(p1[0]); - if (c != - isc_ascii_tolower(p2[0])) { - goto cont1; - } - c = isc_ascii_tolower(p1[1]); - if (c != - isc_ascii_tolower(p2[1])) { - goto cont1; - } - c = isc_ascii_tolower(p1[2]); - if (c != - isc_ascii_tolower(p2[2])) { - goto cont1; - } - c = isc_ascii_tolower(p1[3]); - if (c != - isc_ascii_tolower(p2[3])) { - goto cont1; - } - count -= 4; - p1 += 4; - p2 += 4; - } - while (count-- > 0) { - c = isc_ascii_tolower(*p1++); - if (c != - isc_ascii_tolower(*p2++)) { - goto cont1; - } + if (!isc_ascii_lowerequal(p1, p2, + count)) { + goto cont1; } + p1 += count; + p2 += count; } break; cont1: diff --git a/lib/dns/name.c b/lib/dns/name.c index 10075a85828..0ed6be964d2 100644 --- a/lib/dns/name.c +++ b/lib/dns/name.c @@ -442,7 +442,7 @@ dns_namereln_t dns_name_fullcompare(const dns_name_t *name1, const dns_name_t *name2, int *orderp, unsigned int *nlabelsp) { unsigned int l1, l2, l, count1, count2, count, nlabels; - int cdiff, ldiff, chdiff; + int cdiff, ldiff, diff; unsigned char *label1, *label2; unsigned char *offsets1, *offsets2; dns_offsets_t odata1, odata2; @@ -492,8 +492,7 @@ dns_name_fullcompare(const dns_name_t *name1, const dns_name_t *name2, offsets1 += l1; offsets2 += l2; - while (l > 0) { - l--; + while (l-- > 0) { offsets1--; offsets2--; label1 = &name1->ndata[*offsets1]; @@ -501,12 +500,6 @@ dns_name_fullcompare(const dns_name_t *name1, const dns_name_t *name2, count1 = *label1++; count2 = *label2++; - /* - * We dropped bitstring labels, and we don't support any - * other extended label types. - */ - INSIST(count1 <= 63 && count2 <= 63); - cdiff = (int)count1 - (int)count2; if (cdiff < 0) { count = count1; @@ -514,44 +507,12 @@ dns_name_fullcompare(const dns_name_t *name1, const dns_name_t *name2, count = count2; } - /* Loop unrolled for performance */ - while (count > 3) { - chdiff = (int)isc_ascii_tolower(label1[0]) - - (int)isc_ascii_tolower(label2[0]); - if (chdiff != 0) { - *orderp = chdiff; - goto done; - } - chdiff = (int)isc_ascii_tolower(label1[1]) - - (int)isc_ascii_tolower(label2[1]); - if (chdiff != 0) { - *orderp = chdiff; - goto done; - } - chdiff = (int)isc_ascii_tolower(label1[2]) - - (int)isc_ascii_tolower(label2[2]); - if (chdiff != 0) { - *orderp = chdiff; - goto done; - } - chdiff = (int)isc_ascii_tolower(label1[3]) - - (int)isc_ascii_tolower(label2[3]); - if (chdiff != 0) { - *orderp = chdiff; - goto done; - } - count -= 4; - label1 += 4; - label2 += 4; - } - while (count-- > 0) { - chdiff = (int)isc_ascii_tolower(*label1++) - - (int)isc_ascii_tolower(*label2++); - if (chdiff != 0) { - *orderp = chdiff; - goto done; - } + diff = isc_ascii_lowercmp(label1, label2, count); + if (diff != 0) { + *orderp = diff; + goto done; } + if (cdiff != 0) { *orderp = cdiff; goto done; @@ -601,9 +562,7 @@ dns_name_compare(const dns_name_t *name1, const dns_name_t *name2) { bool dns_name_equal(const dns_name_t *name1, const dns_name_t *name2) { - unsigned int l, count; - unsigned char c; - unsigned char *label1, *label2; + unsigned int length; /* * Are 'name1' and 'name2' equal? @@ -626,57 +585,13 @@ dns_name_equal(const dns_name_t *name1, const dns_name_t *name2) { return (true); } - if (name1->length != name2->length) { - return (false); - } - - l = name1->labels; - - if (l != name2->labels) { + length = name1->length; + if (length != name2->length) { return (false); } - label1 = name1->ndata; - label2 = name2->ndata; - while (l-- > 0) { - count = *label1++; - if (count != *label2++) { - return (false); - } - - INSIST(count <= 63); /* no bitstring support */ - - /* Loop unrolled for performance */ - while (count > 3) { - c = isc_ascii_tolower(label1[0]); - if (c != isc_ascii_tolower(label2[0])) { - return (false); - } - c = isc_ascii_tolower(label1[1]); - if (c != isc_ascii_tolower(label2[1])) { - return (false); - } - c = isc_ascii_tolower(label1[2]); - if (c != isc_ascii_tolower(label2[2])) { - return (false); - } - c = isc_ascii_tolower(label1[3]); - if (c != isc_ascii_tolower(label2[3])) { - return (false); - } - count -= 4; - label1 += 4; - label2 += 4; - } - while (count-- > 0) { - c = isc_ascii_tolower(*label1++); - if (c != isc_ascii_tolower(*label2++)) { - return (false); - } - } - } - - return (true); + /* label lengths are < 64 so tolower() does not affect them */ + return (isc_ascii_lowerequal(name1->ndata, name2->ndata, length)); } bool @@ -711,10 +626,6 @@ dns_name_caseequal(const dns_name_t *name1, const dns_name_t *name2) { int dns_name_rdatacompare(const dns_name_t *name1, const dns_name_t *name2) { - unsigned int l1, l2, l, count1, count2, count; - unsigned char c1, c2; - unsigned char *label1, *label2; - /* * Compare two absolute names as rdata. */ @@ -726,47 +637,9 @@ dns_name_rdatacompare(const dns_name_t *name1, const dns_name_t *name2) { REQUIRE(name2->labels > 0); REQUIRE((name2->attributes & DNS_NAMEATTR_ABSOLUTE) != 0); - l1 = name1->labels; - l2 = name2->labels; - - l = (l1 < l2) ? l1 : l2; - - label1 = name1->ndata; - label2 = name2->ndata; - while (l > 0) { - l--; - count1 = *label1++; - count2 = *label2++; - - /* no bitstring support */ - INSIST(count1 <= 63 && count2 <= 63); - - if (count1 != count2) { - return ((count1 < count2) ? -1 : 1); - } - count = count1; - while (count > 0) { - count--; - c1 = isc_ascii_tolower(*label1++); - c2 = isc_ascii_tolower(*label2++); - if (c1 < c2) { - return (-1); - } else if (c1 > c2) { - return (1); - } - } - } - - /* - * If one name had more labels than the other, their common - * prefix must have been different because the shorter name - * ended with the root label and the longer one can't have - * a root label in the middle of it. Therefore, if we get - * to this point, the lengths must be equal. - */ - INSIST(l1 == l2); - - return (0); + /* label lengths are < 64 so tolower() does not affect them */ + return (isc_ascii_lowercmp(name1->ndata, name2->ndata, + ISC_MIN(name1->length, name2->length))); } bool @@ -1572,8 +1445,7 @@ dns_name_tofilenametext(const dns_name_t *name, bool omit_final_dot, isc_result_t dns_name_downcase(const dns_name_t *source, dns_name_t *name, isc_buffer_t *target) { - unsigned char *sndata, *ndata; - unsigned int nlen, count, labels; + unsigned char *ndata; isc_buffer_t buffer; /* @@ -1599,33 +1471,13 @@ dns_name_downcase(const dns_name_t *source, dns_name_t *name, name->ndata = ndata; } - sndata = source->ndata; - nlen = source->length; - labels = source->labels; - - if (nlen > (target->length - target->used)) { + if (source->length > (target->length - target->used)) { MAKE_EMPTY(name); return (ISC_R_NOSPACE); } - while (labels > 0 && nlen > 0) { - labels--; - count = *sndata++; - *ndata++ = count; - nlen--; - if (count < 64) { - INSIST(nlen >= count); - while (count > 0) { - *ndata++ = isc_ascii_tolower(*sndata++); - nlen--; - count--; - } - } else { - FATAL_ERROR(__FILE__, __LINE__, - "Unexpected label type %02x", count); - /* Does not return. */ - } - } + /* label lengths are < 64 so tolower() does not affect them */ + isc_ascii_lowercopy(ndata, source->ndata, source->length); if (source != name) { name->labels = source->labels; diff --git a/lib/dns/rbtdb.c b/lib/dns/rbtdb.c index 38220271ff4..48e80d806c2 100644 --- a/lib/dns/rbtdb.c +++ b/lib/dns/rbtdb.c @@ -9374,9 +9374,7 @@ rdataset_getownercase(const dns_rdataset_t *rdataset, dns_name_t *name) { } if (CASEFULLYLOWER(header)) { - for (size_t i = 0; i < name->length; i++) { - name->ndata[i] = isc_ascii_tolower(name->ndata[i]); - } + isc_ascii_lowercopy(name->ndata, name->ndata, name->length); } else { uint8_t *nd = name->ndata; for (size_t i = 0; i < name->length; i++) { diff --git a/lib/isc/hash.c b/lib/isc/hash.c index 7e2b17c2df6..37622ea8c59 100644 --- a/lib/isc/hash.c +++ b/lib/isc/hash.c @@ -89,12 +89,9 @@ isc_hash64(const void *data, const size_t length, const bool case_sensitive) { if (case_sensitive) { isc_siphash24(isc_hash_key, data, length, (uint8_t *)&hval); } else { - const uint8_t *byte = data; uint8_t lower[1024]; - REQUIRE(length <= 1024); - for (unsigned i = 0; i < length; i++) { - lower[i] = isc_ascii_tolower(byte[i]); - } + REQUIRE(length <= sizeof(lower)); + isc_ascii_lowercopy(lower, data, length); isc_siphash24(isc_hash_key, lower, length, (uint8_t *)&hval); } @@ -113,12 +110,9 @@ isc_hash32(const void *data, const size_t length, const bool case_sensitive) { if (case_sensitive) { isc_halfsiphash24(isc_hash_key, data, length, (uint8_t *)&hval); } else { - const uint8_t *byte = data; uint8_t lower[1024]; - REQUIRE(length <= 1024); - for (unsigned i = 0; i < length; i++) { - lower[i] = isc_ascii_tolower(byte[i]); - } + REQUIRE(length <= sizeof(lower)); + isc_ascii_lowercopy(lower, data, length); isc_halfsiphash24(isc_hash_key, lower, length, (uint8_t *)&hval); } diff --git a/lib/isc/include/isc/ascii.h b/lib/isc/include/isc/ascii.h index 486e1c540e9..c11641b7504 100644 --- a/lib/isc/include/isc/ascii.h +++ b/lib/isc/include/isc/ascii.h @@ -13,7 +13,11 @@ #pragma once +#include #include +#include + +#include /* * ASCII case conversion @@ -27,12 +31,144 @@ extern const uint8_t isc__ascii_toupper[256]; #define isc_ascii_tolower(c) isc__ascii_tolower[(uint8_t)(c)] #define isc_ascii_toupper(c) isc__ascii_toupper[(uint8_t)(c)] +/* + * A variant tolower() implementation with no memory accesses, + * for use when the compiler is able to autovectorize. + */ +static inline uint8_t +isc__ascii_tolower1(uint8_t c) { + return (c + ('a' - 'A') * ('A' <= c && c <= 'Z')); +} + +/* + * Copy `len` bytes from `src` to `dst`, converting to lower case. + */ +static inline void +isc_ascii_lowercopy(uint8_t *dst, const uint8_t *src, unsigned len) { + while (len-- > 0) { + *dst++ = isc__ascii_tolower1(*src++); + } +} + /* * Convert a string to lower case in place */ static inline void isc_ascii_strtolower(char *str) { - for (size_t len = strlen(str); len > 0; len--, str++) { - *str = isc_ascii_tolower(*str); + isc_ascii_lowercopy((uint8_t *)str, (uint8_t *)str, + (unsigned)strlen(str)); +} + +/* + * Convert 8 bytes to lower case, using SWAR tricks (SIMD within a register). + * Based on "Hacker's Delight" by Henry S. Warren, "searching for a value in a + * given range", p. 95. Eight bytes is wider than many labels in DNS names, so + * it does not seem worth dealing with the portability issues of wide vector + * registers. If there was a vector string load instruction (analogous to + * memove() below) the balance might be different. + */ +static inline uint64_t +isc__ascii_tolower8(uint64_t octets) { + /* + * Multiply a single-byte constant by `all_bytes` to replicate + * it to all eight bytes in a word. + */ + uint64_t all_bytes = 0x0101010101010101; + /* + * Clear the top bit of each byte to make space for a per-byte flag. + */ + uint64_t heptets = octets & (0x7F * all_bytes); + /* + * We will need to avoid going wrong if our flag bits were originally + * set, and clear calculation leftovers in our non-flag bits + */ + uint64_t is_ascii = ~octets & (0x80 * all_bytes); + /* + * To compare a heptet to `N`, we can add `0x7F - N` so that carry + * propagation will set the flag when our heptet is greater than `N` + */ + uint64_t is_gt_Z = heptets + (0x7F - 'Z') * all_bytes; + /* + * Add one for greater-than-or-equal comparison + */ + uint64_t is_ge_A = heptets + (0x80 - 'A') * all_bytes; + /* + * Now we have what we need to identify the ascii uppercase bytes + */ + uint64_t is_upper = (is_ge_A ^ is_gt_Z) & is_ascii; + /* + * Move the is_upper flag bits to bit 0x20 (which is 'a' - 'A') + * and use them to adjust each byte as required + */ + return (octets | (is_upper >> 2)); +} + +/* + * Helper function to do an unaligned load of 8 bytes in host byte order + */ +static inline uint64_t +isc__ascii_load8(const uint8_t *ptr) { + uint64_t bytes = 0; + memmove(&bytes, ptr, sizeof(bytes)); + return (bytes); +} + +/* + * Compare `len` bytes at `a` and `b` for case-insensitive equality + */ +static inline bool +isc_ascii_lowerequal(const uint8_t *a, const uint8_t *b, unsigned len) { + uint64_t a8 = 0, b8 = 0; + while (len >= 8) { + a8 = isc__ascii_tolower8(isc__ascii_load8(a)); + b8 = isc__ascii_tolower8(isc__ascii_load8(b)); + if (a8 != b8) { + return (false); + } + len -= 8; + a += 8; + b += 8; + } + while (len-- > 0) { + if (isc_ascii_tolower(*a++) != isc_ascii_tolower(*b++)) { + return (false); + } + } + return (true); +} + +/* + * Compare `len` bytes at `a` and `b` for case-insensitive order. + * Unlike the previous functions (which do not need to care about byte + * order) here we need to ensure the comparisons are lexicographic, + * i.e. they treat the strings as big-endian numbers. + */ +static inline int +isc_ascii_lowercmp(const uint8_t *a, const uint8_t *b, unsigned len) { + uint64_t a8 = 0, b8 = 0; + while (len >= 8) { + a8 = isc__ascii_tolower8(htobe64(isc__ascii_load8(a))); + b8 = isc__ascii_tolower8(htobe64(isc__ascii_load8(b))); + if (a8 != b8) { + goto ret; + } + len -= 8; + a += 8; + b += 8; + } + while (len-- > 0) { + a8 = isc_ascii_tolower(*a++); + b8 = isc_ascii_tolower(*b++); + if (a8 != b8) { + goto ret; + } + } +ret: + if (a8 < b8) { + return (-1); + } + if (a8 > b8) { + return (+1); } + return (0); } diff --git a/tests/dns/name_test.c b/tests/dns/name_test.c index 42fa631e85f..72154470770 100644 --- a/tests/dns/name_test.c +++ b/tests/dns/name_test.c @@ -61,21 +61,21 @@ ISC_RUN_TEST_IMPL(fullcompare) { { "", "", dns_namereln_equal, 0, 0 }, { "foo", "", dns_namereln_subdomain, 1, 0 }, { "", "foo", dns_namereln_contains, -1, 0 }, - { "foo", "bar", dns_namereln_none, 4, 0 }, - { "bar", "foo", dns_namereln_none, -4, 0 }, + { "foo", "bar", dns_namereln_none, 1, 0 }, + { "bar", "foo", dns_namereln_none, -1, 0 }, { "bar.foo", "foo", dns_namereln_subdomain, 1, 1 }, { "foo", "bar.foo", dns_namereln_contains, -1, 1 }, { "baz.bar.foo", "bar.foo", dns_namereln_subdomain, 1, 2 }, { "bar.foo", "baz.bar.foo", dns_namereln_contains, -1, 2 }, - { "foo.example", "bar.example", dns_namereln_commonancestor, 4, + { "foo.example", "bar.example", dns_namereln_commonancestor, 1, 1 }, /* absolute */ { ".", ".", dns_namereln_equal, 0, 1 }, - { "foo.", "bar.", dns_namereln_commonancestor, 4, 1 }, - { "bar.", "foo.", dns_namereln_commonancestor, -4, 1 }, + { "foo.", "bar.", dns_namereln_commonancestor, 1, 1 }, + { "bar.", "foo.", dns_namereln_commonancestor, -1, 1 }, { "foo.example.", "bar.example.", dns_namereln_commonancestor, - 4, 2 }, + 1, 2 }, { "bar.foo.", "foo.", dns_namereln_subdomain, 1, 2 }, { "foo.", "bar.foo.", dns_namereln_contains, -1, 2 }, { "baz.bar.foo.", "bar.foo.", dns_namereln_subdomain, 1, 3 },