From 460bed7cfd8a6dd035caff5f5b1b33edf96fa3bb Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 28 May 2021 08:29:24 -0400 Subject: [PATCH] Fix adaption in AnnotatedLabel; repair needless expense in coercion Fixed regression involving clause adaption of labeled ORM compound elements, such as single-table inheritance discriminator expressions with conditionals or CASE expressions, which could cause aliased expressions such as those used in ORM join / joinedload operations to not be adapted correctly, such as referring to the wrong table in the ON clause in a join. This change also improves a performance bump that was located within the process of invoking :meth:`_sql.Select.join` given an ORM attribute as a target. Fixes: #6550 Change-Id: I98906476f0cce6f41ea00b77c789baa818e9d167 --- doc/build/changelog/unreleased_14/6550.rst | 13 ++++++++++ lib/sqlalchemy/sql/coercions.py | 26 ++++++++++++++++---- lib/sqlalchemy/sql/elements.py | 1 + lib/sqlalchemy/sql/traversals.py | 1 - test/aaa_profiling/test_orm.py | 16 +++++++++++++ test/profiles.txt | 4 ++++ test/sql/test_external_traversal.py | 28 ++++++++++++++++++++++ 7 files changed, 83 insertions(+), 6 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6550.rst diff --git a/doc/build/changelog/unreleased_14/6550.rst b/doc/build/changelog/unreleased_14/6550.rst new file mode 100644 index 0000000000..8cb553f240 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6550.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, orm, regression + :tickets: 6550 + + Fixed regression involving clause adaption of labeled ORM compound + elements, such as single-table inheritance discriminator expressions with + conditionals or CASE expressions, which could cause aliased expressions + such as those used in ORM join / joinedload operations to not be adapted + correctly, such as referring to the wrong table in the ON clause in a join. + + This change also improves a performance bump that was located within the + process of invoking :meth:`_sql.Select.join` given an ORM attribute + as a target. diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 36ac507ad3..82068d7683 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -151,12 +151,25 @@ def expect( is_clause_element = False - while hasattr(element, "__clause_element__"): + # this is a special performance optimization for ORM + # joins used by JoinTargetImpl that we don't go through the + # work of creating __clause_element__() when we only need the + # original QueryableAttribute, as the former will do clause + # adaption and all that which is just thrown away here. + if ( + impl._skip_clauseelement_for_target_match + and isinstance(element, role) + and hasattr(element, "__clause_element__") + ): is_clause_element = True - if not getattr(element, "is_clause_element", False): - element = element.__clause_element__() - else: - break + else: + while hasattr(element, "__clause_element__"): + is_clause_element = True + + if not getattr(element, "is_clause_element", False): + element = element.__clause_element__() + else: + break if not is_clause_element: if impl._use_inspection: @@ -230,6 +243,7 @@ class RoleImpl(object): _post_coercion = None _resolve_literal_only = False + _skip_clauseelement_for_target_match = False def __init__(self, role_class): self._role_class = role_class @@ -860,6 +874,8 @@ class HasCTEImpl(ReturnsRowsImpl): class JoinTargetImpl(RoleImpl): __slots__ = () + _skip_clauseelement_for_target_match = True + def _literal_coercion(self, element, legacy=False, **kw): if isinstance(element, str): return element diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index f212cb079c..213f47c409 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -4395,6 +4395,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): return self.element.foreign_keys def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw): + self._reset_memoizations() self._element = clone(self._element, **kw) if anonymize_labels: self.name = self._resolve_label = _anonymous_label.safe_construct( diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index e64eff6a41..35f2bd62f9 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -737,7 +737,6 @@ class HasCopyInternals(object): continue if obj is not None: - result = meth(attrname, self, obj, **kw) if result is not None: setattr(self, attrname, result) diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 8116e5f215..356ea252d8 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -2,6 +2,7 @@ from sqlalchemy import and_ from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import join +from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.orm import aliased @@ -966,6 +967,21 @@ class JoinConditionTest(NoCache, fixtures.DeclarativeMappedTest): go() + def test_a_to_b_aliased_select_join(self): + A, B = self.classes("A", "B") + + b1 = aliased(B) + + stmt = select(A) + + @profiling.function_call_count(times=50, warmup=1) + def go(): + # should not do any adaption or aliasing, this is just getting + # the args. See #6550 where we also fixed this. + stmt.join(A.b.of_type(b1)) + + go() + def test_a_to_d(self): A, D = self.classes("A", "D") diff --git a/test/profiles.txt b/test/profiles.txt index dbf6612815..3b5b1aca3e 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -273,6 +273,10 @@ test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_c test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_cextensions 10304 test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_nocextensions 10454 +# TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased_select_join + +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased_select_join x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_cextensions 1104 + # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 4053 diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 9e829baeab..3469dcb372 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -458,6 +458,34 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): select(f), "SELECT t1_1.col1 * :col1_1 AS anon_1 FROM t1 AS t1_1" ) + @testing.combinations( + (lambda t1: t1.c.col1, "t1_1.col1"), + (lambda t1: t1.c.col1 == "foo", "t1_1.col1 = :col1_1"), + ( + lambda t1: case((t1.c.col1 == "foo", "bar"), else_=t1.c.col1), + "CASE WHEN (t1_1.col1 = :col1_1) THEN :param_1 ELSE t1_1.col1 END", + ), + argnames="case, expected", + ) + @testing.combinations(False, True, argnames="label_") + @testing.combinations(False, True, argnames="annotate") + def test_annotated_label_cases(self, case, expected, label_, annotate): + """test #6550""" + + t1 = table("t1", column("col1")) + a1 = t1.alias() + + expr = case(t1=t1) + + if label_: + expr = expr.label(None) + if annotate: + expr = expr._annotate({"foo": "bar"}) + + adapted = sql_util.ClauseAdapter(a1).traverse(expr) + + self.assert_compile(adapted, expected) + @testing.combinations((null(),), (true(),)) def test_dont_adapt_singleton_elements(self, elem): """test :ticket:`6259`""" -- 2.47.3