]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Don't pass vistor to immutables in cloned traverse
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Apr 2022 13:41:55 +0000 (09:41 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Apr 2022 14:09:56 +0000 (10:09 -0400)
Saw someone using cloned_traverse to move columns around
(changing their .table) and not surprisingly having poor results.
As cloned traversal is to provide a hook for in-place mutation
of elements, it should not be given Immutable objects as these
should not be changed once they are structurally composed.

Change-Id: I43b22f52f243ef481a75d2cf5ecc73d50f110a81

lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/visitors.py
test/sql/test_external_traversal.py

index d5874334087b8e433051083805af47acb674eb49..bb51693cfeb1406c583c40889e3eb2d90fb503d5 100644 (file)
@@ -121,7 +121,17 @@ def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]:
 
 
 class Immutable:
-    """mark a ClauseElement as 'immutable' when expressions are cloned."""
+    """mark a ClauseElement as 'immutable' when expressions are cloned.
+
+    "immutable" objects refers to the "mutability" of an object in the
+    context of SQL DQL and DML generation.   Such as, in DQL, one can
+    compose a SELECT or subquery of varied forms, but one cannot modify
+    the structure of a specific table or column within DQL.
+    :class:`.Immutable` is mostly intended to follow this concept, and as
+    such the primary "immutable" objects are :class:`.ColumnClause`,
+    :class:`.Column`, :class:`.TableClause`, :class:`.Table`.
+
+    """
 
     _is_immutable = True
 
index 5ff746bb3ac01822a2ebe8d916437abd7492bc44..7363f9ddc28adecd09e90329b8b84bed91886dd6 100644 (file)
@@ -446,6 +446,8 @@ class HasTraverseInternals:
 
     _traverse_internals: _TraverseInternalsType
 
+    _is_immutable: bool = False
+
     @util.preload_module("sqlalchemy.sql.traversals")
     def get_children(
         self, omit_attrs: Tuple[str, ...] = (), **kw: Any
@@ -974,12 +976,26 @@ def cloned_traverse(
     visitors: Mapping[str, _TraverseCallableType[Any]],
 ) -> Optional[ExternallyTraversible]:
     """Clone the given expression structure, allowing modifications by
-    visitors.
+    visitors for mutable objects.
 
     Traversal usage is the same as that of :func:`.visitors.traverse`.
     The visitor functions present in the ``visitors`` dictionary may also
     modify the internals of the given structure as the traversal proceeds.
 
+    The :func:`.cloned_traverse` function does **not** provide objects that are
+    part of the :class:`.Immutable` interface to the visit methods (this
+    primarily includes :class:`.ColumnClause`, :class:`.Column`,
+    :class:`.TableClause` and :class:`.Table` objects). As this traversal is
+    only intended to allow in-place mutation of objects, :class:`.Immutable`
+    objects are skipped. The :meth:`.Immutable._clone` method is still called
+    on each object to allow for objects to replace themselves with a different
+    object based on a clone of their sub-internals (e.g. a
+    :class:`.ColumnClause` that clones its subquery to return a new
+    :class:`.ColumnClause`).
+
+    .. versionchanged:: 2.0  The :func:`.cloned_traverse` function omits
+       objects that are part of the :class:`.Immutable` interface.
+
     The central API feature used by the :func:`.visitors.cloned_traverse`
     and :func:`.visitors.replacement_traverse` functions, in addition to the
     :meth:`_expression.ClauseElement.get_children`
@@ -1021,11 +1037,21 @@ def cloned_traverse(
                         cloned[id(elem)] = newelem
                         return newelem
 
+                # the _clone method for immutable normally returns "self".
+                # however, the method is still allowed to return a
+                # different object altogether; ColumnClause._clone() will
+                # based on options clone the subquery to which it is associated
+                # and return the new corresponding column.
                 cloned[id(elem)] = newelem = elem._clone(clone=clone, **kw)
                 newelem._copy_internals(clone=clone, **kw)
-                meth = visitors.get(newelem.__visit_name__, None)
-                if meth:
-                    meth(newelem)
+
+                # however, visit methods which are tasked with in-place
+                # mutation of the object should not get access to the immutable
+                # object.
+                if not elem._is_immutable:
+                    meth = visitors.get(newelem.__visit_name__, None)
+                    if meth:
+                        meth(newelem)
             return cloned[id(elem)]
 
     if obj is not None:
index ace618e501f62911d1bf0bb585713cd7f3beb654..30d25be90dcfaf0f0a54afa4b6f2d086f9d4628c 100644 (file)
@@ -1,5 +1,6 @@
 import pickle
 import re
+from unittest import mock
 
 from sqlalchemy import and_
 from sqlalchemy import bindparam
@@ -827,6 +828,34 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
             sel._generate_cache_key()[1],
         )
 
+    def test_dont_traverse_immutables(self):
+        meta = MetaData()
+
+        b = Table("b", meta, Column("id", Integer), Column("data", String))
+
+        subq = select(b.c.id).where(b.c.data == "some data").subquery()
+
+        check = mock.Mock()
+
+        class Vis(dict):
+            def get(self, key, default=None):
+                return getattr(check, key)
+
+            def __missing__(self, key):
+                return getattr(check, key)
+
+        visitors.cloned_traverse(subq, {}, Vis())
+
+        eq_(
+            check.mock_calls,
+            [
+                mock.call.bindparam(mock.ANY),
+                mock.call.binary(mock.ANY),
+                mock.call.select(mock.ANY),
+                mock.call.subquery(mock.ANY),
+            ],
+        )
+
     def test_params_on_expr_against_subquery(self):
         """test #7489"""