]> git.ipfire.org Git - thirdparty/vectorscan.git/commitdiff
smallwrite: aho-corasick construction for literals
authorJustin Viiret <justin.viiret@intel.com>
Fri, 31 Mar 2017 03:04:44 +0000 (14:04 +1100)
committerMatthew Barr <matthew.barr@intel.com>
Wed, 26 Apr 2017 05:19:51 +0000 (15:19 +1000)
src/smallwrite/smallwrite_build.cpp
src/smallwrite/smallwrite_build.h

index f7c9ad8c4aefcc79afe8a91918156fb21c600a0f..a27db736875fc30d36d654c30df8236857e33c4a 100644 (file)
  * POSSIBILITY OF SUCH DAMAGE.
  */
 
+/**
+ * \file
+ * \brief Small-write engine build code.
+ */
+
 #include "smallwrite/smallwrite_build.h"
 
 #include "grey.h"
@@ -48,6 +53,7 @@
 #include "util/alloc.h"
 #include "util/bytecode_ptr.h"
 #include "util/charreach.h"
+#include "util/compare.h"
 #include "util/compile_context.h"
 #include "util/container.h"
 #include "util/make_unique.h"
 #include <vector>
 #include <utility>
 
+#include <boost/graph/breadth_first_search.hpp>
+
 using namespace std;
 
 namespace ue2 {
 
-#define LITERAL_MERGE_CHUNK_SIZE 25
 #define DFA_MERGE_MAX_STATES 8000
 #define MAX_TRIE_VERTICES 8000
 
-namespace { // unnamed
-
 struct LitTrieVertexProps {
     LitTrieVertexProps() = default;
-    explicit LitTrieVertexProps(char c_in) : c(c_in) {}
-    char c = 0;
+    explicit LitTrieVertexProps(u8 c_in) : c(c_in) {}
     size_t index; // managed by ue2_graph
+    u8 c = 0; //!< character reached on this vertex
+    flat_set<ReportID> reports; //!< managed reports fired on this vertex
 };
 
 struct LitTrieEdgeProps {
-    LitTrieEdgeProps() = default;
     size_t index; // managed by ue2_graph
 };
 
+/**
+ * \brief BGL graph used to store a trie of literals (for later AC construction
+ * into a DFA).
+ */
 struct LitTrie
     : public ue2_graph<LitTrie, LitTrieVertexProps, LitTrieEdgeProps> {
 
     LitTrie() : root(add_vertex(*this)) {}
 
-    const vertex_descriptor root;
+    const vertex_descriptor root; //!< Root vertex for the trie.
 };
 
+static
+bool is_empty(const LitTrie &trie) {
+    return num_vertices(trie) <= 1;
+}
+
+static
+std::set<ReportID> all_reports(const LitTrie &trie) {
+    std::set<ReportID> reports;
+    for (auto v : vertices_range(trie)) {
+        insert(&reports, trie[v].reports);
+    }
+    return reports;
+}
+
+using LitTrieVertex = LitTrie::vertex_descriptor;
+using LitTrieEdge = LitTrie::edge_descriptor;
+
+namespace { // unnamed
+
 // Concrete impl class
 class SmallWriteBuildImpl : public SmallWriteBuild {
 public:
@@ -110,15 +138,15 @@ public:
     const CompileContext &cc;
 
     unique_ptr<raw_dfa> rdfa;
-    vector<pair<ue2_literal, ReportID> > cand_literals;
     LitTrie lit_trie;
     LitTrie lit_trie_nocase;
+    size_t num_literals = 0;
     bool poisoned;
 };
 
 } // namespace
 
-SmallWriteBuild::~SmallWriteBuild() { }
+SmallWriteBuild::~SmallWriteBuild() = default;
 
 SmallWriteBuildImpl::SmallWriteBuildImpl(size_t num_patterns,
                                          const ReportManager &rm_in,
@@ -272,25 +300,27 @@ void SmallWriteBuildImpl::add(const NGHolder &g, const ExpressionInfo &expr) {
 }
 
 static
-bool add_to_trie(const ue2_literal &literal, LitTrie &trie) {
+bool add_to_trie(const ue2_literal &literal, ReportID report, LitTrie &trie) {
     auto u = trie.root;
-    for (auto &c : literal) {
+    for (const auto &c : literal) {
         auto next = LitTrie::null_vertex();
         for (auto v : adjacent_vertices_range(u, trie)) {
-            if (trie[v].c == c.c) {
+            if (trie[v].c == (u8)c.c) {
                 next = v;
                 break;
             }
         }
-        if (next == LitTrie::null_vertex()) {
-            next = add_vertex(LitTrieVertexProps(c.c), trie);
+        if (!next) {
+            next = add_vertex(LitTrieVertexProps((u8)c.c), trie);
             add_edge(u, next, trie);
         }
         u = next;
     }
 
-    DEBUG_PRINTF("added '%s' to trie, now %zu vertices\n",
-                  escapeString(literal).c_str(), num_vertices(trie));
+    trie[u].reports.insert(report);
+
+    DEBUG_PRINTF("added '%s' (report %u) to trie, now %zu vertices\n",
+                  escapeString(literal).c_str(), report, num_vertices(trie));
     return num_vertices(trie) <= MAX_TRIE_VERTICES;
 }
 
@@ -298,103 +328,308 @@ void SmallWriteBuildImpl::add(const ue2_literal &literal, ReportID r) {
     // If the graph is poisoned (i.e. we can't build a SmallWrite version),
     // we don't even try.
     if (poisoned) {
+        DEBUG_PRINTF("poisoned\n");
         return;
     }
 
     if (literal.length() > cc.grey.smallWriteLargestBuffer) {
+        DEBUG_PRINTF("exceeded length limit\n");
         return; /* too long */
     }
 
-    cand_literals.push_back(make_pair(literal, r));
-
-    if (!add_to_trie(literal,
-                     literal.any_nocase() ? lit_trie_nocase : lit_trie)) {
+    if (++num_literals > cc.grey.smallWriteMaxLiterals) {
+        DEBUG_PRINTF("exceeded literal limit\n");
         poisoned = true;
         return;
     }
 
-    if (cand_literals.size() > cc.grey.smallWriteMaxLiterals) {
+    auto &trie = literal.any_nocase() ? lit_trie_nocase : lit_trie;
+    if (!add_to_trie(literal, r, trie)) {
+        DEBUG_PRINTF("trie add failed\n");
         poisoned = true;
     }
 }
 
-static
-void lit_to_graph(NGHolder *h, const ue2_literal &literal, ReportID r) {
-    NFAVertex u = h->startDs;
-    for (const auto &c : literal) {
-        NFAVertex v = add_vertex(*h);
-        add_edge(u, v, *h);
-        (*h)[v].char_reach = c;
-        u = v;
+namespace {
+
+/**
+ * \brief BFS visitor for Aho-Corasick automaton construction.
+ *
+ * This is doing two things:
+ *
+ *   - Computing the failure edges (also called fall or supply edges) for each
+ *     vertex, giving the longest suffix of the path to that point that is also
+ *     a prefix in the trie reached on the same character. The BFS traversal
+ *     makes it possible to build these from earlier failure paths.
+ *
+ *   - Computing the output function for each vertex, which is done by
+ *     propagating the reports from failure paths as well. This ensures that
+ *     substrings of the current path also report correctly.
+ */
+struct ACVisitor : public boost::default_bfs_visitor {
+    ACVisitor(LitTrie &trie_in,
+              map<LitTrieVertex, LitTrieVertex> &failure_map_in,
+              vector<LitTrieVertex> &ordering_in)
+        : mutable_trie(trie_in), failure_map(failure_map_in),
+          ordering(ordering_in) {}
+
+    LitTrieVertex find_failure_target(LitTrieVertex u, LitTrieVertex v,
+                                      const LitTrie &trie) {
+        assert(u == trie.root || contains(failure_map, u));
+        assert(!contains(failure_map, v));
+
+        const auto &c = trie[v].c;
+
+        while (u != trie.root) {
+            auto f = failure_map.at(u);
+            for (auto w : adjacent_vertices_range(f, trie)) {
+                if (trie[w].c == c) {
+                    return w;
+                }
+            }
+            u = f;
+        }
+
+        DEBUG_PRINTF("no failure edge\n");
+        return LitTrie::null_vertex();
+    }
+
+    void tree_edge(LitTrieEdge e, const LitTrie &trie) {
+        auto u = source(e, trie);
+        auto v = target(e, trie);
+        DEBUG_PRINTF("bfs (%zu, %zu) on '%c'\n", trie[u].index, trie[v].index,
+                     trie[v].c);
+        ordering.push_back(v);
+
+        auto f = find_failure_target(u, v, trie);
+
+        if (f) {
+            DEBUG_PRINTF("final failure vertex %zu\n", trie[f].index);
+            failure_map.emplace(v, f);
+
+            // Propagate reports from failure path to ensure we correctly
+            // report substrings.
+            insert(&mutable_trie[v].reports, mutable_trie[f].reports);
+        } else {
+            DEBUG_PRINTF("final failure vertex root\n");
+            failure_map.emplace(v, trie.root);
+        }
     }
-    (*h)[u].reports.insert(r);
-    add_edge(u, h->accept, *h);
+
+private:
+    LitTrie &mutable_trie; //!< For setting reports property.
+    map<LitTrieVertex, LitTrieVertex> &failure_map;
+    vector<LitTrieVertex> &ordering; //!< BFS ordering for vertices.
+};
 }
 
-bool SmallWriteBuildImpl::determiniseLiterals() {
-    DEBUG_PRINTF("handling literals\n");
-    assert(!poisoned);
-    assert(cand_literals.size() <= cc.grey.smallWriteMaxLiterals);
+static UNUSED
+bool isSaneTrie(const LitTrie &trie) {
+    CharReach seen;
+    for (auto u : vertices_range(trie)) {
+        seen.clear();
+        for (auto v : adjacent_vertices_range(u, trie)) {
+            if (seen.test(trie[v].c)) {
+                return false;
+            }
+            seen.set(trie[v].c);
+        }
+    }
+    return true;
+}
 
-    if (cand_literals.empty()) {
-        return true; /* nothing to do */
+/**
+ * \brief Turn the given literal trie into an AC automaton by adding additional
+ * edges and reports.
+ */
+static
+void buildAutomaton(LitTrie &trie) {
+    assert(isSaneTrie(trie));
+
+    // Find our failure transitions and reports.
+    map<LitTrieVertex, LitTrieVertex> failure_map;
+    vector<LitTrieVertex> ordering;
+    ACVisitor ac_vis(trie, failure_map, ordering);
+    boost::breadth_first_search(trie, trie.root, visitor(ac_vis));
+
+    // Compute missing edges from failure map.
+    for (auto v : ordering) {
+        DEBUG_PRINTF("vertex %zu\n", trie[v].index);
+        CharReach seen;
+        for (auto w : adjacent_vertices_range(v, trie)) {
+            DEBUG_PRINTF("edge to %zu with reach 0x%02x\n", trie[w].index,
+                         trie[w].c);
+            assert(!seen.test(trie[w].c));
+            seen.set(trie[w].c);
+        }
+        auto parent = failure_map.at(v);
+        for (auto w : adjacent_vertices_range(parent, trie)) {
+            if (!seen.test(trie[w].c)) {
+                add_edge(v, w, trie);
+            }
+        }
     }
+}
+
+static
+vector<CharReach> getAlphabet(const LitTrie &trie, bool nocase) {
+    vector<CharReach> esets = {CharReach::dot()};
+    for (auto v : vertices_range(trie)) {
+        if (v == trie.root) {
+            continue;
+        }
 
-    vector<unique_ptr<raw_dfa> > temp_dfas;
+        CharReach cr;
+        if (nocase) {
+            cr.set(mytoupper(trie[v].c));
+            cr.set(mytolower(trie[v].c));
+        } else {
+            cr.set(trie[v].c);
+        }
 
-    for (const auto &cand : cand_literals) {
-        NGHolder h;
-        DEBUG_PRINTF("determinising %s\n", dumpString(cand.first).c_str());
-        lit_to_graph(&h, cand.first, cand.second);
-        temp_dfas.push_back(buildMcClellan(h, &rm, cc.grey));
+        for (size_t i = 0; i < esets.size(); i++) {
+            if (esets[i].count() == 1) {
+                continue;
+            }
 
-        // If we couldn't build a McClellan DFA for this portion, then we
-        // can't SmallWrite optimize the entire graph, so we can't
-        // optimize any of it
-        if (!temp_dfas.back()) {
-            DEBUG_PRINTF("failed to determinise\n");
-            poisoned = true;
-            return false;
+            CharReach t = cr & esets[i];
+            if (t.any() && t != esets[i]) {
+                esets[i] &= ~t;
+                esets.push_back(t);
+            }
         }
     }
 
-    if (!rdfa && temp_dfas.size() == 1) {
-        /* no need to merge there is only one dfa */
-        rdfa = move(temp_dfas[0]);
-        return true;
+    // For deterministic compiles.
+    sort(esets.begin(), esets.end());
+    return esets;
+}
+
+static
+u16 buildAlphabet(const LitTrie &trie, bool nocase,
+                  array<u16, ALPHABET_SIZE> &alpha,
+                  array<u16, ALPHABET_SIZE> &unalpha) {
+    const auto &esets = getAlphabet(trie, nocase);
+
+    u16 i = 0;
+    for (const auto &cr : esets) {
+        u16 leader = cr.find_first();
+        for (size_t s = cr.find_first(); s != cr.npos; s = cr.find_next(s)) {
+            alpha[s] = i;
+        }
+        unalpha[i] = leader;
+        i++;
     }
 
-    /* do a merge of the new dfas */
+    for (u16 j = N_CHARS; j < ALPHABET_SIZE; j++, i++) {
+        alpha[j] = i;
+        unalpha[i] = j;
+    }
 
-    vector<const raw_dfa *> to_merge;
+    DEBUG_PRINTF("alphabet size %u\n", i);
+    return i;
+}
+
+/** \brief Construct a raw_dfa from a literal trie. */
+static
+unique_ptr<raw_dfa> buildDfa(LitTrie &trie, bool nocase) {
+    DEBUG_PRINTF("trie has %zu states\n", num_vertices(trie));
+
+    buildAutomaton(trie);
+
+    auto rdfa = make_unique<raw_dfa>(NFA_OUTFIX);
+
+    // Calculate alphabet.
+    array<u16, ALPHABET_SIZE> unalpha;
+    auto &alpha = rdfa->alpha_remap;
+    rdfa->alpha_size = buildAlphabet(trie, nocase, alpha, unalpha);
+
+    // Construct states and transitions.
+    const u16 root_state = DEAD_STATE + 1;
+    rdfa->start_anchored = root_state;
+    rdfa->start_floating = root_state;
+    rdfa->states.resize(num_vertices(trie) + 1, dstate(rdfa->alpha_size));
+
+    // Dead state.
+    fill(rdfa->states[DEAD_STATE].next.begin(),
+         rdfa->states[DEAD_STATE].next.end(), DEAD_STATE);
+
+    for (auto u : vertices_range(trie)) {
+        auto u_state = trie[u].index + 1;
+        DEBUG_PRINTF("state %zu\n", u_state);
+        assert(u_state < rdfa->states.size());
+        auto &ds = rdfa->states[u_state];
+        ds.daddy = root_state;
+        ds.reports = trie[u].reports;
+
+        if (!ds.reports.empty()) {
+            DEBUG_PRINTF("reports: %s\n", as_string_list(ds.reports).c_str());
+        }
+
+        // By default, transition back to the root.
+        fill(ds.next.begin(), ds.next.end(), root_state);
+        // TOP should be a self-loop.
+        ds.next[alpha[TOP]] = u_state;
 
-    if (rdfa) {/* also include the existing dfa */
-        to_merge.push_back(rdfa.get());
+        // Add in the real transitions.
+        for (auto v : adjacent_vertices_range(u, trie)) {
+            if (v == trie.root) {
+                continue;
+            }
+            auto v_state = trie[v].index + 1;
+            assert((u16)trie[v].c < alpha.size());
+            u16 sym = alpha[trie[v].c];
+            DEBUG_PRINTF("edge to %zu on 0x%02x (sym %u)\n", v_state,
+                         trie[v].c, sym);
+            assert(sym < ds.next.size());
+            assert(ds.next[sym] == root_state);
+            ds.next[sym] = v_state;
+        }
     }
 
-    for (const auto &d : temp_dfas) {
-        to_merge.push_back(d.get());
+    return rdfa;
+}
+
+bool SmallWriteBuildImpl::determiniseLiterals() {
+    DEBUG_PRINTF("handling literals\n");
+    assert(!poisoned);
+    assert(num_literals <= cc.grey.smallWriteMaxLiterals);
+
+    if (is_empty(lit_trie) && is_empty(lit_trie_nocase)) {
+        DEBUG_PRINTF("no literals\n");
+        return true; /* nothing to do */
     }
 
-    assert(to_merge.size() > 1);
+    vector<unique_ptr<raw_dfa>> dfas;
 
-    while (to_merge.size() > LITERAL_MERGE_CHUNK_SIZE) {
-        vector<const raw_dfa *> small_merge;
-        small_merge.insert(small_merge.end(), to_merge.begin(),
-                           to_merge.begin() + LITERAL_MERGE_CHUNK_SIZE);
+    if (!is_empty(lit_trie)) {
+        dfas.push_back(buildDfa(lit_trie, false));
+        DEBUG_PRINTF("caseful literal dfa with %zu states\n",
+                     dfas.back()->states.size());
+    }
+    if (!is_empty(lit_trie_nocase)) {
+        dfas.push_back(buildDfa(lit_trie_nocase, true));
+        DEBUG_PRINTF("nocase literal dfa with %zu states\n",
+                     dfas.back()->states.size());
+    }
 
-        temp_dfas.push_back(
-            mergeAllDfas(small_merge, DFA_MERGE_MAX_STATES, &rm, cc.grey));
+    if (rdfa) {
+        dfas.push_back(move(rdfa));
+        DEBUG_PRINTF("general dfa with %zu states\n",
+                     dfas.back()->states.size());
+    }
 
-        if (!temp_dfas.back()) {
-            DEBUG_PRINTF("merge failed\n");
-            poisoned = true;
-            return false;
-        }
+    // If we only have one DFA, no merging is necessary.
+    if (dfas.size() == 1) {
+        DEBUG_PRINTF("only one dfa\n");
+        rdfa = move(dfas.front());
+        return true;
+    }
 
-        to_merge.erase(to_merge.begin(),
-                       to_merge.begin() + LITERAL_MERGE_CHUNK_SIZE);
-        to_merge.push_back(temp_dfas.back().get());
+    // Merge all DFAs.
+    vector<const raw_dfa *> to_merge;
+    for (const auto &d : dfas) {
+        to_merge.push_back(d.get());
     }
 
     auto merged = mergeAllDfas(to_merge, DFA_MERGE_MAX_STATES, &rm, cc.grey);
@@ -405,11 +640,11 @@ bool SmallWriteBuildImpl::determiniseLiterals() {
         return false;
     }
 
-    DEBUG_PRINTF("merge succeeded, built %p\n", merged.get());
+    DEBUG_PRINTF("merge succeeded, built dfa with %zu states\n",
+                 merged->states.size());
 
-    // Replace our only DFA with the merged one
+    // Replace our only DFA with the merged one.
     rdfa = move(merged);
-
     return true;
 }
 
@@ -527,7 +762,7 @@ unique_ptr<SmallWriteBuild> makeSmallWriteBuilder(size_t num_patterns,
 }
 
 bytecode_ptr<SmallWriteEngine> SmallWriteBuildImpl::build(u32 roseQuality) {
-    if (!rdfa && cand_literals.empty()) {
+    if (!rdfa && is_empty(lit_trie) && is_empty(lit_trie_nocase)) {
         DEBUG_PRINTF("no smallwrite engine\n");
         poisoned = true;
         return nullptr;
@@ -579,9 +814,10 @@ set<ReportID> SmallWriteBuildImpl::all_reports() const {
     if (rdfa) {
         insert(&reports, ::ue2::all_reports(*rdfa));
     }
-    for (const auto &cand : cand_literals) {
-        reports.insert(cand.second);
-    }
+
+    insert(&reports, ::ue2::all_reports(lit_trie));
+    insert(&reports, ::ue2::all_reports(lit_trie_nocase));
+
     return reports;
 }
 
index 92222d62b2fd7a5e41e3b463fccadb468535ef28..648b13db794d3751a7aff653c6e0ae0efc977221 100644 (file)
 #define SMWR_BUILD_H
 
 /**
- * SmallWrite Build interface. Everything you ever needed to feed literals in
- * and get a SmallWriteEngine out. This header should be everything needed by
- * the rest of UE2.
+ * \file
+ * \brief Small-write engine build interface.
+ *
+ * Everything you ever needed to feed literals in and get a SmallWriteEngine
+ * out. This header should be everything needed by the rest of UE2.
  */
 
 #include "ue2common.h"
-#include "util/alloc.h"
 #include "util/bytecode_ptr.h"
 #include "util/noncopyable.h"
 
@@ -53,14 +54,14 @@ class ExpressionInfo;
 class NGHolder;
 class ReportManager;
 
-// Abstract interface intended for callers from elsewhere in the tree, real
-// underlying implementation is SmallWriteBuildImpl in smwr_build_impl.h.
+/**
+ * Abstract interface intended for callers from elsewhere in the tree, real
+ * underlying implementation is SmallWriteBuildImpl in smwr_build_impl.h.
+ */
 class SmallWriteBuild : noncopyable {
 public:
-    // Destructor
     virtual ~SmallWriteBuild();
 
-    // Construct a runtime implementation.
     virtual bytecode_ptr<SmallWriteEngine> build(u32 roseQuality) = 0;
 
     virtual void add(const NGHolder &g, const ExpressionInfo &expr) = 0;
@@ -69,7 +70,7 @@ public:
     virtual std::set<ReportID> all_reports() const = 0;
 };
 
-// Construct a usable SmallWrite builder.
+/** \brief Construct a usable SmallWrite builder. */
 std::unique_ptr<SmallWriteBuild>
 makeSmallWriteBuilder(size_t num_patterns, const ReportManager &rm,
                       const CompileContext &cc);