]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix race conditions in lambda statements
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Jun 2022 19:00:20 +0000 (15:00 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 8 Jun 2022 20:59:12 +0000 (16:59 -0400)
Fixed multiple observed race conditions related to :func:`.lambda_stmt`,
including an initial "dogpile" issue when a new Python code object is
initially analyzed among multiple simultaneous threads which created both a
performance issue as well as some internal corruption of state.
Additionally repaired observed race condition which could occur when
"cloning" an expression construct that is also in the process of being
compiled or otherwise accessed in a different thread due to memoized
attributes altering the ``__dict__`` while iterated, for Python versions
prior to 3.10; in particular the lambda SQL construct is sensitive to this
as it holds onto a single statement object persistently. The iteration has
been refined to use ``dict.copy()`` with or without an additional iteration
instead.

Fixes: #8098
Change-Id: I4e0b627bfa187f1780dc68ec81b94db1c78f846a
(cherry picked from commit 117878f7870377f143917a22160320a891eb0211)

doc/build/changelog/unreleased_14/8098.rst [new file with mode: 0644]
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/lambdas.py

diff --git a/doc/build/changelog/unreleased_14/8098.rst b/doc/build/changelog/unreleased_14/8098.rst
new file mode 100644 (file)
index 0000000..0267817
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8098
+
+    Fixed multiple observed race conditions related to :func:`.lambda_stmt`,
+    including an initial "dogpile" issue when a new Python code object is
+    initially analyzed among multiple simultaneous threads which created both a
+    performance issue as well as some internal corruption of state.
+    Additionally repaired observed race condition which could occur when
+    "cloning" an expression construct that is also in the process of being
+    compiled or otherwise accessed in a different thread due to memoized
+    attributes altering the ``__dict__`` while iterated, for Python versions
+    prior to 3.10; in particular the lambda SQL construct is sensitive to this
+    as it holds onto a single statement object persistently. The iteration has
+    been refined to use ``dict.copy()`` with or without an additional iteration
+    instead.
index 52339e35a73ef5aa3f06a6bac066d27c4a4c5d98..ec685d1fac144d9b0c99dfc58229d330434b6e89 100644 (file)
@@ -559,8 +559,9 @@ class Generative(HasMemoized):
         cls = self.__class__
         s = cls.__new__(cls)
         if skip:
+            # ensure this iteration remains atomic
             s.__dict__ = {
-                k: v for k, v in self.__dict__.items() if k not in skip
+                k: v for k, v in self.__dict__.copy().items() if k not in skip
             }
         else:
             s.__dict__ = self.__dict__.copy()
index 42ec3e0e7d2ec28913f6d46bd5ee19f973e95523..a1891f19cabe9a096cc63fe463bd6c94b73fdb12 100644 (file)
@@ -241,7 +241,14 @@ class ClauseElement(
         """
         skip = self._memoized_keys
         c = self.__class__.__new__(self.__class__)
-        c.__dict__ = {k: v for k, v in self.__dict__.items() if k not in skip}
+
+        if skip:
+            # ensure this iteration remains atomic
+            c.__dict__ = {
+                k: v for k, v in self.__dict__.copy().items() if k not in skip
+            }
+        else:
+            c.__dict__ = self.__dict__.copy()
 
         # this is a marker that helps to "equate" clauses to each other
         # when a Select returns its list of FROM clauses.  the cloning
index 5f91559987d3f2553114095c9833d0ab75a046d0..584efe4c68896d0f5b5af40d91e056b3756ea03a 100644 (file)
@@ -9,6 +9,7 @@ import inspect
 import itertools
 import operator
 import sys
+import threading
 import types
 import weakref
 
@@ -218,11 +219,17 @@ class LambdaElement(elements.ClauseElement):
 
         if rec is None:
             if cache_key is not traversals.NO_CACHE:
-                rec = AnalyzedFunction(
-                    tracker, self, apply_propagate_attrs, fn
-                )
-                rec.closure_bindparams = bindparams
-                lambda_cache[tracker_key + cache_key] = rec
+
+                with AnalyzedCode._generation_mutex:
+                    key = tracker_key + cache_key
+                    if key not in lambda_cache:
+                        rec = AnalyzedFunction(
+                            tracker, self, apply_propagate_attrs, fn
+                        )
+                        rec.closure_bindparams = bindparams
+                        lambda_cache[key] = rec
+                    else:
+                        rec = lambda_cache[key]
             else:
                 rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
 
@@ -607,6 +614,8 @@ class AnalyzedCode(object):
     )
     _fns = weakref.WeakKeyDictionary()
 
+    _generation_mutex = threading.RLock()
+
     @classmethod
     def get(cls, fn, lambda_element, lambda_kw, **kw):
         try:
@@ -614,10 +623,16 @@ class AnalyzedCode(object):
             return cls._fns[fn.__code__]
         except KeyError:
             pass
-        cls._fns[fn.__code__] = analyzed = AnalyzedCode(
-            fn, lambda_element, lambda_kw, **kw
-        )
-        return analyzed
+
+        with cls._generation_mutex:
+            # check for other thread already created object
+            if fn.__code__ in cls._fns:
+                return cls._fns[fn.__code__]
+
+            cls._fns[fn.__code__] = analyzed = AnalyzedCode(
+                fn, lambda_element, lambda_kw, **kw
+            )
+            return analyzed
 
     def __init__(self, fn, lambda_element, opts):
         if inspect.ismethod(fn):