]> git.ipfire.org Git - thirdparty/vectorscan.git/commitdiff
Make key 64 bits where large shifts may be used.
authorJustin Viiret <justin.viiret@intel.com>
Fri, 27 Nov 2015 02:30:59 +0000 (13:30 +1100)
committerMatthew Barr <matthew.barr@intel.com>
Sun, 6 Dec 2015 22:38:32 +0000 (09:38 +1100)
This fixes a long-standing issue with large multibit structures.

src/util/multibit.c
src/util/multibit.h
unit/internal/multi_bit.cpp

index ca5e5656b3415be0358d38eaa3bbb5eeb5b4534d..c22b73ffeebc05bd8b46216f4d23c17104d3714e 100644 (file)
@@ -142,23 +142,25 @@ const u32 mmbit_root_offset_from_level[7] = {
 u32 mmbit_size(u32 total_bits) {
     MDEBUG_PRINTF("%u\n", total_bits);
 
-    // UE-2228: multibit has bugs in very, very large cases that we should be
-    // protected against at compile time by resource limits.
-    assert(total_bits <= 1U << 30);
-
     // Flat model multibit structures are just stored as a bit vector.
     if (total_bits <= MMB_FLAT_MAX_BITS) {
         return ROUNDUP_N(total_bits, 8) / 8;
     }
 
-    u32 current_level = 1;
-    u32 total = 0;
+    u64a current_level = 1; // Number of blocks on current level.
+    u64a total = 0;         // Total number of blocks.
     while (current_level * MMB_KEY_BITS < total_bits) {
         total += current_level;
         current_level <<= MMB_KEY_SHIFT;
     }
-    total += (total_bits + MMB_KEY_BITS - 1)/MMB_KEY_BITS;
-    return sizeof(MMB_TYPE) * total;
+
+    // Last level is a one-for-one bit vector. It needs room for total_bits
+    // elements, rounded up to the nearest block.
+    u64a last_level = ((u64a)total_bits + MMB_KEY_BITS - 1) / MMB_KEY_BITS;
+    total += last_level;
+
+    assert(total * sizeof(MMB_TYPE) <= UINT32_MAX);
+    return (u32)(total * sizeof(MMB_TYPE));
 }
 
 #ifdef DUMP_SUPPORT
index 251316b3cc6045d1585aba81fb7ad0a60b5afb94..771c158d2ddb8bca73a22630180c4af7efe0773d 100644 (file)
@@ -235,18 +235,18 @@ const u8 *mmbit_get_level_root_const(const u8 *bits, u32 level) {
 /** \brief get the block for this key on the current level as a u8 ptr */
 static really_inline
 u8 *mmbit_get_block_ptr(u8 *bits, u32 max_level, u32 level, u32 key) {
-    return mmbit_get_level_root(bits, level) +
-           (key >> (mmbit_get_ks(max_level, level) + MMB_KEY_SHIFT)) *
-               sizeof(MMB_TYPE);
+    u8 *level_root = mmbit_get_level_root(bits, level);
+    u32 ks = mmbit_get_ks(max_level, level);
+    return level_root + ((u64a)key >> (ks + MMB_KEY_SHIFT)) * sizeof(MMB_TYPE);
 }
 
 /** \brief get the block for this key on the current level as a const u8 ptr */
 static really_inline
 const u8 *mmbit_get_block_ptr_const(const u8 *bits, u32 max_level, u32 level,
                                     u32 key) {
-    return mmbit_get_level_root_const(bits, level) +
-           (key >> (mmbit_get_ks(max_level, level) + MMB_KEY_SHIFT)) *
-               sizeof(MMB_TYPE);
+    const u8 *level_root = mmbit_get_level_root_const(bits, level);
+    u32 ks = mmbit_get_ks(max_level, level);
+    return level_root + ((u64a)key >> (ks + MMB_KEY_SHIFT)) * sizeof(MMB_TYPE);
 }
 
 /** \brief get the _byte_ for this key on the current level as a u8 ptr */
@@ -254,7 +254,7 @@ static really_inline
 u8 *mmbit_get_byte_ptr(u8 *bits, u32 max_level, u32 level, u32 key) {
     u8 *level_root = mmbit_get_level_root(bits, level);
     u32 ks = mmbit_get_ks(max_level, level);
-    return level_root + (key >> (ks + MMB_KEY_SHIFT - 3));
+    return level_root + ((u64a)key >> (ks + MMB_KEY_SHIFT - 3));
 }
 
 /** \brief get our key value for the current level */
@@ -721,11 +721,11 @@ u32 mmbit_iterate_bounded_flat(const u8 *bits, u32 total_bits, u32 begin,
 }
 
 static really_inline
-MMB_TYPE get_lowhi_masks(u32 level, u32 max_level, u32 block_min, u32 block_max,
-                         u32 block_base) {
+MMB_TYPE get_lowhi_masks(u32 level, u32 max_level, u64a block_min, u64a block_max,
+                         u64a block_base) {
     const u32 level_shift = (max_level - level) * MMB_KEY_SHIFT;
-    u32 lshift = (block_min - block_base) >> level_shift;
-    u32 ushift = (block_max - block_base) >> level_shift;
+    u64a lshift = (block_min - block_base) >> level_shift;
+    u64a ushift = (block_max - block_base) >> level_shift;
     MMB_TYPE lmask = lshift < 64 ? ~mmb_mask_zero_to_nocheck(lshift) : 0;
     MMB_TYPE umask =
         ushift < 63 ? mmb_mask_zero_to_nocheck(ushift + 1) : MMB_ALL_ONES;
@@ -734,7 +734,7 @@ MMB_TYPE get_lowhi_masks(u32 level, u32 max_level, u32 block_min, u32 block_max,
 
 static really_inline
 u32 mmbit_iterate_bounded_big(const u8 *bits, u32 total_bits, u32 it_start, u32 it_end) {
-    u32 key = 0;
+    u64a key = 0;
     u32 ks = mmbit_keyshift(total_bits);
     const u32 max_level = mmbit_maxlevel_from_keyshift(ks);
     u32 level = 0;
@@ -743,9 +743,9 @@ u32 mmbit_iterate_bounded_big(const u8 *bits, u32 total_bits, u32 it_start, u32
         assert(level <= max_level);
 
         u32 block_width = MMB_KEY_BITS << ks;
-        u32 block_base = key*block_width;
-        u32 block_min = MAX(it_start, block_base);
-        u32 block_max = MIN(it_end, block_base + block_width - 1);
+        u64a block_base = key * block_width;
+        u64a block_min = MAX(it_start, block_base);
+        u64a block_max = MIN(it_end, block_base + block_width - 1);
         const u8 *block_ptr =
             mmbit_get_level_root_const(bits, level) + key * sizeof(MMB_TYPE);
         MMB_TYPE block = mmb_load(block_ptr);
@@ -761,13 +761,14 @@ u32 mmbit_iterate_bounded_big(const u8 *bits, u32 total_bits, u32 it_start, u32
             // No bit found, go up a level
             // we know that this block didn't have any answers, so we can push
             // our start iterator forward.
-            it_start = block_base + block_width;
-            if (it_start > it_end) {
+            u64a next_start = block_base + block_width;
+            if (next_start > it_end) {
                 break;
             }
             if (level-- == 0) {
                 break;
             }
+            it_start = next_start;
             key >>= MMB_KEY_SHIFT;
             ks += MMB_KEY_SHIFT;
         }
index 2772521960eab7e08c0069257a0f8636bc6dd486..3f5c590856265329e475522c750191ecf5600cb1 100644 (file)
@@ -363,7 +363,9 @@ TEST_P(MultiBitTest, BoundedIteratorSingle) {
     ASSERT_TRUE(ba != nullptr);
 
     // Set one bit on and run some checks.
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
+        SCOPED_TRACE(i);
+
         mmbit_clear(ba, test_size);
         mmbit_set(ba, test_size, i);
 
@@ -381,7 +383,12 @@ TEST_P(MultiBitTest, BoundedIteratorSingle) {
 
         // Scanning from one past our bit to the end should find nothing.
         if (i != test_size - 1) {
-            ASSERT_EQ(MMB_INVALID, mmbit_iterate_bounded(ba, test_size, i + 1, test_size));
+            // Ordinary iterator.
+            ASSERT_EQ(MMB_INVALID, mmbit_iterate(ba, test_size, i));
+
+            // Bounded iterator.
+            ASSERT_EQ(MMB_INVALID,
+                      mmbit_iterate_bounded(ba, test_size, i + 1, test_size));
         }
     }
 }
@@ -393,7 +400,7 @@ TEST_P(MultiBitTest, BoundedIteratorAll) {
     // Switch everything on.
     fill_mmbit(ba, test_size);
 
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         if (i != 0) {
             ASSERT_EQ(0U, mmbit_iterate_bounded(ba, test_size, 0, i));
         }
@@ -408,13 +415,13 @@ TEST_P(MultiBitTest, BoundedIteratorEven) {
 
     // Set every even-numbered bit and see what we can see.
     mmbit_clear(ba, test_size);
-    for (u32 i = 0; i < test_size; i += 2) {
+    for (u64a i = 0; i < test_size; i += 2) {
         mmbit_set(ba, test_size, i);
     }
 
     u32 even_stride = stride % 2 ? stride + 1 : stride;
 
-    for (u32 i = 0; i < test_size; i += even_stride) {
+    for (u64a i = 0; i < test_size; i += even_stride) {
         // Scanning from each even bit to the end should find itself.
         ASSERT_EQ(i, mmbit_iterate_bounded(ba, test_size, i, test_size));
 
@@ -439,13 +446,13 @@ TEST_P(MultiBitTest, BoundedIteratorOdd) {
 
     // Set every odd-numbered bit and see what we can see.
     mmbit_clear(ba, test_size);
-    for (u32 i = 1; i < test_size; i += 2) {
+    for (u64a i = 1; i < test_size; i += 2) {
         mmbit_set(ba, test_size, i);
     }
 
     u32 even_stride = stride % 2 ? stride + 1 : stride;
 
-    for (u32 i = 0; i < test_size; i += even_stride) {
+    for (u64a i = 0; i < test_size; i += even_stride) {
         // Scanning from each even bit to the end should find i+1.
         if (i+1 < test_size) {
             ASSERT_EQ(i+1, mmbit_iterate_bounded(ba, test_size, i, test_size));
@@ -473,7 +480,7 @@ TEST_P(MultiBitTest, Set) {
     mmbit_clear(ba, test_size);
     ASSERT_FALSE(mmbit_any(ba, test_size));
 
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         SCOPED_TRACE(i);
 
         // set a bit that wasn't set before
@@ -500,7 +507,7 @@ TEST_P(MultiBitTest, Iter) {
     mmbit_clear(ba, test_size);
     ASSERT_EQ(MMB_INVALID, mmbit_iterate(ba, test_size, MMB_INVALID));
 
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         SCOPED_TRACE(i);
         mmbit_clear(ba, test_size);
         mmbit_set(ba, test_size, i);
@@ -517,13 +524,13 @@ TEST_P(MultiBitTest, IterAll) {
     ASSERT_EQ(MMB_INVALID, mmbit_iterate(ba, test_size, MMB_INVALID));
 
     // Set all bits.
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         mmbit_set(ba, test_size, i);
     }
 
     // Find all bits.
     u32 it = MMB_INVALID;
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         ASSERT_EQ(i, mmbit_iterate(ba, test_size, it));
         it = i;
     }
@@ -536,7 +543,7 @@ TEST_P(MultiBitTest, AnyPrecise) {
     mmbit_clear(ba, test_size);
     ASSERT_FALSE(mmbit_any_precise(ba, test_size));
 
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         SCOPED_TRACE(i);
         mmbit_clear(ba, test_size);
         mmbit_set(ba, test_size, i);
@@ -551,7 +558,7 @@ TEST_P(MultiBitTest, Any) {
     mmbit_clear(ba, test_size);
     ASSERT_FALSE(mmbit_any(ba, test_size));
 
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         SCOPED_TRACE(i);
         mmbit_clear(ba, test_size);
         mmbit_set(ba, test_size, i);
@@ -567,7 +574,7 @@ TEST_P(MultiBitTest, UnsetRange1) {
     fill_mmbit(ba, test_size);
 
     // Use mmbit_unset_range to switch off any single bit.
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         SCOPED_TRACE(i);
         ASSERT_TRUE(mmbit_isset(ba, test_size, i));
         mmbit_unset_range(ba, test_size, i, i + 1);
@@ -590,7 +597,7 @@ TEST_P(MultiBitTest, UnsetRange2) {
     // Use mmbit_unset_range to switch off all bits.
     mmbit_unset_range(ba, test_size, 0, test_size);
 
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         SCOPED_TRACE(i);
         ASSERT_FALSE(mmbit_isset(ba, test_size, i));
     }
@@ -601,12 +608,12 @@ TEST_P(MultiBitTest, UnsetRange3) {
     ASSERT_TRUE(ba != nullptr);
 
     // Use mmbit_unset_range to switch off bits in chunks of 3.
-    for (u32 i = 0; i < test_size - 3; i += stride) {
+    for (u64a i = 0; i < test_size - 3; i += stride) {
         // Switch on the bit before, the bits in question, and the bit after.
         if (i > 0) {
             mmbit_set(ba, test_size, i - 1);
         }
-        for (u32 j = i; j < min(i + 4, test_size); j++) {
+        for (u64a j = i; j < min(i + 4, (u64a)test_size); j++) {
             mmbit_set(ba, test_size, j);
         }
 
@@ -635,7 +642,7 @@ TEST_P(MultiBitTest, InitRangeAll) {
     mmbit_init_range(ba, test_size, 0, test_size);
 
     // Make sure they're all set.
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         SCOPED_TRACE(i);
         ASSERT_TRUE(mmbit_isset(ba, test_size, i));
     }
@@ -656,7 +663,7 @@ TEST_P(MultiBitTest, InitRangeOne) {
     SCOPED_TRACE(test_size);
     ASSERT_TRUE(ba != nullptr);
 
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         mmbit_init_range(ba, test_size, i, i + 1);
 
         // Only bit 'i' should be on.
@@ -685,7 +692,7 @@ TEST_P(MultiBitTest, InitRangeChunked) {
             ASSERT_EQ(chunk_begin, mmbit_iterate(ba, test_size, MMB_INVALID));
 
             // All bits in the chunk should be on.
-            for (u32 i = chunk_begin; i < chunk_end; i += stride) {
+            for (u64a i = chunk_begin; i < chunk_end; i += stride) {
                 SCOPED_TRACE(i);
                 ASSERT_TRUE(mmbit_isset(ba, test_size, i));
             }
@@ -985,7 +992,7 @@ TEST_P(MultiBitTest, SparseIteratorBeginAll) {
     vector<mmbit_sparse_iter> it;
     vector<u32> bits;
     bits.reserve(test_size / stride);
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         bits.push_back(i);
     }
     mmbBuildSparseIterator(it, bits, test_size);
@@ -1032,7 +1039,7 @@ TEST_P(MultiBitTest, SparseIteratorBeginThirds) {
     // Switch every third bits on in state
     mmbit_clear(ba, test_size);
     ASSERT_FALSE(mmbit_any(ba, test_size));
-    for (u32 i = 0; i < test_size; i += 3) {
+    for (u64a i = 0; i < test_size; i += 3) {
         mmbit_set(ba, test_size, i);
     }
 
@@ -1044,7 +1051,7 @@ TEST_P(MultiBitTest, SparseIteratorBeginThirds) {
     ASSERT_EQ(0U, val);
     ASSERT_EQ(0U, idx);
 
-    for (u32 i = 0; i < test_size - 3; i += 3) {
+    for (u64a i = 0; i < test_size - 3; i += 3) {
         mmbit_unset(ba, test_size, i);
         val = mmbit_sparse_iter_begin(ba, test_size, &idx, &it[0], &state[0]);
         ASSERT_EQ(i+3, val);
@@ -1060,7 +1067,7 @@ TEST_P(MultiBitTest, SparseIteratorNextAll) {
     vector<mmbit_sparse_iter> it;
     vector<u32> bits;
     bits.reserve(test_size / stride);
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         bits.push_back(i);
     }
     mmbBuildSparseIterator(it, bits, test_size);
@@ -1103,7 +1110,7 @@ TEST_P(MultiBitTest, SparseIteratorNextExactStrided) {
     vector<mmbit_sparse_iter> it;
     vector<u32> bits;
     bits.reserve(test_size / stride);
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         bits.push_back(i);
         mmbit_set(ba, test_size, i);
     }
@@ -1135,7 +1142,7 @@ TEST_P(MultiBitTest, SparseIteratorNextNone) {
     vector<mmbit_sparse_iter> it;
     vector<u32> bits;
     bits.reserve(test_size / stride);
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         bits.push_back(i);
     }
     mmbBuildSparseIterator(it, bits, test_size);
@@ -1164,7 +1171,7 @@ TEST_P(MultiBitTest, SparseIteratorUnsetAll) {
     vector<mmbit_sparse_iter> it;
     vector<u32> bits;
     bits.reserve(test_size / stride);
-    for (u32 i = 0; i < test_size; i += stride) {
+    for (u64a i = 0; i < test_size; i += stride) {
         bits.push_back(i);
     }
     mmbBuildSparseIterator(it, bits, test_size);
@@ -1194,10 +1201,10 @@ TEST_P(MultiBitTest, SparseIteratorUnsetHalves) {
 
     // Two sparse iterators: one for even bits, one for odd ones
     vector<u32> even, odd;
-    for (u32 i = 0; i < test_size; i += 2) {
+    for (u64a i = 0; i < test_size; i += 2) {
         even.push_back(i);
     }
-    for (u32 i = 1; i < test_size; i += 2) {
+    for (u64a i = 1; i < test_size; i += 2) {
         odd.push_back(i);
     }
 
@@ -1277,9 +1284,9 @@ static const MultiBitTestParam multibitTests[] = {
     { 1U << 28, 15073 },
     { 1U << 29, 24413 },
     { 1U << 30, 50377 },
+    { 1U << 31, 104729 },
 
-    // XXX: cases this large segfault in mmbit_set, FIXME NOW
-    //{ 1U << 31, 3701 },
+    // { UINT32_MAX, 104729 }, // Very slow
 };
 
 INSTANTIATE_TEST_CASE_P(MultiBit, MultiBitTest, ValuesIn(multibitTests));