]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
de-clone FROM objects placed into from_linter.froms
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 20 Jul 2023 16:36:35 +0000 (12:36 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 24 Jul 2023 15:24:07 +0000 (11:24 -0400)
Fixed issue where internal cloning used by the ORM for expressions like
:meth:`_orm.relationship.Comparator.any` to produce correlated EXISTS
constructs would interfere with the "cartesian product warning" feature of
the SQL compiler, leading the SQL compiler to warn when all elements of the
statement were correctly joined.

Fixes: #10124
Change-Id: I31c1ba538e2b943278e8cc0b7fddc107968a0826

doc/build/changelog/unreleased_20/10124.rst [new file with mode: 0644]
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
test/sql/test_from_linter.py

diff --git a/doc/build/changelog/unreleased_20/10124.rst b/doc/build/changelog/unreleased_20/10124.rst
new file mode 100644 (file)
index 0000000..65b5584
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 10124
+
+    Fixed issue where internal cloning used by the ORM for expressions like
+    :meth:`_orm.relationship.Comparator.any` to produce correlated EXISTS
+    constructs would interfere with the "cartesian product warning" feature of
+    the SQL compiler, leading the SQL compiler to warn when all elements of the
+    statement were correctly joined.
index 8ff11cc78108c354290897233d1184c28e770365..913ab4300d0d04b128658badf38657f9d893040d 100644 (file)
@@ -333,6 +333,15 @@ def _expand_cloned(
     return itertools.chain(*[x._cloned_set for x in elements])
 
 
+def _de_clone(
+    elements: Iterable[_CLE],
+) -> Iterable[_CLE]:
+    for x in elements:
+        while x._is_clone_of is not None:
+            x = x._is_clone_of
+        yield x
+
+
 def _cloned_intersection(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]:
     """return the intersection of sets a and b, counting
     any overlap between 'cloned' predecessors.
index df93161461a52332da87fe7727dbc150e7d536f1..2304bde7542482358794990a26fd98aed165ede9 100644 (file)
@@ -69,6 +69,7 @@ from . import sqltypes
 from . import util as sql_util
 from ._typing import is_column_element
 from ._typing import is_dml
+from .base import _de_clone
 from .base import _from_objects
 from .base import _NONE_NAME
 from .base import _SentinelDefaultCharacterization
@@ -3289,14 +3290,19 @@ class SQLCompiler(Compiled):
                 enclosing_lateral = kw["enclosing_lateral"]
                 lateral_from_linter.edges.update(
                     itertools.product(
-                        binary.left._from_objects + [enclosing_lateral],
-                        binary.right._from_objects + [enclosing_lateral],
+                        _de_clone(
+                            binary.left._from_objects + [enclosing_lateral]
+                        ),
+                        _de_clone(
+                            binary.right._from_objects + [enclosing_lateral]
+                        ),
                     )
                 )
             else:
                 from_linter.edges.update(
                     itertools.product(
-                        binary.left._from_objects, binary.right._from_objects
+                        _de_clone(binary.left._from_objects),
+                        _de_clone(binary.right._from_objects),
                     )
                 )
 
@@ -4080,7 +4086,7 @@ class SQLCompiler(Compiled):
 
         if asfrom:
             if from_linter:
-                from_linter.froms[cte] = cte_name
+                from_linter.froms[cte._de_clone()] = cte_name
 
             if not is_new_cte and embedded_in_current_named_cte:
                 return self.preparer.format_alias(cte, cte_name)  # type: ignore[no-any-return]  # noqa: E501
@@ -4164,7 +4170,7 @@ class SQLCompiler(Compiled):
             return self.preparer.format_alias(alias, alias_name)
         elif asfrom:
             if from_linter:
-                from_linter.froms[alias] = alias_name
+                from_linter.froms[alias._de_clone()] = alias_name
 
             inner = alias.element._compiler_dispatch(
                 self, asfrom=True, lateral=lateral, **kwargs
@@ -4257,7 +4263,7 @@ class SQLCompiler(Compiled):
 
         if asfrom:
             if from_linter:
-                from_linter.froms[element] = (
+                from_linter.froms[element._de_clone()] = (
                     name if name is not None else "(unnamed VALUES element)"
                 )
 
@@ -5160,7 +5166,8 @@ class SQLCompiler(Compiled):
         if from_linter:
             from_linter.edges.update(
                 itertools.product(
-                    join.left._from_objects, join.right._from_objects
+                    _de_clone(join.left._from_objects),
+                    _de_clone(join.right._from_objects),
                 )
             )
 
index 8381ee7601831cb328bab049b28c63a4225cff54..1c86c1669e85bd220d0bb15459f2cc7ed453eea6 100644 (file)
@@ -323,7 +323,7 @@ class ClauseElement(
     def description(self) -> Optional[str]:
         return None
 
-    _is_clone_of: Optional[ClauseElement] = None
+    _is_clone_of: Optional[Self] = None
 
     is_clause_element = True
     is_selectable = False
@@ -458,6 +458,11 @@ class ClauseElement(
             f = f._is_clone_of
         return s
 
+    def _de_clone(self):
+        while self._is_clone_of is not None:
+            self = self._is_clone_of
+        return self
+
     @property
     def entity_namespace(self):
         raise AttributeError(
index 9a471d57126477f50f16909c1d970ab89b908054..139499d941e76006bda0721e25db1702d2527bc1 100644 (file)
@@ -38,6 +38,35 @@ class TestFindUnmatchingFroms(fixtures.TablesTest):
         self.c = self.tables.table_c
         self.d = self.tables.table_d
 
+    @testing.variation(
+        "what_to_clone", ["nothing", "fromclause", "whereclause", "both"]
+    )
+    def test_cloned_aliases(self, what_to_clone):
+        a1 = self.a.alias()
+        b1 = self.b.alias()
+        c = self.c
+
+        j1 = a1.join(b1, a1.c.col_a == b1.c.col_b)
+        j1_from = j1
+        b1_where = b1
+
+        if what_to_clone.fromclause or what_to_clone.both:
+            a1c = a1._clone()
+            b1c = b1._clone()
+            j1_from = a1c.join(b1c, a1c.c.col_a == b1c.c.col_b)
+
+        if what_to_clone.whereclause or what_to_clone.both:
+            b1_where = b1_where._clone()
+
+        query = (
+            select(c)
+            .select_from(c, j1_from)
+            .where(b1_where.c.col_b == c.c.col_c)
+        )
+        for start in None, c:
+            froms, start = find_unmatching_froms(query, start)
+            assert not froms
+
     def test_everything_is_connected(self):
         query = (
             select(self.a)