]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixes: #5198 add default expression param to query_expression() 5354/head
authorHaoyu Sun <raptorsun@gmail.com>
Wed, 27 May 2020 22:42:07 +0000 (00:42 +0200)
committerHaoyu Sun <raptorsun@gmail.com>
Fri, 29 May 2020 10:08:08 +0000 (12:08 +0200)
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/strategies.py
test/orm/test_deferred.py

index 24945ef52baf20443a713485fcc6505e7798781d..7fb94e6e9a2bfd2c2eb903a8314a88a11ebcf313 100644 (file)
@@ -177,9 +177,23 @@ def deferred(*columns, **kw):
     return ColumnProperty(deferred=True, *columns, **kw)
 
 
-def query_expression():
+def query_expression(default_expr=_sql.null()):
     """Indicate an attribute that populates from a query-time SQL expression.
 
+    :param default_expr:
+        SQL clause object. The default query expression if
+        not assigned later by `with_expression`. Here is an example of an
+        attribute my_expr defaulting to 1::
+
+            from sqlalchemy.sql import literal
+
+            class C(Base):
+                #...
+                my_expr = query_expression(literal(1))
+
+        .. versionadded:: 1.3.18
+
+
     .. versionadded:: 1.2
 
     .. seealso::
@@ -187,7 +201,7 @@ def query_expression():
         :ref:`mapper_querytime_expression`
 
     """
-    prop = ColumnProperty(_sql.null())
+    prop = ColumnProperty(default_expr)
     prop.strategy_key = (("query_expression", True),)
     return prop
 
index 7edac8990e8b2770ba1e32190f017e98e4803522..a84ae87c4155414ac6ea653019cea5c66db7cd32 100644 (file)
@@ -235,6 +235,9 @@ class ColumnLoader(LoaderStrategy):
 class ExpressionColumnLoader(ColumnLoader):
     def __init__(self, parent, strategy_key):
         super(ExpressionColumnLoader, self).__init__(parent, strategy_key)
+        self.have_default_expression = any([
+            not c.expression.type._isnull for c in self.parent_property.columns
+        ])
 
     def setup_query(
         self,
@@ -247,19 +250,24 @@ class ExpressionColumnLoader(ColumnLoader):
         memoized_populators,
         **kwargs
     ):
-
+        columns = None
         if loadopt and "expression" in loadopt.local_opts:
             columns = [loadopt.local_opts["expression"]]
+        elif self.have_default_expression:
+            columns = self.parent_property.columns
 
-            for c in columns:
-                if adapter:
-                    c = adapter.columns[c]
-                column_collection.append(c)
+        if columns is None:
+             return
 
-            fetch = columns[0]
+        for c in columns:
             if adapter:
-                fetch = adapter.columns[fetch]
-            memoized_populators[self.parent_property] = fetch
+                c = adapter.columns[c]
+            column_collection.append(c)
+
+        fetch = columns[0]
+        if adapter:
+            fetch = adapter.columns[fetch]
+        memoized_populators[self.parent_property] = fetch
 
     def create_row_processor(
         self, context, path, loadopt, mapper, result, adapter, populators
index b9198033b9052cd9d25fb01d8fe272d29bb7ee12..1a430f8e6a4b58c8acdb9152afe5888a673c231c 100644 (file)
@@ -24,6 +24,7 @@ from sqlalchemy.orm import undefer
 from sqlalchemy.orm import undefer_group
 from sqlalchemy.orm import with_expression
 from sqlalchemy.orm import with_polymorphic
+from sqlalchemy.sql import literal
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
@@ -1720,9 +1721,16 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
 
             b_expr = query_expression()
 
+        class C(fixtures.ComparableEntity, Base):
+            __tablename__ = "c"
+            id = Column(Integer, primary_key=True)
+            x = Column(Integer)
+
+            c_expr = query_expression(literal(1))
+
     @classmethod
     def insert_data(cls, connection):
-        A, B = cls.classes("A", "B")
+        A, B, C = cls.classes("A", "B", "C")
         s = Session(connection)
 
         s.add_all(
@@ -1731,6 +1739,8 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
                 A(id=2, x=2, y=3),
                 A(id=3, x=5, y=10, bs=[B(id=3, p=5, q=0)]),
                 A(id=4, x=2, y=10, bs=[B(id=4, p=19, q=8), B(id=5, p=5, q=5)]),
+                C(id=1, x=1),
+                C(id=2, x=2),
             ]
         )
 
@@ -1749,6 +1759,25 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
 
         eq_(a1.all(), [A(my_expr=5), A(my_expr=15), A(my_expr=12)])
 
+    def test_expr_default_value(self):
+        A = self.classes.A
+        C = self.classes.C
+        s = Session()
+
+        a1 = s.query(A).order_by(A.id).filter(A.x > 1)
+        eq_(a1.all(), [A(my_expr=None), A(my_expr=None), A(my_expr=None),])
+
+        c1 = s.query(C).order_by(C.id)
+        eq_(c1.all(), [C(c_expr=1), C(c_expr=1)])
+
+        c2 = (
+            s.query(C)
+            .options(with_expression(C.c_expr, C.x * 2))
+            .filter(C.x > 1)
+            .order_by(C.id)
+        )
+        eq_(c2.all(), [C(c_expr=4),])
+
     def test_reuse_expr(self):
         A = self.classes.A