]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add default expression to query_expression()
authorHaoyu Sun <raptorsun@gmail.com>
Fri, 29 May 2020 18:31:07 +0000 (14:31 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Jun 2020 19:15:03 +0000 (15:15 -0400)
Added a new parameter :paramref:`_orm.query_expression.default_expr` to the
:func:`_orm.query_expression` construct, which will be appled to queries
automatically if the :func:`_orm.with_expression` option is not used. Pull
request courtesy Haoyu Sun.

Fixes: #5198
Closes: #5354
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/5354
Pull-request-sha: 57dd4922a3ae4e5fe56dcc541d85ce42256b38b9

Change-Id: I3400f2c00b58bf161f31c74c579feb9ac0f03356

doc/build/changelog/unreleased_13/5198.rst [new file with mode: 0644]
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/strategies.py
test/orm/test_deferred.py

diff --git a/doc/build/changelog/unreleased_13/5198.rst b/doc/build/changelog/unreleased_13/5198.rst
new file mode 100644 (file)
index 0000000..b19da1e
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 5198
+
+    Added a new parameter :paramref:`_orm.query_expression.default_expr` to the
+    :func:`_orm.query_expression` construct, which will be appled to queries
+    automatically if the :func:`_orm.with_expression` option is not used. Pull
+    request courtesy Haoyu Sun.
\ No newline at end of file
index 110c27811d26a584e08a98dc6d58bec80b6f8dda..fabb095a2c32c88b57009c00d22e2166e2b978bf 100644 (file)
@@ -180,9 +180,22 @@ def deferred(*columns, **kw):
     return ColumnProperty(deferred=True, *columns, **kw)
 
 
-def query_expression():
+def query_expression(default_expr=_sql.null()):
     """Indicate an attribute that populates from a query-time SQL expression.
 
+    :param default_expr: Optional SQL expression object that will be used in
+        all cases if not assigned later with :func:`_orm.with_expression`.
+        E.g.::
+
+            from sqlalchemy.sql import literal
+
+            class C(Base):
+                #...
+                my_expr = query_expression(literal(1))
+
+        .. versionadded:: 1.3.18
+
+
     .. versionadded:: 1.2
 
     .. seealso::
@@ -190,7 +203,7 @@ def query_expression():
         :ref:`mapper_querytime_expression`
 
     """
-    prop = ColumnProperty(_sql.null())
+    prop = ColumnProperty(default_expr)
     prop.strategy_key = (("query_expression", True),)
     return prop
 
index 626018997a61e4da2fa98d6c3f470987dcf63067..47cb9ded4b109be91a709f7efbe490e14fb860e2 100644 (file)
@@ -245,6 +245,11 @@ class ExpressionColumnLoader(ColumnLoader):
     def __init__(self, parent, strategy_key):
         super(ExpressionColumnLoader, self).__init__(parent, strategy_key)
 
+        null = sql.null()
+        self._have_default_expression = any(
+            not c.compare(null) for c in self.parent_property.columns
+        )
+
     def setup_query(
         self,
         compile_state,
@@ -256,19 +261,24 @@ class ExpressionColumnLoader(ColumnLoader):
         memoized_populators,
         **kwargs
     ):
-
+        columns = None
         if loadopt and "expression" in loadopt.local_opts:
             columns = [loadopt.local_opts["expression"]]
+        elif self._have_default_expression:
+            columns = self.parent_property.columns
 
-            for c in columns:
-                if adapter:
-                    c = adapter.columns[c]
-                column_collection.append(c)
+        if columns is None:
+            return
 
-            fetch = columns[0]
+        for c in columns:
             if adapter:
-                fetch = adapter.columns[fetch]
-            memoized_populators[self.parent_property] = fetch
+                c = adapter.columns[c]
+            column_collection.append(c)
+
+        fetch = columns[0]
+        if adapter:
+            fetch = adapter.columns[fetch]
+        memoized_populators[self.parent_property] = fetch
 
     def create_row_processor(
         self, context, path, loadopt, mapper, result, adapter, populators
index e0eba3d1117e85c12b4d1da867d619334e60a6e3..a0388ded7a3eff8fdbe41dabec639e29fe30ec59 100644 (file)
@@ -24,6 +24,7 @@ from sqlalchemy.orm import undefer
 from sqlalchemy.orm import undefer_group
 from sqlalchemy.orm import with_expression
 from sqlalchemy.orm import with_polymorphic
+from sqlalchemy.sql import literal
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
@@ -1725,9 +1726,16 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
 
             b_expr = query_expression()
 
+        class C(fixtures.ComparableEntity, Base):
+            __tablename__ = "c"
+            id = Column(Integer, primary_key=True)
+            x = Column(Integer)
+
+            c_expr = query_expression(literal(1))
+
     @classmethod
     def insert_data(cls, connection):
-        A, B = cls.classes("A", "B")
+        A, B, C = cls.classes("A", "B", "C")
         s = Session(connection)
 
         s.add_all(
@@ -1736,6 +1744,8 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
                 A(id=2, x=2, y=3),
                 A(id=3, x=5, y=10, bs=[B(id=3, p=5, q=0)]),
                 A(id=4, x=2, y=10, bs=[B(id=4, p=19, q=8), B(id=5, p=5, q=5)]),
+                C(id=1, x=1),
+                C(id=2, x=2),
             ]
         )
 
@@ -1754,6 +1764,25 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
 
         eq_(a1.all(), [A(my_expr=5), A(my_expr=15), A(my_expr=12)])
 
+    def test_expr_default_value(self):
+        A = self.classes.A
+        C = self.classes.C
+        s = Session()
+
+        a1 = s.query(A).order_by(A.id).filter(A.x > 1)
+        eq_(a1.all(), [A(my_expr=None), A(my_expr=None), A(my_expr=None)])
+
+        c1 = s.query(C).order_by(C.id)
+        eq_(c1.all(), [C(c_expr=1), C(c_expr=1)])
+
+        c2 = (
+            s.query(C)
+            .options(with_expression(C.c_expr, C.x * 2))
+            .filter(C.x > 1)
+            .order_by(C.id)
+        )
+        eq_(c2.all(), [C(c_expr=4)])
+
     def test_reuse_expr(self):
         A = self.classes.A