]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
memoize current options and joins w with_entities/with_only_cols
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Jun 2021 19:13:34 +0000 (15:13 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 17 Jun 2021 13:48:52 +0000 (09:48 -0400)
Fixed further regressions in the same area as that of :ticket:`6052` where
loader options as well as invocations of methods like
:meth:`_orm.Query.join` would fail if the left side of the statement for
which the option/join depends upon were replaced by using the
:meth:`_orm.Query.with_entities` method, or when using 2.0 style queries
when using the :meth:`_sql.Select.with_only_columns` method. A new set of
state has been added to the objects which tracks the "left" entities that
the options / join were made against which is memoized when the lead
entities are changed.

Fixes: #6503
Fixes: #6253
Change-Id: I211b2af98b0b20d1263fb15dc513884dcc5de6a4

16 files changed:
doc/build/changelog/unreleased_14/6503.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/sql/visitors.py
test/orm/test_cache_key.py
test/orm/test_joins.py
test/orm/test_options.py
test/profiles.txt
test/sql/test_compare.py
test/sql/test_external_traversal.py
test/sql/test_select.py

diff --git a/doc/build/changelog/unreleased_14/6503.rst b/doc/build/changelog/unreleased_14/6503.rst
new file mode 100644 (file)
index 0000000..a2d50bc
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 6503, 6253
+
+    Fixed further regressions in the same area as that of :ticket:`6052` where
+    loader options as well as invocations of methods like
+    :meth:`_orm.Query.join` would fail if the left side of the statement for
+    which the option/join depends upon were replaced by using the
+    :meth:`_orm.Query.with_entities` method, or when using 2.0 style queries
+    when using the :meth:`_sql.Select.with_only_columns` method. A new set of
+    state has been added to the objects which tracks the "left" entities that
+    the options / join were made against which is memoized when the lead
+    entities are changed.
index e4448f9536008fb6a315cbde232168d105384bfa..321eeada01203e1166c404316d82601f40a5eac6 100644 (file)
@@ -322,10 +322,16 @@ class ORMCompileState(CompileState):
         return loading.instances(result, querycontext)
 
     @property
-    def _mapper_entities(self):
-        return (
+    def _lead_mapper_entities(self):
+        """return all _MapperEntity objects in the lead entities collection.
+
+        Does **not** include entities that have been replaced by
+        with_entities(), with_only_columns()
+
+        """
+        return [
             ent for ent in self._entities if isinstance(ent, _MapperEntity)
-        )
+        ]
 
     def _create_with_polymorphic_adapter(self, ext_info, selectable):
         if (
@@ -405,7 +411,9 @@ class ORMFromStatementCompileState(ORMCompileState):
             self.use_legacy_query_style,
         )
 
-        _QueryEntity.to_compile_state(self, statement_container._raw_columns)
+        _QueryEntity.to_compile_state(
+            self, statement_container._raw_columns, self._entities
+        )
 
         self.current_path = statement_container._compile_options._current_path
 
@@ -477,6 +485,8 @@ class ORMFromStatementCompileState(ORMCompileState):
 class ORMSelectCompileState(ORMCompileState, SelectState):
     _joinpath = _joinpoint = _EMPTY_DICT
 
+    _memoized_entities = _EMPTY_DICT
+
     _from_obj_alias = None
     _has_mapper_entities = False
 
@@ -572,15 +582,48 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
             statement._label_style, self.use_legacy_query_style
         )
 
-        _QueryEntity.to_compile_state(self, select_statement._raw_columns)
+        if select_statement._memoized_select_entities:
+            self._memoized_entities = {
+                memoized_entities: _QueryEntity.to_compile_state(
+                    self,
+                    memoized_entities._raw_columns,
+                    [],
+                )
+                for memoized_entities in (
+                    select_statement._memoized_select_entities
+                )
+            }
+
+        _QueryEntity.to_compile_state(
+            self, select_statement._raw_columns, self._entities
+        )
 
         self.current_path = select_statement._compile_options._current_path
 
         self.eager_order_by = ()
 
-        if toplevel and select_statement._with_options:
+        if toplevel and (
+            select_statement._with_options
+            or select_statement._memoized_select_entities
+        ):
             self.attributes = {"_unbound_load_dedupes": set()}
 
+            for (
+                memoized_entities
+            ) in select_statement._memoized_select_entities:
+                for opt in memoized_entities._with_options:
+                    if opt._is_compile_state:
+                        opt.process_compile_state_replaced_entities(
+                            self,
+                            [
+                                ent
+                                for ent in self._memoized_entities[
+                                    memoized_entities
+                                ]
+                                if isinstance(ent, _MapperEntity)
+                            ],
+                        )
+
             for opt in self.select_statement._with_options:
                 if opt._is_compile_state:
                     opt.process_compile_state(self)
@@ -626,11 +669,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         if self.compile_options._set_base_alias:
             self._set_select_from_alias()
 
+        for memoized_entities in query._memoized_select_entities:
+            if memoized_entities._setup_joins:
+                self._join(
+                    memoized_entities._setup_joins,
+                    self._memoized_entities[memoized_entities],
+                )
+            if memoized_entities._legacy_setup_joins:
+                self._legacy_join(
+                    memoized_entities._legacy_setup_joins,
+                    self._memoized_entities[memoized_entities],
+                )
+
         if query._setup_joins:
-            self._join(query._setup_joins)
+            self._join(query._setup_joins, self._entities)
 
         if query._legacy_setup_joins:
-            self._legacy_join(query._legacy_setup_joins)
+            self._legacy_join(query._legacy_setup_joins, self._entities)
 
         current_adapter = self._get_current_adapter()
 
@@ -782,7 +837,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
         # entities will also set up polymorphic adapters for mappers
         # that have with_polymorphic configured
-        _QueryEntity.to_compile_state(self, query._raw_columns)
+        _QueryEntity.to_compile_state(self, query._raw_columns, self._entities)
         return self
 
     @classmethod
@@ -921,7 +976,18 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
     def _all_equivs(self):
         equivs = {}
-        for ent in self._mapper_entities:
+
+        for memoized_entities in self._memoized_entities.values():
+            for ent in [
+                ent
+                for ent in memoized_entities
+                if isinstance(ent, _MapperEntity)
+            ]:
+                equivs.update(ent.mapper._equivalent_columns)
+
+        for ent in [
+            ent for ent in self._entities if isinstance(ent, _MapperEntity)
+        ]:
             equivs.update(ent.mapper._equivalent_columns)
         return equivs
 
@@ -1211,7 +1277,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
         return _adapt_clause
 
-    def _join(self, args):
+    def _join(self, args, entities_collection):
         for (right, onclause, from_, flags) in args:
             isouter = flags["isouter"]
             full = flags["full"]
@@ -1316,6 +1382,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
             # figure out the final "left" and "right" sides and create an
             # ORMJoin to add to our _from_obj tuple
             self._join_left_to_right(
+                entities_collection,
                 left,
                 right,
                 onclause,
@@ -1326,7 +1393,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
                 full,
             )
 
-    def _legacy_join(self, args):
+    def _legacy_join(self, args, entities_collection):
         """consumes arguments from join() or outerjoin(), places them into a
         consistent format with which to form the actual JOIN constructs.
 
@@ -1474,6 +1541,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
             # figure out the final "left" and "right" sides and create an
             # ORMJoin to add to our _from_obj tuple
             self._join_left_to_right(
+                entities_collection,
                 left,
                 right,
                 onclause,
@@ -1489,6 +1557,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
     def _join_left_to_right(
         self,
+        entities_collection,
         left,
         right,
         onclause,
@@ -1513,7 +1582,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
                 left,
                 replace_from_obj_index,
                 use_entity_index,
-            ) = self._join_determine_implicit_left_side(left, right, onclause)
+            ) = self._join_determine_implicit_left_side(
+                entities_collection, left, right, onclause
+            )
         else:
             # left is given via a relationship/name, or as explicit left side.
             # Determine where in our
@@ -1522,7 +1593,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
             (
                 replace_from_obj_index,
                 use_entity_index,
-            ) = self._join_place_explicit_left_side(left)
+            ) = self._join_place_explicit_left_side(entities_collection, left)
 
         if left is right and not create_aliases:
             raise sa_exc.InvalidRequestError(
@@ -1568,9 +1639,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
                 # entity_zero.selectable, but if with_polymorphic() were used
                 # might be distinct
                 assert isinstance(
-                    self._entities[use_entity_index], _MapperEntity
+                    entities_collection[use_entity_index], _MapperEntity
                 )
-                left_clause = self._entities[use_entity_index].selectable
+                left_clause = entities_collection[use_entity_index].selectable
             else:
                 left_clause = left
 
@@ -1585,7 +1656,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
                 )
             ]
 
-    def _join_determine_implicit_left_side(self, left, right, onclause):
+    def _join_determine_implicit_left_side(
+        self, entities_collection, left, right, onclause
+    ):
         """When join conditions don't express the left side explicitly,
         determine if an existing FROM or entity in this query
         can serve as the left hand side.
@@ -1635,12 +1708,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
                     "to help resolve the ambiguity." % (right,)
                 )
 
-        elif self._entities:
+        elif entities_collection:
             # we have no explicit FROMs, so the implicit left has to
             # come from our list of entities.
 
             potential = {}
-            for entity_index, ent in enumerate(self._entities):
+            for entity_index, ent in enumerate(entities_collection):
                 entity = ent.entity_zero_or_selectable
                 if entity is None:
                     continue
@@ -1689,7 +1762,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
         return left, replace_from_obj_index, use_entity_index
 
-    def _join_place_explicit_left_side(self, left):
+    def _join_place_explicit_left_side(self, entities_collection, left):
         """When join conditions express a left side explicitly, determine
         where in our existing list of FROM clauses we should join towards,
         or if we need to make a new join, and if so is it from one of our
@@ -1743,10 +1816,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         # aliasing / adaptation rules present on that entity if any
         if (
             replace_from_obj_index is None
-            and self._entities
+            and entities_collection
             and hasattr(l_info, "mapper")
         ):
-            for idx, ent in enumerate(self._entities):
+            for idx, ent in enumerate(entities_collection):
                 # TODO: should we be checking for multiple mapper entities
                 # matching?
                 if isinstance(ent, _MapperEntity) and ent.corresponds_to(left):
@@ -2194,11 +2267,14 @@ class _QueryEntity(object):
     __slots__ = ()
 
     @classmethod
-    def to_compile_state(cls, compile_state, entities):
+    def to_compile_state(cls, compile_state, entities, entities_collection):
+
         for idx, entity in enumerate(entities):
             if entity._is_lambda_element:
                 if entity._is_sequence:
-                    cls.to_compile_state(compile_state, entity._resolved)
+                    cls.to_compile_state(
+                        compile_state, entity._resolved, entities_collection
+                    )
                     continue
                 else:
                     entity = entity._resolved
@@ -2206,26 +2282,38 @@ class _QueryEntity(object):
             if entity.is_clause_element:
                 if entity.is_selectable:
                     if "parententity" in entity._annotations:
-                        _MapperEntity(compile_state, entity)
+                        _MapperEntity(
+                            compile_state, entity, entities_collection
+                        )
                     else:
                         _ColumnEntity._for_columns(
-                            compile_state, entity._select_iterable, idx
+                            compile_state,
+                            entity._select_iterable,
+                            entities_collection,
+                            idx,
                         )
                 else:
                     if entity._annotations.get("bundle", False):
-                        _BundleEntity(compile_state, entity)
+                        _BundleEntity(
+                            compile_state, entity, entities_collection
+                        )
                     elif entity._is_clause_list:
                         # this is legacy only - test_composites.py
                         # test_query_cols_legacy
                         _ColumnEntity._for_columns(
-                            compile_state, entity._select_iterable, idx
+                            compile_state,
+                            entity._select_iterable,
+                            entities_collection,
+                            idx,
                         )
                     else:
                         _ColumnEntity._for_columns(
-                            compile_state, [entity], idx
+                            compile_state, [entity], entities_collection, idx
                         )
             elif entity.is_bundle:
-                _BundleEntity(compile_state, entity)
+                _BundleEntity(compile_state, entity, entities_collection)
+
+        return entities_collection
 
 
 class _MapperEntity(_QueryEntity):
@@ -2244,8 +2332,8 @@ class _MapperEntity(_QueryEntity):
         "_polymorphic_discriminator",
     )
 
-    def __init__(self, compile_state, entity):
-        compile_state._entities.append(self)
+    def __init__(self, compile_state, entity, entities_collection):
+        entities_collection.append(self)
         if compile_state._primary_entity is None:
             compile_state._primary_entity = self
         compile_state._has_mapper_entities = True
@@ -2418,7 +2506,12 @@ class _BundleEntity(_QueryEntity):
     )
 
     def __init__(
-        self, compile_state, expr, setup_entities=True, parent_bundle=None
+        self,
+        compile_state,
+        expr,
+        entities_collection,
+        setup_entities=True,
+        parent_bundle=None,
     ):
         compile_state._has_orm_entities = True
 
@@ -2426,7 +2519,7 @@ class _BundleEntity(_QueryEntity):
         if parent_bundle:
             parent_bundle._entities.append(self)
         else:
-            compile_state._entities.append(self)
+            entities_collection.append(self)
 
         if isinstance(
             expr, (attributes.QueryableAttribute, interfaces.PropComparator)
@@ -2443,12 +2536,26 @@ class _BundleEntity(_QueryEntity):
         if setup_entities:
             for expr in bundle.exprs:
                 if "bundle" in expr._annotations:
-                    _BundleEntity(compile_state, expr, parent_bundle=self)
+                    _BundleEntity(
+                        compile_state,
+                        expr,
+                        entities_collection,
+                        parent_bundle=self,
+                    )
                 elif isinstance(expr, Bundle):
-                    _BundleEntity(compile_state, expr, parent_bundle=self)
+                    _BundleEntity(
+                        compile_state,
+                        expr,
+                        entities_collection,
+                        parent_bundle=self,
+                    )
                 else:
                     _ORMColumnEntity._for_columns(
-                        compile_state, [expr], None, parent_bundle=self
+                        compile_state,
+                        [expr],
+                        entities_collection,
+                        None,
+                        parent_bundle=self,
                     )
 
         self.supports_single_entity = self.bundle.single_entity
@@ -2516,7 +2623,12 @@ class _ColumnEntity(_QueryEntity):
 
     @classmethod
     def _for_columns(
-        cls, compile_state, columns, raw_column_index, parent_bundle=None
+        cls,
+        compile_state,
+        columns,
+        entities_collection,
+        raw_column_index,
+        parent_bundle=None,
     ):
         for column in columns:
             annotations = column._annotations
@@ -2532,6 +2644,7 @@ class _ColumnEntity(_QueryEntity):
                     _IdentityTokenEntity(
                         compile_state,
                         column,
+                        entities_collection,
                         _entity,
                         raw_column_index,
                         parent_bundle=parent_bundle,
@@ -2540,6 +2653,7 @@ class _ColumnEntity(_QueryEntity):
                     _ORMColumnEntity(
                         compile_state,
                         column,
+                        entities_collection,
                         _entity,
                         raw_column_index,
                         parent_bundle=parent_bundle,
@@ -2548,6 +2662,7 @@ class _ColumnEntity(_QueryEntity):
                 _RawColumnEntity(
                     compile_state,
                     column,
+                    entities_collection,
                     raw_column_index,
                     parent_bundle=parent_bundle,
                 )
@@ -2630,7 +2745,12 @@ class _RawColumnEntity(_ColumnEntity):
     )
 
     def __init__(
-        self, compile_state, column, raw_column_index, parent_bundle=None
+        self,
+        compile_state,
+        column,
+        entities_collection,
+        raw_column_index,
+        parent_bundle=None,
     ):
         self.expr = column
         self.raw_column_index = raw_column_index
@@ -2643,7 +2763,7 @@ class _RawColumnEntity(_ColumnEntity):
         if parent_bundle:
             parent_bundle._entities.append(self)
         else:
-            compile_state._entities.append(self)
+            entities_collection.append(self)
 
         self.column = column
         self.entity_zero_or_selectable = (
@@ -2690,6 +2810,7 @@ class _ORMColumnEntity(_ColumnEntity):
         self,
         compile_state,
         column,
+        entities_collection,
         parententity,
         raw_column_index,
         parent_bundle=None,
@@ -2729,7 +2850,7 @@ class _ORMColumnEntity(_ColumnEntity):
         if parent_bundle:
             parent_bundle._entities.append(self)
         else:
-            compile_state._entities.append(self)
+            entities_collection.append(self)
 
         compile_state._has_orm_entities = True
 
index c9a601f9956295dd5cecac7824603a3ce08880ce..28b4bfb2d05dd20622d083612deb1f9bfe39c641 100644 (file)
@@ -750,6 +750,18 @@ class LoaderOption(ORMOption):
 
     _is_compile_state = True
 
+    def process_compile_state_replaced_entities(
+        self, compile_state, mapper_entities
+    ):
+        """Apply a modification to a given :class:`.CompileState`,
+        given entities that were replaced by with_only_columns() or
+        with_entities().
+
+        .. versionadded:: 1.4.19
+
+        """
+        self.process_compile_state(compile_state)
+
     def process_compile_state(self, compile_state):
         """Apply a modification to a given :class:`.CompileState`."""
 
index cacfb8d84e74784b514cc196381f6a7bd41810dc..7ba31fa7a0e01f75f152064f4158c29a3712853d 100644 (file)
@@ -57,6 +57,7 @@ from ..sql.annotation import SupportsCloneAnnotations
 from ..sql.base import _entity_namespace_key
 from ..sql.base import _generative
 from ..sql.base import Executable
+from ..sql.selectable import _MemoizedSelectEntities
 from ..sql.selectable import _SelectFromElements
 from ..sql.selectable import ForUpdateArg
 from ..sql.selectable import GroupedElement
@@ -125,6 +126,8 @@ class Query(
     _legacy_setup_joins = ()
     _label_style = LABEL_STYLE_LEGACY_ORM
 
+    _memoized_select_entities = ()
+
     _compile_options = ORMCompileState.default_compile_options
 
     load_options = QueryContext.default_load_options
@@ -1433,6 +1436,7 @@ class Query(
                         limit(1)
 
         """
+        _MemoizedSelectEntities._generate_for_statement(self)
         self._set_entities(entities)
 
     @_generative
index e371442fdde39adc592b6d498770e2b0baf91b7b..91e62752502faa8c4d1ab8badd81c98babfd9225 100644 (file)
@@ -172,13 +172,32 @@ class Load(Generative, LoaderOption):
     _of_type = None
     _extra_criteria = ()
 
+    def process_compile_state_replaced_entities(
+        self, compile_state, mapper_entities
+    ):
+        if not compile_state.compile_options._enable_eagerloads:
+            return
+
+        # process is being run here so that the options given are validated
+        # against what the lead entities were, as well as to accommodate
+        # for the entities having been replaced with equivalents
+        self._process(
+            compile_state,
+            mapper_entities,
+            not bool(compile_state.current_path),
+        )
+
     def process_compile_state(self, compile_state):
         if not compile_state.compile_options._enable_eagerloads:
             return
 
-        self._process(compile_state, not bool(compile_state.current_path))
+        self._process(
+            compile_state,
+            compile_state._lead_mapper_entities,
+            not bool(compile_state.current_path),
+        )
 
-    def _process(self, compile_state, raiseerr):
+    def _process(self, compile_state, mapper_entities, raiseerr):
         is_refresh = compile_state.compile_options._for_refresh_state
         current_path = compile_state.current_path
         if current_path:
@@ -700,7 +719,7 @@ class _UnboundLoad(Load):
         state["path"] = tuple(ret)
         self.__dict__ = state
 
-    def _process(self, compile_state, raiseerr):
+    def _process(self, compile_state, mapper_entities, raiseerr):
         dedupes = compile_state.attributes["_unbound_load_dedupes"]
         is_refresh = compile_state.compile_options._for_refresh_state
         for val in self._to_bind:
@@ -709,10 +728,7 @@ class _UnboundLoad(Load):
                 if is_refresh and not val.propagate_to_loaders:
                     continue
                 val._bind_loader(
-                    [
-                        ent.entity_zero
-                        for ent in compile_state._mapper_entities
-                    ],
+                    [ent.entity_zero for ent in mapper_entities],
                     compile_state.current_path,
                     compile_state.attributes,
                     raiseerr,
index 213f47c4097f5deda43dfe6e6f0feb08114cccd9..709106b6b9ba7073745e3d4a16244de64506ffa2 100644 (file)
@@ -32,7 +32,6 @@ from .base import NO_ARG
 from .base import PARSE_AUTOCOMMIT
 from .base import SingletonConstant
 from .coercions import _document_text_coercion
-from .traversals import _get_children
 from .traversals import HasCopyInternals
 from .traversals import MemoizedHasCacheKey
 from .traversals import NO_CACHE
@@ -389,33 +388,6 @@ class ClauseElement(
         """
         return traversals.compare(self, other, **kw)
 
-    def get_children(self, omit_attrs=(), **kw):
-        r"""Return immediate child :class:`.visitors.Traversible`
-        elements of this :class:`.visitors.Traversible`.
-
-        This is used for visit traversal.
-
-        \**kw may contain flags that change the collection that is
-        returned, for example to return a subset of items in order to
-        cut down on larger traversals, or to return child items from a
-        different context (such as schema-level collections instead of
-        clause-level).
-
-        """
-        try:
-            traverse_internals = self._traverse_internals
-        except AttributeError:
-            # user-defined classes may not have a _traverse_internals
-            return []
-
-        return itertools.chain.from_iterable(
-            meth(obj, **kw)
-            for attrname, obj, meth in _get_children.run_generated_dispatch(
-                self, traverse_internals, "_generated_get_children_traversal"
-            )
-            if attrname not in omit_attrs and obj is not None
-        )
-
     def self_group(self, against=None):
         """Apply a 'grouping' to this :class:`_expression.ClauseElement`.
 
index 1610191d1e7e58b3a95bf8cfe596cdda119a57f4..e1dee091bdadafd8efc127273e1e6176cd0f01b5 100644 (file)
@@ -18,7 +18,9 @@ from operator import attrgetter
 from . import coercions
 from . import operators
 from . import roles
+from . import traversals
 from . import type_api
+from . import visitors
 from .annotation import Annotated
 from .annotation import SupportsCloneAnnotations
 from .base import _clone
@@ -4131,8 +4133,13 @@ class SelectState(util.MemoizedSlots, CompileState):
         self.statement = statement
         self.from_clauses = statement._from_obj
 
+        for memoized_entities in statement._memoized_select_entities:
+            self._setup_joins(
+                memoized_entities._setup_joins, memoized_entities._raw_columns
+            )
+
         if statement._setup_joins:
-            self._setup_joins(statement._setup_joins)
+            self._setup_joins(statement._setup_joins, statement._raw_columns)
 
         self.froms = self._get_froms(statement)
 
@@ -4361,7 +4368,7 @@ class SelectState(util.MemoizedSlots, CompileState):
     def all_selected_columns(cls, statement):
         return [c for c in _select_iterables(statement._raw_columns)]
 
-    def _setup_joins(self, args):
+    def _setup_joins(self, args, raw_columns):
         for (right, onclause, left, flags) in args:
             isouter = flags["isouter"]
             full = flags["full"]
@@ -4371,7 +4378,7 @@ class SelectState(util.MemoizedSlots, CompileState):
                     left,
                     replace_from_obj_index,
                 ) = self._join_determine_implicit_left_side(
-                    left, right, onclause
+                    raw_columns, left, right, onclause
                 )
             else:
                 (replace_from_obj_index) = self._join_place_explicit_left_side(
@@ -4403,7 +4410,9 @@ class SelectState(util.MemoizedSlots, CompileState):
                 )
 
     @util.preload_module("sqlalchemy.sql.util")
-    def _join_determine_implicit_left_side(self, left, right, onclause):
+    def _join_determine_implicit_left_side(
+        self, raw_columns, left, right, onclause
+    ):
         """When join conditions don't express the left side explicitly,
         determine if an existing FROM or entity in this query
         can serve as the left hand side.
@@ -4431,10 +4440,7 @@ class SelectState(util.MemoizedSlots, CompileState):
 
             for from_clause in itertools.chain(
                 itertools.chain.from_iterable(
-                    [
-                        element._from_objects
-                        for element in statement._raw_columns
-                    ]
+                    [element._from_objects for element in raw_columns]
                 ),
                 itertools.chain.from_iterable(
                     [
@@ -4531,6 +4537,47 @@ class _SelectFromElements(object):
             yield element
 
 
+class _MemoizedSelectEntities(
+    traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible
+):
+    __visit_name__ = "memoized_select_entities"
+
+    _traverse_internals = [
+        ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+        ("_setup_joins", InternalTraversal.dp_setup_join_tuple),
+        ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple),
+        ("_with_options", InternalTraversal.dp_executable_options),
+    ]
+
+    _annotations = util.EMPTY_DICT
+
+    def _clone(self, **kw):
+        c = self.__class__.__new__(self.__class__)
+        c.__dict__ = {k: v for k, v in self.__dict__.items()}
+        c._is_clone_of = self
+        return c
+
+    @classmethod
+    def _generate_for_statement(cls, select_stmt):
+        if (
+            select_stmt._setup_joins
+            or select_stmt._legacy_setup_joins
+            or select_stmt._with_options
+        ):
+            self = _MemoizedSelectEntities()
+            self._raw_columns = select_stmt._raw_columns
+            self._setup_joins = select_stmt._setup_joins
+            self._legacy_setup_joins = select_stmt._legacy_setup_joins
+            self._with_options = select_stmt._with_options
+
+            select_stmt._memoized_select_entities += (self,)
+            select_stmt._raw_columns = (
+                select_stmt._setup_joins
+            ) = (
+                select_stmt._legacy_setup_joins
+            ) = select_stmt._with_options = ()
+
+
 class Select(
     HasPrefixes,
     HasSuffixes,
@@ -4559,6 +4606,7 @@ class Select(
 
     _setup_joins = ()
     _legacy_setup_joins = ()
+    _memoized_select_entities = ()
 
     _distinct = False
     _distinct_on = ()
@@ -4574,6 +4622,10 @@ class Select(
     _traverse_internals = (
         [
             ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+            (
+                "_memoized_select_entities",
+                InternalTraversal.dp_memoized_select_entities,
+            ),
             ("_from_obj", InternalTraversal.dp_clauseelement_list),
             ("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
             ("_having_criteria", InternalTraversal.dp_clauseelement_tuple),
@@ -5461,16 +5513,14 @@ class Select(
         # is the case for now.
         self._assert_no_memoizations()
 
-        rc = []
-        for c in coercions._expression_collection_was_a_list(
-            "columns", "Select.with_only_columns", columns
-        ):
-            c = coercions.expect(roles.ColumnsClauseRole, c)
-            # TODO: why are we doing this here?
-            if isinstance(c, ScalarSelect):
-                c = c.self_group(against=operators.comma_op)
-            rc.append(c)
-        self._raw_columns = rc
+        _MemoizedSelectEntities._generate_for_statement(self)
+
+        self._raw_columns = [
+            coercions.expect(roles.ColumnsClauseRole, c)
+            for c in coercions._expression_collection_was_a_list(
+                "columns", "Select.with_only_columns", columns
+            )
+        ]
 
     @property
     def whereclause(self):
index 35f2bd62f94cf08f499195596dbe3dba2762831e..a86d16ef4c06a8345df183ee0c18d725e0e58e2a 100644 (file)
@@ -194,6 +194,8 @@ class HasCacheKey(object):
                     elif (
                         meth is InternalTraversal.dp_clauseelement_list
                         or meth is InternalTraversal.dp_clauseelement_tuple
+                        or meth
+                        is InternalTraversal.dp_memoized_select_entities
                     ):
                         result += (
                             attrname,
@@ -409,6 +411,9 @@ class _CacheKey(ExtendedInternalTraversal):
     visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
     visit_annotations_key = InternalTraversal.dp_annotations_key
     visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
+    visit_memoized_select_entities = (
+        InternalTraversal.dp_memoized_select_entities
+    )
 
     visit_string = (
         visit_boolean
@@ -799,6 +804,9 @@ class _CopyInternals(InternalTraversal):
             for (target, onclause, from_, flags) in element
         )
 
+    def visit_memoized_select_entities(self, attrname, parent, element, **kw):
+        return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
+
     def visit_dml_ordered_values(
         self, attrname, parent, element, clone=_clone, **kw
     ):
@@ -919,6 +927,9 @@ class _GetChildren(InternalTraversal):
             if onclause is not None and not isinstance(onclause, str):
                 yield _flatten_clauseelement(onclause)
 
+    def visit_memoized_select_entities(self, element, **kw):
+        return self.visit_clauseelement_tuple(element, **kw)
+
     def visit_dml_ordered_values(self, element, **kw):
         for k, v in element:
             if hasattr(k, "__clause_element__"):
@@ -1265,6 +1276,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
             self.stack.append((l_onclause, r_onclause))
             self.stack.append((l_from, r_from))
 
+    def visit_memoized_select_entities(
+        self, attrname, left_parent, left, right_parent, right, **kw
+    ):
+        return self.visit_clauseelement_tuple(
+            attrname, left_parent, left, right_parent, right, **kw
+        )
+
     def visit_table_hint_list(
         self, attrname, left_parent, left, right_parent, right, **kw
     ):
index 93ee8eb1c175df94084d3141cb23596f0dc87122..c750c546ad55997d3c285fc14dcd2236bced68af 100644 (file)
@@ -24,6 +24,7 @@ http://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
 """
 
 from collections import deque
+import itertools
 import operator
 
 from .. import exc
@@ -119,6 +120,38 @@ class Traversible(util.with_metaclass(TraversibleType)):
 
     """
 
+    @util.preload_module("sqlalchemy.sql.traversals")
+    def get_children(self, omit_attrs=(), **kw):
+        r"""Return immediate child :class:`.visitors.Traversible`
+        elements of this :class:`.visitors.Traversible`.
+
+        This is used for visit traversal.
+
+        \**kw may contain flags that change the collection that is
+        returned, for example to return a subset of items in order to
+        cut down on larger traversals, or to return child items from a
+        different context (such as schema-level collections instead of
+        clause-level).
+
+        """
+
+        traversals = util.preloaded.sql_traversals
+
+        try:
+            traverse_internals = self._traverse_internals
+        except AttributeError:
+            # user-defined classes may not have a _traverse_internals
+            return []
+
+        dispatch = traversals._get_children.run_generated_dispatch
+        return itertools.chain.from_iterable(
+            meth(obj, **kw)
+            for attrname, obj, meth in dispatch(
+                self, traverse_internals, "_generated_get_children_traversal"
+            )
+            if attrname not in omit_attrs and obj is not None
+        )
+
 
 class _InternalTraversalType(type):
     def __init__(cls, clsname, bases, clsdict):
@@ -393,6 +426,8 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
 
     dp_setup_join_tuple = symbol("SJ")
 
+    dp_memoized_select_entities = symbol("ME")
+
     dp_statement_hint_list = symbol("SH")
     """Visit the ``_statement_hints`` collection of a
     :class:`_expression.Select`
index 67f2d0230616d62191a40216173d8a49fb15b365..7b6feb96a2fd5f46c4892b5ed00879a2ef952022 100644 (file)
@@ -30,6 +30,7 @@ from sqlalchemy.sql.visitors import InternalTraversal
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import mock
+from sqlalchemy.testing import ne_
 from sqlalchemy.testing.fixtures import fixture_session
 from test.orm import _fixtures
 from .inheritance import _poly_fixtures
@@ -313,6 +314,111 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
             compare_values=True,
         )
 
+    def test_orm_query_using_with_entities(self):
+        """test issue #6503"""
+        User, Address, Keyword, Order, Item = self.classes(
+            "User", "Address", "Keyword", "Order", "Item"
+        )
+
+        self._run_cache_key_fixture(
+            lambda: stmt_20(
+                fixture_session()
+                .query(User)
+                .join(User.addresses)
+                .with_entities(Address.id),
+                #
+                fixture_session().query(Address.id).join(User.addresses),
+                #
+                fixture_session()
+                .query(User)
+                .options(selectinload(User.addresses))
+                .with_entities(User.id),
+                #
+                fixture_session()
+                .query(User)
+                .options(selectinload(User.addresses)),
+                #
+                fixture_session().query(User).with_entities(User.id),
+                #
+                # here, propagate_attr->orm is Address, entity is Address.id,
+                # but the join() + with_entities() will log a
+                # _MemoizedSelectEntities to differentiate
+                fixture_session()
+                .query(Address, Order)
+                .join(Address.dingaling)
+                .with_entities(Address.id),
+                #
+                # same, propagate_attr->orm is Address, entity is Address.id,
+                # but the join() + with_entities() will log a
+                # _MemoizedSelectEntities to differentiate
+                fixture_session()
+                .query(Address, User)
+                .join(Address.dingaling)
+                .with_entities(Address.id),
+            ),
+            compare_values=True,
+        )
+
+    def test_more_with_entities_sanity_checks(self):
+        """test issue #6503"""
+        User, Address, Keyword, Order, Item = self.classes(
+            "User", "Address", "Keyword", "Order", "Item"
+        )
+
+        sess = fixture_session()
+
+        q1 = (
+            sess.query(Address, Order)
+            .with_entities(Address.id)
+            ._statement_20()
+        )
+        q2 = (
+            sess.query(Address, User).with_entities(Address.id)._statement_20()
+        )
+
+        assert not q1._memoized_select_entities
+        assert not q2._memoized_select_entities
+
+        # no joins or options, so q1 and q2 have the same cache key as Order/
+        # User are discarded.  Note Address is first so propagate_attrs->orm is
+        # Address.
+        eq_(q1._generate_cache_key(), q2._generate_cache_key())
+
+        q3 = sess.query(Order).with_entities(Address.id)._statement_20()
+        q4 = sess.query(User).with_entities(Address.id)._statement_20()
+
+        # with Order/User as lead entity, this affects propagate_attrs->orm
+        # so keys are different
+        ne_(q3._generate_cache_key(), q4._generate_cache_key())
+
+        # confirm by deleting propagate attrs and memoized key and
+        # running again
+        q3._propagate_attrs = None
+        q4._propagate_attrs = None
+        del q3.__dict__["_generate_cache_key"]
+        del q4.__dict__["_generate_cache_key"]
+        eq_(q3._generate_cache_key(), q4._generate_cache_key())
+
+        # once there's a join() or options() prior to with_entities, now they
+        # are not discarded from the key; Order and User are in the
+        # _MemoizedSelectEntities
+        q5 = (
+            sess.query(Address, Order)
+            .join(Address.dingaling)
+            .with_entities(Address.id)
+            ._statement_20()
+        )
+        q6 = (
+            sess.query(Address, User)
+            .join(Address.dingaling)
+            .with_entities(Address.id)
+            ._statement_20()
+        )
+
+        assert q5._memoized_select_entities
+        assert q6._memoized_select_entities
+        ne_(q5._generate_cache_key(), q6._generate_cache_key())
+
     def test_orm_query_from_statement(self):
         User, Address, Keyword, Order, Item = self.classes(
             "User", "Address", "Keyword", "Order", "Item"
index 7f6e1b72eca548d61a39735fe8a5abc0d973fe31..25fa7e6615c1340ca11da4da592df9cffc279c1f 100644 (file)
@@ -327,6 +327,43 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "JOIN addresses ON users.id = addresses.user_id",
         )
 
+    @testing.combinations((True,), (False,), argnames="legacy")
+    @testing.combinations((True,), (False,), argnames="threelevel")
+    def test_join_with_entities(self, legacy, threelevel):
+        """test issue #6503"""
+
+        User, Address, Dingaling = self.classes("User", "Address", "Dingaling")
+
+        if legacy:
+            sess = fixture_session()
+            stmt = sess.query(User).join(Address).with_entities(Address.id)
+        else:
+            stmt = select(User).join(Address).with_only_columns(Address.id)
+
+            stmt = stmt.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+
+        if threelevel:
+            if legacy:
+                stmt = stmt.join(Address.dingaling).with_entities(Dingaling.id)
+            else:
+                stmt = stmt.join(Address.dingaling).with_only_columns(
+                    Dingaling.id
+                )
+
+        if threelevel:
+            self.assert_compile(
+                stmt,
+                "SELECT dingalings.id AS dingalings_id "
+                "FROM users JOIN addresses ON users.id = addresses.user_id "
+                "JOIN dingalings ON addresses.id = dingalings.address_id",
+            )
+        else:
+            self.assert_compile(
+                stmt,
+                "SELECT addresses.id AS addresses_id FROM users "
+                "JOIN addresses ON users.id = addresses.user_id",
+            )
+
     def test_invalid_kwarg_join(self):
         User = self.classes.User
         sess = fixture_session()
index 4bef121d919208bc24f751f14f75154d69ed7ecb..31ab100fac177e6cca19281adfaeffae1da1df82 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy import Column
 from sqlalchemy import ForeignKey
 from sqlalchemy import inspect
 from sqlalchemy import Integer
+from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.orm import aliased
@@ -24,6 +25,7 @@ from sqlalchemy.orm import util as orm_util
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import assert_raises_message
+from sqlalchemy.testing.assertions import AssertsCompiledSQL
 from sqlalchemy.testing.assertions import eq_
 from sqlalchemy.testing.fixtures import fixture_session
 from test.orm import _fixtures
@@ -95,7 +97,7 @@ class PathTest(object):
                 val._bind_loader(
                     [
                         ent.entity_zero
-                        for ent in q._compile_state()._mapper_entities
+                        for ent in q._compile_state()._lead_mapper_entities
                     ],
                     q._compile_options._current_path,
                     attr,
@@ -104,7 +106,7 @@ class PathTest(object):
         else:
             compile_state = q._compile_state()
             compile_state.attributes = attr = {}
-            opt._process(compile_state, True)
+            opt._process(compile_state, [], True)
 
         assert_paths = [k[1] for k in attr]
         eq_(
@@ -401,6 +403,92 @@ class OfTypePathingTest(PathTest, QueryTest):
         )
 
 
+class WithEntitiesTest(QueryTest, AssertsCompiledSQL):
+    def test_options_legacy_with_entities_onelevel(self):
+        """test issue #6253 (part of #6503)"""
+
+        User = self.classes.User
+        sess = fixture_session()
+
+        q = (
+            sess.query(User)
+            .options(joinedload(User.addresses))
+            .with_entities(User.id)
+        )
+        self.assert_compile(q, "SELECT users.id AS users_id FROM users")
+
+    def test_options_with_only_cols_onelevel(self):
+        """test issue #6253 (part of #6503)"""
+
+        User = self.classes.User
+
+        q = (
+            select(User)
+            .options(joinedload(User.addresses))
+            .with_only_columns(User.id)
+        )
+        self.assert_compile(q, "SELECT users.id FROM users")
+
+    def test_options_entities_replaced_with_equivs_one(self):
+        User = self.classes.User
+        Address = self.classes.Address
+
+        q = (
+            select(User, Address)
+            .options(joinedload(User.addresses))
+            .with_only_columns(User)
+        )
+        self.assert_compile(
+            q,
+            "SELECT users.id, users.name, addresses_1.id AS id_1, "
+            "addresses_1.user_id, addresses_1.email_address FROM users "
+            "LEFT OUTER JOIN addresses AS addresses_1 "
+            "ON users.id = addresses_1.user_id ORDER BY addresses_1.id",
+        )
+
+    def test_options_entities_replaced_with_equivs_two(self):
+        User = self.classes.User
+        Address = self.classes.Address
+
+        q = (
+            select(User, Address)
+            .options(joinedload(User.addresses), joinedload(Address.dingaling))
+            .with_only_columns(User)
+        )
+        self.assert_compile(
+            q,
+            "SELECT users.id, users.name, addresses_1.id AS id_1, "
+            "addresses_1.user_id, addresses_1.email_address FROM users "
+            "LEFT OUTER JOIN addresses AS addresses_1 "
+            "ON users.id = addresses_1.user_id ORDER BY addresses_1.id",
+        )
+
+    def test_options_entities_replaced_with_equivs_three(self):
+        User = self.classes.User
+        Address = self.classes.Address
+
+        q = (
+            select(User)
+            .options(joinedload(User.addresses))
+            .with_only_columns(User, Address)
+            .options(joinedload(Address.dingaling))
+        )
+        self.assert_compile(
+            q,
+            "SELECT users.id, users.name, addresses.id AS id_1, "
+            "addresses.user_id, addresses.email_address, "
+            "addresses_1.id AS id_2, addresses_1.user_id AS user_id_1, "
+            "addresses_1.email_address AS email_address_1, "
+            "dingalings_1.id AS id_3, dingalings_1.address_id, "
+            "dingalings_1.data "
+            "FROM users LEFT OUTER JOIN addresses AS addresses_1 "
+            "ON users.id = addresses_1.user_id, addresses "
+            "LEFT OUTER JOIN dingalings AS dingalings_1 "
+            "ON addresses.id = dingalings_1.address_id "
+            "ORDER BY addresses_1.id",
+        )
+
+
 class OptionsTest(PathTest, QueryTest):
     def _option_fixture(self, *arg):
         return strategy_options._UnboundLoad._from_keys(
@@ -1479,7 +1567,7 @@ class PickleTest(PathTest, QueryTest):
         load = opt._bind_loader(
             [
                 ent.entity_zero
-                for ent in query._compile_state()._mapper_entities
+                for ent in query._compile_state()._lead_mapper_entities
             ],
             query._compile_options._current_path,
             attr,
@@ -1516,7 +1604,7 @@ class PickleTest(PathTest, QueryTest):
         load = opt._bind_loader(
             [
                 ent.entity_zero
-                for ent in query._compile_state()._mapper_entities
+                for ent in query._compile_state()._lead_mapper_entities
             ],
             query._compile_options._current_path,
             attr,
@@ -1560,7 +1648,7 @@ class LocalOptsTest(PathTest, QueryTest):
                 ctx = query._compile_state()
                 for tb in opt._to_bind:
                     tb._bind_loader(
-                        [ent.entity_zero for ent in ctx._mapper_entities],
+                        [ent.entity_zero for ent in ctx._lead_mapper_entities],
                         query._compile_options._current_path,
                         attr,
                         False,
@@ -1658,7 +1746,7 @@ class SubOptionsTest(PathTest, QueryTest):
             val._bind_loader(
                 [
                     ent.entity_zero
-                    for ent in q._compile_state()._mapper_entities
+                    for ent in q._compile_state()._lead_mapper_entities
                 ],
                 q._compile_options._current_path,
                 attr_a,
@@ -1672,7 +1760,7 @@ class SubOptionsTest(PathTest, QueryTest):
                 val._bind_loader(
                     [
                         ent.entity_zero
-                        for ent in q._compile_state()._mapper_entities
+                        for ent in q._compile_state()._lead_mapper_entities
                     ],
                     q._compile_options._current_path,
                     attr_b,
index 3b5b1aca3e025b144670eac1cab42309647468f5..6e6f430a3900b08321573f5054574bf08e9177e6 100644 (file)
@@ -1,15 +1,15 @@
 # /home/classic/dev/sqlalchemy/test/profiles.txt
 # This file is written out on a per-environment basis.
-# For each test in aaa_profiling, the corresponding function and 
+# For each test in aaa_profiling, the corresponding function and
 # environment is located within this file.  If it doesn't exist,
 # the test is skipped.
-# If a callcount does exist, it is compared to what we received. 
+# If a callcount does exist, it is compared to what we received.
 # assertions are raised if the counts do not match.
-# 
-# To add a new callcount test, apply the function_call_count 
-# decorator and re-run the tests using the --write-profiles 
+#
+# To add a new callcount test, apply the function_call_count
+# decorator and re-run the tests using the --write-profiles
 # option - this file will be rewritten including the new count.
-# 
+#
 
 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert
 
@@ -240,10 +240,10 @@ test.aaa_profiling.test_orm.AttributeOverheadTest.test_collection_append_remove
 
 # TEST: test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching
 
-test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 60
-test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 60
-test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_cextensions 61
-test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_nocextensions 61
+test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 68
+test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 68
+test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_cextensions 73
+test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_nocextensions 73
 
 # TEST: test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching
 
index 257776c506209b0ae892ff20c8a8b62dce61861e..e96a47553b17d47ab36ac57abe555653e4d6c94b 100644 (file)
@@ -513,6 +513,14 @@ class CoreFixtures(object):
                 func.bernoulli(1), name="bar", seed=func.random()
             ),
         ),
+        lambda: (
+            # test issue #6503
+            # join from table_a -> table_c, select table_b.c.a
+            select(table_a).join(table_c).with_only_columns(table_b.c.a),
+            # join from table_b -> table_c, select table_b.c.a
+            select(table_b.c.a).join(table_c).with_only_columns(table_b.c.a),
+            select(table_a).with_only_columns(table_b.c.a),
+        ),
         lambda: (
             table_a.insert(),
             table_a.insert().values({})._annotate({"nocache": True}),
index 3469dcb372ecc2452c91c23be97406dcf76c8dbc..c7e51c80703603b767298c92fb723f9c05ccc2f4 100644 (file)
@@ -1747,6 +1747,29 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
             "addresses.user_id",
         )
 
+    def test_prev_entities_adapt(self):
+        """test #6503"""
+
+        m = MetaData()
+        users = Table("users", m, Column("id", Integer, primary_key=True))
+        addresses = Table(
+            "addresses",
+            m,
+            Column("id", Integer, primary_key=True),
+            Column("user_id", ForeignKey("users.id")),
+        )
+
+        ualias = users.alias()
+
+        s = select(users).join(addresses).with_only_columns(addresses.c.id)
+        s = sql_util.ClauseAdapter(ualias).traverse(s)
+
+        self.assert_compile(
+            s,
+            "SELECT addresses.id FROM users AS users_1 "
+            "JOIN addresses ON users_1.id = addresses.user_id",
+        )
+
     @testing.combinations((True,), (False,), argnames="use_adapt_from")
     def test_table_to_alias_1(self, use_adapt_from):
         t1alias = t1.alias("t1alias")
index f9f1acfa0193685fe20ffea3d17db79833264746..d1f9e381f9da824a8edb48e7f438f2e06a197aad 100644 (file)
@@ -266,6 +266,33 @@ class FutureSelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "ON parent.id = child.parent_id",
         )
 
+    def test_join_implicit_left_side_wo_cols_onelevel(self):
+        """test issue #6503"""
+        stmt = select(parent).join(child).with_only_columns(child.c.id)
+
+        self.assert_compile(
+            stmt,
+            "SELECT child.id FROM parent "
+            "JOIN child ON parent.id = child.parent_id",
+        )
+
+    def test_join_implicit_left_side_wo_cols_twolevel(self):
+        """test issue #6503"""
+        stmt = (
+            select(parent)
+            .join(child)
+            .with_only_columns(child.c.id)
+            .join(grandchild)
+            .with_only_columns(grandchild.c.id)
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT grandchild.id FROM parent "
+            "JOIN child ON parent.id = child.parent_id "
+            "JOIN grandchild ON child.id = grandchild.child_id",
+        )
+
     def test_right_nested_inner_join(self):
         inner = child.join(grandchild)