]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
improve targeting and labeling for unary() in columns clause
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Mar 2021 02:52:03 +0000 (21:52 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Mar 2021 16:15:38 +0000 (11:15 -0500)
Fixed regression where usage of the standalone :func:`_sql.distinct()` used
in the form of being directly SELECTed would fail to be locatable in the
result set by column identity, which is how the ORM locates columns. While
standalone :func:`_sql.distinct()` is not oriented towards being directly
SELECTed (use :meth:`_sql.select.distinct` for a regular
``SELECT DISTINCT..``) , it was usable to a limited extent in this way
previously (but wouldn't work in subqueries, for example). The column
targeting for unary expressions such as "DISTINCT <col>" has been improved
so that this case works again, and an additional improvement has been made
so that usage of this form in a subquery at least generates valid SQL which
was not the case previously.

The change additionally enhances the ability to target elements in
``row._mapping`` based on SQL expression objects in ORM-enabled
SELECT statements, including whether the statement was invoked by
``connection.execute()`` or ``session.execute()``.

Fixes: #6008
Change-Id: I5cfa39435f5418861d70a7db8f52ab4ced6a792e

doc/build/changelog/unreleased_14/6008.rst [new file with mode: 0644]
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/sql/compiler.py
test/orm/test_cache_key.py
test/orm/test_eager_relations.py
test/orm/test_query.py
test/sql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_14/6008.rst b/doc/build/changelog/unreleased_14/6008.rst
new file mode 100644 (file)
index 0000000..87b4c64
--- /dev/null
@@ -0,0 +1,20 @@
+.. change::
+    :tags: bug, regression, sql
+    :tickets: 6008
+
+    Fixed regression where usage of the standalone :func:`_sql.distinct()` used
+    in the form of being directly SELECTed would fail to be locatable in the
+    result set by column identity, which is how the ORM locates columns. While
+    standalone :func:`_sql.distinct()` is not oriented towards being directly
+    SELECTed (use :meth:`_sql.select.distinct` for a regular
+    ``SELECT DISTINCT..``) , it was usable to a limited extent in this way
+    previously (but wouldn't work in subqueries, for example). The column
+    targeting for unary expressions such as "DISTINCT <col>" has been improved
+    so that this case works again, and an additional improvement has been made
+    so that usage of this form in a subquery at least generates valid SQL which
+    was not the case previously.
+
+    The change additionally enhances the ability to target elements in
+    ``row._mapping`` based on SQL expression objects in ORM-enabled
+    SELECT statements, including whether the statement was invoked by
+    ``connection.execute()`` or ``session.execute()``.
index a9933c7ae77527a9432ac42c934c351ec5717beb..b5b92b1714492f138ccd2f13ad7b74cb79712a15 100644 (file)
@@ -1290,7 +1290,9 @@ class BaseCursorResult(object):
                 compiled
                 and compiled._result_columns
                 and context.cache_hit is context.dialect.CACHE_HIT
-                and not compiled._rewrites_selected_columns
+                and not context.execution_options.get(
+                    "_result_disable_adapt_to_context", False
+                )
                 and compiled.statement is not context.invoked_statement
             ):
                 metadata = metadata._adapt_to_context(context)
index 23bae5cc0821fa74c076a8d2575a14cdc604fd31..9cfa0fdc6d7125deceb42fea00c73646f7362e37 100644 (file)
@@ -127,6 +127,11 @@ class QueryContext(object):
             )
 
 
+_result_disable_adapt_to_context = util.immutabledict(
+    {"_result_disable_adapt_to_context": True}
+)
+
+
 class ORMCompileState(CompileState):
     # note this is a dictionary, but the
     # default_compile_options._with_polymorphic_adapt_map is a tuple
@@ -234,6 +239,17 @@ class ORMCompileState(CompileState):
             statement._execution_options,
         )
 
+        # add _result_disable_adapt_to_context=True to execution options.
+        # this will disable the ResultSetMetadata._adapt_to_context()
+        # step which we don't need, as we have result processors cached
+        # against the original SELECT statement before caching.
+        if not execution_options:
+            execution_options = _result_disable_adapt_to_context
+        else:
+            execution_options = execution_options.union(
+                _result_disable_adapt_to_context
+            )
+
         if "yield_per" in execution_options or load_options._yield_per:
             execution_options = execution_options.union(
                 {
@@ -343,7 +359,6 @@ class ORMFromStatementCompileState(ORMCompileState):
     def create_for_statement(cls, statement_container, compiler, **kw):
 
         if compiler is not None:
-            compiler._rewrites_selected_columns = True
             toplevel = not compiler.stack
         else:
             toplevel = True
@@ -475,7 +490,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
         if compiler is not None:
             toplevel = not compiler.stack
-            compiler._rewrites_selected_columns = True
             self.global_attributes = compiler._global_attributes
         else:
             toplevel = True
@@ -2160,7 +2174,7 @@ class _QueryEntity(object):
 
     @classmethod
     def to_compile_state(cls, compile_state, entities):
-        for entity in entities:
+        for idx, entity in enumerate(entities):
             if entity._is_lambda_element:
                 if entity._is_sequence:
                     cls.to_compile_state(compile_state, entity._resolved)
@@ -2174,7 +2188,7 @@ class _QueryEntity(object):
                         _MapperEntity(compile_state, entity)
                     else:
                         _ColumnEntity._for_columns(
-                            compile_state, entity._select_iterable
+                            compile_state, entity._select_iterable, idx
                         )
                 else:
                     if entity._annotations.get("bundle", False):
@@ -2183,10 +2197,12 @@ class _QueryEntity(object):
                         # this is legacy only - test_composites.py
                         # test_query_cols_legacy
                         _ColumnEntity._for_columns(
-                            compile_state, entity._select_iterable
+                            compile_state, entity._select_iterable, idx
                         )
                     else:
-                        _ColumnEntity._for_columns(compile_state, [entity])
+                        _ColumnEntity._for_columns(
+                            compile_state, [entity], idx
+                        )
             elif entity.is_bundle:
                 _BundleEntity(compile_state, entity)
 
@@ -2411,7 +2427,7 @@ class _BundleEntity(_QueryEntity):
                     _BundleEntity(compile_state, expr, parent_bundle=self)
                 else:
                     _ORMColumnEntity._for_columns(
-                        compile_state, [expr], parent_bundle=self
+                        compile_state, [expr], None, parent_bundle=self
                     )
 
         self.supports_single_entity = self.bundle.single_entity
@@ -2470,10 +2486,17 @@ class _BundleEntity(_QueryEntity):
 
 
 class _ColumnEntity(_QueryEntity):
-    __slots__ = ("_fetch_column", "_row_processor")
+    __slots__ = (
+        "_fetch_column",
+        "_row_processor",
+        "raw_column_index",
+        "translate_raw_column",
+    )
 
     @classmethod
-    def _for_columns(cls, compile_state, columns, parent_bundle=None):
+    def _for_columns(
+        cls, compile_state, columns, raw_column_index, parent_bundle=None
+    ):
         for column in columns:
             annotations = column._annotations
             if "parententity" in annotations:
@@ -2489,6 +2512,7 @@ class _ColumnEntity(_QueryEntity):
                         compile_state,
                         column,
                         _entity,
+                        raw_column_index,
                         parent_bundle=parent_bundle,
                     )
                 else:
@@ -2496,11 +2520,15 @@ class _ColumnEntity(_QueryEntity):
                         compile_state,
                         column,
                         _entity,
+                        raw_column_index,
                         parent_bundle=parent_bundle,
                     )
             else:
                 _RawColumnEntity(
-                    compile_state, column, parent_bundle=parent_bundle
+                    compile_state,
+                    column,
+                    raw_column_index,
+                    parent_bundle=parent_bundle,
                 )
 
     @property
@@ -2517,7 +2545,15 @@ class _ColumnEntity(_QueryEntity):
         # the resulting callable is entirely cacheable so just return
         # it if we already made one
         if self._row_processor is not None:
-            return self._row_processor
+            getter, label_name, extra_entities = self._row_processor
+            if self.translate_raw_column:
+                extra_entities += (
+                    result.context.invoked_statement._raw_columns[
+                        self.raw_column_index
+                    ],
+                )
+
+            return getter, label_name, extra_entities
 
         # retrieve the column that would have been set up in
         # setup_compile_state, to avoid doing redundant work
@@ -2547,7 +2583,16 @@ class _ColumnEntity(_QueryEntity):
 
         ret = getter, self._label_name, self._extra_entities
         self._row_processor = ret
-        return ret
+
+        if self.translate_raw_column:
+            extra_entities = self._extra_entities + (
+                result.context.invoked_statement._raw_columns[
+                    self.raw_column_index
+                ],
+            )
+            return getter, self._label_name, extra_entities
+        else:
+            return ret
 
 
 class _RawColumnEntity(_ColumnEntity):
@@ -2563,9 +2608,12 @@ class _RawColumnEntity(_ColumnEntity):
         "_extra_entities",
     )
 
-    def __init__(self, compile_state, column, parent_bundle=None):
+    def __init__(
+        self, compile_state, column, raw_column_index, parent_bundle=None
+    ):
         self.expr = column
-
+        self.raw_column_index = raw_column_index
+        self.translate_raw_column = raw_column_index is not None
         self._label_name = compile_state._label_convention(column)
 
         if parent_bundle:
@@ -2619,9 +2667,9 @@ class _ORMColumnEntity(_ColumnEntity):
         compile_state,
         column,
         parententity,
+        raw_column_index,
         parent_bundle=None,
     ):
-
         annotations = column._annotations
 
         _entity = parententity
@@ -2634,9 +2682,17 @@ class _ORMColumnEntity(_ColumnEntity):
         orm_key = annotations.get("proxy_key", None)
         if orm_key:
             self.expr = getattr(_entity.entity, orm_key)
+            self.translate_raw_column = False
         else:
+            # if orm_key is not present, that means this is an ad-hoc
+            # SQL ColumnElement, like a CASE() or other expression.
+            # include this column position from the invoked statement
+            # in the ORM-level ResultSetMetaData on each execute, so that
+            # it can be targeted by identity after caching
             self.expr = column
+            self.translate_raw_column = raw_column_index is not None
 
+        self.raw_column_index = raw_column_index
         self._label_name = compile_state._label_convention(
             column, col_name=orm_key
         )
@@ -2715,6 +2771,8 @@ class _ORMColumnEntity(_ColumnEntity):
 
 
 class _IdentityTokenEntity(_ORMColumnEntity):
+    translate_raw_column = False
+
     def setup_compile_state(self, compile_state):
         pass
 
index c6fa6072ea32b664e9029a2b09136c38226aef3f..f635c1ee41cb74880c5b5bf58bb813ba8c16aa1d 100644 (file)
@@ -404,35 +404,6 @@ class Compiled(object):
 
     .. versionadded:: 1.4
 
-    """
-
-    _rewrites_selected_columns = False
-    """if True, indicates the compile_state object rewrites an incoming
-    ReturnsRows (like a Select) so that the columns we compile against in the
-    result set are not what were expressed on the outside.   this is a hint to
-    the execution context to not link the statement.selected_columns to the
-    columns mapped in the result object.
-
-    That is, when this flag is False::
-
-        stmt = some_statement()
-
-        result = conn.execute(stmt)
-        row = result.first()
-
-        # selected_columns are in a 1-1 relationship with the
-        # columns in the result, and are targetable in mapping
-        for col in stmt.selected_columns:
-            assert col in row._mapping
-
-    When True::
-
-        # selected columns are not what are in the rows.  the context
-        # rewrote the statement for some other set of selected_columns.
-        for col in stmt.selected_columns:
-            assert col not in row._mapping
-
-
     """
 
     cache_key = None
@@ -1858,7 +1829,15 @@ class SQLCompiler(Compiled):
         )
         return getattr(self, attrname, None)
 
-    def visit_unary(self, unary, **kw):
+    def visit_unary(
+        self, unary, add_to_result_map=None, result_map_targets=(), **kw
+    ):
+
+        if add_to_result_map is not None:
+            result_map_targets += (unary,)
+            kw["add_to_result_map"] = add_to_result_map
+            kw["result_map_targets"] = result_map_targets
+
         if unary.operator:
             if unary.modifier:
                 raise exc.CompileError(
@@ -2870,6 +2849,7 @@ class SQLCompiler(Compiled):
             and (
                 not isinstance(column, elements.UnaryExpression)
                 or column.wraps_column_expression
+                or asfrom
             )
             and (
                 not hasattr(column, "name")
index 8b1d185382009d40b6d3770d4336d17ff3e419b3..cd06ce56a12ba3d8e272550e8a7075bcb5703ff0 100644 (file)
@@ -19,6 +19,7 @@ from sqlalchemy.orm import subqueryload
 from sqlalchemy.orm import with_loader_criteria
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.sql.base import CacheableOptions
+from sqlalchemy.sql.expression import case
 from sqlalchemy.sql.visitors import InternalTraversal
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
@@ -599,3 +600,65 @@ class RoundTripTest(QueryTest, AssertsCompiledSQL):
         for i in range(5):
             fn = random.choice([go1, go2])
             self.assert_sql_count(testing.db, fn, 2)
+
+    @testing.combinations((True,), (False,), argnames="use_core")
+    @testing.combinations((True,), (False,), argnames="arbitrary_element")
+    @testing.combinations((True,), (False,), argnames="exercise_caching")
+    def test_column_targeting_core_execute(
+        self,
+        plain_fixture,
+        connection,
+        use_core,
+        arbitrary_element,
+        exercise_caching,
+    ):
+        """test that CursorResultSet will do a column rewrite for any core
+        execute even if the ORM compiled the statement.
+
+        This translates the current stmt.selected_columns to the cached
+        ResultSetMetaData._keymap.      The ORM skips this because loading.py
+        has also cached the selected_columns that are used.   But for
+        an outside-facing Core execute, this has to remain turned on.
+
+        Additionally, we want targeting of SQL expressions to work with both
+        Core and ORM statement executions. So the ORM still has to do some
+        translation here for these elements to be supported.
+
+        """
+        User, Address = plain_fixture
+        user_table = inspect(User).persist_selectable
+
+        def go():
+
+            my_thing = case((User.id > 9, 1), else_=2)
+
+            # include entities in the statement so that we test that
+            # the column indexing from
+            # ORM select()._raw_columns -> Core select()._raw_columns is
+            # translated appropriately
+            stmt = (
+                select(User, Address.email_address, my_thing, User.name)
+                .join(Address)
+                .where(User.name == "ed")
+            )
+
+            if arbitrary_element:
+                target, exp = (my_thing, 2)
+            elif use_core:
+                target, exp = (user_table.c.name, "ed")
+            else:
+                target, exp = (User.name, "ed")
+
+            if use_core:
+                row = connection.execute(stmt).first()
+
+            else:
+                row = Session(connection).execute(stmt).first()
+
+            eq_(row._mapping[target], exp)
+
+        if exercise_caching:
+            for i in range(3):
+                go()
+        else:
+            go()
index d575b85f1b1bf32b22e0935970221a75e75dce35..ca0508db4d946ab556efb5d0f4ebf1e7184dea06 100644 (file)
@@ -3537,6 +3537,8 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
         # test a different unary operator
         # TODO: there is no test in Core that asserts what is happening
         # here as far as the label generation for the ORDER BY
+        # NOTE: this very old test was in fact producing invalid SQL
+        # until #6008 was fixed
         self.assert_compile(
             fixture_session()
             .query(A)
@@ -3547,7 +3549,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             "AS anon_1_anon_2, b_1.id AS b_1_id, b_1.a_id AS "
             "b_1_a_id, b_1.value AS b_1_value FROM (SELECT a.id "
             "AS a_id, NOT (SELECT sum(b.value) AS sum_1 FROM b "
-            "WHERE b.a_id = a.id) FROM a ORDER BY NOT (SELECT "
+            "WHERE b.a_id = a.id) AS anon_2 FROM a ORDER BY NOT (SELECT "
             "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
             "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 "
             "ON anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2",
index f58a1161e47139cc57701fc6bbe321dc8ffa612a..c36946b6ed5ec3e3f48e044a9854d93907896266 100644 (file)
@@ -4141,6 +4141,26 @@ class DistinctTest(QueryTest, AssertsCompiledSQL):
             .all(),
         )
 
+    def test_basic_standalone(self):
+        User = self.classes.User
+
+        # issue 6008.  the UnaryExpression now places itself into the
+        # result map so that it can be matched positionally without the need
+        # for any label.
+        q = fixture_session().query(distinct(User.id)).order_by(User.id)
+        self.assert_compile(
+            q, "SELECT DISTINCT users.id FROM users ORDER BY users.id"
+        )
+        eq_([(7,), (8,), (9,), (10,)], q.all())
+
+    def test_standalone_w_subquery(self):
+        User = self.classes.User
+        q = fixture_session().query(distinct(User.id))
+
+        subq = q.subquery()
+        q = fixture_session().query(subq).order_by(subq.c[0])
+        eq_([(7,), (8,), (9,), (10,)], q.all())
+
     def test_no_automatic_distinct_thing_w_future(self):
         User = self.classes.User
 
index 7d1817b295ebc4433edff1333b38f7b500a61c76..b2d44343849896c426234d79756a9fd6ce6c2c9f 100644 (file)
@@ -1719,6 +1719,23 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT DISTINCT mytable.myid FROM mytable",
         )
 
+        self.assert_compile(
+            select(distinct(table1.c.myid)).set_label_style(
+                LABEL_STYLE_TABLENAME_PLUS_COL
+            ),
+            "SELECT DISTINCT mytable.myid FROM mytable",
+        )
+
+        # the bug fixed here as part of #6008 is the same bug that's
+        # in 1.3 as well, producing
+        # "SELECT anon_2.anon_1 FROM (SELECT distinct mytable.myid
+        # FROM mytable) AS anon_2"
+        self.assert_compile(
+            select(select(distinct(table1.c.myid)).subquery()),
+            "SELECT anon_2.anon_1 FROM (SELECT "
+            "DISTINCT mytable.myid AS anon_1 FROM mytable) AS anon_2",
+        )
+
         self.assert_compile(
             select(table1.c.myid).distinct(),
             "SELECT DISTINCT mytable.myid FROM mytable",