From 553ac45aae5712e64a5380ba1fa1c6028acf5f39 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 11 Feb 2021 14:05:49 -0500 Subject: [PATCH] Apply consistent labeling for all future style ORM queries Fixed issue in new 1.4/2.0 style ORM queries where a statement-level label style would not be preserved in the keys used by result rows; this has been applied to all combinations of Core/ORM columns / session vs. connection etc. so that the linkage from statement to result row is the same in all cases. also repairs a cache key bug where query.from_statement() vs. select().from_statement() would not be disambiguated; the compile options were not included in the cache key for FromStatement. Fixes: #5933 Change-Id: I22f6cf0f0b3360e55299cdcb2452cead2b2458ea --- doc/build/changelog/unreleased_14/5933.rst | 12 + lib/sqlalchemy/orm/context.py | 89 +++++-- lib/sqlalchemy/orm/loading.py | 4 +- lib/sqlalchemy/orm/query.py | 10 +- lib/sqlalchemy/orm/strategies.py | 2 +- lib/sqlalchemy/sql/elements.py | 3 + lib/sqlalchemy/sql/selectable.py | 147 +++++++----- test/orm/test_cache_key.py | 7 +- test/orm/test_query.py | 262 ++++++++++++++++++++- test/sql/test_compare.py | 2 + test/sql/test_selectable.py | 39 +++ 11 files changed, 488 insertions(+), 89 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/5933.rst diff --git a/doc/build/changelog/unreleased_14/5933.rst b/doc/build/changelog/unreleased_14/5933.rst new file mode 100644 index 0000000000..2d510413c8 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5933.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, orm + :tickets: 5933 + + Fixed issue in new 1.4/2.0 style ORM queries where a statement-level label + style would not be preserved in the keys used by result rows; this has been + applied to all combinations of Core/ORM columns / session vs. connection + etc. so that the linkage from statement to result row is the same in all + cases. + + + diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index f9a0b72fe2..621ed826c0 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -42,6 +42,9 @@ _path_registry = PathRegistry.root _EMPTY_DICT = util.immutabledict() +LABEL_STYLE_LEGACY_ORM = util.symbol("LABEL_STYLE_LEGACY_ORM") + + class QueryContext(object): __slots__ = ( "compile_state", @@ -174,6 +177,21 @@ class ORMCompileState(CompileState): def __init__(self, *arg, **kw): raise NotImplementedError() + @classmethod + def _column_naming_convention(cls, label_style, legacy): + + if legacy: + + def name(col, col_name=None): + if col_name: + return col_name + else: + return getattr(col, "key") + + return name + else: + return SelectState._column_naming_convention(label_style) + @classmethod def create_for_statement(cls, statement_container, compiler, **kw): """Create a context for a statement given a :class:`.Compiler`. @@ -345,6 +363,25 @@ class ORMFromStatementCompileState(ORMCompileState): self.compile_options = statement_container._compile_options + if ( + self.use_legacy_query_style + and isinstance(statement, expression.SelectBase) + and not statement._is_textual + and statement._label_style is LABEL_STYLE_NONE + ): + self.statement = statement.set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + else: + self.statement = statement + + self._label_convention = self._column_naming_convention( + statement._label_style + if not statement._is_textual + else LABEL_STYLE_NONE, + self.use_legacy_query_style, + ) + _QueryEntity.to_compile_state(self, statement_container._raw_columns) self.current_path = statement_container._compile_options._current_path @@ -370,16 +407,6 @@ class ORMFromStatementCompileState(ORMCompileState): self.create_eager_joins = [] self._fallback_from_clauses = [] - if ( - isinstance(statement, expression.SelectBase) - and not statement._is_textual - and statement._label_style is util.symbol("LABEL_STYLE_NONE") - ): - self.statement = statement.set_label_style( - LABEL_STYLE_TABLENAME_PLUS_COL - ) - else: - self.statement = statement self.order_by = None if isinstance(self.statement, expression.TextClause): @@ -499,20 +526,27 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self.compile_options = select_statement._compile_options - _QueryEntity.to_compile_state(self, select_statement._raw_columns) - # determine label style. we can make different decisions here. # at the moment, trying to see if we can always use DISAMBIGUATE_ONLY # rather than LABEL_STYLE_NONE, and if we can use disambiguate style # for new style ORM selects too. - if self.select_statement._label_style is LABEL_STYLE_NONE: - if self.use_legacy_query_style and not self.for_statement: + if ( + self.use_legacy_query_style + and self.select_statement._label_style is LABEL_STYLE_LEGACY_ORM + ): + if not self.for_statement: self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL else: self.label_style = LABEL_STYLE_DISAMBIGUATE_ONLY else: self.label_style = self.select_statement._label_style + self._label_convention = self._column_naming_convention( + statement._label_style, self.use_legacy_query_style + ) + + _QueryEntity.to_compile_state(self, select_statement._raw_columns) + self.current_path = select_statement._compile_options._current_path self.eager_order_by = () @@ -685,7 +719,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): ) @classmethod - def _create_entities_collection(cls, query): + def _create_entities_collection(cls, query, legacy): """Creates a partial ORMSelectCompileState that includes the full collection of _MapperEntity and other _QueryEntity objects. @@ -710,6 +744,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState): ) self._setup_with_polymorphics() + self._label_convention = self._column_naming_convention( + query._label_style, legacy + ) + # entities will also set up polymorphic adapters for mappers # that have with_polymorphic configured _QueryEntity.to_compile_state(self, query._raw_columns) @@ -1979,10 +2017,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self._where_criteria += (crit,) -def _column_descriptions(query_or_select_stmt, compile_state=None): +def _column_descriptions( + query_or_select_stmt, compile_state=None, legacy=False +): if compile_state is None: compile_state = ORMSelectCompileState._create_entities_collection( - query_or_select_stmt + query_or_select_stmt, legacy=legacy ) ctx = compile_state return [ @@ -2518,7 +2558,8 @@ class _RawColumnEntity(_ColumnEntity): def __init__(self, compile_state, column, parent_bundle=None): self.expr = column - self._label_name = getattr(column, "key", None) + + self._label_name = compile_state._label_convention(column) if parent_bundle: parent_bundle._entities.append(self) @@ -2582,13 +2623,17 @@ class _ORMColumnEntity(_ColumnEntity): # a column if it was acquired using the class' adapter directly, # such as using AliasedInsp._adapt_element(). this occurs # within internal loaders. - self._label_name = _label_name = annotations.get("orm_key", None) - if _label_name: - self.expr = getattr(_entity.entity, _label_name) + + orm_key = annotations.get("orm_key", None) + if orm_key: + self.expr = getattr(_entity.entity, orm_key) else: - self._label_name = getattr(column, "key", None) self.expr = column + self._label_name = compile_state._label_convention( + column, col_name=orm_key + ) + _entity._post_inspect self.entity_zero = self.entity_zero_or_selectable = ezero = _entity self.mapper = mapper = _entity.mapper diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index a63a4236d3..24751bf1d4 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -252,7 +252,9 @@ def merge_result(query, iterator, load=True): else: frozen_result = None - ctx = querycontext.ORMSelectCompileState._create_entities_collection(query) + ctx = querycontext.ORMSelectCompileState._create_entities_collection( + query, True + ) autoflush = session.autoflush try: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index d368182540..30cb9e7301 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -30,6 +30,7 @@ from .base import _assertions from .context import _column_descriptions from .context import _legacy_determine_last_joined_entity from .context import _legacy_filter_by_entity_zero +from .context import LABEL_STYLE_LEGACY_ORM from .context import ORMCompileState from .context import ORMFromStatementCompileState from .context import QueryContext @@ -59,7 +60,6 @@ from ..sql.selectable import ForUpdateArg from ..sql.selectable import HasHints from ..sql.selectable import HasPrefixes from ..sql.selectable import HasSuffixes -from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import SelectStatementGrouping from ..sql.visitors import InternalTraversal @@ -119,7 +119,7 @@ class Query( _from_obj = () _setup_joins = () _legacy_setup_joins = () - _label_style = LABEL_STYLE_NONE + _label_style = LABEL_STYLE_LEGACY_ORM _compile_options = ORMCompileState.default_compile_options @@ -2825,7 +2825,7 @@ class Query( """ - return _column_descriptions(self) + return _column_descriptions(self, legacy=True) def instances(self, result_proxy, context=None): """Return an ORM result given a :class:`_engine.CursorResult` and @@ -3199,6 +3199,10 @@ class FromStatement(SelectStatementGrouping, Executable): ("element", InternalTraversal.dp_clauseelement), ] + Executable._executable_traverse_internals + _cache_key_traversal = _traverse_internals + [ + ("_compile_options", InternalTraversal.dp_has_cache_key) + ] + def __init__(self, entities, element): self._raw_columns = [ coercions.expect( diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index c80b8f5a2d..51f75baf37 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1593,7 +1593,7 @@ class SubqueryLoader(PostLoader): # much of this we need. in particular I can't get a test to # fail if the "set_base_alias" is missing and not sure why that is. orig_compile_state = compile_state_cls._create_entities_collection( - orig_query + orig_query, legacy=False ) ( diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index c6eae739d5..2bd1c3ae31 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -3720,6 +3720,9 @@ class Grouping(GroupedElement, ColumnElement): def _key_label(self): return self._label + def _gen_label(self, name): + return name + @property def _label(self): return getattr(self.element, "_label", None) or self.anon_label diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index a273e0c903..7e2c5dd3bd 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2776,13 +2776,11 @@ class SelectBase( representing the columns that this SELECT statement or similar construct returns in its result set. - This collection differs from the - :attr:`_expression.FromClause.columns` collection - of a :class:`_expression.FromClause` - in that the columns within this collection - cannot be directly nested inside another SELECT statement; a subquery - must be applied first which provides for the necessary parenthesization - required by SQL. + This collection differs from the :attr:`_expression.FromClause.columns` + collection of a :class:`_expression.FromClause` in that the columns + within this collection cannot be directly nested inside another SELECT + statement; a subquery must be applied first which provides for the + necessary parenthesization required by SQL. .. versionadded:: 1.4 @@ -4078,6 +4076,60 @@ class SelectState(util.MemoizedSlots, CompileState): def from_statement(cls, statement, from_statement): cls._plugin_not_implemented() + @classmethod + def _column_naming_convention(cls, label_style): + names = set() + pa = [] + + if label_style is LABEL_STYLE_NONE: + + def go(c, col_name=None): + return col_name or c._proxy_key + + elif label_style is LABEL_STYLE_TABLENAME_PLUS_COL: + + def go(c, col_name=None): + # we use key_label since this name is intended for targeting + # within the ColumnCollection only, it's not related to SQL + # rendering which always uses column name for SQL label names + + if col_name: + name = c._gen_label(col_name) + else: + name = c._key_label + + if name in names: + if not pa: + pa.append(prefix_anon_map()) + + name = c._label_anon_label % pa[0] + else: + names.add(name) + + return name + + else: + + def go(c, col_name=None): + # we use key_label since this name is intended for targeting + # within the ColumnCollection only, it's not related to SQL + # rendering which always uses column name for SQL label names + if col_name: + name = col_name + else: + name = c._proxy_key + if name in names: + if not pa: + pa.append(prefix_anon_map()) + + name = c.anon_label % pa[0] + else: + names.add(name) + + return name + + return go + def _get_froms(self, statement): seen = set() froms = [] @@ -5519,63 +5571,41 @@ class Select( representing the columns that this SELECT statement or similar construct returns in its result set. - This collection differs from the - :attr:`_expression.FromClause.columns` collection - of a :class:`_expression.FromClause` - in that the columns within this collection - cannot be directly nested inside another SELECT statement; a subquery - must be applied first which provides for the necessary parenthesization - required by SQL. + This collection differs from the :attr:`_expression.FromClause.columns` + collection of a :class:`_expression.FromClause` in that the columns + within this collection cannot be directly nested inside another SELECT + statement; a subquery must be applied first which provides for the + necessary parenthesization required by SQL. For a :func:`_expression.select` construct, the collection here is - exactly what would be rendered inside the "SELECT" statement, and the - :class:`_expression.ColumnElement` - objects are directly present as they were - given, e.g.:: + exactly what would be rendered inside the "SELECT" statement, and the + :class:`_expression.ColumnElement` objects are directly present as they + were given, e.g.:: col1 = column('q', Integer) col2 = column('p', Integer) stmt = select(col1, col2) Above, ``stmt.selected_columns`` would be a collection that contains - the ``col1`` and ``col2`` objects directly. For a statement that is + the ``col1`` and ``col2`` objects directly. For a statement that is against a :class:`_schema.Table` or other - :class:`_expression.FromClause`, the collection - will use the :class:`_expression.ColumnElement` - objects that are in the + :class:`_expression.FromClause`, the collection will use the + :class:`_expression.ColumnElement` objects that are in the :attr:`_expression.FromClause.c` collection of the from element. .. versionadded:: 1.4 """ - names = set() - pa = None - collection = [] - - for c in self._exported_columns_iterator(): - # we use key_label since this name is intended for targeting - # within the ColumnCollection only, it's not related to SQL - # rendering which always uses column name for SQL label names - if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL: - name = c._key_label - else: - name = c._proxy_key - if name in names: - if pa is None: - pa = prefix_anon_map() - - if self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL: - name = c._label_anon_label % pa - else: - name = c.anon_label % pa - else: - names.add(name) - collection.append((name, c)) - return ColumnCollection(collection).as_immutable() + # compare to SelectState._generate_columns_plus_names, which + # generates the actual names used in the SELECT string. that + # method is more complex because it also renders columns that are + # fully ambiguous, e.g. same column more than once. + conv = SelectState._column_naming_convention(self._label_style) - # def _exported_columns_iterator(self): - # return _select_iterables(self._raw_columns) + return ColumnCollection( + [(conv(c), c) for c in self._exported_columns_iterator()] + ).as_immutable() def _exported_columns_iterator(self): meth = SelectState.get_plugin_class(self).exported_columns_iterator @@ -6170,19 +6200,16 @@ class TextualSelect(SelectBase): representing the columns that this SELECT statement or similar construct returns in its result set. - This collection differs from the - :attr:`_expression.FromClause.columns` collection - of a :class:`_expression.FromClause` - in that the columns within this collection - cannot be directly nested inside another SELECT statement; a subquery - must be applied first which provides for the necessary parenthesization - required by SQL. + This collection differs from the :attr:`_expression.FromClause.columns` + collection of a :class:`_expression.FromClause` in that the columns + within this collection cannot be directly nested inside another SELECT + statement; a subquery must be applied first which provides for the + necessary parenthesization required by SQL. - For a :class:`_expression.TextualSelect` construct, - the collection contains the - :class:`_expression.ColumnElement` - objects that were passed to the constructor, - typically via the :meth:`_expression.TextClause.columns` method. + For a :class:`_expression.TextualSelect` construct, the collection + contains the :class:`_expression.ColumnElement` objects that were + passed to the constructor, typically via the + :meth:`_expression.TextClause.columns` method. .. versionadded:: 1.4 diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py index 7ef9d1b604..8b1d185382 100644 --- a/test/orm/test_cache_key.py +++ b/test/orm/test_cache_key.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import join as orm_join from sqlalchemy.orm import joinedload from sqlalchemy.orm import Load from sqlalchemy.orm import mapper +from sqlalchemy.orm import Query from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session @@ -29,7 +30,10 @@ from ..sql.test_compare import CacheKeyFixture def stmt_20(*elements): - return tuple(elem._statement_20() for elem in elements) + return tuple( + elem._statement_20() if isinstance(elem, Query) else elem + for elem in elements + ) class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): @@ -294,6 +298,7 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): fixture_session() .query(User) .from_statement(text("select * from user")), + select(User).from_statement(text("select * from user")), fixture_session() .query(User) .options(selectinload(User.addresses)) diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 0fb5e7dd65..d86d2ff702 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -20,6 +20,9 @@ from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import LABEL_STYLE_DISAMBIGUATE_ONLY +from sqlalchemy import LABEL_STYLE_NONE +from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import null @@ -57,7 +60,6 @@ from sqlalchemy.orm.util import join from sqlalchemy.orm.util import with_parent from sqlalchemy.sql import expression from sqlalchemy.sql import operators -from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures @@ -489,6 +491,264 @@ class RowTupleTest(QueryTest): eq_(row, (User(id=7), [7])) +class RowLabelingTest(QueryTest): + @testing.fixture + def assert_row_keys(self): + def go(stmt, expected, coreorm_exec): + + if coreorm_exec == "core": + with testing.db.connect() as conn: + row = conn.execute(stmt).first() + else: + s = fixture_session() + + row = s.execute(stmt).first() + + eq_(row.keys(), expected) + + # we are disambiguating in exported_columns even if + # LABEL_STYLE_NONE, this seems weird also + if ( + stmt._label_style is not LABEL_STYLE_NONE + and coreorm_exec == "core" + ): + eq_(stmt.exported_columns.keys(), list(expected)) + + if ( + stmt._label_style is not LABEL_STYLE_NONE + and coreorm_exec == "orm" + ): + try: + column_descriptions = stmt.column_descriptions + except (NotImplementedError, AttributeError): + pass + else: + eq_( + [ + entity["name"] + for entity in column_descriptions + if entity["name"] is not None + ], + list(expected), + ) + + return go + + def test_entity(self, assert_row_keys): + User = self.classes.User + stmt = select(User) + + assert_row_keys(stmt, ("User",), "orm") + + @testing.combinations( + (LABEL_STYLE_NONE, ("id", "name")), + (LABEL_STYLE_DISAMBIGUATE_ONLY, ("id", "name")), + (LABEL_STYLE_TABLENAME_PLUS_COL, ("users_id", "users_name")), + argnames="label_style,expected", + ) + @testing.combinations(("core",), ("orm",), argnames="coreorm_exec") + @testing.combinations(("core",), ("orm",), argnames="coreorm_cols") + def test_explicit_cols( + self, + assert_row_keys, + label_style, + expected, + coreorm_cols, + coreorm_exec, + ): + User = self.classes.User + users = self.tables.users + + if coreorm_cols == "core": + stmt = select(users.c.id, users.c.name).set_label_style( + label_style + ) + else: + stmt = select(User.id, User.name).set_label_style(label_style) + + assert_row_keys(stmt, expected, coreorm_exec) + + def test_explicit_cols_legacy(self): + User = self.classes.User + + s = fixture_session() + q = s.query(User.id, User.name) + row = q.first() + + eq_(row.keys(), ("id", "name")) + + eq_( + [entity["name"] for entity in q.column_descriptions], + ["id", "name"], + ) + + @testing.combinations( + (LABEL_STYLE_NONE, ("id", "name", "id", "name")), + (LABEL_STYLE_DISAMBIGUATE_ONLY, ("id", "name", "id_1", "name_1")), + ( + LABEL_STYLE_TABLENAME_PLUS_COL, + ("u1_id", "u1_name", "u2_id", "u2_name"), + ), + argnames="label_style,expected", + ) + @testing.combinations(("core",), ("orm",), argnames="coreorm_exec") + @testing.combinations(("core",), ("orm",), argnames="coreorm_cols") + def test_explicit_ambiguous_cols_subq( + self, + assert_row_keys, + label_style, + expected, + coreorm_cols, + coreorm_exec, + ): + User = self.classes.User + users = self.tables.users + + if coreorm_cols == "core": + u1 = select(users.c.id, users.c.name).subquery("u1") + u2 = select(users.c.id, users.c.name).subquery("u2") + elif coreorm_cols == "orm": + u1 = select(User.id, User.name).subquery("u1") + u2 = select(User.id, User.name).subquery("u2") + + stmt = ( + select(u1, u2) + .join_from(u1, u2, u1.c.id == u2.c.id) + .set_label_style(label_style) + ) + assert_row_keys(stmt, expected, coreorm_exec) + + @testing.combinations( + (LABEL_STYLE_NONE, ("id", "name", "User", "id", "name", "a1")), + ( + LABEL_STYLE_DISAMBIGUATE_ONLY, + ("id", "name", "User", "id_1", "name_1", "a1"), + ), + ( + LABEL_STYLE_TABLENAME_PLUS_COL, + ("u1_id", "u1_name", "User", "u2_id", "u2_name", "a1"), + ), + argnames="label_style,expected", + ) + def test_explicit_ambiguous_cols_w_entities( + self, + assert_row_keys, + label_style, + expected, + ): + User = self.classes.User + u1 = select(User.id, User.name).subquery("u1") + u2 = select(User.id, User.name).subquery("u2") + + a1 = aliased(User, name="a1") + stmt = ( + select(u1, User, u2, a1) + .join_from(u1, u2, u1.c.id == u2.c.id) + .join(User, User.id == u1.c.id) + .join(a1, a1.id == u1.c.id) + .set_label_style(label_style) + ) + assert_row_keys(stmt, expected, "orm") + + @testing.combinations( + (LABEL_STYLE_NONE, ("id", "name", "id", "name")), + (LABEL_STYLE_DISAMBIGUATE_ONLY, ("id", "name", "id_1", "name_1")), + ( + LABEL_STYLE_TABLENAME_PLUS_COL, + ("u1_id", "u1_name", "u2_id", "u2_name"), + ), + argnames="label_style,expected", + ) + def test_explicit_ambiguous_cols_subq_fromstatement( + self, assert_row_keys, label_style, expected + ): + User = self.classes.User + + u1 = select(User.id, User.name).subquery("u1") + u2 = select(User.id, User.name).subquery("u2") + + stmt = ( + select(u1, u2) + .join_from(u1, u2, u1.c.id == u2.c.id) + .set_label_style(label_style) + ) + + stmt = select(u1, u2).from_statement(stmt) + + assert_row_keys(stmt, expected, "orm") + + @testing.combinations( + (LABEL_STYLE_NONE, ("id", "name", "id", "name")), + (LABEL_STYLE_DISAMBIGUATE_ONLY, ("id", "name", "id", "name")), + (LABEL_STYLE_TABLENAME_PLUS_COL, ("id", "name", "id", "name")), + argnames="label_style,expected", + ) + def test_explicit_ambiguous_cols_subq_fromstatement_legacy( + self, label_style, expected + ): + User = self.classes.User + + u1 = select(User.id, User.name).subquery("u1") + u2 = select(User.id, User.name).subquery("u2") + + stmt = ( + select(u1, u2) + .join_from(u1, u2, u1.c.id == u2.c.id) + .set_label_style(label_style) + ) + + s = fixture_session() + row = s.query(u1, u2).from_statement(stmt).first() + eq_(row.keys(), expected) + + def test_explicit_ambiguous_orm_cols_legacy(self): + User = self.classes.User + + u1 = select(User.id, User.name).subquery("u1") + u2 = select(User.id, User.name).subquery("u2") + + s = fixture_session() + row = s.query(u1, u2).join(u2, u1.c.id == u2.c.id).first() + eq_(row.keys(), ["id", "name", "id", "name"]) + + def test_entity_anon_aliased(self, assert_row_keys): + User = self.classes.User + + u1 = aliased(User) + stmt = select(u1) + + assert_row_keys(stmt, (), "orm") + + def test_entity_name_aliased(self, assert_row_keys): + User = self.classes.User + + u1 = aliased(User, name="u1") + stmt = select(u1) + + assert_row_keys(stmt, ("u1",), "orm") + + @testing.combinations( + (LABEL_STYLE_NONE, ("u1", "u2")), + (LABEL_STYLE_DISAMBIGUATE_ONLY, ("u1", "u2")), + (LABEL_STYLE_TABLENAME_PLUS_COL, ("u1", "u2")), + argnames="label_style,expected", + ) + def test_multi_entity_name_aliased( + self, assert_row_keys, label_style, expected + ): + User = self.classes.User + + u1 = aliased(User, name="u1") + u2 = aliased(User, name="u2") + stmt = ( + select(u1, u2) + .join_from(u1, u2, u1.id == u2.id) + .set_label_style(label_style) + ) + + assert_row_keys(stmt, expected, "orm") + + class GetTest(QueryTest): def test_loader_options(self): User = self.classes.User diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 30235995db..9a4b8b1996 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -61,6 +61,7 @@ from sqlalchemy.sql.lambdas import LambdaOptions from sqlalchemy.sql.selectable import _OffsetLimitParam from sqlalchemy.sql.selectable import AliasedReturnsRows from sqlalchemy.sql.selectable import FromGrouping +from sqlalchemy.sql.selectable import LABEL_STYLE_NONE from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.sql.selectable import Select from sqlalchemy.sql.selectable import Selectable @@ -400,6 +401,7 @@ class CoreFixtures(object): select(table_a.c.b, table_a.c.a).set_label_style( LABEL_STYLE_TABLENAME_PLUS_COL ), + select(table_a.c.b, table_a.c.a).set_label_style(LABEL_STYLE_NONE), select(table_a.c.a).where(table_a.c.b == 5), select(table_a.c.a) .where(table_a.c.b == 5) diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index ce33ed10e5..e15c740752 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -32,6 +32,7 @@ from sqlalchemy.sql import annotation from sqlalchemy.sql import base from sqlalchemy.sql import column from sqlalchemy.sql import elements +from sqlalchemy.sql import LABEL_STYLE_DISAMBIGUATE_ONLY from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.sql import operators from sqlalchemy.sql import table @@ -2916,6 +2917,7 @@ class ReprTest(fixtures.TestBase): class WithLabelsTest(fixtures.TestBase): def _assert_result_keys(self, s, keys): compiled = s.compile() + eq_(set(compiled._create_result_map()), set(keys)) def _assert_subq_result_keys(self, s, keys): @@ -2934,10 +2936,13 @@ class WithLabelsTest(fixtures.TestBase): self._assert_subq_result_keys(sel, ["x", "x_1"]) + eq_(sel.selected_columns.keys(), ["x", "x"]) + def test_names_overlap_label(self): sel = self._names_overlap().set_label_style( LABEL_STYLE_TABLENAME_PLUS_COL ) + eq_(sel.selected_columns.keys(), ["t1_x", "t2_x"]) eq_(list(sel.selected_columns.keys()), ["t1_x", "t2_x"]) eq_(list(sel.subquery().c.keys()), ["t1_x", "t2_x"]) self._assert_result_keys(sel, ["t1_x", "t2_x"]) @@ -2951,6 +2956,7 @@ class WithLabelsTest(fixtures.TestBase): def test_names_overlap_keys_dont_nolabel(self): sel = self._names_overlap_keys_dont() + eq_(sel.selected_columns.keys(), ["a", "b"]) eq_(list(sel.selected_columns.keys()), ["a", "b"]) eq_(list(sel.subquery().c.keys()), ["a", "b"]) self._assert_result_keys(sel, ["x"]) @@ -2959,10 +2965,41 @@ class WithLabelsTest(fixtures.TestBase): sel = self._names_overlap_keys_dont().set_label_style( LABEL_STYLE_TABLENAME_PLUS_COL ) + eq_(sel.selected_columns.keys(), ["t1_a", "t2_b"]) eq_(list(sel.selected_columns.keys()), ["t1_a", "t2_b"]) eq_(list(sel.subquery().c.keys()), ["t1_a", "t2_b"]) self._assert_result_keys(sel, ["t1_x", "t2_x"]) + def _columns_repeated(self): + m = MetaData() + t1 = Table("t1", m, Column("x", Integer), Column("y", Integer)) + return select(t1.c.x, t1.c.y, t1.c.x).set_label_style(LABEL_STYLE_NONE) + + def test_element_repeated_nolabels(self): + sel = self._columns_repeated().set_label_style(LABEL_STYLE_NONE) + eq_(sel.selected_columns.keys(), ["x", "y", "x"]) + eq_(list(sel.selected_columns.keys()), ["x", "y", "x"]) + eq_(list(sel.subquery().c.keys()), ["x", "y", "x_1"]) + self._assert_result_keys(sel, ["x", "y"]) + + def test_element_repeated_disambiguate(self): + sel = self._columns_repeated().set_label_style( + LABEL_STYLE_DISAMBIGUATE_ONLY + ) + eq_(sel.selected_columns.keys(), ["x", "y", "x_1"]) + eq_(list(sel.selected_columns.keys()), ["x", "y", "x_1"]) + eq_(list(sel.subquery().c.keys()), ["x", "y", "x_1"]) + self._assert_result_keys(sel, ["x", "y", "x__1"]) + + def test_element_repeated_labels(self): + sel = self._columns_repeated().set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + eq_(sel.selected_columns.keys(), ["t1_x", "t1_y", "t1_x_1"]) + eq_(list(sel.selected_columns.keys()), ["t1_x", "t1_y", "t1_x_1"]) + eq_(list(sel.subquery().c.keys()), ["t1_x", "t1_y", "t1_x_1"]) + self._assert_result_keys(sel, ["t1_x__1", "t1_x", "t1_y"]) + def _labels_overlap(self): m = MetaData() t1 = Table("t", m, Column("x_id", Integer)) @@ -2971,6 +3008,7 @@ class WithLabelsTest(fixtures.TestBase): def test_labels_overlap_nolabel(self): sel = self._labels_overlap() + eq_(sel.selected_columns.keys(), ["x_id", "id"]) eq_(list(sel.selected_columns.keys()), ["x_id", "id"]) eq_(list(sel.subquery().c.keys()), ["x_id", "id"]) self._assert_result_keys(sel, ["x_id", "id"]) @@ -3077,6 +3115,7 @@ class WithLabelsTest(fixtures.TestBase): def test_keys_overlap_names_dont_nolabel(self): sel = self._keys_overlap_names_dont() + eq_(sel.selected_columns.keys(), ["x", "b_1"]) self._assert_result_keys(sel, ["a", "b"]) def test_keys_overlap_names_dont_label(self): -- 2.47.2