]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure entity or None returned from _entity_from_pre_ent_zero()
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 17 Mar 2021 12:59:09 +0000 (08:59 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 17 Mar 2021 13:00:30 +0000 (09:00 -0400)
Fixed regression where the :meth:`_orm.Query.exists` method would fail to
create an expression if the entity list of the :class:`_orm.Query` were
an arbitrary SQL column expression.

Fixes: #6076
Change-Id: I292dd5f527b2cbc1b76ca765b4ea321ef8535709

doc/build/changelog/unreleased_14/6076.rst [new file with mode: 0644]
lib/sqlalchemy/orm/query.py
test/orm/test_query.py

diff --git a/doc/build/changelog/unreleased_14/6076.rst b/doc/build/changelog/unreleased_14/6076.rst
new file mode 100644 (file)
index 0000000..c23b5eb
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 6076
+
+    Fixed regression where the :meth:`_orm.Query.exists` method would fail to
+    create an expression if the entity list of the :class:`_orm.Query` were
+    an arbitrary SQL column expression.
+
index 4e2b4cdebbcff8e6878ef4e8958e76486c5cfa37..8ad4092e689ffb0cc1f8446e805b4de6f95ef4b1 100644 (file)
@@ -22,6 +22,7 @@ import itertools
 import operator
 import types
 
+from sqlalchemy.sql import visitors
 from . import exc as orm_exc
 from . import interfaces
 from . import loading
@@ -197,7 +198,12 @@ class Query(
         elif "bundle" in ent._annotations:
             return ent._annotations["bundle"]
         else:
-            return ent
+            # label, other SQL expression
+            for element in visitors.iterate(ent):
+                if "parententity" in element._annotations:
+                    return element._annotations["parententity"]
+            else:
+                return None
 
     def _only_full_mapper_zero(self, methname):
         if (
index 05ab160074c9d1fc6ea1aa36c3e459aaa8b02b60..800c0ad5df833fb820d1ecfc2727cb0374b75914 100644 (file)
@@ -3925,6 +3925,36 @@ class ExistsTest(QueryTest, AssertsCompiledSQL):
             ") AS anon_1",
         )
 
+    def test_exists_col_expression(self):
+        User = self.classes.User
+        sess = fixture_session()
+
+        q1 = sess.query(User.id)
+        self.assert_compile(
+            sess.query(q1.exists()),
+            "SELECT EXISTS (" "SELECT 1 FROM users" ") AS anon_1",
+        )
+
+    def test_exists_labeled_col_expression(self):
+        User = self.classes.User
+        sess = fixture_session()
+
+        q1 = sess.query(User.id.label("foo"))
+        self.assert_compile(
+            sess.query(q1.exists()),
+            "SELECT EXISTS (" "SELECT 1 FROM users" ") AS anon_1",
+        )
+
+    def test_exists_arbitrary_col_expression(self):
+        User = self.classes.User
+        sess = fixture_session()
+
+        q1 = sess.query(func.foo(User.id))
+        self.assert_compile(
+            sess.query(q1.exists()),
+            "SELECT EXISTS (" "SELECT 1 FROM users" ") AS anon_1",
+        )
+
     def test_exists_col_warning(self):
         User = self.classes.User
         Address = self.classes.Address