From: Mike Bayer Date: Tue, 15 Jun 2021 19:13:34 +0000 (-0400) Subject: memoize current options and joins w with_entities/with_only_cols X-Git-Tag: rel_1_4_19~13^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5b3e887f46afdbee312d5efd2a14f7c9b7eeac65;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git memoize current options and joins w with_entities/with_only_cols Fixed further regressions in the same area as that of :ticket:`6052` where loader options as well as invocations of methods like :meth:`_orm.Query.join` would fail if the left side of the statement for which the option/join depends upon were replaced by using the :meth:`_orm.Query.with_entities` method, or when using 2.0 style queries when using the :meth:`_sql.Select.with_only_columns` method. A new set of state has been added to the objects which tracks the "left" entities that the options / join were made against which is memoized when the lead entities are changed. Fixes: #6503 Fixes: #6253 Change-Id: I211b2af98b0b20d1263fb15dc513884dcc5de6a4 --- diff --git a/doc/build/changelog/unreleased_14/6503.rst b/doc/build/changelog/unreleased_14/6503.rst new file mode 100644 index 0000000000..a2d50bc99c --- /dev/null +++ b/doc/build/changelog/unreleased_14/6503.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, orm, regression + :tickets: 6503, 6253 + + Fixed further regressions in the same area as that of :ticket:`6052` where + loader options as well as invocations of methods like + :meth:`_orm.Query.join` would fail if the left side of the statement for + which the option/join depends upon were replaced by using the + :meth:`_orm.Query.with_entities` method, or when using 2.0 style queries + when using the :meth:`_sql.Select.with_only_columns` method. A new set of + state has been added to the objects which tracks the "left" entities that + the options / join were made against which is memoized when the lead + entities are changed. diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index e4448f9536..321eeada01 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -322,10 +322,16 @@ class ORMCompileState(CompileState): return loading.instances(result, querycontext) @property - def _mapper_entities(self): - return ( + def _lead_mapper_entities(self): + """return all _MapperEntity objects in the lead entities collection. + + Does **not** include entities that have been replaced by + with_entities(), with_only_columns() + + """ + return [ ent for ent in self._entities if isinstance(ent, _MapperEntity) - ) + ] def _create_with_polymorphic_adapter(self, ext_info, selectable): if ( @@ -405,7 +411,9 @@ class ORMFromStatementCompileState(ORMCompileState): self.use_legacy_query_style, ) - _QueryEntity.to_compile_state(self, statement_container._raw_columns) + _QueryEntity.to_compile_state( + self, statement_container._raw_columns, self._entities + ) self.current_path = statement_container._compile_options._current_path @@ -477,6 +485,8 @@ class ORMFromStatementCompileState(ORMCompileState): class ORMSelectCompileState(ORMCompileState, SelectState): _joinpath = _joinpoint = _EMPTY_DICT + _memoized_entities = _EMPTY_DICT + _from_obj_alias = None _has_mapper_entities = False @@ -572,15 +582,48 @@ class ORMSelectCompileState(ORMCompileState, SelectState): statement._label_style, self.use_legacy_query_style ) - _QueryEntity.to_compile_state(self, select_statement._raw_columns) + if select_statement._memoized_select_entities: + self._memoized_entities = { + memoized_entities: _QueryEntity.to_compile_state( + self, + memoized_entities._raw_columns, + [], + ) + for memoized_entities in ( + select_statement._memoized_select_entities + ) + } + + _QueryEntity.to_compile_state( + self, select_statement._raw_columns, self._entities + ) self.current_path = select_statement._compile_options._current_path self.eager_order_by = () - if toplevel and select_statement._with_options: + if toplevel and ( + select_statement._with_options + or select_statement._memoized_select_entities + ): self.attributes = {"_unbound_load_dedupes": set()} + for ( + memoized_entities + ) in select_statement._memoized_select_entities: + for opt in memoized_entities._with_options: + if opt._is_compile_state: + opt.process_compile_state_replaced_entities( + self, + [ + ent + for ent in self._memoized_entities[ + memoized_entities + ] + if isinstance(ent, _MapperEntity) + ], + ) + for opt in self.select_statement._with_options: if opt._is_compile_state: opt.process_compile_state(self) @@ -626,11 +669,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState): if self.compile_options._set_base_alias: self._set_select_from_alias() + for memoized_entities in query._memoized_select_entities: + if memoized_entities._setup_joins: + self._join( + memoized_entities._setup_joins, + self._memoized_entities[memoized_entities], + ) + if memoized_entities._legacy_setup_joins: + self._legacy_join( + memoized_entities._legacy_setup_joins, + self._memoized_entities[memoized_entities], + ) + if query._setup_joins: - self._join(query._setup_joins) + self._join(query._setup_joins, self._entities) if query._legacy_setup_joins: - self._legacy_join(query._legacy_setup_joins) + self._legacy_join(query._legacy_setup_joins, self._entities) current_adapter = self._get_current_adapter() @@ -782,7 +837,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # entities will also set up polymorphic adapters for mappers # that have with_polymorphic configured - _QueryEntity.to_compile_state(self, query._raw_columns) + _QueryEntity.to_compile_state(self, query._raw_columns, self._entities) return self @classmethod @@ -921,7 +976,18 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def _all_equivs(self): equivs = {} - for ent in self._mapper_entities: + + for memoized_entities in self._memoized_entities.values(): + for ent in [ + ent + for ent in memoized_entities + if isinstance(ent, _MapperEntity) + ]: + equivs.update(ent.mapper._equivalent_columns) + + for ent in [ + ent for ent in self._entities if isinstance(ent, _MapperEntity) + ]: equivs.update(ent.mapper._equivalent_columns) return equivs @@ -1211,7 +1277,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): return _adapt_clause - def _join(self, args): + def _join(self, args, entities_collection): for (right, onclause, from_, flags) in args: isouter = flags["isouter"] full = flags["full"] @@ -1316,6 +1382,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # figure out the final "left" and "right" sides and create an # ORMJoin to add to our _from_obj tuple self._join_left_to_right( + entities_collection, left, right, onclause, @@ -1326,7 +1393,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): full, ) - def _legacy_join(self, args): + def _legacy_join(self, args, entities_collection): """consumes arguments from join() or outerjoin(), places them into a consistent format with which to form the actual JOIN constructs. @@ -1474,6 +1541,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # figure out the final "left" and "right" sides and create an # ORMJoin to add to our _from_obj tuple self._join_left_to_right( + entities_collection, left, right, onclause, @@ -1489,6 +1557,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def _join_left_to_right( self, + entities_collection, left, right, onclause, @@ -1513,7 +1582,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState): left, replace_from_obj_index, use_entity_index, - ) = self._join_determine_implicit_left_side(left, right, onclause) + ) = self._join_determine_implicit_left_side( + entities_collection, left, right, onclause + ) else: # left is given via a relationship/name, or as explicit left side. # Determine where in our @@ -1522,7 +1593,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): ( replace_from_obj_index, use_entity_index, - ) = self._join_place_explicit_left_side(left) + ) = self._join_place_explicit_left_side(entities_collection, left) if left is right and not create_aliases: raise sa_exc.InvalidRequestError( @@ -1568,9 +1639,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # entity_zero.selectable, but if with_polymorphic() were used # might be distinct assert isinstance( - self._entities[use_entity_index], _MapperEntity + entities_collection[use_entity_index], _MapperEntity ) - left_clause = self._entities[use_entity_index].selectable + left_clause = entities_collection[use_entity_index].selectable else: left_clause = left @@ -1585,7 +1656,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState): ) ] - def _join_determine_implicit_left_side(self, left, right, onclause): + def _join_determine_implicit_left_side( + self, entities_collection, left, right, onclause + ): """When join conditions don't express the left side explicitly, determine if an existing FROM or entity in this query can serve as the left hand side. @@ -1635,12 +1708,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState): "to help resolve the ambiguity." % (right,) ) - elif self._entities: + elif entities_collection: # we have no explicit FROMs, so the implicit left has to # come from our list of entities. potential = {} - for entity_index, ent in enumerate(self._entities): + for entity_index, ent in enumerate(entities_collection): entity = ent.entity_zero_or_selectable if entity is None: continue @@ -1689,7 +1762,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): return left, replace_from_obj_index, use_entity_index - def _join_place_explicit_left_side(self, left): + def _join_place_explicit_left_side(self, entities_collection, left): """When join conditions express a left side explicitly, determine where in our existing list of FROM clauses we should join towards, or if we need to make a new join, and if so is it from one of our @@ -1743,10 +1816,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # aliasing / adaptation rules present on that entity if any if ( replace_from_obj_index is None - and self._entities + and entities_collection and hasattr(l_info, "mapper") ): - for idx, ent in enumerate(self._entities): + for idx, ent in enumerate(entities_collection): # TODO: should we be checking for multiple mapper entities # matching? if isinstance(ent, _MapperEntity) and ent.corresponds_to(left): @@ -2194,11 +2267,14 @@ class _QueryEntity(object): __slots__ = () @classmethod - def to_compile_state(cls, compile_state, entities): + def to_compile_state(cls, compile_state, entities, entities_collection): + for idx, entity in enumerate(entities): if entity._is_lambda_element: if entity._is_sequence: - cls.to_compile_state(compile_state, entity._resolved) + cls.to_compile_state( + compile_state, entity._resolved, entities_collection + ) continue else: entity = entity._resolved @@ -2206,26 +2282,38 @@ class _QueryEntity(object): if entity.is_clause_element: if entity.is_selectable: if "parententity" in entity._annotations: - _MapperEntity(compile_state, entity) + _MapperEntity( + compile_state, entity, entities_collection + ) else: _ColumnEntity._for_columns( - compile_state, entity._select_iterable, idx + compile_state, + entity._select_iterable, + entities_collection, + idx, ) else: if entity._annotations.get("bundle", False): - _BundleEntity(compile_state, entity) + _BundleEntity( + compile_state, entity, entities_collection + ) elif entity._is_clause_list: # this is legacy only - test_composites.py # test_query_cols_legacy _ColumnEntity._for_columns( - compile_state, entity._select_iterable, idx + compile_state, + entity._select_iterable, + entities_collection, + idx, ) else: _ColumnEntity._for_columns( - compile_state, [entity], idx + compile_state, [entity], entities_collection, idx ) elif entity.is_bundle: - _BundleEntity(compile_state, entity) + _BundleEntity(compile_state, entity, entities_collection) + + return entities_collection class _MapperEntity(_QueryEntity): @@ -2244,8 +2332,8 @@ class _MapperEntity(_QueryEntity): "_polymorphic_discriminator", ) - def __init__(self, compile_state, entity): - compile_state._entities.append(self) + def __init__(self, compile_state, entity, entities_collection): + entities_collection.append(self) if compile_state._primary_entity is None: compile_state._primary_entity = self compile_state._has_mapper_entities = True @@ -2418,7 +2506,12 @@ class _BundleEntity(_QueryEntity): ) def __init__( - self, compile_state, expr, setup_entities=True, parent_bundle=None + self, + compile_state, + expr, + entities_collection, + setup_entities=True, + parent_bundle=None, ): compile_state._has_orm_entities = True @@ -2426,7 +2519,7 @@ class _BundleEntity(_QueryEntity): if parent_bundle: parent_bundle._entities.append(self) else: - compile_state._entities.append(self) + entities_collection.append(self) if isinstance( expr, (attributes.QueryableAttribute, interfaces.PropComparator) @@ -2443,12 +2536,26 @@ class _BundleEntity(_QueryEntity): if setup_entities: for expr in bundle.exprs: if "bundle" in expr._annotations: - _BundleEntity(compile_state, expr, parent_bundle=self) + _BundleEntity( + compile_state, + expr, + entities_collection, + parent_bundle=self, + ) elif isinstance(expr, Bundle): - _BundleEntity(compile_state, expr, parent_bundle=self) + _BundleEntity( + compile_state, + expr, + entities_collection, + parent_bundle=self, + ) else: _ORMColumnEntity._for_columns( - compile_state, [expr], None, parent_bundle=self + compile_state, + [expr], + entities_collection, + None, + parent_bundle=self, ) self.supports_single_entity = self.bundle.single_entity @@ -2516,7 +2623,12 @@ class _ColumnEntity(_QueryEntity): @classmethod def _for_columns( - cls, compile_state, columns, raw_column_index, parent_bundle=None + cls, + compile_state, + columns, + entities_collection, + raw_column_index, + parent_bundle=None, ): for column in columns: annotations = column._annotations @@ -2532,6 +2644,7 @@ class _ColumnEntity(_QueryEntity): _IdentityTokenEntity( compile_state, column, + entities_collection, _entity, raw_column_index, parent_bundle=parent_bundle, @@ -2540,6 +2653,7 @@ class _ColumnEntity(_QueryEntity): _ORMColumnEntity( compile_state, column, + entities_collection, _entity, raw_column_index, parent_bundle=parent_bundle, @@ -2548,6 +2662,7 @@ class _ColumnEntity(_QueryEntity): _RawColumnEntity( compile_state, column, + entities_collection, raw_column_index, parent_bundle=parent_bundle, ) @@ -2630,7 +2745,12 @@ class _RawColumnEntity(_ColumnEntity): ) def __init__( - self, compile_state, column, raw_column_index, parent_bundle=None + self, + compile_state, + column, + entities_collection, + raw_column_index, + parent_bundle=None, ): self.expr = column self.raw_column_index = raw_column_index @@ -2643,7 +2763,7 @@ class _RawColumnEntity(_ColumnEntity): if parent_bundle: parent_bundle._entities.append(self) else: - compile_state._entities.append(self) + entities_collection.append(self) self.column = column self.entity_zero_or_selectable = ( @@ -2690,6 +2810,7 @@ class _ORMColumnEntity(_ColumnEntity): self, compile_state, column, + entities_collection, parententity, raw_column_index, parent_bundle=None, @@ -2729,7 +2850,7 @@ class _ORMColumnEntity(_ColumnEntity): if parent_bundle: parent_bundle._entities.append(self) else: - compile_state._entities.append(self) + entities_collection.append(self) compile_state._has_orm_entities = True diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index c9a601f995..28b4bfb2d0 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -750,6 +750,18 @@ class LoaderOption(ORMOption): _is_compile_state = True + def process_compile_state_replaced_entities( + self, compile_state, mapper_entities + ): + """Apply a modification to a given :class:`.CompileState`, + given entities that were replaced by with_only_columns() or + with_entities(). + + .. versionadded:: 1.4.19 + + """ + self.process_compile_state(compile_state) + def process_compile_state(self, compile_state): """Apply a modification to a given :class:`.CompileState`.""" diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index cacfb8d84e..7ba31fa7a0 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -57,6 +57,7 @@ from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _entity_namespace_key from ..sql.base import _generative from ..sql.base import Executable +from ..sql.selectable import _MemoizedSelectEntities from ..sql.selectable import _SelectFromElements from ..sql.selectable import ForUpdateArg from ..sql.selectable import GroupedElement @@ -125,6 +126,8 @@ class Query( _legacy_setup_joins = () _label_style = LABEL_STYLE_LEGACY_ORM + _memoized_select_entities = () + _compile_options = ORMCompileState.default_compile_options load_options = QueryContext.default_load_options @@ -1433,6 +1436,7 @@ class Query( limit(1) """ + _MemoizedSelectEntities._generate_for_statement(self) self._set_entities(entities) @_generative diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index e371442fdd..91e6275250 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -172,13 +172,32 @@ class Load(Generative, LoaderOption): _of_type = None _extra_criteria = () + def process_compile_state_replaced_entities( + self, compile_state, mapper_entities + ): + if not compile_state.compile_options._enable_eagerloads: + return + + # process is being run here so that the options given are validated + # against what the lead entities were, as well as to accommodate + # for the entities having been replaced with equivalents + self._process( + compile_state, + mapper_entities, + not bool(compile_state.current_path), + ) + def process_compile_state(self, compile_state): if not compile_state.compile_options._enable_eagerloads: return - self._process(compile_state, not bool(compile_state.current_path)) + self._process( + compile_state, + compile_state._lead_mapper_entities, + not bool(compile_state.current_path), + ) - def _process(self, compile_state, raiseerr): + def _process(self, compile_state, mapper_entities, raiseerr): is_refresh = compile_state.compile_options._for_refresh_state current_path = compile_state.current_path if current_path: @@ -700,7 +719,7 @@ class _UnboundLoad(Load): state["path"] = tuple(ret) self.__dict__ = state - def _process(self, compile_state, raiseerr): + def _process(self, compile_state, mapper_entities, raiseerr): dedupes = compile_state.attributes["_unbound_load_dedupes"] is_refresh = compile_state.compile_options._for_refresh_state for val in self._to_bind: @@ -709,10 +728,7 @@ class _UnboundLoad(Load): if is_refresh and not val.propagate_to_loaders: continue val._bind_loader( - [ - ent.entity_zero - for ent in compile_state._mapper_entities - ], + [ent.entity_zero for ent in mapper_entities], compile_state.current_path, compile_state.attributes, raiseerr, diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 213f47c409..709106b6b9 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -32,7 +32,6 @@ from .base import NO_ARG from .base import PARSE_AUTOCOMMIT from .base import SingletonConstant from .coercions import _document_text_coercion -from .traversals import _get_children from .traversals import HasCopyInternals from .traversals import MemoizedHasCacheKey from .traversals import NO_CACHE @@ -389,33 +388,6 @@ class ClauseElement( """ return traversals.compare(self, other, **kw) - def get_children(self, omit_attrs=(), **kw): - r"""Return immediate child :class:`.visitors.Traversible` - elements of this :class:`.visitors.Traversible`. - - This is used for visit traversal. - - \**kw may contain flags that change the collection that is - returned, for example to return a subset of items in order to - cut down on larger traversals, or to return child items from a - different context (such as schema-level collections instead of - clause-level). - - """ - try: - traverse_internals = self._traverse_internals - except AttributeError: - # user-defined classes may not have a _traverse_internals - return [] - - return itertools.chain.from_iterable( - meth(obj, **kw) - for attrname, obj, meth in _get_children.run_generated_dispatch( - self, traverse_internals, "_generated_get_children_traversal" - ) - if attrname not in omit_attrs and obj is not None - ) - def self_group(self, against=None): """Apply a 'grouping' to this :class:`_expression.ClauseElement`. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 1610191d1e..e1dee091bd 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -18,7 +18,9 @@ from operator import attrgetter from . import coercions from . import operators from . import roles +from . import traversals from . import type_api +from . import visitors from .annotation import Annotated from .annotation import SupportsCloneAnnotations from .base import _clone @@ -4131,8 +4133,13 @@ class SelectState(util.MemoizedSlots, CompileState): self.statement = statement self.from_clauses = statement._from_obj + for memoized_entities in statement._memoized_select_entities: + self._setup_joins( + memoized_entities._setup_joins, memoized_entities._raw_columns + ) + if statement._setup_joins: - self._setup_joins(statement._setup_joins) + self._setup_joins(statement._setup_joins, statement._raw_columns) self.froms = self._get_froms(statement) @@ -4361,7 +4368,7 @@ class SelectState(util.MemoizedSlots, CompileState): def all_selected_columns(cls, statement): return [c for c in _select_iterables(statement._raw_columns)] - def _setup_joins(self, args): + def _setup_joins(self, args, raw_columns): for (right, onclause, left, flags) in args: isouter = flags["isouter"] full = flags["full"] @@ -4371,7 +4378,7 @@ class SelectState(util.MemoizedSlots, CompileState): left, replace_from_obj_index, ) = self._join_determine_implicit_left_side( - left, right, onclause + raw_columns, left, right, onclause ) else: (replace_from_obj_index) = self._join_place_explicit_left_side( @@ -4403,7 +4410,9 @@ class SelectState(util.MemoizedSlots, CompileState): ) @util.preload_module("sqlalchemy.sql.util") - def _join_determine_implicit_left_side(self, left, right, onclause): + def _join_determine_implicit_left_side( + self, raw_columns, left, right, onclause + ): """When join conditions don't express the left side explicitly, determine if an existing FROM or entity in this query can serve as the left hand side. @@ -4431,10 +4440,7 @@ class SelectState(util.MemoizedSlots, CompileState): for from_clause in itertools.chain( itertools.chain.from_iterable( - [ - element._from_objects - for element in statement._raw_columns - ] + [element._from_objects for element in raw_columns] ), itertools.chain.from_iterable( [ @@ -4531,6 +4537,47 @@ class _SelectFromElements(object): yield element +class _MemoizedSelectEntities( + traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible +): + __visit_name__ = "memoized_select_entities" + + _traverse_internals = [ + ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ("_setup_joins", InternalTraversal.dp_setup_join_tuple), + ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple), + ("_with_options", InternalTraversal.dp_executable_options), + ] + + _annotations = util.EMPTY_DICT + + def _clone(self, **kw): + c = self.__class__.__new__(self.__class__) + c.__dict__ = {k: v for k, v in self.__dict__.items()} + c._is_clone_of = self + return c + + @classmethod + def _generate_for_statement(cls, select_stmt): + if ( + select_stmt._setup_joins + or select_stmt._legacy_setup_joins + or select_stmt._with_options + ): + self = _MemoizedSelectEntities() + self._raw_columns = select_stmt._raw_columns + self._setup_joins = select_stmt._setup_joins + self._legacy_setup_joins = select_stmt._legacy_setup_joins + self._with_options = select_stmt._with_options + + select_stmt._memoized_select_entities += (self,) + select_stmt._raw_columns = ( + select_stmt._setup_joins + ) = ( + select_stmt._legacy_setup_joins + ) = select_stmt._with_options = () + + class Select( HasPrefixes, HasSuffixes, @@ -4559,6 +4606,7 @@ class Select( _setup_joins = () _legacy_setup_joins = () + _memoized_select_entities = () _distinct = False _distinct_on = () @@ -4574,6 +4622,10 @@ class Select( _traverse_internals = ( [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ( + "_memoized_select_entities", + InternalTraversal.dp_memoized_select_entities, + ), ("_from_obj", InternalTraversal.dp_clauseelement_list), ("_where_criteria", InternalTraversal.dp_clauseelement_tuple), ("_having_criteria", InternalTraversal.dp_clauseelement_tuple), @@ -5461,16 +5513,14 @@ class Select( # is the case for now. self._assert_no_memoizations() - rc = [] - for c in coercions._expression_collection_was_a_list( - "columns", "Select.with_only_columns", columns - ): - c = coercions.expect(roles.ColumnsClauseRole, c) - # TODO: why are we doing this here? - if isinstance(c, ScalarSelect): - c = c.self_group(against=operators.comma_op) - rc.append(c) - self._raw_columns = rc + _MemoizedSelectEntities._generate_for_statement(self) + + self._raw_columns = [ + coercions.expect(roles.ColumnsClauseRole, c) + for c in coercions._expression_collection_was_a_list( + "columns", "Select.with_only_columns", columns + ) + ] @property def whereclause(self): diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 35f2bd62f9..a86d16ef4c 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -194,6 +194,8 @@ class HasCacheKey(object): elif ( meth is InternalTraversal.dp_clauseelement_list or meth is InternalTraversal.dp_clauseelement_tuple + or meth + is InternalTraversal.dp_memoized_select_entities ): result += ( attrname, @@ -409,6 +411,9 @@ class _CacheKey(ExtendedInternalTraversal): visit_clauseelement_list = InternalTraversal.dp_clauseelement_list visit_annotations_key = InternalTraversal.dp_annotations_key visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple + visit_memoized_select_entities = ( + InternalTraversal.dp_memoized_select_entities + ) visit_string = ( visit_boolean @@ -799,6 +804,9 @@ class _CopyInternals(InternalTraversal): for (target, onclause, from_, flags) in element ) + def visit_memoized_select_entities(self, attrname, parent, element, **kw): + return self.visit_clauseelement_tuple(attrname, parent, element, **kw) + def visit_dml_ordered_values( self, attrname, parent, element, clone=_clone, **kw ): @@ -919,6 +927,9 @@ class _GetChildren(InternalTraversal): if onclause is not None and not isinstance(onclause, str): yield _flatten_clauseelement(onclause) + def visit_memoized_select_entities(self, element, **kw): + return self.visit_clauseelement_tuple(element, **kw) + def visit_dml_ordered_values(self, element, **kw): for k, v in element: if hasattr(k, "__clause_element__"): @@ -1265,6 +1276,13 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): self.stack.append((l_onclause, r_onclause)) self.stack.append((l_from, r_from)) + def visit_memoized_select_entities( + self, attrname, left_parent, left, right_parent, right, **kw + ): + return self.visit_clauseelement_tuple( + attrname, left_parent, left, right_parent, right, **kw + ) + def visit_table_hint_list( self, attrname, left_parent, left, right_parent, right, **kw ): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 93ee8eb1c1..c750c546ad 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -24,6 +24,7 @@ http://techspot.zzzeek.org/2008/01/23/expression-transformations/ . """ from collections import deque +import itertools import operator from .. import exc @@ -119,6 +120,38 @@ class Traversible(util.with_metaclass(TraversibleType)): """ + @util.preload_module("sqlalchemy.sql.traversals") + def get_children(self, omit_attrs=(), **kw): + r"""Return immediate child :class:`.visitors.Traversible` + elements of this :class:`.visitors.Traversible`. + + This is used for visit traversal. + + \**kw may contain flags that change the collection that is + returned, for example to return a subset of items in order to + cut down on larger traversals, or to return child items from a + different context (such as schema-level collections instead of + clause-level). + + """ + + traversals = util.preloaded.sql_traversals + + try: + traverse_internals = self._traverse_internals + except AttributeError: + # user-defined classes may not have a _traverse_internals + return [] + + dispatch = traversals._get_children.run_generated_dispatch + return itertools.chain.from_iterable( + meth(obj, **kw) + for attrname, obj, meth in dispatch( + self, traverse_internals, "_generated_get_children_traversal" + ) + if attrname not in omit_attrs and obj is not None + ) + class _InternalTraversalType(type): def __init__(cls, clsname, bases, clsdict): @@ -393,6 +426,8 @@ class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)): dp_setup_join_tuple = symbol("SJ") + dp_memoized_select_entities = symbol("ME") + dp_statement_hint_list = symbol("SH") """Visit the ``_statement_hints`` collection of a :class:`_expression.Select` diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py index 67f2d02306..7b6feb96a2 100644 --- a/test/orm/test_cache_key.py +++ b/test/orm/test_cache_key.py @@ -30,6 +30,7 @@ from sqlalchemy.sql.visitors import InternalTraversal from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import mock +from sqlalchemy.testing import ne_ from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures from .inheritance import _poly_fixtures @@ -313,6 +314,111 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): compare_values=True, ) + def test_orm_query_using_with_entities(self): + """test issue #6503""" + User, Address, Keyword, Order, Item = self.classes( + "User", "Address", "Keyword", "Order", "Item" + ) + + self._run_cache_key_fixture( + lambda: stmt_20( + fixture_session() + .query(User) + .join(User.addresses) + .with_entities(Address.id), + # + fixture_session().query(Address.id).join(User.addresses), + # + fixture_session() + .query(User) + .options(selectinload(User.addresses)) + .with_entities(User.id), + # + fixture_session() + .query(User) + .options(selectinload(User.addresses)), + # + fixture_session().query(User).with_entities(User.id), + # + # here, propagate_attr->orm is Address, entity is Address.id, + # but the join() + with_entities() will log a + # _MemoizedSelectEntities to differentiate + fixture_session() + .query(Address, Order) + .join(Address.dingaling) + .with_entities(Address.id), + # + # same, propagate_attr->orm is Address, entity is Address.id, + # but the join() + with_entities() will log a + # _MemoizedSelectEntities to differentiate + fixture_session() + .query(Address, User) + .join(Address.dingaling) + .with_entities(Address.id), + ), + compare_values=True, + ) + + def test_more_with_entities_sanity_checks(self): + """test issue #6503""" + User, Address, Keyword, Order, Item = self.classes( + "User", "Address", "Keyword", "Order", "Item" + ) + + sess = fixture_session() + + q1 = ( + sess.query(Address, Order) + .with_entities(Address.id) + ._statement_20() + ) + q2 = ( + sess.query(Address, User).with_entities(Address.id)._statement_20() + ) + + assert not q1._memoized_select_entities + assert not q2._memoized_select_entities + + # no joins or options, so q1 and q2 have the same cache key as Order/ + # User are discarded. Note Address is first so propagate_attrs->orm is + # Address. + eq_(q1._generate_cache_key(), q2._generate_cache_key()) + + q3 = sess.query(Order).with_entities(Address.id)._statement_20() + q4 = sess.query(User).with_entities(Address.id)._statement_20() + + # with Order/User as lead entity, this affects propagate_attrs->orm + # so keys are different + ne_(q3._generate_cache_key(), q4._generate_cache_key()) + + # confirm by deleting propagate attrs and memoized key and + # running again + q3._propagate_attrs = None + q4._propagate_attrs = None + del q3.__dict__["_generate_cache_key"] + del q4.__dict__["_generate_cache_key"] + eq_(q3._generate_cache_key(), q4._generate_cache_key()) + + # once there's a join() or options() prior to with_entities, now they + # are not discarded from the key; Order and User are in the + # _MemoizedSelectEntities + q5 = ( + sess.query(Address, Order) + .join(Address.dingaling) + .with_entities(Address.id) + ._statement_20() + ) + q6 = ( + sess.query(Address, User) + .join(Address.dingaling) + .with_entities(Address.id) + ._statement_20() + ) + + assert q5._memoized_select_entities + assert q6._memoized_select_entities + ne_(q5._generate_cache_key(), q6._generate_cache_key()) + def test_orm_query_from_statement(self): User, Address, Keyword, Order, Item = self.classes( "User", "Address", "Keyword", "Order", "Item" diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py index 7f6e1b72ec..25fa7e6615 100644 --- a/test/orm/test_joins.py +++ b/test/orm/test_joins.py @@ -327,6 +327,43 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "JOIN addresses ON users.id = addresses.user_id", ) + @testing.combinations((True,), (False,), argnames="legacy") + @testing.combinations((True,), (False,), argnames="threelevel") + def test_join_with_entities(self, legacy, threelevel): + """test issue #6503""" + + User, Address, Dingaling = self.classes("User", "Address", "Dingaling") + + if legacy: + sess = fixture_session() + stmt = sess.query(User).join(Address).with_entities(Address.id) + else: + stmt = select(User).join(Address).with_only_columns(Address.id) + + stmt = stmt.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + + if threelevel: + if legacy: + stmt = stmt.join(Address.dingaling).with_entities(Dingaling.id) + else: + stmt = stmt.join(Address.dingaling).with_only_columns( + Dingaling.id + ) + + if threelevel: + self.assert_compile( + stmt, + "SELECT dingalings.id AS dingalings_id " + "FROM users JOIN addresses ON users.id = addresses.user_id " + "JOIN dingalings ON addresses.id = dingalings.address_id", + ) + else: + self.assert_compile( + stmt, + "SELECT addresses.id AS addresses_id FROM users " + "JOIN addresses ON users.id = addresses.user_id", + ) + def test_invalid_kwarg_join(self): User = self.classes.User sess = fixture_session() diff --git a/test/orm/test_options.py b/test/orm/test_options.py index 4bef121d91..31ab100fac 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -3,6 +3,7 @@ from sqlalchemy import Column from sqlalchemy import ForeignKey from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.orm import aliased @@ -24,6 +25,7 @@ from sqlalchemy.orm import util as orm_util from sqlalchemy.orm import with_polymorphic from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises_message +from sqlalchemy.testing.assertions import AssertsCompiledSQL from sqlalchemy.testing.assertions import eq_ from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures @@ -95,7 +97,7 @@ class PathTest(object): val._bind_loader( [ ent.entity_zero - for ent in q._compile_state()._mapper_entities + for ent in q._compile_state()._lead_mapper_entities ], q._compile_options._current_path, attr, @@ -104,7 +106,7 @@ class PathTest(object): else: compile_state = q._compile_state() compile_state.attributes = attr = {} - opt._process(compile_state, True) + opt._process(compile_state, [], True) assert_paths = [k[1] for k in attr] eq_( @@ -401,6 +403,92 @@ class OfTypePathingTest(PathTest, QueryTest): ) +class WithEntitiesTest(QueryTest, AssertsCompiledSQL): + def test_options_legacy_with_entities_onelevel(self): + """test issue #6253 (part of #6503)""" + + User = self.classes.User + sess = fixture_session() + + q = ( + sess.query(User) + .options(joinedload(User.addresses)) + .with_entities(User.id) + ) + self.assert_compile(q, "SELECT users.id AS users_id FROM users") + + def test_options_with_only_cols_onelevel(self): + """test issue #6253 (part of #6503)""" + + User = self.classes.User + + q = ( + select(User) + .options(joinedload(User.addresses)) + .with_only_columns(User.id) + ) + self.assert_compile(q, "SELECT users.id FROM users") + + def test_options_entities_replaced_with_equivs_one(self): + User = self.classes.User + Address = self.classes.Address + + q = ( + select(User, Address) + .options(joinedload(User.addresses)) + .with_only_columns(User) + ) + self.assert_compile( + q, + "SELECT users.id, users.name, addresses_1.id AS id_1, " + "addresses_1.user_id, addresses_1.email_address FROM users " + "LEFT OUTER JOIN addresses AS addresses_1 " + "ON users.id = addresses_1.user_id ORDER BY addresses_1.id", + ) + + def test_options_entities_replaced_with_equivs_two(self): + User = self.classes.User + Address = self.classes.Address + + q = ( + select(User, Address) + .options(joinedload(User.addresses), joinedload(Address.dingaling)) + .with_only_columns(User) + ) + self.assert_compile( + q, + "SELECT users.id, users.name, addresses_1.id AS id_1, " + "addresses_1.user_id, addresses_1.email_address FROM users " + "LEFT OUTER JOIN addresses AS addresses_1 " + "ON users.id = addresses_1.user_id ORDER BY addresses_1.id", + ) + + def test_options_entities_replaced_with_equivs_three(self): + User = self.classes.User + Address = self.classes.Address + + q = ( + select(User) + .options(joinedload(User.addresses)) + .with_only_columns(User, Address) + .options(joinedload(Address.dingaling)) + ) + self.assert_compile( + q, + "SELECT users.id, users.name, addresses.id AS id_1, " + "addresses.user_id, addresses.email_address, " + "addresses_1.id AS id_2, addresses_1.user_id AS user_id_1, " + "addresses_1.email_address AS email_address_1, " + "dingalings_1.id AS id_3, dingalings_1.address_id, " + "dingalings_1.data " + "FROM users LEFT OUTER JOIN addresses AS addresses_1 " + "ON users.id = addresses_1.user_id, addresses " + "LEFT OUTER JOIN dingalings AS dingalings_1 " + "ON addresses.id = dingalings_1.address_id " + "ORDER BY addresses_1.id", + ) + + class OptionsTest(PathTest, QueryTest): def _option_fixture(self, *arg): return strategy_options._UnboundLoad._from_keys( @@ -1479,7 +1567,7 @@ class PickleTest(PathTest, QueryTest): load = opt._bind_loader( [ ent.entity_zero - for ent in query._compile_state()._mapper_entities + for ent in query._compile_state()._lead_mapper_entities ], query._compile_options._current_path, attr, @@ -1516,7 +1604,7 @@ class PickleTest(PathTest, QueryTest): load = opt._bind_loader( [ ent.entity_zero - for ent in query._compile_state()._mapper_entities + for ent in query._compile_state()._lead_mapper_entities ], query._compile_options._current_path, attr, @@ -1560,7 +1648,7 @@ class LocalOptsTest(PathTest, QueryTest): ctx = query._compile_state() for tb in opt._to_bind: tb._bind_loader( - [ent.entity_zero for ent in ctx._mapper_entities], + [ent.entity_zero for ent in ctx._lead_mapper_entities], query._compile_options._current_path, attr, False, @@ -1658,7 +1746,7 @@ class SubOptionsTest(PathTest, QueryTest): val._bind_loader( [ ent.entity_zero - for ent in q._compile_state()._mapper_entities + for ent in q._compile_state()._lead_mapper_entities ], q._compile_options._current_path, attr_a, @@ -1672,7 +1760,7 @@ class SubOptionsTest(PathTest, QueryTest): val._bind_loader( [ ent.entity_zero - for ent in q._compile_state()._mapper_entities + for ent in q._compile_state()._lead_mapper_entities ], q._compile_options._current_path, attr_b, diff --git a/test/profiles.txt b/test/profiles.txt index 3b5b1aca3e..6e6f430a39 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -1,15 +1,15 @@ # /home/classic/dev/sqlalchemy/test/profiles.txt # This file is written out on a per-environment basis. -# For each test in aaa_profiling, the corresponding function and +# For each test in aaa_profiling, the corresponding function and # environment is located within this file. If it doesn't exist, # the test is skipped. -# If a callcount does exist, it is compared to what we received. +# If a callcount does exist, it is compared to what we received. # assertions are raised if the counts do not match. -# -# To add a new callcount test, apply the function_call_count -# decorator and re-run the tests using the --write-profiles +# +# To add a new callcount test, apply the function_call_count +# decorator and re-run the tests using the --write-profiles # option - this file will be rewritten including the new count. -# +# # TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert @@ -240,10 +240,10 @@ test.aaa_profiling.test_orm.AttributeOverheadTest.test_collection_append_remove # TEST: test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching -test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 60 -test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 60 -test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_cextensions 61 -test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_nocextensions 61 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 68 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 68 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_cextensions 73 +test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_key_bound_branching x86_64_linux_cpython_3.9_sqlite_pysqlite_dbapiunicode_nocextensions 73 # TEST: test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 257776c506..e96a47553b 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -513,6 +513,14 @@ class CoreFixtures(object): func.bernoulli(1), name="bar", seed=func.random() ), ), + lambda: ( + # test issue #6503 + # join from table_a -> table_c, select table_b.c.a + select(table_a).join(table_c).with_only_columns(table_b.c.a), + # join from table_b -> table_c, select table_b.c.a + select(table_b.c.a).join(table_c).with_only_columns(table_b.c.a), + select(table_a).with_only_columns(table_b.c.a), + ), lambda: ( table_a.insert(), table_a.insert().values({})._annotate({"nocache": True}), diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 3469dcb372..c7e51c8070 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -1747,6 +1747,29 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): "addresses.user_id", ) + def test_prev_entities_adapt(self): + """test #6503""" + + m = MetaData() + users = Table("users", m, Column("id", Integer, primary_key=True)) + addresses = Table( + "addresses", + m, + Column("id", Integer, primary_key=True), + Column("user_id", ForeignKey("users.id")), + ) + + ualias = users.alias() + + s = select(users).join(addresses).with_only_columns(addresses.c.id) + s = sql_util.ClauseAdapter(ualias).traverse(s) + + self.assert_compile( + s, + "SELECT addresses.id FROM users AS users_1 " + "JOIN addresses ON users_1.id = addresses.user_id", + ) + @testing.combinations((True,), (False,), argnames="use_adapt_from") def test_table_to_alias_1(self, use_adapt_from): t1alias = t1.alias("t1alias") diff --git a/test/sql/test_select.py b/test/sql/test_select.py index f9f1acfa01..d1f9e381f9 100644 --- a/test/sql/test_select.py +++ b/test/sql/test_select.py @@ -266,6 +266,33 @@ class FutureSelectTest(fixtures.TestBase, AssertsCompiledSQL): "ON parent.id = child.parent_id", ) + def test_join_implicit_left_side_wo_cols_onelevel(self): + """test issue #6503""" + stmt = select(parent).join(child).with_only_columns(child.c.id) + + self.assert_compile( + stmt, + "SELECT child.id FROM parent " + "JOIN child ON parent.id = child.parent_id", + ) + + def test_join_implicit_left_side_wo_cols_twolevel(self): + """test issue #6503""" + stmt = ( + select(parent) + .join(child) + .with_only_columns(child.c.id) + .join(grandchild) + .with_only_columns(grandchild.c.id) + ) + + self.assert_compile( + stmt, + "SELECT grandchild.id FROM parent " + "JOIN child ON parent.id = child.parent_id " + "JOIN grandchild ON child.id = grandchild.child_id", + ) + def test_right_nested_inner_join(self): inner = child.join(grandchild)