]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support filter_by() from columns, functions, Core SQL
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 May 2021 12:41:09 +0000 (08:41 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 May 2021 15:54:21 +0000 (11:54 -0400)
Fixed regression where :meth:`_orm.Query.filter_by` would not work if the
lead entity were a SQL function or other expression derived from the
primary entity in question, rather than a simple entity or column of that
entity. Additionally, improved the behavior of
:meth:`_sql.Select.filter_by` overall to work with column expressions even
in a non-ORM context.

Fixes: #6414
Change-Id: I316b5bf98293bec1ede08787f6181dd14be85419

doc/build/changelog/unreleased_14/6414.rst [new file with mode: 0644]
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
test/orm/test_query.py
test/sql/test_select.py

diff --git a/doc/build/changelog/unreleased_14/6414.rst b/doc/build/changelog/unreleased_14/6414.rst
new file mode 100644 (file)
index 0000000..353f526
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 6414
+
+    Fixed regression where :meth:`_orm.Query.filter_by` would not work if the
+    lead entity were a SQL function or other expression derived from the
+    primary entity in question, rather than a simple entity or column of that
+    entity. Additionally, improved the behavior of
+    :meth:`_sql.Select.filter_by` overall to work with column expressions even
+    in a non-ORM context.
index d9f05e823bed56f7e583e310eb3f20284c97839c..6d65d9061b257c7a5b85c3d3dc6c8b2c6ef07529 100644 (file)
@@ -15,6 +15,7 @@ import operator
 import re
 
 from . import roles
+from . import visitors
 from .traversals import HasCacheKey  # noqa
 from .traversals import HasCopyInternals  # noqa
 from .traversals import MemoizedHasCacheKey  # noqa
@@ -1575,6 +1576,23 @@ def _bind_or_error(schemaitem, msg=None):
     return bind
 
 
+def _entity_namespace(entity):
+    """Return the nearest .entity_namespace for the given entity.
+
+    If not immediately available, does an iterate to find a sub-element
+    that has one, if any.
+
+    """
+    try:
+        return entity.entity_namespace
+    except AttributeError:
+        for elem in visitors.iterate(entity):
+            if hasattr(elem, "entity_namespace"):
+                return elem.entity_namespace
+        else:
+            raise
+
+
 def _entity_namespace_key(entity, key):
     """Return an entry from an entity_namespace.
 
@@ -1584,8 +1602,8 @@ def _entity_namespace_key(entity, key):
 
     """
 
-    ns = entity.entity_namespace
     try:
+        ns = _entity_namespace(entity)
         return getattr(ns, key)
     except AttributeError as err:
         util.raise_(
index e27b978021557bc67d3a4ed46317a2a18c9632cd..416a4e82ea24fbd9d45ad7e2850a1cc24de2415c 100644 (file)
@@ -300,6 +300,13 @@ class ClauseElement(
             f = f._is_clone_of
         return s
 
+    @property
+    def entity_namespace(self):
+        raise AttributeError(
+            "This SQL expression has no entity namespace "
+            "with which to filter from."
+        )
+
     def __getstate__(self):
         d = self.__dict__.copy()
         d.pop("_is_clone_of", None)
@@ -4664,6 +4671,13 @@ class ColumnClause(
         # expect the columns of tables and subqueries to be leaf nodes.
         return []
 
+    @property
+    def entity_namespace(self):
+        if self.table is not None:
+            return self.table.entity_namespace
+        else:
+            return super(ColumnClause, self).entity_namespace
+
     @HasMemoized.memoized_attribute
     def _from_objects(self):
         t = self.table
index 02ed55100158864a473e8d9500046db6530e928c..0aa870ce4b8511707951b3d95bf535de8817e3db 100644 (file)
@@ -15,6 +15,7 @@ from . import roles
 from . import schema
 from . import sqltypes
 from . import util as sqlutil
+from .base import _entity_namespace
 from .base import ColumnCollection
 from .base import Executable
 from .base import Generative
@@ -618,6 +619,16 @@ class FunctionElement(Executable, ColumnElement, FromClause, Generative):
         else:
             return super(FunctionElement, self).self_group(against=against)
 
+    @property
+    def entity_namespace(self):
+        """overrides FromClause.entity_namespace as functions are generally
+        column expressions and not FromClauses.
+
+        """
+        # ideally functions would not be fromclauses but we failed to make
+        # this adjustment in 1.4
+        return _entity_namespace(self.clause_expr)
+
 
 class FunctionAsBinary(BinaryExpression):
     _traverse_internals = [
index d1723bcf162baa0080faa1de5169a805ca30494e..d26f94bb885870db4f3453ae7febbbe7909a51b2 100644 (file)
@@ -3357,6 +3357,63 @@ class FilterTest(QueryTest, AssertsCompiledSQL):
             checkparams={"email_address_1": "ed@ed.com", "name_1": "ed"},
         )
 
+    def test_filter_by_against_function(self):
+        """test #6414
+
+        this is related to #6401 where the fact that Function is a
+        FromClause, an architectural mistake that we unfortunately did not
+        fix, is confusing the use of entity_namespace etc.
+
+        """
+        User = self.classes.User
+        sess = fixture_session()
+
+        q1 = sess.query(func.count(User.id)).filter_by(name="ed")
+
+        self.assert_compile(
+            q1,
+            "SELECT count(users.id) AS count_1 FROM users "
+            "WHERE users.name = :name_1",
+        )
+
+    def test_filter_by_against_cast(self):
+        """test #6414"""
+        User = self.classes.User
+        sess = fixture_session()
+
+        q1 = sess.query(cast(User.id, Integer)).filter_by(name="ed")
+
+        self.assert_compile(
+            q1,
+            "SELECT CAST(users.id AS INTEGER) AS users_id FROM users "
+            "WHERE users.name = :name_1",
+        )
+
+    def test_filter_by_against_binary(self):
+        """test #6414"""
+        User = self.classes.User
+        sess = fixture_session()
+
+        q1 = sess.query(User.id == 5).filter_by(name="ed")
+
+        self.assert_compile(
+            q1,
+            "SELECT users.id = :id_1 AS anon_1 FROM users "
+            "WHERE users.name = :name_1",
+        )
+
+    def test_filter_by_against_label(self):
+        """test #6414"""
+        User = self.classes.User
+        sess = fixture_session()
+
+        q1 = sess.query(User.id.label("foo")).filter_by(name="ed")
+
+        self.assert_compile(
+            q1,
+            "SELECT users.id AS foo FROM users " "WHERE users.name = :name_1",
+        )
+
     def test_empty_filters(self):
         User = self.classes.User
         sess = fixture_session()
index 1dfb4cd19e96147cb561bd934a42163a2bdd8123..f9f1acfa0193685fe20ffea3d17db79833264746 100644 (file)
@@ -1,6 +1,8 @@
+from sqlalchemy import cast
 from sqlalchemy import Column
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
+from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import select
@@ -13,6 +15,7 @@ from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import fixtures
 
+
 table1 = table(
     "mytable",
     column("myid", Integer),
@@ -296,7 +299,59 @@ class FutureSelectTest(fixtures.TestBase, AssertsCompiledSQL):
             checkparams={"data_1": "p1", "data_2": "c1", "otherid_1": 5},
         )
 
-    def test_filter_by_no_property(self):
+    def test_filter_by_from_col(self):
+        stmt = select(table1.c.myid).filter_by(name="foo")
+        self.assert_compile(
+            stmt,
+            "SELECT mytable.myid FROM mytable WHERE mytable.name = :name_1",
+        )
+
+    def test_filter_by_from_func(self):
+        """test #6414"""
+        stmt = select(func.count(table1.c.myid)).filter_by(name="foo")
+        self.assert_compile(
+            stmt,
+            "SELECT count(mytable.myid) AS count_1 "
+            "FROM mytable WHERE mytable.name = :name_1",
+        )
+
+    def test_filter_by_from_func_not_the_first_arg(self):
+        """test #6414"""
+        stmt = select(func.bar(True, table1.c.myid)).filter_by(name="foo")
+        self.assert_compile(
+            stmt,
+            "SELECT bar(:bar_2, mytable.myid) AS bar_1 "
+            "FROM mytable WHERE mytable.name = :name_1",
+        )
+
+    def test_filter_by_from_cast(self):
+        """test #6414"""
+        stmt = select(cast(table1.c.myid, Integer)).filter_by(name="foo")
+        self.assert_compile(
+            stmt,
+            "SELECT CAST(mytable.myid AS INTEGER) AS myid "
+            "FROM mytable WHERE mytable.name = :name_1",
+        )
+
+    def test_filter_by_from_binary(self):
+        """test #6414"""
+        stmt = select(table1.c.myid == 5).filter_by(name="foo")
+        self.assert_compile(
+            stmt,
+            "SELECT mytable.myid = :myid_1 AS anon_1 "
+            "FROM mytable WHERE mytable.name = :name_1",
+        )
+
+    def test_filter_by_from_label(self):
+        """test #6414"""
+        stmt = select(table1.c.myid.label("some_id")).filter_by(name="foo")
+        self.assert_compile(
+            stmt,
+            "SELECT mytable.myid AS some_id "
+            "FROM mytable WHERE mytable.name = :name_1",
+        )
+
+    def test_filter_by_no_property_from_table(self):
         assert_raises_message(
             exc.InvalidRequestError,
             'Entity namespace for "mytable" has no property "foo"',
@@ -304,6 +359,14 @@ class FutureSelectTest(fixtures.TestBase, AssertsCompiledSQL):
             foo="bar",
         )
 
+    def test_filter_by_no_property_from_col(self):
+        assert_raises_message(
+            exc.InvalidRequestError,
+            'Entity namespace for "mytable.myid" has no property "foo"',
+            select(table1.c.myid).filter_by,
+            foo="bar",
+        )
+
     def test_select_tuple_outer(self):
         stmt = select(tuple_(table1.c.myid, table1.c.name))