From d4e7ea7abcf9c95de30cfaacb54477d50ce61c1f Mon Sep 17 00:00:00 2001 From: Martin Grimm Date: Mon, 9 Oct 2023 17:10:43 +0000 Subject: [PATCH] Rewrite SplayNode to eliminate recursive calls (#1431) Recursive method calls in SplayNode can lead to a stack overflow with large (degenerate) trees, e.g. after creating a large dst acl from a sorted ip list. --- include/splay.h | 123 +++++++++++++++++++++++++++----------------- test-suite/splay.cc | 15 +++--- 2 files changed, 83 insertions(+), 55 deletions(-) diff --git a/include/splay.h b/include/splay.h index 8c0f788f42..5a113261e9 100644 --- a/include/splay.h +++ b/include/splay.h @@ -20,15 +20,13 @@ class SplayNode public: typedef V Value; typedef int SPLAYCMP(Value const &a, Value const &b); - typedef void SPLAYFREE(Value &); - typedef void SPLAYWALKEE(Value const & nodedata, void *state); - static void DefaultFree (Value &aValue) {delete aValue;} SplayNode (Value const &); Value data; mutable SplayNode *left; mutable SplayNode *right; - void destroy(SPLAYFREE * = DefaultFree); + mutable SplayNode *visitThreadUp; + SplayNode const * start() const; SplayNode const * finish() const; @@ -39,13 +37,8 @@ public: /// look in the splay for data for where compare(data,candidate) == true. /// return NULL if not found, a pointer to the sought data if found. template SplayNode * splay(const FindValue &data, int( * compare)(FindValue const &a, Value const &b)) const; - - /// recursively visit left nodes, this node, and then right nodes - template void visit(Visitor &v) const; }; -typedef SplayNode splayNode; - template class SplayConstIterator; @@ -61,6 +54,9 @@ public: typedef void SPLAYFREE(Value &); typedef SplayIterator iterator; typedef const SplayConstIterator const_iterator; + + static void DefaultFree(Value &v) { delete v; } + Splay():head(nullptr), elements (0) {} template Value const *find (FindValue const &, int( * compare)(FindValue const &a, Value const &b)) const; @@ -69,7 +65,7 @@ public: void remove(Value const &, SPLAYCMP *compare); - void destroy(SPLAYFREE * = SplayNode::DefaultFree); + void destroy(SPLAYFREE * = DefaultFree); SplayNode const * start() const; @@ -83,10 +79,13 @@ public: const_iterator end() const; - /// recursively visit all nodes, in left-to-right order - template void visit(Visitor &v) const; + /// left-to-right visit of all stored Values + template void visit(ValueVisitor &) const; private: + /// left-to-right walk through all nodes + template void visitEach(NodeVisitor &) const; + mutable SplayNode * head; size_t elements; }; @@ -94,41 +93,26 @@ private: SQUIDCEXTERN int splayLastResult; template -SplayNode::SplayNode (Value const &someData) : data(someData), left(nullptr), right (nullptr) {} +SplayNode::SplayNode(const Value &someData): data(someData), left(nullptr), right(nullptr), visitThreadUp(nullptr) {} template SplayNode const * SplayNode::start() const { - if (left) - return left->start(); - - return this; + auto cur = this; + while (cur->left) + cur = cur->left; + return cur; } template SplayNode const * SplayNode::finish() const { - if (right) - return right->finish(); - - return this; -} - -template -void -SplayNode::destroy(SPLAYFREE * free_func) -{ - if (left) - left->destroy(free_func); - - if (right) - right->destroy(free_func); - - free_func(data); - - delete this; + auto cur = this; + while (cur->right) + cur = cur->right; + return cur; } template @@ -248,13 +232,60 @@ SplayNode::splay(FindValue const &dataToFind, int( * compare)(FindValue const template template void -SplayNode::visit(Visitor &visitor) const +Splay::visitEach(Visitor &visitor) const { - if (left) - left->visit(visitor); - visitor(data); - if (right) - right->visit(visitor); + // In-order walk through tree using modified Morris Traversal: To avoid a + // leftover thread up (and, therefore, a fatal loop in the tree) due to a + // visitor exception, we use an extra pointer visitThreadUp instead of + // manipulating the right child link and interfering with other methods + // that use that link. + // This also helps to distinguish between up and down movements, eliminating + // the need to descent into left subtree a second time after traversing the + // thread to find the loop and remove the temporary thread. + + if (!head) + return; + + auto cur = head; + auto movedUp = false; + cur->visitThreadUp = nullptr; + + while (cur) { + if (!cur->left || movedUp) { + // no (unvisited) left subtree, so handle current node ... + const auto old = cur; + if (cur->right) { + // ... and descent into right subtree + cur = cur->right; + movedUp = false; + } + else if (cur->visitThreadUp) { + // ... or back up the thread + cur = cur->visitThreadUp; + movedUp = true; + } else { + // end of traversal + cur = nullptr; + } + visitor(old); + // old may be destroyed here + } else { + // first descent into left subtree + + // find right-most child in left tree + auto rmc = cur->left; + while (rmc->right) { + rmc->visitThreadUp = nullptr; // cleanup old threads on the way + rmc = rmc->right; + } + // create thread up back to cur + rmc->visitThreadUp = cur; + + // finally descent into left subtree + cur = cur->left; + movedUp = false; + } + } } template @@ -262,8 +293,8 @@ template void Splay::visit(Visitor &visitor) const { - if (head) - head->visit(visitor); + const auto internalVisitor = [&visitor](const SplayNode *node) { visitor(node->data); }; + visitEach(internalVisitor); } template @@ -333,8 +364,8 @@ template void Splay:: destroy(SPLAYFREE *free_func) { - if (head) - head->destroy(free_func); + const auto destroyer = [free_func](SplayNode *node) { free_func(node->data); delete node; }; + visitEach(destroyer); head = nullptr; diff --git a/test-suite/splay.cc b/test-suite/splay.cc index e5c966b3c6..4299479d74 100644 --- a/test-suite/splay.cc +++ b/test-suite/splay.cc @@ -136,15 +136,12 @@ main(int, char *[]) { /* test void * splay containers */ - splayNode *top = nullptr; + const auto top = new Splay(); for (int i = 0; i < 100; ++i) { intnode *I = (intnode *)xcalloc(sizeof(intnode), 1); I->i = nextRandom(); - if (top) - top = top->insert(I, compareintvoid); - else - top = new splayNode(static_cast(new intnode(101))); + top->insert(I, compareintvoid); } SplayCheck::BeginWalk(); @@ -158,13 +155,13 @@ main(int, char *[]) /* test typesafe splay containers */ { /* intnode* */ - SplayNode *safeTop = new SplayNode(new intnode(101)); + const auto safeTop = new Splay(); for ( int i = 0; i < 100; ++i) { intnode *I; I = new intnode; I->i = nextRandom(); - safeTop = safeTop->insert(I, compareint); + safeTop->insert(I, compareint); } SplayCheck::BeginWalk(); @@ -174,12 +171,12 @@ main(int, char *[]) } { /* intnode */ - SplayNode *safeTop = new SplayNode(101); + const auto safeTop = new Splay(); for (int i = 0; i < 100; ++i) { intnode I; I.i = nextRandom(); - safeTop = safeTop->insert(I, compareintref); + safeTop->insert(I, compareintref); } SplayCheck::BeginWalk(); -- 2.47.2