]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
improve column targeting issues with query_expression
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 25 Nov 2022 21:49:28 +0000 (16:49 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Nov 2022 02:31:24 +0000 (21:31 -0500)
Fixed issues in :func:`_orm.with_expression` where expressions that were
composed of columns within a subquery being SELECTed from, or when using
``.from_statement()``, would not render correct SQL **if** the expression
had a label name that matched the attribute which used
:func:`_orm.query_expression`, even when :func:`_orm.query_expression` had
no default expression. For the moment, if the :func:`_orm.query_expression`
**does** have a default expression, that label name is still used for that
default, and an additional label with the same name will be ignored.
Overall, this case is pretty thorny so further adjustments might be
warranted.

Fixes: #8881
Change-Id: Ie939b1470cb2e824717384be65f4cd8edd619942

doc/build/changelog/unreleased_14/8881.rst [new file with mode: 0644]
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/properties.py
test/orm/test_core_compilation.py
test/orm/test_deferred.py

diff --git a/doc/build/changelog/unreleased_14/8881.rst b/doc/build/changelog/unreleased_14/8881.rst
new file mode 100644 (file)
index 0000000..f3fe5e6
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 8881
+
+    Fixed issues in :func:`_orm.with_expression` where expressions that were
+    composed of columns within a subquery being SELECTed from, or when using
+    ``.from_statement()``, would not render correct SQL **if** the expression
+    had a label name that matched the attribute which used
+    :func:`_orm.query_expression`, even when :func:`_orm.query_expression` had
+    no default expression. For the moment, if the :func:`_orm.query_expression`
+    **does** have a default expression, that label name is still used for that
+    default, and an additional label with the same name will be ignored.
+    Overall, this case is pretty thorny so further adjustments might be
+    warranted.
index 30119d9d79e4813ed83d90e648cd342b42fea807..c4abb1c8e3bfb89f20ec57cad8727b1a4298c09f 100644 (file)
@@ -1970,6 +1970,7 @@ def query_expression(
         info=info,
         doc=doc,
     )
+
     prop.strategy_key = (("query_expression", True),)
     return prop
 
index b8e1521a238f4fef3d9b66f80a4b9fd5deb84dc4..f942ad092f43c6299a4aead049edecd731fb41a2 100644 (file)
@@ -221,6 +221,9 @@ class ColumnProperty(
         ]
 
     def _memoized_attr__renders_in_subqueries(self) -> bool:
+        if ("query_expression", True) in self.strategy_key:
+            return self.strategy._have_default_expression  # type: ignore
+
         return ("deferred", True) not in self.strategy_key or (
             self not in self.parent._readonly_props  # type: ignore
         )
index 5c2f107f45ce8410023e744459796e97c1bd12fc..b71d6447348ebe2033ecd5b6786325bc2e719c02 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy import func
 from sqlalchemy import insert
 from sqlalchemy import inspect
 from sqlalchemy import Integer
+from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import null
 from sqlalchemy import or_
@@ -979,6 +980,10 @@ class ExtraColsTest(QueryTest, AssertsCompiledSQL):
             properties=util.OrderedDict(
                 [
                     ("value", query_expression()),
+                    (
+                        "value_w_default",
+                        query_expression(default_expr=literal(15)),
+                    ),
                 ]
             ),
         )
@@ -986,6 +991,24 @@ class ExtraColsTest(QueryTest, AssertsCompiledSQL):
 
         return User
 
+    @testing.fixture
+    def deferred_fixture(self):
+        User = self.classes.User
+        users = self.tables.users
+
+        self.mapper_registry.map_imperatively(
+            User,
+            users,
+            properties={
+                "name": deferred(users.c.name),
+                "name_upper": column_property(
+                    func.upper(users.c.name), deferred=True
+                ),
+            },
+        )
+
+        return User
+
     @testing.fixture
     def query_expression_w_joinedload_fixture(self):
         users, User = (
@@ -1126,10 +1149,71 @@ class ExtraColsTest(QueryTest, AssertsCompiledSQL):
 
         self.assert_compile(
             stmt,
-            "SELECT users.name || :name_1 AS anon_1, users.id, "
+            "SELECT users.name || :name_1 AS anon_1, :param_1 AS anon_2, "
+            "users.id, "
             "users.name FROM users",
         )
 
+    def test_exported_columns_query_expression(self, query_expression_fixture):
+        """test behaviors related to #8881"""
+        User = query_expression_fixture
+
+        stmt = select(User)
+
+        eq_(
+            stmt.selected_columns.keys(),
+            ["value_w_default", "id", "name"],
+        )
+
+        stmt = select(User).options(
+            with_expression(User.value, User.name + "foo")
+        )
+
+        # bigger problem.  we still don't include 'value', because we dont
+        # run query options here.  not "correct", but is at least consistent
+        # with deferred
+        eq_(
+            stmt.selected_columns.keys(),
+            ["value_w_default", "id", "name"],
+        )
+
+    def test_exported_columns_colprop(self, column_property_fixture):
+        """test behaviors related to #8881"""
+        User, _ = column_property_fixture
+
+        stmt = select(User)
+
+        # we get all the cols because they are not deferred and have a value
+        eq_(
+            stmt.selected_columns.keys(),
+            ["concat", "count", "id", "name"],
+        )
+
+    def test_exported_columns_deferred(self, deferred_fixture):
+        """test behaviors related to #8881"""
+        User = deferred_fixture
+
+        stmt = select(User)
+
+        # don't include 'name_upper' as it's deferred and readonly.
+        # "name" however is a column on the table, so even though it is
+        # deferred, it gets special treatment (related to #6661)
+        eq_(
+            stmt.selected_columns.keys(),
+            ["id", "name"],
+        )
+
+        stmt = select(User).options(
+            undefer(User.name), undefer(User.name_upper)
+        )
+
+        # undefer doesn't affect the readonly col because we dont look
+        # at options when we do selected_columns
+        eq_(
+            stmt.selected_columns.keys(),
+            ["id", "name"],
+        )
+
     def test_with_expr_two(self, query_expression_fixture):
         User = query_expression_fixture
 
@@ -1142,7 +1226,8 @@ class ExtraColsTest(QueryTest, AssertsCompiledSQL):
 
         self.assert_compile(
             stmt,
-            "SELECT anon_1.foo, anon_1.id, anon_1.name FROM "
+            "SELECT anon_1.foo, :param_1 AS anon_2, anon_1.id, "
+            "anon_1.name FROM "
             "(SELECT users.id AS id, users.name AS name, "
             "users.name || :name_1 AS foo FROM users) AS anon_1",
         )
index a8317671c772283df1287c2a263733966145fe2d..0f2bb2013078cd6ab748910ada8d87e01956ad79 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy import null
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import union_all
 from sqlalchemy import util
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import attributes
@@ -2054,6 +2055,14 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
 
             bs = relationship("B", order_by="B.id")
 
+        class A_default(fixtures.ComparableEntity, Base):
+            __tablename__ = "a_default"
+            id = Column(Integer, primary_key=True)
+            x = Column(Integer)
+            y = Column(Integer)
+
+            my_expr = query_expression(default_expr=literal(15))
+
         class B(fixtures.ComparableEntity, Base):
             __tablename__ = "b"
             id = Column(Integer, primary_key=True)
@@ -2072,7 +2081,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
 
     @classmethod
     def insert_data(cls, connection):
-        A, B, C = cls.classes("A", "B", "C")
+        A, A_default, B, C = cls.classes("A", "A_default", "B", "C")
         s = Session(connection)
 
         s.add_all(
@@ -2083,6 +2092,8 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
                 A(id=4, x=2, y=10, bs=[B(id=4, p=19, q=8), B(id=5, p=5, q=5)]),
                 C(id=1, x=1),
                 C(id=2, x=2),
+                A_default(id=1, x=1, y=2),
+                A_default(id=2, x=2, y=3),
             ]
         )
 
@@ -2257,6 +2268,149 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
         q.first()
         eq_(a1.my_expr, 5)
 
+    @testing.combinations("core", "orm", argnames="use_core")
+    @testing.combinations(
+        "from_statement", "aliased", argnames="use_from_statement"
+    )
+    @testing.combinations(
+        "same_name", "different_name", argnames="use_same_labelname"
+    )
+    @testing.combinations(
+        "has_default", "no_default", argnames="attr_has_default"
+    )
+    def test_expr_from_subq_plain(
+        self,
+        use_core,
+        use_from_statement,
+        use_same_labelname,
+        attr_has_default,
+    ):
+        """test #8881"""
+
+        if attr_has_default == "has_default":
+            A = self.classes.A_default
+        else:
+            A = self.classes.A
+
+        s = fixture_session()
+
+        if use_same_labelname == "same_name":
+            labelname = "my_expr"
+        else:
+            labelname = "hi"
+
+        if use_core == "core":
+            stmt = select(A.__table__, literal(12).label(labelname))
+        else:
+            stmt = select(A, literal(12).label(labelname))
+
+        if use_from_statement == "aliased":
+            subq = stmt.subquery()
+            a1 = aliased(A, subq)
+            stmt = select(a1).options(
+                with_expression(a1.my_expr, subq.c[labelname])
+            )
+        else:
+            subq = stmt
+            stmt = (
+                select(A)
+                .options(
+                    with_expression(
+                        A.my_expr, subq.selected_columns[labelname]
+                    )
+                )
+                .from_statement(subq)
+            )
+
+        a_obj = s.scalars(stmt).first()
+
+        if (
+            use_same_labelname == "same_name"
+            and attr_has_default == "has_default"
+            and use_core == "orm"
+        ):
+            eq_(a_obj.my_expr, 15)
+        else:
+            eq_(a_obj.my_expr, 12)
+
+    @testing.combinations("core", "orm", argnames="use_core")
+    @testing.combinations(
+        "from_statement", "aliased", argnames="use_from_statement"
+    )
+    @testing.combinations(
+        "same_name", "different_name", argnames="use_same_labelname"
+    )
+    @testing.combinations(
+        "has_default", "no_default", argnames="attr_has_default"
+    )
+    def test_expr_from_subq_union(
+        self,
+        use_core,
+        use_from_statement,
+        use_same_labelname,
+        attr_has_default,
+    ):
+        """test #8881"""
+
+        if attr_has_default == "has_default":
+            A = self.classes.A_default
+        else:
+            A = self.classes.A
+
+        s = fixture_session()
+
+        if use_same_labelname == "same_name":
+            labelname = "my_expr"
+        else:
+            labelname = "hi"
+
+        if use_core == "core":
+            stmt = union_all(
+                select(A.__table__, literal(12).label(labelname)).where(
+                    A.__table__.c.id == 1
+                ),
+                select(A.__table__, literal(18).label(labelname)).where(
+                    A.__table__.c.id == 2
+                ),
+            )
+
+        else:
+            stmt = union_all(
+                select(A, literal(12).label(labelname)).where(A.id == 1),
+                select(A, literal(18).label(labelname)).where(A.id == 2),
+            )
+
+        if use_from_statement == "aliased":
+            subq = stmt.subquery()
+            a1 = aliased(A, subq)
+            stmt = select(a1).options(
+                with_expression(a1.my_expr, subq.c[labelname])
+            )
+        else:
+            subq = stmt
+            stmt = (
+                select(A)
+                .options(
+                    with_expression(
+                        A.my_expr, subq.selected_columns[labelname]
+                    )
+                )
+                .from_statement(subq)
+            )
+
+        a_objs = s.scalars(stmt).all()
+
+        if (
+            use_same_labelname == "same_name"
+            and attr_has_default == "has_default"
+            and use_core == "orm"
+        ):
+            eq_(a_objs[0].my_expr, 15)
+            eq_(a_objs[1].my_expr, 15)
+        else:
+            eq_(a_objs[0].my_expr, 12)
+            eq_(a_objs[1].my_expr, 18)
+
 
 class RaiseLoadTest(fixtures.DeclarativeMappedTest):
     @classmethod