]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure correlate_except is checked for empty tuple
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Jan 2022 22:28:52 +0000 (17:28 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Jan 2022 22:39:33 +0000 (17:39 -0500)
Fixed issue where :meth:`_sql.Select.correlate_except` method, when passed
either the ``None`` value or no arguments, would not correlate any elements
when used in an ORM context (that is, passing ORM entities as FROM
clauses), rather than causing all FROM elements to be considered as
"correlated" in the same way which occurs when using Core-only constructs.

Fixes: #7514
Change-Id: Ic4a5252c8f3c1140aba6c308264948f3a91f33f5

doc/build/changelog/unreleased_14/7514.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
test/orm/test_core_compilation.py
test/sql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_14/7514.rst b/doc/build/changelog/unreleased_14/7514.rst
new file mode 100644 (file)
index 0000000..bf6fd47
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 7514
+
+    Fixed issue where :meth:`_sql.Select.correlate_except` method, when passed
+    either the ``None`` value or no arguments, would not correlate any elements
+    when used in an ORM context (that is, passing ORM entities as FROM
+    clauses), rather than causing all FROM elements to be considered as
+    "correlated" in the same way which occurs when using Core-only constructs.
index 7c2d7295486794675a9546986ed6df00710c65e0..0c1d16d0e2747f158609b893a023d89668f3eb38 100644 (file)
@@ -774,7 +774,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
                     for s in query._correlate
                 )
             )
-        elif query._correlate_except:
+        elif query._correlate_except is not None:
             self.correlate_except = tuple(
                 util.flatten_iterator(
                     sql_util.surface_selectables(s) if s is not None else None
@@ -1192,7 +1192,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         if correlate:
             statement.correlate.non_generative(statement, *correlate)
 
-        if correlate_except:
+        if correlate_except is not None:
             statement.correlate_except.non_generative(
                 statement, *correlate_except
             )
index 000a96a422a5d0fa24ca1120df219ead5865e0da..28f42797e41532dee649cf9414edf585763c1ad5 100644 (file)
@@ -1,8 +1,10 @@
 from sqlalchemy import bindparam
+from sqlalchemy import Column
 from sqlalchemy import exc
 from sqlalchemy import func
 from sqlalchemy import insert
 from sqlalchemy import inspect
+from sqlalchemy import Integer
 from sqlalchemy import literal_column
 from sqlalchemy import null
 from sqlalchemy import or_
@@ -31,12 +33,15 @@ from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.util import resolve_lambda
+from sqlalchemy.util.langhelpers import hybridproperty
 from .inheritance import _poly_fixtures
 from .test_query import QueryTest
+from ..sql.test_compiler import CorrelateTest as _CoreCorrelateTest
 
 # TODO:
 # composites / unions, etc.
@@ -2320,3 +2325,29 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
         )
         self.assert_compile(stmt1, expected)
         self.assert_compile(stmt2, expected)
+
+
+class CorrelateTest(fixtures.DeclarativeMappedTest, _CoreCorrelateTest):
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class T1(Base):
+            __tablename__ = "t1"
+            a = Column(Integer, primary_key=True)
+
+            @hybridproperty
+            def c(self):
+                return self
+
+        class T2(Base):
+            __tablename__ = "t2"
+            a = Column(Integer, primary_key=True)
+
+            @hybridproperty
+            def c(self):
+                return self
+
+    def _fixture(self):
+        t1, t2 = self.classes("T1", "T2")
+        return t1, t2, select(t1).where(t1.c.a == t2.c.a)
index c0fa5748430ce8d2f2aa7404409350f9c90e8242..5ea1110c6f8b3ca97ad48fc8a3a7d55b26b2c9ae 100644 (file)
@@ -5920,6 +5920,14 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
             )
         )
 
+    def test_correlate_except_empty(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_where_all_correlated(
+            select(t1, t2).where(
+                t2.c.a == s1.correlate_except().scalar_subquery()
+            )
+        )
+
     def test_correlate_except_having(self):
         t1, t2, s1 = self._fixture()
         self._assert_having_correlated(