]> git.ipfire.org Git - thirdparty/squid.git/commitdiff
Rewrite SplayNode to eliminate recursive calls (#1431)
authorMartin Grimm <magri@web.de>
Mon, 9 Oct 2023 17:10:43 +0000 (17:10 +0000)
committerSquid Anubis <squid-anubis@squid-cache.org>
Mon, 9 Oct 2023 17:10:54 +0000 (17:10 +0000)
Recursive method calls in SplayNode can lead to a stack overflow with
large (degenerate) trees, e.g. after creating a large dst acl from a
sorted ip list.

include/splay.h
test-suite/splay.cc

index 8c0f788f423c35efa30646c59f1e0025344ce4b2..5a113261e919889eac105ffc530be5d02c39f0da 100644 (file)
@@ -20,15 +20,13 @@ class SplayNode
 public:
     typedef V Value;
     typedef int SPLAYCMP(Value const &a, Value const &b);
-    typedef void SPLAYFREE(Value &);
-    typedef void SPLAYWALKEE(Value const & nodedata, void *state);
-    static void DefaultFree (Value &aValue) {delete aValue;}
 
     SplayNode<V> (Value const &);
     Value data;
     mutable SplayNode<V> *left;
     mutable SplayNode<V> *right;
-    void destroy(SPLAYFREE * = DefaultFree);
+    mutable SplayNode<V> *visitThreadUp;
+
     SplayNode<V> const * start() const;
     SplayNode<V> const * finish() const;
 
@@ -39,13 +37,8 @@ public:
     /// look in the splay for data for where compare(data,candidate) == true.
     /// return NULL if not found, a pointer to the sought data if found.
     template <class FindValue> SplayNode<V> * splay(const FindValue &data, int( * compare)(FindValue const &a, Value const &b)) const;
-
-    /// recursively visit left nodes, this node, and then right nodes
-    template <class Visitor> void visit(Visitor &v) const;
 };
 
-typedef SplayNode<void *> splayNode;
-
 template <class V>
 class SplayConstIterator;
 
@@ -61,6 +54,9 @@ public:
     typedef void SPLAYFREE(Value &);
     typedef SplayIterator<V> iterator;
     typedef const SplayConstIterator<V> const_iterator;
+
+    static void DefaultFree(Value &v) { delete v; }
+
     Splay():head(nullptr), elements (0) {}
 
     template <class FindValue> Value const *find (FindValue const &, int( * compare)(FindValue const &a, Value const &b)) const;
@@ -69,7 +65,7 @@ public:
 
     void remove(Value const &, SPLAYCMP *compare);
 
-    void destroy(SPLAYFREE * = SplayNode<V>::DefaultFree);
+    void destroy(SPLAYFREE * = DefaultFree);
 
     SplayNode<V> const * start() const;
 
@@ -83,10 +79,13 @@ public:
 
     const_iterator end() const;
 
-    /// recursively visit all nodes, in left-to-right order
-    template <class Visitor> void visit(Visitor &v) const;
+    /// left-to-right visit of all stored Values
+    template <typename ValueVisitor> void visit(ValueVisitor &) const;
 
 private:
+    /// left-to-right walk through all nodes
+    template <typename NodeVisitor> void visitEach(NodeVisitor &) const;
+
     mutable SplayNode<V> * head;
     size_t elements;
 };
@@ -94,41 +93,26 @@ private:
 SQUIDCEXTERN int splayLastResult;
 
 template<class V>
-SplayNode<V>::SplayNode (Value const &someData) : data(someData), left(nullptr), right (nullptr) {}
+SplayNode<V>::SplayNode(const Value &someData): data(someData), left(nullptr), right(nullptr), visitThreadUp(nullptr) {}
 
 template<class V>
 SplayNode<V> const *
 SplayNode<V>::start() const
 {
-    if (left)
-        return left->start();
-
-    return this;
+    auto cur = this;
+    while (cur->left)
+        cur = cur->left;
+    return cur;
 }
 
 template<class V>
 SplayNode<V> const *
 SplayNode<V>::finish() const
 {
-    if (right)
-        return right->finish();
-
-    return this;
-}
-
-template<class V>
-void
-SplayNode<V>::destroy(SPLAYFREE * free_func)
-{
-    if (left)
-        left->destroy(free_func);
-
-    if (right)
-        right->destroy(free_func);
-
-    free_func(data);
-
-    delete this;
+    auto cur = this;
+    while (cur->right)
+        cur = cur->right;
+    return cur;
 }
 
 template<class V>
@@ -248,13 +232,60 @@ SplayNode<V>::splay(FindValue const &dataToFind, int( * compare)(FindValue const
 template <class V>
 template <class Visitor>
 void
-SplayNode<V>::visit(Visitor &visitor) const
+Splay<V>::visitEach(Visitor &visitor) const
 {
-    if (left)
-        left->visit(visitor);
-    visitor(data);
-    if (right)
-        right->visit(visitor);
+    // In-order walk through tree using modified Morris Traversal: To avoid a
+    // leftover thread up (and, therefore, a fatal loop in the tree) due to a
+    // visitor exception, we use an extra pointer visitThreadUp instead of
+    // manipulating the right child link and interfering with other methods
+    // that use that link.
+    // This also helps to distinguish between up and down movements, eliminating
+    // the need to descent into left subtree a second time after traversing the
+    // thread to find the loop and remove the temporary thread.
+
+    if (!head)
+        return;
+
+    auto cur = head;
+    auto movedUp = false;
+    cur->visitThreadUp = nullptr;
+
+    while (cur) {
+        if (!cur->left || movedUp) {
+            // no (unvisited) left subtree, so handle current node ...
+            const auto old = cur;
+            if (cur->right) {
+                // ... and descent into right subtree
+                cur = cur->right;
+                movedUp = false;
+            }
+            else if (cur->visitThreadUp) {
+                // ... or back up the thread
+                cur = cur->visitThreadUp;
+                movedUp = true;
+            } else {
+                // end of traversal
+                cur = nullptr;
+            }
+            visitor(old);
+            // old may be destroyed here
+        } else {
+            // first descent into left subtree
+
+            // find right-most child in left tree
+            auto rmc = cur->left;
+            while (rmc->right) {
+                rmc->visitThreadUp = nullptr; // cleanup old threads on the way
+                rmc = rmc->right;
+            }
+            // create thread up back to cur
+            rmc->visitThreadUp = cur;
+
+            // finally descent into left subtree
+            cur = cur->left;
+            movedUp = false;
+        }
+    }
 }
 
 template <class V>
@@ -262,8 +293,8 @@ template <class Visitor>
 void
 Splay<V>::visit(Visitor &visitor) const
 {
-    if (head)
-        head->visit(visitor);
+    const auto internalVisitor = [&visitor](const SplayNode<V> *node) { visitor(node->data); };
+    visitEach(internalVisitor);
 }
 
 template <class V>
@@ -333,8 +364,8 @@ template <class V>
 void
 Splay<V>:: destroy(SPLAYFREE *free_func)
 {
-    if (head)
-        head->destroy(free_func);
+    const auto destroyer = [free_func](SplayNode<V> *node) { free_func(node->data); delete node; };
+    visitEach(destroyer);
 
     head = nullptr;
 
index e5c966b3c6405e3b0e22d3f6a51a69b5c119069d..4299479d74bd9b0e05274d3eb67534bb8b8269ca 100644 (file)
@@ -136,15 +136,12 @@ main(int, char *[])
 
     {
         /* test void * splay containers */
-        splayNode *top = nullptr;
+        const auto top = new Splay<void *>();
 
         for (int i = 0; i < 100; ++i) {
             intnode *I = (intnode *)xcalloc(sizeof(intnode), 1);
             I->i = nextRandom();
-            if (top)
-                top = top->insert(I, compareintvoid);
-            else
-                top = new splayNode(static_cast<void*>(new intnode(101)));
+            top->insert(I, compareintvoid);
         }
 
         SplayCheck::BeginWalk();
@@ -158,13 +155,13 @@ main(int, char *[])
     /* test typesafe splay containers */
     {
         /* intnode* */
-        SplayNode<intnode *> *safeTop = new SplayNode<intnode *>(new intnode(101));
+        const auto safeTop = new Splay<intnode *>();
 
         for ( int i = 0; i < 100; ++i) {
             intnode *I;
             I = new intnode;
             I->i = nextRandom();
-            safeTop = safeTop->insert(I, compareint);
+            safeTop->insert(I, compareint);
         }
 
         SplayCheck::BeginWalk();
@@ -174,12 +171,12 @@ main(int, char *[])
     }
     {
         /* intnode */
-        SplayNode<intnode> *safeTop = new SplayNode<intnode>(101);
+        const auto safeTop = new Splay<intnode>();
 
         for (int i = 0; i < 100; ++i) {
             intnode I;
             I.i = nextRandom();
-            safeTop = safeTop->insert(I, compareintref);
+            safeTop->insert(I, compareintref);
         }
 
         SplayCheck::BeginWalk();