]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Apply consistent labeling for all future style ORM queries
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Feb 2021 19:05:49 +0000 (14:05 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Feb 2021 23:25:52 +0000 (18:25 -0500)
Fixed issue in new 1.4/2.0 style ORM queries where a statement-level label
style would not be preserved in the keys used by result rows; this has been
applied to all combinations of Core/ORM columns / session vs. connection
etc. so that the linkage from statement to result row is the same in all
cases.

also repairs a cache key bug where query.from_statement()
vs. select().from_statement() would not be disambiguated; the
compile options were not included in the cache key for
FromStatement.

Fixes: #5933
Change-Id: I22f6cf0f0b3360e55299cdcb2452cead2b2458ea

doc/build/changelog/unreleased_14/5933.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
test/orm/test_cache_key.py
test/orm/test_query.py
test/sql/test_compare.py
test/sql/test_selectable.py

diff --git a/doc/build/changelog/unreleased_14/5933.rst b/doc/build/changelog/unreleased_14/5933.rst
new file mode 100644 (file)
index 0000000..2d51041
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 5933
+
+    Fixed issue in new 1.4/2.0 style ORM queries where a statement-level label
+    style would not be preserved in the keys used by result rows; this has been
+    applied to all combinations of Core/ORM columns / session vs. connection
+    etc. so that the linkage from statement to result row is the same in all
+    cases.
+
+
+
index f9a0b72fe2802beb9f7977488b89a21fc5752a41..621ed826c022da57931b527a0d9f40c863674ff2 100644 (file)
@@ -42,6 +42,9 @@ _path_registry = PathRegistry.root
 _EMPTY_DICT = util.immutabledict()
 
 
+LABEL_STYLE_LEGACY_ORM = util.symbol("LABEL_STYLE_LEGACY_ORM")
+
+
 class QueryContext(object):
     __slots__ = (
         "compile_state",
@@ -174,6 +177,21 @@ class ORMCompileState(CompileState):
     def __init__(self, *arg, **kw):
         raise NotImplementedError()
 
+    @classmethod
+    def _column_naming_convention(cls, label_style, legacy):
+
+        if legacy:
+
+            def name(col, col_name=None):
+                if col_name:
+                    return col_name
+                else:
+                    return getattr(col, "key")
+
+            return name
+        else:
+            return SelectState._column_naming_convention(label_style)
+
     @classmethod
     def create_for_statement(cls, statement_container, compiler, **kw):
         """Create a context for a statement given a :class:`.Compiler`.
@@ -345,6 +363,25 @@ class ORMFromStatementCompileState(ORMCompileState):
 
         self.compile_options = statement_container._compile_options
 
+        if (
+            self.use_legacy_query_style
+            and isinstance(statement, expression.SelectBase)
+            and not statement._is_textual
+            and statement._label_style is LABEL_STYLE_NONE
+        ):
+            self.statement = statement.set_label_style(
+                LABEL_STYLE_TABLENAME_PLUS_COL
+            )
+        else:
+            self.statement = statement
+
+        self._label_convention = self._column_naming_convention(
+            statement._label_style
+            if not statement._is_textual
+            else LABEL_STYLE_NONE,
+            self.use_legacy_query_style,
+        )
+
         _QueryEntity.to_compile_state(self, statement_container._raw_columns)
 
         self.current_path = statement_container._compile_options._current_path
@@ -370,16 +407,6 @@ class ORMFromStatementCompileState(ORMCompileState):
         self.create_eager_joins = []
         self._fallback_from_clauses = []
 
-        if (
-            isinstance(statement, expression.SelectBase)
-            and not statement._is_textual
-            and statement._label_style is util.symbol("LABEL_STYLE_NONE")
-        ):
-            self.statement = statement.set_label_style(
-                LABEL_STYLE_TABLENAME_PLUS_COL
-            )
-        else:
-            self.statement = statement
         self.order_by = None
 
         if isinstance(self.statement, expression.TextClause):
@@ -499,20 +526,27 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
         self.compile_options = select_statement._compile_options
 
-        _QueryEntity.to_compile_state(self, select_statement._raw_columns)
-
         # determine label style.   we can make different decisions here.
         # at the moment, trying to see if we can always use DISAMBIGUATE_ONLY
         # rather than LABEL_STYLE_NONE, and if we can use disambiguate style
         # for new style ORM selects too.
-        if self.select_statement._label_style is LABEL_STYLE_NONE:
-            if self.use_legacy_query_style and not self.for_statement:
+        if (
+            self.use_legacy_query_style
+            and self.select_statement._label_style is LABEL_STYLE_LEGACY_ORM
+        ):
+            if not self.for_statement:
                 self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL
             else:
                 self.label_style = LABEL_STYLE_DISAMBIGUATE_ONLY
         else:
             self.label_style = self.select_statement._label_style
 
+        self._label_convention = self._column_naming_convention(
+            statement._label_style, self.use_legacy_query_style
+        )
+
+        _QueryEntity.to_compile_state(self, select_statement._raw_columns)
+
         self.current_path = select_statement._compile_options._current_path
 
         self.eager_order_by = ()
@@ -685,7 +719,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
                 )
 
     @classmethod
-    def _create_entities_collection(cls, query):
+    def _create_entities_collection(cls, query, legacy):
         """Creates a partial ORMSelectCompileState that includes
         the full collection of _MapperEntity and other _QueryEntity objects.
 
@@ -710,6 +744,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
             )
             self._setup_with_polymorphics()
 
+        self._label_convention = self._column_naming_convention(
+            query._label_style, legacy
+        )
+
         # entities will also set up polymorphic adapters for mappers
         # that have with_polymorphic configured
         _QueryEntity.to_compile_state(self, query._raw_columns)
@@ -1979,10 +2017,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
                 self._where_criteria += (crit,)
 
 
-def _column_descriptions(query_or_select_stmt, compile_state=None):
+def _column_descriptions(
+    query_or_select_stmt, compile_state=None, legacy=False
+):
     if compile_state is None:
         compile_state = ORMSelectCompileState._create_entities_collection(
-            query_or_select_stmt
+            query_or_select_stmt, legacy=legacy
         )
     ctx = compile_state
     return [
@@ -2518,7 +2558,8 @@ class _RawColumnEntity(_ColumnEntity):
 
     def __init__(self, compile_state, column, parent_bundle=None):
         self.expr = column
-        self._label_name = getattr(column, "key", None)
+
+        self._label_name = compile_state._label_convention(column)
 
         if parent_bundle:
             parent_bundle._entities.append(self)
@@ -2582,13 +2623,17 @@ class _ORMColumnEntity(_ColumnEntity):
         # a column if it was acquired using the class' adapter directly,
         # such as using AliasedInsp._adapt_element().  this occurs
         # within internal loaders.
-        self._label_name = _label_name = annotations.get("orm_key", None)
-        if _label_name:
-            self.expr = getattr(_entity.entity, _label_name)
+
+        orm_key = annotations.get("orm_key", None)
+        if orm_key:
+            self.expr = getattr(_entity.entity, orm_key)
         else:
-            self._label_name = getattr(column, "key", None)
             self.expr = column
 
+        self._label_name = compile_state._label_convention(
+            column, col_name=orm_key
+        )
+
         _entity._post_inspect
         self.entity_zero = self.entity_zero_or_selectable = ezero = _entity
         self.mapper = mapper = _entity.mapper
index a63a4236d33fc8224df901d7971c98c9d4c9cf6e..24751bf1d41dfee4ba36ea04f988f4ac366992b5 100644 (file)
@@ -252,7 +252,9 @@ def merge_result(query, iterator, load=True):
     else:
         frozen_result = None
 
-    ctx = querycontext.ORMSelectCompileState._create_entities_collection(query)
+    ctx = querycontext.ORMSelectCompileState._create_entities_collection(
+        query, True
+    )
 
     autoflush = session.autoflush
     try:
index d368182540d6b170ba65666ad37972ade1a24227..30cb9e73017a4f3f5da9fb5949c0df307447e0e9 100644 (file)
@@ -30,6 +30,7 @@ from .base import _assertions
 from .context import _column_descriptions
 from .context import _legacy_determine_last_joined_entity
 from .context import _legacy_filter_by_entity_zero
+from .context import LABEL_STYLE_LEGACY_ORM
 from .context import ORMCompileState
 from .context import ORMFromStatementCompileState
 from .context import QueryContext
@@ -59,7 +60,6 @@ from ..sql.selectable import ForUpdateArg
 from ..sql.selectable import HasHints
 from ..sql.selectable import HasPrefixes
 from ..sql.selectable import HasSuffixes
-from ..sql.selectable import LABEL_STYLE_NONE
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from ..sql.selectable import SelectStatementGrouping
 from ..sql.visitors import InternalTraversal
@@ -119,7 +119,7 @@ class Query(
     _from_obj = ()
     _setup_joins = ()
     _legacy_setup_joins = ()
-    _label_style = LABEL_STYLE_NONE
+    _label_style = LABEL_STYLE_LEGACY_ORM
 
     _compile_options = ORMCompileState.default_compile_options
 
@@ -2825,7 +2825,7 @@ class Query(
 
         """
 
-        return _column_descriptions(self)
+        return _column_descriptions(self, legacy=True)
 
     def instances(self, result_proxy, context=None):
         """Return an ORM result given a :class:`_engine.CursorResult` and
@@ -3199,6 +3199,10 @@ class FromStatement(SelectStatementGrouping, Executable):
         ("element", InternalTraversal.dp_clauseelement),
     ] + Executable._executable_traverse_internals
 
+    _cache_key_traversal = _traverse_internals + [
+        ("_compile_options", InternalTraversal.dp_has_cache_key)
+    ]
+
     def __init__(self, entities, element):
         self._raw_columns = [
             coercions.expect(
index c80b8f5a2d72e8ea829e421d4980db879b99ec8a..51f75baf37da06550edf3389a05af47bd0893fb6 100644 (file)
@@ -1593,7 +1593,7 @@ class SubqueryLoader(PostLoader):
         # much of this we need.    in particular I can't get a test to
         # fail if the "set_base_alias" is missing and not sure why that is.
         orig_compile_state = compile_state_cls._create_entities_collection(
-            orig_query
+            orig_query, legacy=False
         )
 
         (
index c6eae739d5a98ea6c3af320ed78a2e404a153907..2bd1c3ae31f9e5d43b58d6397cf54dce8fd83447 100644 (file)
@@ -3720,6 +3720,9 @@ class Grouping(GroupedElement, ColumnElement):
     def _key_label(self):
         return self._label
 
+    def _gen_label(self, name):
+        return name
+
     @property
     def _label(self):
         return getattr(self.element, "_label", None) or self.anon_label
index a273e0c903d9887b5569519d4666b542818374de..7e2c5dd3bd081999bac5f42ff0c4107ac0daabae 100644 (file)
@@ -2776,13 +2776,11 @@ class SelectBase(
         representing the columns that
         this SELECT statement or similar construct returns in its result set.
 
-        This collection differs from the
-        :attr:`_expression.FromClause.columns` collection
-        of a :class:`_expression.FromClause`
-        in that the columns within this collection
-        cannot be directly nested inside another SELECT statement; a subquery
-        must be applied first which provides for the necessary parenthesization
-        required by SQL.
+        This collection differs from the :attr:`_expression.FromClause.columns`
+        collection of a :class:`_expression.FromClause` in that the columns
+        within this collection cannot be directly nested inside another SELECT
+        statement; a subquery must be applied first which provides for the
+        necessary parenthesization required by SQL.
 
         .. versionadded:: 1.4
 
@@ -4078,6 +4076,60 @@ class SelectState(util.MemoizedSlots, CompileState):
     def from_statement(cls, statement, from_statement):
         cls._plugin_not_implemented()
 
+    @classmethod
+    def _column_naming_convention(cls, label_style):
+        names = set()
+        pa = []
+
+        if label_style is LABEL_STYLE_NONE:
+
+            def go(c, col_name=None):
+                return col_name or c._proxy_key
+
+        elif label_style is LABEL_STYLE_TABLENAME_PLUS_COL:
+
+            def go(c, col_name=None):
+                # we use key_label since this name is intended for targeting
+                # within the ColumnCollection only, it's not related to SQL
+                # rendering which always uses column name for SQL label names
+
+                if col_name:
+                    name = c._gen_label(col_name)
+                else:
+                    name = c._key_label
+
+                if name in names:
+                    if not pa:
+                        pa.append(prefix_anon_map())
+
+                    name = c._label_anon_label % pa[0]
+                else:
+                    names.add(name)
+
+                return name
+
+        else:
+
+            def go(c, col_name=None):
+                # we use key_label since this name is intended for targeting
+                # within the ColumnCollection only, it's not related to SQL
+                # rendering which always uses column name for SQL label names
+                if col_name:
+                    name = col_name
+                else:
+                    name = c._proxy_key
+                if name in names:
+                    if not pa:
+                        pa.append(prefix_anon_map())
+
+                    name = c.anon_label % pa[0]
+                else:
+                    names.add(name)
+
+                return name
+
+        return go
+
     def _get_froms(self, statement):
         seen = set()
         froms = []
@@ -5519,63 +5571,41 @@ class Select(
         representing the columns that
         this SELECT statement or similar construct returns in its result set.
 
-        This collection differs from the
-        :attr:`_expression.FromClause.columns` collection
-        of a :class:`_expression.FromClause`
-        in that the columns within this collection
-        cannot be directly nested inside another SELECT statement; a subquery
-        must be applied first which provides for the necessary parenthesization
-        required by SQL.
+        This collection differs from the :attr:`_expression.FromClause.columns`
+        collection of a :class:`_expression.FromClause` in that the columns
+        within this collection cannot be directly nested inside another SELECT
+        statement; a subquery must be applied first which provides for the
+        necessary parenthesization required by SQL.
 
         For a :func:`_expression.select` construct, the collection here is
-        exactly what would be rendered inside the  "SELECT" statement, and the
-        :class:`_expression.ColumnElement`
-        objects are  directly present as they were
-        given, e.g.::
+        exactly what would be rendered inside the "SELECT" statement, and the
+        :class:`_expression.ColumnElement` objects are directly present as they
+        were given, e.g.::
 
             col1 = column('q', Integer)
             col2 = column('p', Integer)
             stmt = select(col1, col2)
 
         Above, ``stmt.selected_columns`` would be a collection that contains
-        the ``col1`` and ``col2`` objects directly.    For a statement that is
+        the ``col1`` and ``col2`` objects directly. For a statement that is
         against a :class:`_schema.Table` or other
-        :class:`_expression.FromClause`, the collection
-        will use the :class:`_expression.ColumnElement`
-        objects that are in the
+        :class:`_expression.FromClause`, the collection will use the
+        :class:`_expression.ColumnElement` objects that are in the
         :attr:`_expression.FromClause.c` collection of the from element.
 
         .. versionadded:: 1.4
 
         """
-        names = set()
-        pa = None
-        collection = []
-
-        for c in self._exported_columns_iterator():
-            # we use key_label since this name is intended for targeting
-            # within the ColumnCollection only, it's not related to SQL
-            # rendering which always uses column name for SQL label names
-            if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL:
-                name = c._key_label
-            else:
-                name = c._proxy_key
-            if name in names:
-                if pa is None:
-                    pa = prefix_anon_map()
-
-                if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL:
-                    name = c._label_anon_label % pa
-                else:
-                    name = c.anon_label % pa
-            else:
-                names.add(name)
-            collection.append((name, c))
 
-        return ColumnCollection(collection).as_immutable()
+        # compare to SelectState._generate_columns_plus_names, which
+        # generates the actual names used in the SELECT string.  that
+        # method is more complex because it also renders columns that are
+        # fully ambiguous, e.g. same column more than once.
+        conv = SelectState._column_naming_convention(self._label_style)
 
-    # def _exported_columns_iterator(self):
-    #    return _select_iterables(self._raw_columns)
+        return ColumnCollection(
+            [(conv(c), c) for c in self._exported_columns_iterator()]
+        ).as_immutable()
 
     def _exported_columns_iterator(self):
         meth = SelectState.get_plugin_class(self).exported_columns_iterator
@@ -6170,19 +6200,16 @@ class TextualSelect(SelectBase):
         representing the columns that
         this SELECT statement or similar construct returns in its result set.
 
-        This collection differs from the
-        :attr:`_expression.FromClause.columns` collection
-        of a :class:`_expression.FromClause`
-        in that the columns within this collection
-        cannot be directly nested inside another SELECT statement; a subquery
-        must be applied first which provides for the necessary parenthesization
-        required by SQL.
+        This collection differs from the :attr:`_expression.FromClause.columns`
+        collection of a :class:`_expression.FromClause` in that the columns
+        within this collection cannot be directly nested inside another SELECT
+        statement; a subquery must be applied first which provides for the
+        necessary parenthesization required by SQL.
 
-        For a :class:`_expression.TextualSelect` construct,
-        the collection contains the
-        :class:`_expression.ColumnElement`
-        objects that were passed to the constructor,
-        typically via the :meth:`_expression.TextClause.columns` method.
+        For a :class:`_expression.TextualSelect` construct, the collection
+        contains the :class:`_expression.ColumnElement` objects that were
+        passed to the constructor, typically via the
+        :meth:`_expression.TextClause.columns` method.
 
         .. versionadded:: 1.4
 
index 7ef9d1b604f56bea3f11d4c543d9bc27448e4b5e..8b1d185382009d40b6d3770d4336d17ff3e419b3 100644 (file)
@@ -11,6 +11,7 @@ from sqlalchemy.orm import join as orm_join
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import Load
 from sqlalchemy.orm import mapper
+from sqlalchemy.orm import Query
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
@@ -29,7 +30,10 @@ from ..sql.test_compare import CacheKeyFixture
 
 
 def stmt_20(*elements):
-    return tuple(elem._statement_20() for elem in elements)
+    return tuple(
+        elem._statement_20() if isinstance(elem, Query) else elem
+        for elem in elements
+    )
 
 
 class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
@@ -294,6 +298,7 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
                 fixture_session()
                 .query(User)
                 .from_statement(text("select * from user")),
+                select(User).from_statement(text("select * from user")),
                 fixture_session()
                 .query(User)
                 .options(selectinload(User.addresses))
index 0fb5e7dd65294283feab55226e6119483effab96..d86d2ff702a9d99a2b43798fbd2577c156f3737c 100644 (file)
@@ -20,6 +20,9 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import inspect
 from sqlalchemy import Integer
+from sqlalchemy import LABEL_STYLE_DISAMBIGUATE_ONLY
+from sqlalchemy import LABEL_STYLE_NONE
+from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
 from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import null
@@ -57,7 +60,6 @@ from sqlalchemy.orm.util import join
 from sqlalchemy.orm.util import with_parent
 from sqlalchemy.sql import expression
 from sqlalchemy.sql import operators
-from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
@@ -489,6 +491,264 @@ class RowTupleTest(QueryTest):
         eq_(row, (User(id=7), [7]))
 
 
+class RowLabelingTest(QueryTest):
+    @testing.fixture
+    def assert_row_keys(self):
+        def go(stmt, expected, coreorm_exec):
+
+            if coreorm_exec == "core":
+                with testing.db.connect() as conn:
+                    row = conn.execute(stmt).first()
+            else:
+                s = fixture_session()
+
+                row = s.execute(stmt).first()
+
+            eq_(row.keys(), expected)
+
+            # we are disambiguating in exported_columns even if
+            # LABEL_STYLE_NONE, this seems weird also
+            if (
+                stmt._label_style is not LABEL_STYLE_NONE
+                and coreorm_exec == "core"
+            ):
+                eq_(stmt.exported_columns.keys(), list(expected))
+
+            if (
+                stmt._label_style is not LABEL_STYLE_NONE
+                and coreorm_exec == "orm"
+            ):
+                try:
+                    column_descriptions = stmt.column_descriptions
+                except (NotImplementedError, AttributeError):
+                    pass
+                else:
+                    eq_(
+                        [
+                            entity["name"]
+                            for entity in column_descriptions
+                            if entity["name"] is not None
+                        ],
+                        list(expected),
+                    )
+
+        return go
+
+    def test_entity(self, assert_row_keys):
+        User = self.classes.User
+        stmt = select(User)
+
+        assert_row_keys(stmt, ("User",), "orm")
+
+    @testing.combinations(
+        (LABEL_STYLE_NONE, ("id", "name")),
+        (LABEL_STYLE_DISAMBIGUATE_ONLY, ("id", "name")),
+        (LABEL_STYLE_TABLENAME_PLUS_COL, ("users_id", "users_name")),
+        argnames="label_style,expected",
+    )
+    @testing.combinations(("core",), ("orm",), argnames="coreorm_exec")
+    @testing.combinations(("core",), ("orm",), argnames="coreorm_cols")
+    def test_explicit_cols(
+        self,
+        assert_row_keys,
+        label_style,
+        expected,
+        coreorm_cols,
+        coreorm_exec,
+    ):
+        User = self.classes.User
+        users = self.tables.users
+
+        if coreorm_cols == "core":
+            stmt = select(users.c.id, users.c.name).set_label_style(
+                label_style
+            )
+        else:
+            stmt = select(User.id, User.name).set_label_style(label_style)
+
+        assert_row_keys(stmt, expected, coreorm_exec)
+
+    def test_explicit_cols_legacy(self):
+        User = self.classes.User
+
+        s = fixture_session()
+        q = s.query(User.id, User.name)
+        row = q.first()
+
+        eq_(row.keys(), ("id", "name"))
+
+        eq_(
+            [entity["name"] for entity in q.column_descriptions],
+            ["id", "name"],
+        )
+
+    @testing.combinations(
+        (LABEL_STYLE_NONE, ("id", "name", "id", "name")),
+        (LABEL_STYLE_DISAMBIGUATE_ONLY, ("id", "name", "id_1", "name_1")),
+        (
+            LABEL_STYLE_TABLENAME_PLUS_COL,
+            ("u1_id", "u1_name", "u2_id", "u2_name"),
+        ),
+        argnames="label_style,expected",
+    )
+    @testing.combinations(("core",), ("orm",), argnames="coreorm_exec")
+    @testing.combinations(("core",), ("orm",), argnames="coreorm_cols")
+    def test_explicit_ambiguous_cols_subq(
+        self,
+        assert_row_keys,
+        label_style,
+        expected,
+        coreorm_cols,
+        coreorm_exec,
+    ):
+        User = self.classes.User
+        users = self.tables.users
+
+        if coreorm_cols == "core":
+            u1 = select(users.c.id, users.c.name).subquery("u1")
+            u2 = select(users.c.id, users.c.name).subquery("u2")
+        elif coreorm_cols == "orm":
+            u1 = select(User.id, User.name).subquery("u1")
+            u2 = select(User.id, User.name).subquery("u2")
+
+        stmt = (
+            select(u1, u2)
+            .join_from(u1, u2, u1.c.id == u2.c.id)
+            .set_label_style(label_style)
+        )
+        assert_row_keys(stmt, expected, coreorm_exec)
+
+    @testing.combinations(
+        (LABEL_STYLE_NONE, ("id", "name", "User", "id", "name", "a1")),
+        (
+            LABEL_STYLE_DISAMBIGUATE_ONLY,
+            ("id", "name", "User", "id_1", "name_1", "a1"),
+        ),
+        (
+            LABEL_STYLE_TABLENAME_PLUS_COL,
+            ("u1_id", "u1_name", "User", "u2_id", "u2_name", "a1"),
+        ),
+        argnames="label_style,expected",
+    )
+    def test_explicit_ambiguous_cols_w_entities(
+        self,
+        assert_row_keys,
+        label_style,
+        expected,
+    ):
+        User = self.classes.User
+        u1 = select(User.id, User.name).subquery("u1")
+        u2 = select(User.id, User.name).subquery("u2")
+
+        a1 = aliased(User, name="a1")
+        stmt = (
+            select(u1, User, u2, a1)
+            .join_from(u1, u2, u1.c.id == u2.c.id)
+            .join(User, User.id == u1.c.id)
+            .join(a1, a1.id == u1.c.id)
+            .set_label_style(label_style)
+        )
+        assert_row_keys(stmt, expected, "orm")
+
+    @testing.combinations(
+        (LABEL_STYLE_NONE, ("id", "name", "id", "name")),
+        (LABEL_STYLE_DISAMBIGUATE_ONLY, ("id", "name", "id_1", "name_1")),
+        (
+            LABEL_STYLE_TABLENAME_PLUS_COL,
+            ("u1_id", "u1_name", "u2_id", "u2_name"),
+        ),
+        argnames="label_style,expected",
+    )
+    def test_explicit_ambiguous_cols_subq_fromstatement(
+        self, assert_row_keys, label_style, expected
+    ):
+        User = self.classes.User
+
+        u1 = select(User.id, User.name).subquery("u1")
+        u2 = select(User.id, User.name).subquery("u2")
+
+        stmt = (
+            select(u1, u2)
+            .join_from(u1, u2, u1.c.id == u2.c.id)
+            .set_label_style(label_style)
+        )
+
+        stmt = select(u1, u2).from_statement(stmt)
+
+        assert_row_keys(stmt, expected, "orm")
+
+    @testing.combinations(
+        (LABEL_STYLE_NONE, ("id", "name", "id", "name")),
+        (LABEL_STYLE_DISAMBIGUATE_ONLY, ("id", "name", "id", "name")),
+        (LABEL_STYLE_TABLENAME_PLUS_COL, ("id", "name", "id", "name")),
+        argnames="label_style,expected",
+    )
+    def test_explicit_ambiguous_cols_subq_fromstatement_legacy(
+        self, label_style, expected
+    ):
+        User = self.classes.User
+
+        u1 = select(User.id, User.name).subquery("u1")
+        u2 = select(User.id, User.name).subquery("u2")
+
+        stmt = (
+            select(u1, u2)
+            .join_from(u1, u2, u1.c.id == u2.c.id)
+            .set_label_style(label_style)
+        )
+
+        s = fixture_session()
+        row = s.query(u1, u2).from_statement(stmt).first()
+        eq_(row.keys(), expected)
+
+    def test_explicit_ambiguous_orm_cols_legacy(self):
+        User = self.classes.User
+
+        u1 = select(User.id, User.name).subquery("u1")
+        u2 = select(User.id, User.name).subquery("u2")
+
+        s = fixture_session()
+        row = s.query(u1, u2).join(u2, u1.c.id == u2.c.id).first()
+        eq_(row.keys(), ["id", "name", "id", "name"])
+
+    def test_entity_anon_aliased(self, assert_row_keys):
+        User = self.classes.User
+
+        u1 = aliased(User)
+        stmt = select(u1)
+
+        assert_row_keys(stmt, (), "orm")
+
+    def test_entity_name_aliased(self, assert_row_keys):
+        User = self.classes.User
+
+        u1 = aliased(User, name="u1")
+        stmt = select(u1)
+
+        assert_row_keys(stmt, ("u1",), "orm")
+
+    @testing.combinations(
+        (LABEL_STYLE_NONE, ("u1", "u2")),
+        (LABEL_STYLE_DISAMBIGUATE_ONLY, ("u1", "u2")),
+        (LABEL_STYLE_TABLENAME_PLUS_COL, ("u1", "u2")),
+        argnames="label_style,expected",
+    )
+    def test_multi_entity_name_aliased(
+        self, assert_row_keys, label_style, expected
+    ):
+        User = self.classes.User
+
+        u1 = aliased(User, name="u1")
+        u2 = aliased(User, name="u2")
+        stmt = (
+            select(u1, u2)
+            .join_from(u1, u2, u1.id == u2.id)
+            .set_label_style(label_style)
+        )
+
+        assert_row_keys(stmt, expected, "orm")
+
+
 class GetTest(QueryTest):
     def test_loader_options(self):
         User = self.classes.User
index 30235995dbf7056ee8ce7e17f97d0f08ef818379..9a4b8b1996d9be69dd67a5d4d0c63078ee28e33a 100644 (file)
@@ -61,6 +61,7 @@ from sqlalchemy.sql.lambdas import LambdaOptions
 from sqlalchemy.sql.selectable import _OffsetLimitParam
 from sqlalchemy.sql.selectable import AliasedReturnsRows
 from sqlalchemy.sql.selectable import FromGrouping
+from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
 from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from sqlalchemy.sql.selectable import Select
 from sqlalchemy.sql.selectable import Selectable
@@ -400,6 +401,7 @@ class CoreFixtures(object):
             select(table_a.c.b, table_a.c.a).set_label_style(
                 LABEL_STYLE_TABLENAME_PLUS_COL
             ),
+            select(table_a.c.b, table_a.c.a).set_label_style(LABEL_STYLE_NONE),
             select(table_a.c.a).where(table_a.c.b == 5),
             select(table_a.c.a)
             .where(table_a.c.b == 5)
index ce33ed10e51dd560aa37cacf3b81d9d090713815..e15c740752da0176656ce3c62ce280fe7b021718 100644 (file)
@@ -32,6 +32,7 @@ from sqlalchemy.sql import annotation
 from sqlalchemy.sql import base
 from sqlalchemy.sql import column
 from sqlalchemy.sql import elements
+from sqlalchemy.sql import LABEL_STYLE_DISAMBIGUATE_ONLY
 from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import table
@@ -2916,6 +2917,7 @@ class ReprTest(fixtures.TestBase):
 class WithLabelsTest(fixtures.TestBase):
     def _assert_result_keys(self, s, keys):
         compiled = s.compile()
+
         eq_(set(compiled._create_result_map()), set(keys))
 
     def _assert_subq_result_keys(self, s, keys):
@@ -2934,10 +2936,13 @@ class WithLabelsTest(fixtures.TestBase):
 
         self._assert_subq_result_keys(sel, ["x", "x_1"])
 
+        eq_(sel.selected_columns.keys(), ["x", "x"])
+
     def test_names_overlap_label(self):
         sel = self._names_overlap().set_label_style(
             LABEL_STYLE_TABLENAME_PLUS_COL
         )
+        eq_(sel.selected_columns.keys(), ["t1_x", "t2_x"])
         eq_(list(sel.selected_columns.keys()), ["t1_x", "t2_x"])
         eq_(list(sel.subquery().c.keys()), ["t1_x", "t2_x"])
         self._assert_result_keys(sel, ["t1_x", "t2_x"])
@@ -2951,6 +2956,7 @@ class WithLabelsTest(fixtures.TestBase):
     def test_names_overlap_keys_dont_nolabel(self):
         sel = self._names_overlap_keys_dont()
 
+        eq_(sel.selected_columns.keys(), ["a", "b"])
         eq_(list(sel.selected_columns.keys()), ["a", "b"])
         eq_(list(sel.subquery().c.keys()), ["a", "b"])
         self._assert_result_keys(sel, ["x"])
@@ -2959,10 +2965,41 @@ class WithLabelsTest(fixtures.TestBase):
         sel = self._names_overlap_keys_dont().set_label_style(
             LABEL_STYLE_TABLENAME_PLUS_COL
         )
+        eq_(sel.selected_columns.keys(), ["t1_a", "t2_b"])
         eq_(list(sel.selected_columns.keys()), ["t1_a", "t2_b"])
         eq_(list(sel.subquery().c.keys()), ["t1_a", "t2_b"])
         self._assert_result_keys(sel, ["t1_x", "t2_x"])
 
+    def _columns_repeated(self):
+        m = MetaData()
+        t1 = Table("t1", m, Column("x", Integer), Column("y", Integer))
+        return select(t1.c.x, t1.c.y, t1.c.x).set_label_style(LABEL_STYLE_NONE)
+
+    def test_element_repeated_nolabels(self):
+        sel = self._columns_repeated().set_label_style(LABEL_STYLE_NONE)
+        eq_(sel.selected_columns.keys(), ["x", "y", "x"])
+        eq_(list(sel.selected_columns.keys()), ["x", "y", "x"])
+        eq_(list(sel.subquery().c.keys()), ["x", "y", "x_1"])
+        self._assert_result_keys(sel, ["x", "y"])
+
+    def test_element_repeated_disambiguate(self):
+        sel = self._columns_repeated().set_label_style(
+            LABEL_STYLE_DISAMBIGUATE_ONLY
+        )
+        eq_(sel.selected_columns.keys(), ["x", "y", "x_1"])
+        eq_(list(sel.selected_columns.keys()), ["x", "y", "x_1"])
+        eq_(list(sel.subquery().c.keys()), ["x", "y", "x_1"])
+        self._assert_result_keys(sel, ["x", "y", "x__1"])
+
+    def test_element_repeated_labels(self):
+        sel = self._columns_repeated().set_label_style(
+            LABEL_STYLE_TABLENAME_PLUS_COL
+        )
+        eq_(sel.selected_columns.keys(), ["t1_x", "t1_y", "t1_x_1"])
+        eq_(list(sel.selected_columns.keys()), ["t1_x", "t1_y", "t1_x_1"])
+        eq_(list(sel.subquery().c.keys()), ["t1_x", "t1_y", "t1_x_1"])
+        self._assert_result_keys(sel, ["t1_x__1", "t1_x", "t1_y"])
+
     def _labels_overlap(self):
         m = MetaData()
         t1 = Table("t", m, Column("x_id", Integer))
@@ -2971,6 +3008,7 @@ class WithLabelsTest(fixtures.TestBase):
 
     def test_labels_overlap_nolabel(self):
         sel = self._labels_overlap()
+        eq_(sel.selected_columns.keys(), ["x_id", "id"])
         eq_(list(sel.selected_columns.keys()), ["x_id", "id"])
         eq_(list(sel.subquery().c.keys()), ["x_id", "id"])
         self._assert_result_keys(sel, ["x_id", "id"])
@@ -3077,6 +3115,7 @@ class WithLabelsTest(fixtures.TestBase):
 
     def test_keys_overlap_names_dont_nolabel(self):
         sel = self._keys_overlap_names_dont()
+        eq_(sel.selected_columns.keys(), ["x", "b_1"])
         self._assert_result_keys(sel, ["a", "b"])
 
     def test_keys_overlap_names_dont_label(self):