]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix adaption in AnnotatedLabel; repair needless expense in coercion
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 28 May 2021 12:29:24 +0000 (08:29 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 28 May 2021 14:35:51 +0000 (10:35 -0400)
Fixed regression involving clause adaption of labeled ORM compound
elements, such as single-table inheritance discriminator expressions with
conditionals or CASE expressions, which could cause aliased expressions
such as those used in ORM join / joinedload operations to not be adapted
correctly, such as referring to the wrong table in the ON clause in a join.

This change also improves a performance bump that was located within the
process of invoking :meth:`_sql.Select.join` given an ORM attribute
as a target.

Fixes: #6550
Change-Id: I98906476f0cce6f41ea00b77c789baa818e9d167

doc/build/changelog/unreleased_14/6550.rst [new file with mode: 0644]
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/traversals.py
test/aaa_profiling/test_orm.py
test/profiles.txt
test/sql/test_external_traversal.py

diff --git a/doc/build/changelog/unreleased_14/6550.rst b/doc/build/changelog/unreleased_14/6550.rst
new file mode 100644 (file)
index 0000000..8cb553f
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 6550
+
+    Fixed regression involving clause adaption of labeled ORM compound
+    elements, such as single-table inheritance discriminator expressions with
+    conditionals or CASE expressions, which could cause aliased expressions
+    such as those used in ORM join / joinedload operations to not be adapted
+    correctly, such as referring to the wrong table in the ON clause in a join.
+
+    This change also improves a performance bump that was located within the
+    process of invoking :meth:`_sql.Select.join` given an ORM attribute
+    as a target.
index 36ac507ad3fd40e2433ccb06dcc52cd2e53aa609..82068d7683d552f00c15d890d7160a5479705bc6 100644 (file)
@@ -151,12 +151,25 @@ def expect(
 
             is_clause_element = False
 
-            while hasattr(element, "__clause_element__"):
+            # this is a special performance optimization for ORM
+            # joins used by JoinTargetImpl that we don't go through the
+            # work of creating __clause_element__() when we only need the
+            # original QueryableAttribute, as the former will do clause
+            # adaption and all that which is just thrown away here.
+            if (
+                impl._skip_clauseelement_for_target_match
+                and isinstance(element, role)
+                and hasattr(element, "__clause_element__")
+            ):
                 is_clause_element = True
-                if not getattr(element, "is_clause_element", False):
-                    element = element.__clause_element__()
-                else:
-                    break
+            else:
+                while hasattr(element, "__clause_element__"):
+                    is_clause_element = True
+
+                    if not getattr(element, "is_clause_element", False):
+                        element = element.__clause_element__()
+                    else:
+                        break
 
             if not is_clause_element:
                 if impl._use_inspection:
@@ -230,6 +243,7 @@ class RoleImpl(object):
 
     _post_coercion = None
     _resolve_literal_only = False
+    _skip_clauseelement_for_target_match = False
 
     def __init__(self, role_class):
         self._role_class = role_class
@@ -860,6 +874,8 @@ class HasCTEImpl(ReturnsRowsImpl):
 class JoinTargetImpl(RoleImpl):
     __slots__ = ()
 
+    _skip_clauseelement_for_target_match = True
+
     def _literal_coercion(self, element, legacy=False, **kw):
         if isinstance(element, str):
             return element
index f212cb079c2752d76c7fc46abd3a08fef0042294..213f47c4097f5deda43dfe6e6f0feb08114cccd9 100644 (file)
@@ -4395,6 +4395,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
         return self.element.foreign_keys
 
     def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
+        self._reset_memoizations()
         self._element = clone(self._element, **kw)
         if anonymize_labels:
             self.name = self._resolve_label = _anonymous_label.safe_construct(
index e64eff6a419797c0161b544d3d326a95813e3555..35f2bd62f94cf08f499195596dbe3dba2762831e 100644 (file)
@@ -737,7 +737,6 @@ class HasCopyInternals(object):
                 continue
 
             if obj is not None:
-
                 result = meth(attrname, self, obj, **kw)
                 if result is not None:
                     setattr(self, attrname, result)
index 8116e5f215d6f8ada5bbd8fcdc6266c6fd41f19b..356ea252d8d9ff55d32f9ae6531e2615e3b0d458 100644 (file)
@@ -2,6 +2,7 @@ from sqlalchemy import and_
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import join
+from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.orm import aliased
@@ -966,6 +967,21 @@ class JoinConditionTest(NoCache, fixtures.DeclarativeMappedTest):
 
         go()
 
+    def test_a_to_b_aliased_select_join(self):
+        A, B = self.classes("A", "B")
+
+        b1 = aliased(B)
+
+        stmt = select(A)
+
+        @profiling.function_call_count(times=50, warmup=1)
+        def go():
+            # should not do any adaption or aliasing, this is just getting
+            # the args.  See #6550 where we also fixed this.
+            stmt.join(A.b.of_type(b1))
+
+        go()
+
     def test_a_to_d(self):
         A, D = self.classes("A", "D")
 
index dbf6612815fa89989cf0e6c9baf8264075317fc3..3b5b1aca3e025b144670eac1cab42309647468f5 100644 (file)
@@ -273,6 +273,10 @@ test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_c
 test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_cextensions 10304
 test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_nocextensions 10454
 
+# TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased_select_join
+
+test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased_select_join x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_cextensions 1104
+
 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain
 
 test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 4053
index 9e829baeabf353214d5bdca4352a5987eb1ae5b5..3469dcb372ecc2452c91c23be97406dcf76c8dbc 100644 (file)
@@ -458,6 +458,34 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
             select(f), "SELECT t1_1.col1 * :col1_1 AS anon_1 FROM t1 AS t1_1"
         )
 
+    @testing.combinations(
+        (lambda t1: t1.c.col1, "t1_1.col1"),
+        (lambda t1: t1.c.col1 == "foo", "t1_1.col1 = :col1_1"),
+        (
+            lambda t1: case((t1.c.col1 == "foo", "bar"), else_=t1.c.col1),
+            "CASE WHEN (t1_1.col1 = :col1_1) THEN :param_1 ELSE t1_1.col1 END",
+        ),
+        argnames="case, expected",
+    )
+    @testing.combinations(False, True, argnames="label_")
+    @testing.combinations(False, True, argnames="annotate")
+    def test_annotated_label_cases(self, case, expected, label_, annotate):
+        """test #6550"""
+
+        t1 = table("t1", column("col1"))
+        a1 = t1.alias()
+
+        expr = case(t1=t1)
+
+        if label_:
+            expr = expr.label(None)
+        if annotate:
+            expr = expr._annotate({"foo": "bar"})
+
+        adapted = sql_util.ClauseAdapter(a1).traverse(expr)
+
+        self.assert_compile(adapted, expected)
+
     @testing.combinations((null(),), (true(),))
     def test_dont_adapt_singleton_elements(self, elem):
         """test :ticket:`6259`"""