]> git.ipfire.org Git - thirdparty/vectorscan.git/commitdiff
AVX512 Reinforced FAT teddy.
authorChang, Harry <harry.chang@intel.com>
Thu, 13 Jul 2017 06:38:06 +0000 (14:38 +0800)
committerMatthew Barr <matthew.barr@intel.com>
Mon, 21 Aug 2017 01:14:59 +0000 (11:14 +1000)
src/fdr/teddy.c
src/fdr/teddy_avx2.c
src/fdr/teddy_compile.cpp
src/util/simd_utils.h

index db68749a7eb484f684c29435cb7c844a09741e54..0b3fe28f0e30cd1e86ae0a49d0f56b46139beb2c 100644 (file)
@@ -298,7 +298,7 @@ do {                                                                          \
     const u8 *ptr = a->buf + a->start_offset;                                 \
     u32 floodBackoff = FLOOD_BACKOFF_START;                                   \
     const u8 *tryFloodDetect = a->firstFloodDetect;                           \
-    u32 last_match = (u32)-1;                                                 \
+    u32 last_match = ones_u32;                                                \
     const struct Teddy *teddy = (const struct Teddy *)fdr;                    \
     const size_t iterBytes = 128;                                             \
     DEBUG_PRINTF("params: buf %p len %zu start_offset %zu\n",                 \
@@ -533,7 +533,7 @@ do {                                                                          \
     const u8 *ptr = a->buf + a->start_offset;                                 \
     u32 floodBackoff = FLOOD_BACKOFF_START;                                   \
     const u8 *tryFloodDetect = a->firstFloodDetect;                           \
-    u32 last_match = (u32)-1;                                                 \
+    u32 last_match = ones_u32;                                                \
     const struct Teddy *teddy = (const struct Teddy *)fdr;                    \
     const size_t iterBytes = 64;                                              \
     DEBUG_PRINTF("params: buf %p len %zu start_offset %zu\n",                 \
@@ -712,7 +712,7 @@ do {                                                                          \
     const u8 *ptr = a->buf + a->start_offset;                                 \
     u32 floodBackoff = FLOOD_BACKOFF_START;                                   \
     const u8 *tryFloodDetect = a->firstFloodDetect;                           \
-    u32 last_match = (u32)-1;                                                 \
+    u32 last_match = ones_u32;                                                \
     const struct Teddy *teddy = (const struct Teddy *)fdr;                    \
     const size_t iterBytes = 32;                                              \
     DEBUG_PRINTF("params: buf %p len %zu start_offset %zu\n",                 \
index 1d037028192c53c21535325bfe892e4eb35d2b0a..8f98344c930982c3eed5753eb8fa456b38dc3cff 100644 (file)
@@ -134,6 +134,300 @@ const m256 *getMaskBase_avx2(const struct Teddy *teddy) {
     return (const m256 *)((const u8 *)teddy + ROUNDUP_CL(sizeof(struct Teddy)));
 }
 
+#if defined(HAVE_AVX512)
+
+static really_inline
+const u64a *getReinforcedMaskBase_avx2(const struct Teddy *teddy, u8 numMask) {
+    return (const u64a *)((const u8 *)getMaskBase_avx2(teddy)
+                          + ROUNDUP_CL(2 * numMask * sizeof(m256)));
+}
+
+#ifdef ARCH_64_BIT
+#define CONFIRM_FAT_TEDDY(var, bucket, offset, reason, conf_fn)             \
+do {                                                                        \
+    if (unlikely(diff512(var, ones512()))) {                                \
+        m512 swap = swap256in512(var);                                      \
+        m512 r = interleave512lo(var, swap);                                \
+        m128 r0 = extract128from512(r, 0);                                  \
+        m128 r1 = extract128from512(r, 1);                                  \
+        u64a part1 = movq(r0);                                              \
+        u64a part2 = extract64from128(r0, 1);                               \
+        u64a part5 = movq(r1);                                              \
+        u64a part6 = extract64from128(r1, 1);                               \
+        r = interleave512hi(var, swap);                                     \
+        r0 = extract128from512(r, 0);                                       \
+        r1 = extract128from512(r, 1);                                       \
+        u64a part3 = movq(r0);                                              \
+        u64a part4 = extract64from128(r0, 1);                               \
+        u64a part7 = movq(r1);                                              \
+        u64a part8 = extract64from128(r1, 1);                               \
+        CONF_FAT_CHUNK_64(part1, bucket, offset, reason, conf_fn);          \
+        CONF_FAT_CHUNK_64(part2, bucket, offset + 4, reason, conf_fn);      \
+        CONF_FAT_CHUNK_64(part3, bucket, offset + 8, reason, conf_fn);      \
+        CONF_FAT_CHUNK_64(part4, bucket, offset + 12, reason, conf_fn);     \
+        CONF_FAT_CHUNK_64(part5, bucket, offset + 16, reason, conf_fn);     \
+        CONF_FAT_CHUNK_64(part6, bucket, offset + 20, reason, conf_fn);     \
+        CONF_FAT_CHUNK_64(part7, bucket, offset + 24, reason, conf_fn);     \
+        CONF_FAT_CHUNK_64(part8, bucket, offset + 28, reason, conf_fn);     \
+    }                                                                       \
+} while(0)
+#else
+#define CONFIRM_FAT_TEDDY(var, bucket, offset, reason, conf_fn)             \
+do {                                                                        \
+    if (unlikely(diff512(var, ones512()))) {                                \
+        m512 swap = swap256in512(var);                                      \
+        m512 r = interleave512lo(var, swap);                                \
+        m128 r0 = extract128from512(r, 0);                                  \
+        m128 r1 = extract128from512(r, 1);                                  \
+        u32 part1 = movd(r0);                                               \
+        u32 part2 = extract32from128(r0, 1);                                \
+        u32 part3 = extract32from128(r0, 2);                                \
+        u32 part4 = extract32from128(r0, 3);                                \
+        u32 part9 = movd(r1);                                               \
+        u32 part10 = extract32from128(r1, 1);                               \
+        u32 part11 = extract32from128(r1, 2);                               \
+        u32 part12 = extract32from128(r1, 3);                               \
+        r = interleave512hi(var, swap);                                     \
+        r0 = extract128from512(r, 0);                                       \
+        r1 = extract128from512(r, 1);                                       \
+        u32 part5 = movd(r0);                                               \
+        u32 part6 = extract32from128(r0, 1);                                \
+        u32 part7 = extract32from128(r0, 2);                                \
+        u32 part8 = extract32from128(r0, 3);                                \
+        u32 part13 = movd(r1);                                              \
+        u32 part14 = extract32from128(r1, 1);                               \
+        u32 part15 = extract32from128(r1, 2);                               \
+        u32 part16 = extract32from128(r1, 3);                               \
+        CONF_FAT_CHUNK_32(part1, bucket, offset, reason, conf_fn);          \
+        CONF_FAT_CHUNK_32(part2, bucket, offset + 2, reason, conf_fn);      \
+        CONF_FAT_CHUNK_32(part3, bucket, offset + 4, reason, conf_fn);      \
+        CONF_FAT_CHUNK_32(part4, bucket, offset + 6, reason, conf_fn);      \
+        CONF_FAT_CHUNK_32(part5, bucket, offset + 8, reason, conf_fn);      \
+        CONF_FAT_CHUNK_32(part6, bucket, offset + 10, reason, conf_fn);     \
+        CONF_FAT_CHUNK_32(part7, bucket, offset + 12, reason, conf_fn);     \
+        CONF_FAT_CHUNK_32(part8, bucket, offset + 14, reason, conf_fn);     \
+        CONF_FAT_CHUNK_32(part9, bucket, offset + 16, reason, conf_fn);     \
+        CONF_FAT_CHUNK_32(part10, bucket, offset + 18, reason, conf_fn);    \
+        CONF_FAT_CHUNK_32(part11, bucket, offset + 20, reason, conf_fn);    \
+        CONF_FAT_CHUNK_32(part12, bucket, offset + 22, reason, conf_fn);    \
+        CONF_FAT_CHUNK_32(part13, bucket, offset + 24, reason, conf_fn);    \
+        CONF_FAT_CHUNK_32(part14, bucket, offset + 26, reason, conf_fn);    \
+        CONF_FAT_CHUNK_32(part15, bucket, offset + 28, reason, conf_fn);    \
+        CONF_FAT_CHUNK_32(part16, bucket, offset + 30, reason, conf_fn);    \
+    }                                                                       \
+} while(0)
+#endif
+
+static really_inline
+m512 vectoredLoad2x256(m512 *p_mask, const u8 *ptr, const size_t start_offset,
+                       const u8 *lo, const u8 *hi,
+                       const u8 *buf_history, size_t len_history,
+                       const u32 nMasks) {
+    m256 p_mask256;
+    m512 ret = set2x256(vectoredLoad256(&p_mask256, ptr, start_offset, lo, hi,
+                                        buf_history, len_history, nMasks));
+    *p_mask = set2x256(p_mask256);
+    return ret;
+}
+
+#define PREP_FAT_SHUF_MASK_NO_REINFORCEMENT(val)                            \
+    m512 lo = and512(val, *lo_mask);                                        \
+    m512 hi = and512(rshift64_m512(val, 4), *lo_mask)
+
+#define PREP_FAT_SHUF_MASK                                                  \
+    PREP_FAT_SHUF_MASK_NO_REINFORCEMENT(set2x256(load256(ptr)));            \
+    *c_16 = *(ptr + 15);                                                    \
+    m512 r_msk = set512_64(0ULL, r_msk_base_hi[*c_16],                      \
+                           0ULL, r_msk_base_hi[*c_0],                       \
+                           0ULL, r_msk_base_lo[*c_16],                      \
+                           0ULL, r_msk_base_lo[*c_0]);                      \
+    *c_0 = *(ptr + 31)
+
+#define FAT_SHIFT_OR_M1                                                     \
+    or512(pshufb_m512(dup_mask[0], lo), pshufb_m512(dup_mask[1], hi))
+
+#define FAT_SHIFT_OR_M2                                                     \
+    or512(lshift128_m512(or512(pshufb_m512(dup_mask[2], lo),                \
+                               pshufb_m512(dup_mask[3], hi)),               \
+                         1), FAT_SHIFT_OR_M1)
+
+#define FAT_SHIFT_OR_M3                                                     \
+    or512(lshift128_m512(or512(pshufb_m512(dup_mask[4], lo),                \
+                               pshufb_m512(dup_mask[5], hi)),               \
+                         2), FAT_SHIFT_OR_M2)
+
+#define FAT_SHIFT_OR_M4                                                     \
+    or512(lshift128_m512(or512(pshufb_m512(dup_mask[6], lo),                \
+                               pshufb_m512(dup_mask[7], hi)),               \
+                         3), FAT_SHIFT_OR_M3)
+
+static really_inline
+m512 prep_conf_fat_teddy_no_reinforcement_m1(const m512 *lo_mask,
+                                             const m512 *dup_mask,
+                                             const m512 val) {
+    PREP_FAT_SHUF_MASK_NO_REINFORCEMENT(val);
+    return FAT_SHIFT_OR_M1;
+}
+
+static really_inline
+m512 prep_conf_fat_teddy_no_reinforcement_m2(const m512 *lo_mask,
+                                             const m512 *dup_mask,
+                                             const m512 val) {
+    PREP_FAT_SHUF_MASK_NO_REINFORCEMENT(val);
+    return FAT_SHIFT_OR_M2;
+}
+
+static really_inline
+m512 prep_conf_fat_teddy_no_reinforcement_m3(const m512 *lo_mask,
+                                             const m512 *dup_mask,
+                                             const m512 val) {
+    PREP_FAT_SHUF_MASK_NO_REINFORCEMENT(val);
+    return FAT_SHIFT_OR_M3;
+}
+
+static really_inline
+m512 prep_conf_fat_teddy_no_reinforcement_m4(const m512 *lo_mask,
+                                             const m512 *dup_mask,
+                                             const m512 val) {
+    PREP_FAT_SHUF_MASK_NO_REINFORCEMENT(val);
+    return FAT_SHIFT_OR_M4;
+}
+
+static really_inline
+m512 prep_conf_fat_teddy_m1(const m512 *lo_mask, const m512 *dup_mask,
+                            const u8 *ptr, const u64a *r_msk_base_lo,
+                            const u64a *r_msk_base_hi, u32 *c_0, u32 *c_16) {
+    PREP_FAT_SHUF_MASK;
+    return or512(FAT_SHIFT_OR_M1, r_msk);
+}
+
+static really_inline
+m512 prep_conf_fat_teddy_m2(const m512 *lo_mask, const m512 *dup_mask,
+                            const u8 *ptr, const u64a *r_msk_base_lo,
+                            const u64a *r_msk_base_hi, u32 *c_0, u32 *c_16) {
+    PREP_FAT_SHUF_MASK;
+    return or512(FAT_SHIFT_OR_M2, r_msk);
+}
+
+static really_inline
+m512 prep_conf_fat_teddy_m3(const m512 *lo_mask, const m512 *dup_mask,
+                            const u8 *ptr, const u64a *r_msk_base_lo,
+                            const u64a *r_msk_base_hi, u32 *c_0, u32 *c_16) {
+    PREP_FAT_SHUF_MASK;
+    return or512(FAT_SHIFT_OR_M3, r_msk);
+}
+
+static really_inline
+m512 prep_conf_fat_teddy_m4(const m512 *lo_mask, const m512 *dup_mask,
+                            const u8 *ptr, const u64a *r_msk_base_lo,
+                            const u64a *r_msk_base_hi, u32 *c_0, u32 *c_16) {
+    PREP_FAT_SHUF_MASK;
+    return or512(FAT_SHIFT_OR_M4, r_msk);
+}
+
+#define PREP_CONF_FAT_FN_NO_REINFORCEMENT(val, n)                             \
+    prep_conf_fat_teddy_no_reinforcement_m##n(&lo_mask, dup_mask, val)
+
+#define PREP_CONF_FAT_FN(ptr, n)                                              \
+    prep_conf_fat_teddy_m##n(&lo_mask, dup_mask, ptr,                         \
+                             r_msk_base_lo, r_msk_base_hi, &c_0, &c_16)
+
+#define DUP_FAT_MASK(a) mask_set2x256(set2x256(swap128in256(a)), 0xC3, a)
+
+#define PREPARE_FAT_MASKS_1                                                   \
+    dup_mask[0] = DUP_FAT_MASK(maskBase[0]);                                  \
+    dup_mask[1] = DUP_FAT_MASK(maskBase[1]);
+
+#define PREPARE_FAT_MASKS_2                                                   \
+    PREPARE_FAT_MASKS_1                                                       \
+    dup_mask[2] = DUP_FAT_MASK(maskBase[2]);                                  \
+    dup_mask[3] = DUP_FAT_MASK(maskBase[3]);
+
+#define PREPARE_FAT_MASKS_3                                                   \
+    PREPARE_FAT_MASKS_2                                                       \
+    dup_mask[4] = DUP_FAT_MASK(maskBase[4]);                                  \
+    dup_mask[5] = DUP_FAT_MASK(maskBase[5]);
+
+#define PREPARE_FAT_MASKS_4                                                   \
+    PREPARE_FAT_MASKS_3                                                       \
+    dup_mask[6] = DUP_FAT_MASK(maskBase[6]);                                  \
+    dup_mask[7] = DUP_FAT_MASK(maskBase[7]);
+
+#define PREPARE_FAT_MASKS(n)                                                  \
+    m512 lo_mask = set64x8(0xf);                                              \
+    m512 dup_mask[n * 2];                                                     \
+    PREPARE_FAT_MASKS_##n
+
+#define FDR_EXEC_FAT_TEDDY(fdr, a, control, n_msk, conf_fn)                   \
+do {                                                                          \
+    const u8 *buf_end = a->buf + a->len;                                      \
+    const u8 *ptr = a->buf + a->start_offset;                                 \
+    u32 floodBackoff = FLOOD_BACKOFF_START;                                   \
+    const u8 *tryFloodDetect = a->firstFloodDetect;                           \
+    u32 last_match = ones_u32;                                                \
+    const struct Teddy *teddy = (const struct Teddy *)fdr;                    \
+    const size_t iterBytes = 64;                                              \
+    DEBUG_PRINTF("params: buf %p len %zu start_offset %zu\n",                 \
+                 a->buf, a->len, a->start_offset);                            \
+                                                                              \
+    const m256 *maskBase = getMaskBase_avx2(teddy);                           \
+    PREPARE_FAT_MASKS(n_msk);                                                 \
+    const u32 *confBase = getConfBase(teddy);                                 \
+                                                                              \
+    const u64a *r_msk_base_lo = getReinforcedMaskBase_avx2(teddy, n_msk);     \
+    const u64a *r_msk_base_hi = r_msk_base_lo + (N_CHARS + 1);                \
+    u32 c_0 = 0x100;                                                          \
+    u32 c_16 = 0x100;                                                         \
+    const u8 *mainStart = ROUNDUP_PTR(ptr, 32);                               \
+    DEBUG_PRINTF("derive: ptr: %p mainstart %p\n", ptr, mainStart);           \
+    if (ptr < mainStart) {                                                    \
+        ptr = mainStart - 32;                                                 \
+        m512 p_mask;                                                          \
+        m512 val_0 = vectoredLoad2x256(&p_mask, ptr, a->start_offset,         \
+                                     a->buf, buf_end,                         \
+                                     a->buf_history, a->len_history, n_msk);  \
+        m512 r_0 = PREP_CONF_FAT_FN_NO_REINFORCEMENT(val_0, n_msk);           \
+        r_0 = or512(r_0, p_mask);                                             \
+        CONFIRM_FAT_TEDDY(r_0, 16, 0, VECTORING, conf_fn);                    \
+        ptr += 32;                                                            \
+    }                                                                         \
+                                                                              \
+    if (ptr + 32 <= buf_end) {                                                \
+        m512 r_0 = PREP_CONF_FAT_FN(ptr, n_msk);                              \
+        CONFIRM_FAT_TEDDY(r_0, 16, 0, VECTORING, conf_fn);                    \
+        ptr += 32;                                                            \
+    }                                                                         \
+                                                                              \
+    for (; ptr + iterBytes <= buf_end; ptr += iterBytes) {                    \
+        __builtin_prefetch(ptr + (iterBytes * 4));                            \
+        CHECK_FLOOD;                                                          \
+        m512 r_0 = PREP_CONF_FAT_FN(ptr, n_msk);                              \
+        CONFIRM_FAT_TEDDY(r_0, 16, 0, NOT_CAUTIOUS, conf_fn);                 \
+        m512 r_1 = PREP_CONF_FAT_FN(ptr + 32, n_msk);                         \
+        CONFIRM_FAT_TEDDY(r_1, 16, 32, NOT_CAUTIOUS, conf_fn);                \
+    }                                                                         \
+                                                                              \
+    if (ptr + 32 <= buf_end) {                                                \
+        m512 r_0 = PREP_CONF_FAT_FN(ptr, n_msk);                              \
+        CONFIRM_FAT_TEDDY(r_0, 16, 0, NOT_CAUTIOUS, conf_fn);                 \
+        ptr += 32;                                                            \
+    }                                                                         \
+                                                                              \
+    assert(ptr + 32 > buf_end);                                               \
+    if (ptr < buf_end) {                                                      \
+        m512 p_mask;                                                          \
+        m512 val_0 = vectoredLoad2x256(&p_mask, ptr, 0, ptr, buf_end,         \
+                                     a->buf_history, a->len_history, n_msk);  \
+        m512 r_0 = PREP_CONF_FAT_FN_NO_REINFORCEMENT(val_0, n_msk);           \
+        r_0 = or512(r_0, p_mask);                                             \
+        CONFIRM_FAT_TEDDY(r_0, 16, 0, VECTORING, conf_fn);                    \
+    }                                                                         \
+                                                                              \
+    return HWLM_SUCCESS;                                                      \
+} while(0)
+
+#else // HAVE_AVX512
+
 #ifdef ARCH_64_BIT
 #define CONFIRM_FAT_TEDDY(var, bucket, offset, reason, conf_fn)             \
 do {                                                                        \
@@ -282,7 +576,7 @@ do {                                                                        \
     const u8 *ptr = a->buf + a->start_offset;                               \
     u32 floodBackoff = FLOOD_BACKOFF_START;                                 \
     const u8 *tryFloodDetect = a->firstFloodDetect;                         \
-    u32 last_match = (u32)-1;                                               \
+    u32 last_match = ones_u32;                                              \
     const struct Teddy *teddy = (const struct Teddy *)fdr;                  \
     const size_t iterBytes = 32;                                            \
     DEBUG_PRINTF("params: buf %p len %zu start_offset %zu\n",               \
@@ -342,6 +636,8 @@ do {                                                                        \
     return HWLM_SUCCESS;                                                    \
 } while(0)
 
+#endif // HAVE_AVX512
+
 hwlm_error_t fdr_exec_teddy_avx2_msks1_fat(const struct FDR *fdr,
                                            const struct FDR_Runtime_Args *a,
                                            hwlm_group_t control) {
index 987361347edb6d110849429a16912227e8db71bb..8b8a64200bd250205dbb3da0610c1f978314d007 100644 (file)
@@ -325,44 +325,56 @@ bool pack(const vector<hwlmLiteral> &lits,
 #define REINFORCED_MSK_LEN 8
 
 static
-void initReinforcedTable(u8 *reinforcedMsk) {
-    u64a *mask = (u64a *)reinforcedMsk;
-    fill_n(mask, N_CHARS, 0x00ffffffffffffffULL);
+void initReinforcedTable(u8 *rmsk, const size_t rmsklen,
+                         const u32 maskWidth) {
+    for (u32 b = 0; b < maskWidth; b++) {
+        u64a *mask = (u64a *)(rmsk + b * (rmsklen / maskWidth));
+        fill_n(mask, N_CHARS, 0x00ffffffffffffffULL);
+    }
 }
 
 static
-void fillReinforcedMskZero(u8 *reinforcedMsk) {
-    u8 *mc = reinforcedMsk + NO_REINFORCEMENT * REINFORCED_MSK_LEN;
-    fill_n(mc, REINFORCED_MSK_LEN, 0x00);
+void fillReinforcedMskZero(u8 *rmsk, const size_t rmsklen,
+                           const u32 maskWidth) {
+    for (u32 b = 0; b < maskWidth; b++) {
+        u8 *mc = rmsk + b * (rmsklen / maskWidth) +
+                 NO_REINFORCEMENT * REINFORCED_MSK_LEN;
+        fill_n(mc, REINFORCED_MSK_LEN, 0x00);
+    }
 }
 
 static
-void fillReinforcedMsk(u8 *reinforcedMsk, u16 c, u32 j, u8 bmsk) {
+void fillReinforcedMsk(u8 *rmsk, u32 boff, u16 c, u32 j, u8 bmsk) {
     assert(j > 0);
     if (c == ALL_CHAR_SET) {
         for (size_t i = 0; i < N_CHARS; i++) {
-            u8 *mc = reinforcedMsk + i * REINFORCED_MSK_LEN;
+            u8 *mc = rmsk + boff + i * REINFORCED_MSK_LEN;
             mc[j - 1] &= ~bmsk;
         }
     } else {
-        u8 *mc = reinforcedMsk + c * REINFORCED_MSK_LEN;
+        u8 *mc = rmsk + boff + c * REINFORCED_MSK_LEN;
         mc[j - 1] &= ~bmsk;
     }
 }
 
 #ifdef TEDDY_DEBUG
 static
-void dumpReinforcedMaskTable(const u8 *msks) {
-    for (u32 i = 0; i <= N_CHARS; i++) {
-        printf("0x%02x: ", i);
-        for (u32 j = 0; j < REINFORCED_MSK_LEN; j++) {
-            u8 val = msks[i * REINFORCED_MSK_LEN + j];
-            for (u32 k = 0; k < 8; k++) {
-                printf("%s", ((val >> k) & 0x1) ? "1" : "0");
+void dumpReinforcedMaskTable(const u8 *rmsk, const size_t rmsklen,
+                             const u32 maskWidth) {
+    for (u32 b = 0; b < maskWidth; b++) {
+        printf("reinforcement table for bucket %u..%u:\n", b * 8, b * 8 + 7);
+        for (u32 i = 0; i <= N_CHARS; i++) {
+            printf("0x%02x: ", i);
+            for (u32 j = 0; j < REINFORCED_MSK_LEN; j++) {
+                u8 val = rmsk[b * (rmsklen / maskWidth) +
+                              i * REINFORCED_MSK_LEN + j];
+                for (u32 k = 0; k < 8; k++) {
+                    printf("%s", ((val >> k) & 0x1) ? "1" : "0");
+                }
+                printf(" ");
             }
-            printf(" ");
+            printf("\n");
         }
-        printf("\n");
     }
 }
 #endif
@@ -443,12 +455,13 @@ static
 void fillReinforcedTable(const map<BucketIndex,
                                    vector<LiteralIndex>> &bucketToLits,
                          const vector<hwlmLiteral> &lits,
-                         u8 *reinforcedMsk) {
-    initReinforcedTable(reinforcedMsk);
+                         u8 *rmsk, const size_t rmsklen, const u32 maskWidth) {
+    initReinforcedTable(rmsk, rmsklen, maskWidth);
 
     for (const auto &b2l : bucketToLits) {
         const u32 &bucket_id = b2l.first;
         const vector<LiteralIndex> &ids = b2l.second;
+        const u32 boff = (bucket_id / 8) * (rmsklen / maskWidth);
         const u8 bmsk = 1U << (bucket_id % 8);
 
         for (const LiteralIndex &lit_id : ids) {
@@ -459,23 +472,23 @@ void fillReinforcedTable(const map<BucketIndex,
             // fill in reinforced masks
             for (u32 j = 1; j < REINFORCED_MSK_LEN; j++) {
                 if (sz - 1 < j) {
-                    fillReinforcedMsk(reinforcedMsk, ALL_CHAR_SET, j, bmsk);
+                    fillReinforcedMsk(rmsk, boff, ALL_CHAR_SET, j, bmsk);
                 } else {
                     u8 c = l.s[sz - 1 - j];
                     if (l.nocase && ourisalpha(c)) {
                         u8 c_up = c & 0xdf;
-                        fillReinforcedMsk(reinforcedMsk, c_up, j, bmsk);
+                        fillReinforcedMsk(rmsk, boff, c_up, j, bmsk);
                         u8 c_lo = c | 0x20;
-                        fillReinforcedMsk(reinforcedMsk, c_lo, j, bmsk);
+                        fillReinforcedMsk(rmsk, boff, c_lo, j, bmsk);
                     } else {
-                        fillReinforcedMsk(reinforcedMsk, c, j, bmsk);
+                        fillReinforcedMsk(rmsk, boff, c, j, bmsk);
                     }
                 }
             }
         }
     }
 
-    fillReinforcedMskZero(reinforcedMsk);
+    fillReinforcedMskZero(rmsk, rmsklen, maskWidth);
 }
 
 bytecode_ptr<FDR> TeddyCompiler::build() {
@@ -483,7 +496,7 @@ bytecode_ptr<FDR> TeddyCompiler::build() {
 
     size_t headerSize = sizeof(Teddy);
     size_t maskLen = eng.numMasks * 16 * 2 * maskWidth;
-    size_t reinforcedMaskLen = (N_CHARS + 1) * REINFORCED_MSK_LEN;
+    size_t reinforcedMaskLen = (N_CHARS + 1) * REINFORCED_MSK_LEN * maskWidth;
 
     auto floodTable = setupFDRFloodControl(lits, eng, grey);
     auto confirmTable = setupFullConfs(lits, eng, bucketToLits, make_small);
@@ -525,7 +538,8 @@ bytecode_ptr<FDR> TeddyCompiler::build() {
 
     // Write reinforcement masks.
     u8 *reinforcedMsk = baseMsk + ROUNDUP_CL(maskLen);
-    fillReinforcedTable(bucketToLits, lits, reinforcedMsk);
+    fillReinforcedTable(bucketToLits, lits, reinforcedMsk,
+                        reinforcedMaskLen, maskWidth);
 
 #ifdef TEDDY_DEBUG
     for (u32 i = 0; i < eng.numMasks * 2; i++) {
@@ -541,7 +555,7 @@ bytecode_ptr<FDR> TeddyCompiler::build() {
 
     printf("\n===============================================\n"
            "reinforced mask table for low boundary (original)\n\n");
-    dumpReinforcedMaskTable(reinforcedMsk);
+    dumpReinforcedMaskTable(reinforcedMsk, reinforcedMaskLen, maskWidth);
 #endif
 
     return fdr;
index 8c469d1631cb6205e904ccd5f5c7537a32479991..c1449711b5ac92a13afa7af7bade9b0d021845c6 100644 (file)
@@ -169,16 +169,24 @@ m128 load_m128_from_u64a(const u64a *p) {
 #define rshiftbyte_m128(a, count_immed) _mm_srli_si128(a, count_immed)
 #define lshiftbyte_m128(a, count_immed) _mm_slli_si128(a, count_immed)
 
+#if defined(HAVE_SSE41)
+#define extract32from128(a, imm) _mm_extract_epi32(a, imm)
+#define extract64from128(a, imm) _mm_extract_epi64(a, imm)
+#else
+#define extract32from128(a, imm) movd(_mm_srli_si128(a, imm << 2))
+#define extract64from128(a, imm) movq(_mm_srli_si128(a, imm << 3))
+#endif
+
 #if !defined(HAVE_AVX2)
 // TODO: this entire file needs restructuring - this carveout is awful
 #define extractlow64from256(a) movq(a.lo)
 #define extractlow32from256(a) movd(a.lo)
 #if defined(HAVE_SSE41)
 #define extract32from256(a, imm) _mm_extract_epi32((imm >> 2) ? a.hi : a.lo, imm % 4)
-#define extract64from256(a, imm) _mm_extract_epi64((imm >> 2) ? a.hi : a.lo, imm % 2)
+#define extract64from256(a, imm) _mm_extract_epi64((imm >> 1) ? a.hi : a.lo, imm % 2)
 #else
-#define extract32from256(a, imm) movd(_mm_srli_si128((imm >> 2) ? a.hi : a.lo, (imm % 4) * 8))
-#define extract64from256(a, imm) movq(_mm_srli_si128((imm >> 2) ? a.hi : a.lo, (imm % 2) * 8))
+#define extract32from256(a, imm) movd(_mm_srli_si128((imm >> 2) ? a.hi : a.lo, (imm % 4) * 4))
+#define extract64from256(a, imm) movq(_mm_srli_si128((imm >> 1) ? a.hi : a.lo, (imm % 2) * 8))
 #endif
 
 #endif // !AVX2
@@ -741,8 +749,8 @@ m128 movdq_lo(m256 x) {
 #define extract32from256(a, imm) _mm_extract_epi32(_mm256_extracti128_si256(a, imm >> 2), imm % 4)
 #define extractlow64from256(a) _mm_cvtsi128_si64(cast256to128(a))
 #define extractlow32from256(a) movd(cast256to128(a))
-#define interleave256hi(a, b) _mm256_unpackhi_epi8(a, b);
-#define interleave256lo(a, b) _mm256_unpacklo_epi8(a, b);
+#define interleave256hi(a, b) _mm256_unpackhi_epi8(a, b)
+#define interleave256lo(a, b) _mm256_unpacklo_epi8(a, b)
 #define vpalignr(r, l, offset) _mm256_alignr_epi8(r, l, offset)
 
 static really_inline
@@ -757,6 +765,11 @@ m256 combine2x128(m128 hi, m128 lo) {
 
 #if defined(HAVE_AVX512)
 #define extract128from512(a, imm) _mm512_extracti32x4_epi32(a, imm)
+#define interleave512hi(a, b) _mm512_unpackhi_epi8(a, b)
+#define interleave512lo(a, b) _mm512_unpacklo_epi8(a, b)
+#define set2x256(a) _mm512_broadcast_i64x4(a)
+#define mask_set2x256(src, k, a) _mm512_mask_broadcast_i64x4(src, k, a)
+#define vpermq512(idx, a) _mm512_permutexvar_epi64(idx, a)
 #endif
 
 /****
@@ -980,6 +993,12 @@ m512 set512_64(u64a hi_3, u64a hi_2, u64a hi_1, u64a hi_0,
                             lo_3, lo_2, lo_1, lo_0);
 }
 
+static really_inline
+m512 swap256in512(m512 a) {
+    m512 idx = set512_64(3ULL, 2ULL, 1ULL, 0ULL, 7ULL, 6ULL, 5ULL, 4ULL);
+    return vpermq512(idx, a);
+}
+
 static really_inline
 m512 set4x128(m128 a) {
     return _mm512_broadcast_i32x4(a);