From: Mike Bayer Date: Wed, 29 Apr 2020 23:46:43 +0000 (-0400) Subject: Improve rendering of core statements w/ ORM elements X-Git-Tag: rel_1_4_0b1~287^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4ecd352a9fbb9dbac7b428fe0f098f665c1f0cb1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve rendering of core statements w/ ORM elements This patch contains a variety of ORM and expression layer tweaks to support ORM constructs in select() statements, without the 1.3.x requiremnt in Query that a full _compile_context() + new select() is needed in order to get a working statement object. Includes such tweaks as the ability to implement aliased class of an aliased class, as we are looking to fully support ACs against subqueries, as well as the ability to access anonymously-labeled ColumnProperty expressions within subqueries by naming the ".key" of the label after the property key. Some tuning to query.join() as well as ORMJoin internals to allow things to work more smoothly. Change-Id: Id810f485c5f7ed971529489b84694e02a3356d6d --- diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index c324276447..1d832e4afa 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -117,8 +117,6 @@ class CursorResultMetaData(ResultMetaData): compiled_statement = context.compiled.statement invoked_statement = context.invoked_statement - # same statement was invoked as the one we cached against, - # return self if compiled_statement is invoked_statement: return self diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 7ac556dcc3..f95a30fda0 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -228,7 +228,7 @@ class BakedQuery(object): # in 1.4, this is where before_compile() event is # invoked - statement = query._statement_20(orm_results=True) + statement = query._statement_20() # if the query is not safe to cache, we still do everything as though # we did cache it, since the receiver of _bake() assumes subqueryload diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 32975a9495..7736a1290a 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -401,7 +401,6 @@ Example usage:: from .. import exc from .. import util from ..sql import sqltypes -from ..sql import visitors def compiles(class_, *specs): @@ -456,12 +455,12 @@ def compiles(class_, *specs): def deregister(class_): """Remove all custom compilers associated with a given - :class:`_expression.ClauseElement` type.""" + :class:`_expression.ClauseElement` type. + + """ if hasattr(class_, "_compiler_dispatcher"): - # regenerate default _compiler_dispatch - visitors._generate_compiler_dispatch(class_) - # remove custom directive + class_._compiler_dispatch = class_._original_compiler_dispatch del class_._compiler_dispatcher diff --git a/lib/sqlalchemy/future/selectable.py b/lib/sqlalchemy/future/selectable.py index 74cc13501f..407ec96332 100644 --- a/lib/sqlalchemy/future/selectable.py +++ b/lib/sqlalchemy/future/selectable.py @@ -71,6 +71,10 @@ class Select(_LegacySelect): return self.where(*criteria) + def _exported_columns_iterator(self): + meth = SelectState.get_plugin_class(self).exported_columns_iterator + return meth(self) + def _filter_by_zero(self): if self._setup_joins: meth = SelectState.get_plugin_class( diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 5589f0e0ce..ba30d203bd 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -5,7 +5,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php - from . import attributes from . import interfaces from . import loading @@ -27,6 +26,7 @@ from ..sql import expression from ..sql import roles from ..sql import util as sql_util from ..sql import visitors +from ..sql.base import _select_iterables from ..sql.base import CacheableOptions from ..sql.base import CompileState from ..sql.base import Options @@ -90,7 +90,7 @@ class QueryContext(object): self.execution_options = execution_options or _EMPTY_DICT self.bind_arguments = bind_arguments or _EMPTY_DICT self.compile_state = compile_state - self.query = query = compile_state.query + self.query = query = compile_state.select_statement self.session = session self.propagated_loader_options = { @@ -119,10 +119,14 @@ class QueryContext(object): class ORMCompileState(CompileState): + # note this is a dictionary, but the + # default_compile_options._with_polymorphic_adapt_map is a tuple + _with_polymorphic_adapt_map = _EMPTY_DICT + class default_compile_options(CacheableOptions): _cache_key_traversal = [ ("_use_legacy_query_style", InternalTraversal.dp_boolean), - ("_orm_results", InternalTraversal.dp_boolean), + ("_for_statement", InternalTraversal.dp_boolean), ("_bake_ok", InternalTraversal.dp_boolean), ( "_with_polymorphic_adapt_map", @@ -137,8 +141,18 @@ class ORMCompileState(CompileState): ("_for_refresh_state", InternalTraversal.dp_boolean), ] + # set to True by default from Query._statement_20(), to indicate + # the rendered query should look like a legacy ORM query. right + # now this basically indicates we should use tablename_columnname + # style labels. Generally indicates the statement originated + # from a Query object. _use_legacy_query_style = False - _orm_results = True + + # set *only* when we are coming from the Query.statement + # accessor, or a Query-level equivalent such as + # query.subquery(). this supersedes "toplevel". + _for_statement = False + _bake_ok = True _with_polymorphic_adapt_map = () _current_path = _path_registry @@ -149,42 +163,24 @@ class ORMCompileState(CompileState): _set_base_alias = False _for_refresh_state = False - @classmethod - def merge(cls, other): - return cls + other._state_dict() - current_path = _path_registry def __init__(self, *arg, **kw): raise NotImplementedError() - def dispose(self): - self.attributes.clear() - @classmethod def create_for_statement(cls, statement_container, compiler, **kw): - raise NotImplementedError() + """Create a context for a statement given a :class:`.Compiler`. - @classmethod - def _create_for_legacy_query(cls, query, toplevel, for_statement=False): - stmt = query._statement_20(orm_results=not for_statement) - - # this chooses between ORMFromStatementCompileState and - # ORMSelectCompileState. We could also base this on - # query._statement is not None as we have the ORM Query here - # however this is the more general path. - compile_state_cls = CompileState._get_plugin_class_for_plugin( - stmt, "orm" - ) + This method is always invoked in the context of SQLCompiler.process(). - return compile_state_cls._create_for_statement_or_query( - stmt, toplevel, for_statement=for_statement - ) + For a Select object, this would be invoked from + SQLCompiler.visit_select(). For the special FromStatement object used + by Query to indicate "Query.from_statement()", this is called by + FromStatement._compiler_dispatch() that would be called by + SQLCompiler.process(). - @classmethod - def _create_for_statement_or_query( - cls, statement_container, for_statement=False, - ): + """ raise NotImplementedError() @classmethod @@ -266,21 +262,20 @@ class ORMCompileState(CompileState): and ext_info.mapper.persist_selectable not in self._polymorphic_adapters ): - self._mapper_loads_polymorphically_with( - ext_info.mapper, - sql_util.ColumnAdapter( - selectable, ext_info.mapper._equivalent_columns - ), - ) + for mp in ext_info.mapper.iterate_to_root(): + self._mapper_loads_polymorphically_with( + mp, + sql_util.ColumnAdapter(selectable, mp._equivalent_columns), + ) def _mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers or [mapper]: self._polymorphic_adapters[m2] = adapter - for m in m2.iterate_to_root(): + for m in m2.iterate_to_root(): # TODO: redundant ? self._polymorphic_adapters[m.local_table] = adapter -@sql.base.CompileState.plugin_for("orm", "grouping") +@sql.base.CompileState.plugin_for("orm", "orm_from_statement") class ORMFromStatementCompileState(ORMCompileState): _aliased_generations = util.immutabledict() _from_obj_alias = None @@ -294,31 +289,23 @@ class ORMFromStatementCompileState(ORMCompileState): @classmethod def create_for_statement(cls, statement_container, compiler, **kw): - compiler._rewrites_selected_columns = True - toplevel = not compiler.stack - return cls._create_for_statement_or_query( - statement_container, toplevel - ) - @classmethod - def _create_for_statement_or_query( - cls, statement_container, toplevel, for_statement=False, - ): - # from .query import FromStatement - - # assert isinstance(statement_container, FromStatement) + if compiler is not None: + compiler._rewrites_selected_columns = True + toplevel = not compiler.stack + else: + toplevel = True self = cls.__new__(cls) self._primary_entity = None - self.use_orm_style = ( + self.use_legacy_query_style = ( statement_container.compile_options._use_legacy_query_style ) - self.statement_container = self.query = statement_container - self.requested_statement = statement_container.element + self.statement_container = self.select_statement = statement_container + self.requested_statement = statement = statement_container.element self._entities = [] - self._with_polymorphic_adapt_map = {} self._polymorphic_adapters = {} self._no_yield_pers = set() @@ -349,12 +336,6 @@ class ORMFromStatementCompileState(ORMCompileState): self.create_eager_joins = [] self._fallback_from_clauses = [] - self._setup_for_statement() - - return self - - def _setup_for_statement(self): - statement = self.requested_statement if ( isinstance(statement, expression.SelectBase) and not statement._is_textual @@ -392,6 +373,8 @@ class ORMFromStatementCompileState(ORMCompileState): # for entity in self._entities: # entity.setup_compile_state(self) + return self + def _adapt_col_list(self, cols, current_adapter): return cols @@ -401,7 +384,8 @@ class ORMFromStatementCompileState(ORMCompileState): @sql.base.CompileState.plugin_for("orm", "select") class ORMSelectCompileState(ORMCompileState, SelectState): - _joinpath = _joinpoint = util.immutabledict() + _joinpath = _joinpoint = _EMPTY_DICT + _from_obj_alias = None _has_mapper_entities = False @@ -417,77 +401,71 @@ class ORMSelectCompileState(ORMCompileState, SelectState): @classmethod def create_for_statement(cls, statement, compiler, **kw): + """compiler hook, we arrive here from compiler.visit_select() only.""" + if not statement._is_future: return SelectState(statement, compiler, **kw) - toplevel = not compiler.stack + if compiler is not None: + toplevel = not compiler.stack + compiler._rewrites_selected_columns = True + else: + toplevel = True - compiler._rewrites_selected_columns = True + select_statement = statement - orm_state = cls._create_for_statement_or_query( - statement, for_statement=True, toplevel=toplevel - ) - SelectState.__init__(orm_state, orm_state.statement, compiler, **kw) - return orm_state - - @classmethod - def _create_for_statement_or_query( - cls, query, toplevel, for_statement=False, _entities_only=False - ): - assert isinstance(query, future.Select) - - query.compile_options = cls.default_compile_options.merge( - query.compile_options + # if we are a select() that was never a legacy Query, we won't + # have ORM level compile options. + statement.compile_options = cls.default_compile_options.safe_merge( + statement.compile_options ) self = cls.__new__(cls) - self._primary_entity = None - - self.query = query - self.use_orm_style = query.compile_options._use_legacy_query_style + self.select_statement = select_statement - self.select_statement = select_statement = query + # indicates this select() came from Query.statement + self.for_statement = ( + for_statement + ) = select_statement.compile_options._for_statement - if not hasattr(select_statement.compile_options, "_orm_results"): - select_statement.compile_options = cls.default_compile_options - select_statement.compile_options += {"_orm_results": for_statement} - else: - for_statement = not select_statement.compile_options._orm_results + if not for_statement and not toplevel: + # for subqueries, turn off eagerloads. + # if "for_statement" mode is set, Query.subquery() + # would have set this flag to False already if that's what's + # desired + select_statement.compile_options += { + "_enable_eagerloads": False, + } - self.query = query + # generally if we are from Query or directly from a select() + self.use_legacy_query_style = ( + select_statement.compile_options._use_legacy_query_style + ) self._entities = [] - + self._primary_entity = None self._aliased_generations = {} self._polymorphic_adapters = {} self._no_yield_pers = set() # legacy: only for query.with_polymorphic() - self._with_polymorphic_adapt_map = wpam = dict( - select_statement.compile_options._with_polymorphic_adapt_map - ) - if wpam: + if select_statement.compile_options._with_polymorphic_adapt_map: + self._with_polymorphic_adapt_map = dict( + select_statement.compile_options._with_polymorphic_adapt_map + ) self._setup_with_polymorphics() _QueryEntity.to_compile_state(self, select_statement._raw_columns) - if _entities_only: - return self - - self.compile_options = query.compile_options - - # TODO: the name of this flag "for_statement" has to change, - # as it is difficult to distinguish from the "query._statement" use - # case which is something totally different - self.for_statement = for_statement + self.compile_options = select_statement.compile_options # determine label style. we can make different decisions here. # at the moment, trying to see if we can always use DISAMBIGUATE_ONLY # rather than LABEL_STYLE_NONE, and if we can use disambiguate style # for new style ORM selects too. if self.select_statement._label_style is LABEL_STYLE_NONE: - if self.use_orm_style and not for_statement: + if self.use_legacy_query_style and not self.for_statement: self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL else: self.label_style = LABEL_STYLE_DISAMBIGUATE_ONLY @@ -522,129 +500,16 @@ class ORMSelectCompileState(ORMCompileState, SelectState): info.selectable for info in select_statement._from_obj ] + # this is a fairly arbitrary break into a second method, + # so it might be nicer to break up create_for_statement() + # and _setup_for_generate into three or four logical sections self._setup_for_generate() - return self - - @classmethod - def _create_entities_collection(cls, query): - """Creates a partial ORMSelectCompileState that includes - the full collection of _MapperEntity and other _QueryEntity objects. - - Supports a few remaining use cases that are pre-compilation - but still need to gather some of the column / adaption information. - - """ - self = cls.__new__(cls) - - self._entities = [] - self._primary_entity = None - self._aliased_generations = {} - self._polymorphic_adapters = {} - - # legacy: only for query.with_polymorphic() - self._with_polymorphic_adapt_map = wpam = dict( - query.compile_options._with_polymorphic_adapt_map - ) - if wpam: - self._setup_with_polymorphics() + if compiler is not None: + SelectState.__init__(self, self.statement, compiler, **kw) - _QueryEntity.to_compile_state(self, query._raw_columns) return self - @classmethod - def determine_last_joined_entity(cls, statement): - setup_joins = statement._setup_joins - - if not setup_joins: - return None - - (target, onclause, from_, flags) = setup_joins[-1] - - if isinstance(target, interfaces.PropComparator): - return target.entity - else: - return target - - def _setup_with_polymorphics(self): - # legacy: only for query.with_polymorphic() - for ext_info, wp in self._with_polymorphic_adapt_map.items(): - self._mapper_loads_polymorphically_with(ext_info, wp._adapter) - - def _set_select_from_alias(self): - - query = self.select_statement # query - - assert self.compile_options._set_base_alias - assert len(query._from_obj) == 1 - - adapter = self._get_select_from_alias_from_obj(query._from_obj[0]) - if adapter: - self.compile_options += {"_enable_single_crit": False} - self._from_obj_alias = adapter - - def _get_select_from_alias_from_obj(self, from_obj): - info = from_obj - - if "parententity" in info._annotations: - info = info._annotations["parententity"] - - if hasattr(info, "mapper"): - if not info.is_aliased_class: - raise sa_exc.ArgumentError( - "A selectable (FromClause) instance is " - "expected when the base alias is being set." - ) - else: - return info._adapter - - elif isinstance(info.selectable, sql.selectable.AliasedReturnsRows): - equivs = self._all_equivs() - return sql_util.ColumnAdapter(info, equivs) - else: - return None - - def _mapper_zero(self): - """return the Mapper associated with the first QueryEntity.""" - return self._entities[0].mapper - - def _entity_zero(self): - """Return the 'entity' (mapper or AliasedClass) associated - with the first QueryEntity, or alternatively the 'select from' - entity if specified.""" - - for ent in self.from_clauses: - if "parententity" in ent._annotations: - return ent._annotations["parententity"] - for qent in self._entities: - if qent.entity_zero: - return qent.entity_zero - - return None - - def _only_full_mapper_zero(self, methname): - if self._entities != [self._primary_entity]: - raise sa_exc.InvalidRequestError( - "%s() can only be used against " - "a single mapped class." % methname - ) - return self._primary_entity.entity_zero - - def _only_entity_zero(self, rationale=None): - if len(self._entities) > 1: - raise sa_exc.InvalidRequestError( - rationale - or "This operation requires a Query " - "against a single mapper." - ) - return self._entity_zero() - - def _all_equivs(self): - equivs = {} - for ent in self._mapper_entities: - equivs.update(ent.mapper._equivalent_columns) - return equivs - def _setup_for_generate(self): query = self.select_statement @@ -772,6 +637,140 @@ class ORMSelectCompileState(ORMCompileState, SelectState): {"deepentity": ezero} ) + @classmethod + def _create_entities_collection(cls, query): + """Creates a partial ORMSelectCompileState that includes + the full collection of _MapperEntity and other _QueryEntity objects. + + Supports a few remaining use cases that are pre-compilation + but still need to gather some of the column / adaption information. + + """ + self = cls.__new__(cls) + + self._entities = [] + self._primary_entity = None + self._aliased_generations = {} + self._polymorphic_adapters = {} + + # legacy: only for query.with_polymorphic() + if query.compile_options._with_polymorphic_adapt_map: + self._with_polymorphic_adapt_map = dict( + query.compile_options._with_polymorphic_adapt_map + ) + self._setup_with_polymorphics() + + _QueryEntity.to_compile_state(self, query._raw_columns) + return self + + @classmethod + def determine_last_joined_entity(cls, statement): + setup_joins = statement._setup_joins + + if not setup_joins: + return None + + (target, onclause, from_, flags) = setup_joins[-1] + + if isinstance(target, interfaces.PropComparator): + return target.entity + else: + return target + + @classmethod + def exported_columns_iterator(cls, statement): + for element in statement._raw_columns: + if ( + element.is_selectable + and "entity_namespace" in element._annotations + ): + for elem in _select_iterables( + element._annotations["entity_namespace"].columns + ): + yield elem + else: + for elem in _select_iterables([element]): + yield elem + + def _setup_with_polymorphics(self): + # legacy: only for query.with_polymorphic() + for ext_info, wp in self._with_polymorphic_adapt_map.items(): + self._mapper_loads_polymorphically_with(ext_info, wp._adapter) + + def _set_select_from_alias(self): + + query = self.select_statement # query + + assert self.compile_options._set_base_alias + assert len(query._from_obj) == 1 + + adapter = self._get_select_from_alias_from_obj(query._from_obj[0]) + if adapter: + self.compile_options += {"_enable_single_crit": False} + self._from_obj_alias = adapter + + def _get_select_from_alias_from_obj(self, from_obj): + info = from_obj + + if "parententity" in info._annotations: + info = info._annotations["parententity"] + + if hasattr(info, "mapper"): + if not info.is_aliased_class: + raise sa_exc.ArgumentError( + "A selectable (FromClause) instance is " + "expected when the base alias is being set." + ) + else: + return info._adapter + + elif isinstance(info.selectable, sql.selectable.AliasedReturnsRows): + equivs = self._all_equivs() + return sql_util.ColumnAdapter(info, equivs) + else: + return None + + def _mapper_zero(self): + """return the Mapper associated with the first QueryEntity.""" + return self._entities[0].mapper + + def _entity_zero(self): + """Return the 'entity' (mapper or AliasedClass) associated + with the first QueryEntity, or alternatively the 'select from' + entity if specified.""" + + for ent in self.from_clauses: + if "parententity" in ent._annotations: + return ent._annotations["parententity"] + for qent in self._entities: + if qent.entity_zero: + return qent.entity_zero + + return None + + def _only_full_mapper_zero(self, methname): + if self._entities != [self._primary_entity]: + raise sa_exc.InvalidRequestError( + "%s() can only be used against " + "a single mapped class." % methname + ) + return self._primary_entity.entity_zero + + def _only_entity_zero(self, rationale=None): + if len(self._entities) > 1: + raise sa_exc.InvalidRequestError( + rationale + or "This operation requires a Query " + "against a single mapper." + ) + return self._entity_zero() + + def _all_equivs(self): + equivs = {} + for ent in self._mapper_entities: + equivs.update(ent.mapper._equivalent_columns) + return equivs + def _compound_eager_statement(self): # for eager joins present and LIMIT/OFFSET/DISTINCT, # wrap the query inside a select, @@ -920,6 +919,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): statement = Select.__new__(Select) statement._raw_columns = raw_columns statement._from_obj = from_obj + statement._label_style = label_style if where_criteria: @@ -1653,31 +1653,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState): "target." ) - aliased_entity = ( - right_mapper - and not right_is_aliased - and ( - # TODO: there is a reliance here on aliasing occurring - # when we join to a polymorphic mapper that doesn't actually - # need aliasing. When this condition is present, we should - # be able to say mapper_loads_polymorphically_with() - # and render the straight polymorphic selectable. this - # does not appear to be possible at the moment as the - # adapter no longer takes place on the rest of the query - # and it's not clear where that's failing to happen. - ( - right_mapper.with_polymorphic - and isinstance( - right_mapper._with_polymorphic_selectable, - expression.AliasedReturnsRows, - ) - ) - or overlap - # test for overlap: - # orm/inheritance/relationships.py - # SelfReferentialM2MTest - ) - ) + # test for overlap: + # orm/inheritance/relationships.py + # SelfReferentialM2MTest + aliased_entity = right_mapper and not right_is_aliased and overlap if not need_adapter and (create_aliases or aliased_entity): # there are a few places in the ORM that automatic aliasing @@ -1707,7 +1686,30 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self._aliased_generations[aliased_generation] = ( adapter, ) + self._aliased_generations.get(aliased_generation, ()) - + elif ( + not r_info.is_clause_element + and not right_is_aliased + and right_mapper.with_polymorphic + and isinstance( + right_mapper._with_polymorphic_selectable, + expression.AliasedReturnsRows, + ) + ): + # for the case where the target mapper has a with_polymorphic + # set up, ensure an adapter is set up for criteria that works + # against this mapper. Previously, this logic used to + # use the "create_aliases or aliased_entity" case to generate + # an aliased() object, but this creates an alias that isn't + # strictly necessary. + # see test/orm/test_core_compilation.py + # ::RelNaturalAliasedJoinsTest::test_straight + # and similar + self._mapper_loads_polymorphically_with( + right_mapper, + sql_util.ColumnAdapter( + right_mapper.selectable, right_mapper._equivalent_columns, + ), + ) # if the onclause is a ClauseElement, adapt it with any # adapters that are in place right now if isinstance(onclause, expression.ClauseElement): @@ -1755,8 +1757,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState): "offset_clause": self.select_statement._offset_clause, "distinct": self.distinct, "distinct_on": self.distinct_on, - "prefixes": self.query._prefixes, - "suffixes": self.query._suffixes, + "prefixes": self.select_statement._prefixes, + "suffixes": self.select_statement._suffixes, "group_by": self.group_by or None, } @@ -2036,7 +2038,14 @@ class _MapperEntity(_QueryEntity): self._with_polymorphic_mappers = ext_info.with_polymorphic_mappers self._polymorphic_discriminator = ext_info.polymorphic_on - if mapper.with_polymorphic or mapper._requires_row_aliasing: + if ( + mapper.with_polymorphic + # controversy - only if inheriting mapper is also + # polymorphic? + # or (mapper.inherits and mapper.inherits.with_polymorphic) + or mapper.inherits + or mapper._requires_row_aliasing + ): compile_state._create_with_polymorphic_adapter( ext_info, self.selectable ) @@ -2361,7 +2370,7 @@ class _ORMColumnEntity(_ColumnEntity): _entity._post_inspect self.entity_zero = self.entity_zero_or_selectable = ezero = _entity - self.mapper = _entity.mapper + self.mapper = mapper = _entity.mapper if parent_bundle: parent_bundle._entities.append(self) @@ -2373,7 +2382,11 @@ class _ORMColumnEntity(_ColumnEntity): self._extra_entities = (self.expr, self.column) - if self.mapper.with_polymorphic: + if ( + mapper.with_polymorphic + or mapper.inherits + or mapper._requires_row_aliasing + ): compile_state._create_with_polymorphic_adapter( ezero, ezero.selectable ) @@ -2414,6 +2427,7 @@ class _ORMColumnEntity(_ColumnEntity): column = current_adapter(self.column, False) else: column = self.column + ezero = self.entity_zero single_table_crit = self.mapper._single_table_criterion diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 88d01eb0f2..424ed5dfee 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -345,7 +345,7 @@ def load_on_pk_identity( if load_options is None: load_options = QueryContext.default_load_options - compile_options = ORMCompileState.default_compile_options.merge( + compile_options = ORMCompileState.default_compile_options.safe_merge( q.compile_options ) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 7bfe70c36b..4166e6d2a9 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1743,6 +1743,11 @@ class Mapper( or prop.columns[0] is self.polymorphic_on ) + if isinstance(col, expression.Label): + # new in 1.4, get column property against expressions + # to be addressable in subqueries + col.key = col._key_label = key + self.columns.add(col, key) for col in prop.columns + prop._orig_columns: for col in col.proxy_set: @@ -2282,6 +2287,29 @@ class Mapper( ) ) + def _columns_plus_keys(self, polymorphic_mappers=()): + if polymorphic_mappers: + poly_properties = self._iterate_polymorphic_properties( + polymorphic_mappers + ) + else: + poly_properties = self._polymorphic_properties + + return [ + (prop.key, prop.columns[0]) + for prop in poly_properties + if isinstance(prop, properties.ColumnProperty) + ] + + @HasMemoized.memoized_attribute + def _polymorphic_adapter(self): + if self.with_polymorphic: + return sql_util.ColumnAdapter( + self.selectable, equivalents=self._equivalent_columns + ) + else: + return None + def _iterate_polymorphic_properties(self, mappers=None): """Return an iterator of MapperProperty objects which will render into a SELECT.""" diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 4cf501e3f3..02f0752a54 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -264,6 +264,7 @@ class ColumnProperty(StrategizedProperty): def do_init(self): super(ColumnProperty, self).do_init() + if len(self.columns) > 1 and set(self.parent.primary_key).issuperset( self.columns ): @@ -339,28 +340,51 @@ class ColumnProperty(StrategizedProperty): __slots__ = "__clause_element__", "info", "expressions" + def _orm_annotate_column(self, column): + """annotate and possibly adapt a column to be returned + as the mapped-attribute exposed version of the column. + + The column in this context needs to act as much like the + column in an ORM mapped context as possible, so includes + annotations to give hints to various ORM functions as to + the source entity of this column. It also adapts it + to the mapper's with_polymorphic selectable if one is + present. + + """ + + pe = self._parententity + annotations = { + "entity_namespace": pe, + "parententity": pe, + "parentmapper": pe, + "orm_key": self.prop.key, + } + + col = column + + # for a mapper with polymorphic_on and an adapter, return + # the column against the polymorphic selectable. + # see also orm.util._orm_downgrade_polymorphic_columns + # for the reverse operation. + if self._parentmapper._polymorphic_adapter: + mapper_local_col = col + col = self._parentmapper._polymorphic_adapter.traverse(col) + + # this is a clue to the ORM Query etc. that this column + # was adapted to the mapper's polymorphic_adapter. the + # ORM uses this hint to know which column its adapting. + annotations["adapt_column"] = mapper_local_col + + return col._annotate(annotations)._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": pe} + ) + def _memoized_method___clause_element__(self): if self.adapter: return self.adapter(self.prop.columns[0], self.prop.key) else: - pe = self._parententity - # no adapter, so we aren't aliased - # assert self._parententity is self._parentmapper - return ( - self.prop.columns[0] - ._annotate( - { - "entity_namespace": pe, - "parententity": pe, - "parentmapper": pe, - "orm_key": self.prop.key, - "compile_state_plugin": "orm", - } - ) - ._set_propagate_attrs( - {"compile_state_plugin": "orm", "plugin_subject": pe} - ) - ) + return self._orm_annotate_column(self.prop.columns[0]) def _memoized_attr_info(self): """The .info dictionary for this attribute.""" @@ -384,23 +408,8 @@ class ColumnProperty(StrategizedProperty): for col in self.prop.columns ] else: - # no adapter, so we aren't aliased - # assert self._parententity is self._parentmapper return [ - col._annotate( - { - "parententity": self._parententity, - "parentmapper": self._parententity, - "orm_key": self.prop.key, - "compile_state_plugin": "orm", - } - )._set_propagate_attrs( - { - "compile_state_plugin": "orm", - "plugin_subject": self._parententity, - } - ) - for col in self.prop.columns + self._orm_annotate_column(col) for col in self.prop.columns ] def _fallback_getattr(self, key): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 97a81e30fa..5137f9b1d4 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -360,7 +360,7 @@ class Query( ): # if we don't have legacy top level aliasing features in use # then convert to a future select() directly - stmt = self._statement_20() + stmt = self._statement_20(for_statement=True) else: stmt = self._compile_state(for_statement=True).statement @@ -371,7 +371,24 @@ class Query( return stmt - def _statement_20(self, orm_results=False): + def _final_statement(self, legacy_query_style=True): + """Return the 'final' SELECT statement for this :class:`.Query`. + + This is the Core-only select() that will be rendered by a complete + compilation of this query, and is what .statement used to return + in 1.3. + + This method creates a complete compile state so is fairly expensive. + + """ + + q = self._clone() + + return q._compile_state( + use_legacy_query_style=legacy_query_style + ).statement + + def _statement_20(self, for_statement=False, use_legacy_query_style=True): # TODO: this event needs to be deprecated, as it currently applies # only to ORM query and occurs at this spot that is now more # or less an artificial spot @@ -384,7 +401,10 @@ class Query( self.compile_options += {"_bake_ok": False} compile_options = self.compile_options - compile_options += {"_use_legacy_query_style": True} + compile_options += { + "_for_statement": for_statement, + "_use_legacy_query_style": use_legacy_query_style, + } if self._statement is not None: stmt = FromStatement(self._raw_columns, self._statement) @@ -404,13 +424,16 @@ class Query( compile_options=compile_options, ) - if not orm_results: - stmt.compile_options += {"_orm_results": False} - stmt._propagate_attrs = self._propagate_attrs return stmt - def subquery(self, name=None, with_labels=False, reduce_columns=False): + def subquery( + self, + name=None, + with_labels=False, + reduce_columns=False, + _legacy_core_statement=False, + ): """return the full SELECT statement represented by this :class:`_query.Query`, embedded within an :class:`_expression.Alias`. @@ -436,7 +459,11 @@ class Query( q = self.enable_eagerloads(False) if with_labels: q = q.with_labels() - q = q.statement + + if _legacy_core_statement: + q = q._compile_state(for_statement=True).statement + else: + q = q.statement if reduce_columns: q = q.reduce_columns() @@ -943,7 +970,7 @@ class Query( # tablename_colname style is used which at the moment is asserted # in a lot of unit tests :) - statement = self._statement_20(orm_results=True).apply_labels() + statement = self._statement_20().apply_labels() return db_load_fn( self.session, statement, @@ -1328,13 +1355,13 @@ class Query( self.with_labels() .enable_eagerloads(False) .correlate(None) - .subquery() + .subquery(_legacy_core_statement=True) ._anonymous_fromclause() ) parententity = self._raw_columns[0]._annotations.get("parententity") if parententity: - ac = aliased(parententity, alias=fromclause) + ac = aliased(parententity.mapper, alias=fromclause) q = self._from_selectable(ac) else: q = self._from_selectable(fromclause) @@ -2782,7 +2809,7 @@ class Query( def _iter(self): # new style execution. params = self.load_options._params - statement = self._statement_20(orm_results=True) + statement = self._statement_20() result = self.session.execute( statement, params, @@ -2808,7 +2835,7 @@ class Query( ) def __str__(self): - statement = self._statement_20(orm_results=True) + statement = self._statement_20() try: bind = ( @@ -2879,9 +2906,8 @@ class Query( "for linking ORM results to arbitrary select constructs.", version="1.4", ) - compile_state = ORMCompileState._create_for_legacy_query( - self, toplevel=True - ) + compile_state = self._compile_state(for_statement=False) + context = QueryContext( compile_state, self.session, self.load_options ) @@ -3294,10 +3320,35 @@ class Query( return update_op.rowcount def _compile_state(self, for_statement=False, **kw): - return ORMCompileState._create_for_legacy_query( - self, toplevel=True, for_statement=for_statement, **kw + """Create an out-of-compiler ORMCompileState object. + + The ORMCompileState object is normally created directly as a result + of the SQLCompiler.process() method being handed a Select() + or FromStatement() object that uses the "orm" plugin. This method + provides a means of creating this ORMCompileState object directly + without using the compiler. + + This method is used only for deprecated cases, which include + the .from_self() method for a Query that has multiple levels + of .from_self() in use, as well as the instances() method. It is + also used within the test suite to generate ORMCompileState objects + for test purposes. + + """ + + stmt = self._statement_20(for_statement=for_statement, **kw) + assert for_statement == stmt.compile_options._for_statement + + # this chooses between ORMFromStatementCompileState and + # ORMSelectCompileState. We could also base this on + # query._statement is not None as we have the ORM Query here + # however this is the more general path. + compile_state_cls = ORMCompileState._get_plugin_class_for_plugin( + stmt, "orm" ) + return compile_state_cls.create_for_statement(stmt, None) + def _compile_context(self, for_statement=False): compile_state = self._compile_state(for_statement=for_statement) context = QueryContext(compile_state, self.session, self.load_options) @@ -3311,6 +3362,8 @@ class FromStatement(SelectStatementGrouping, Executable): """ + __visit_name__ = "orm_from_statement" + compile_options = ORMFromStatementCompileState.default_compile_options _compile_state_factory = ORMFromStatementCompileState.create_for_statement @@ -3329,6 +3382,14 @@ class FromStatement(SelectStatementGrouping, Executable): super(FromStatement, self).__init__(element) def _compiler_dispatch(self, compiler, **kw): + + """provide a fixed _compiler_dispatch method. + + This is roughly similar to using the sqlalchemy.ext.compiler + ``@compiles`` extension. + + """ + compile_state = self._compile_state_factory(self, compiler, **kw) toplevel = not compiler.stack diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index e82cd174fc..683f2b9787 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -1170,9 +1170,9 @@ class RelationshipProperty(StrategizedProperty): def __clause_element__(self): adapt_from = self._source_selectable() if self._of_type: - of_type_mapper = inspect(self._of_type).mapper + of_type_entity = inspect(self._of_type) else: - of_type_mapper = None + of_type_entity = None ( pj, @@ -1184,7 +1184,7 @@ class RelationshipProperty(StrategizedProperty): ) = self.property._create_joins( source_selectable=adapt_from, source_polymorphic=True, - of_type_mapper=of_type_mapper, + of_type_entity=of_type_entity, alias_secondary=True, ) if sj is not None: @@ -1311,7 +1311,6 @@ class RelationshipProperty(StrategizedProperty): secondary, target_adapter, ) = self.property._create_joins( - dest_polymorphic=True, dest_selectable=to_selectable, source_selectable=source_selectable, ) @@ -2424,9 +2423,8 @@ class RelationshipProperty(StrategizedProperty): self, source_polymorphic=False, source_selectable=None, - dest_polymorphic=False, dest_selectable=None, - of_type_mapper=None, + of_type_entity=None, alias_secondary=False, ): @@ -2439,9 +2437,17 @@ class RelationshipProperty(StrategizedProperty): if source_polymorphic and self.parent.with_polymorphic: source_selectable = self.parent._with_polymorphic_selectable + if of_type_entity: + dest_mapper = of_type_entity.mapper + if dest_selectable is None: + dest_selectable = of_type_entity.selectable + aliased = True + else: + dest_mapper = self.mapper + if dest_selectable is None: dest_selectable = self.entity.selectable - if dest_polymorphic and self.mapper.with_polymorphic: + if self.mapper.with_polymorphic: aliased = True if self._is_self_referential and source_selectable is None: @@ -2453,8 +2459,6 @@ class RelationshipProperty(StrategizedProperty): ): aliased = True - dest_mapper = of_type_mapper or self.mapper - single_crit = dest_mapper._single_table_criterion aliased = aliased or ( source_selectable is not None diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 626018997a..2b8c384c9a 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1143,7 +1143,7 @@ class SubqueryLoader(PostLoader): ) = self._get_leftmost(subq_path) orig_query = compile_state.attributes.get( - ("orig_query", SubqueryLoader), compile_state.query + ("orig_query", SubqueryLoader), compile_state.select_statement ) # generate a new Query from the original, then diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index ce37d962e8..85f4f85d1e 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -466,7 +466,7 @@ class AliasedClass(object): def __init__( self, - cls, + mapped_class_or_ac, alias=None, name=None, flat=False, @@ -478,7 +478,9 @@ class AliasedClass(object): use_mapper_path=False, represents_outer_join=False, ): - mapper = _class_to_mapper(cls) + insp = inspection.inspect(mapped_class_or_ac) + mapper = insp.mapper + if alias is None: alias = mapper._with_polymorphic_selectable._anonymous_fromclause( name=name, flat=flat @@ -486,7 +488,7 @@ class AliasedClass(object): self._aliased_insp = AliasedInsp( self, - mapper, + insp, alias, name, with_polymorphic_mappers @@ -617,7 +619,7 @@ class AliasedInsp( def __init__( self, entity, - mapper, + inspected, selectable, name, with_polymorphic_mappers, @@ -627,6 +629,10 @@ class AliasedInsp( adapt_on_names, represents_outer_join, ): + + mapped_class_or_ac = inspected.entity + mapper = inspected.mapper + self._weak_entity = weakref.ref(entity) self.mapper = mapper self.selectable = ( @@ -665,9 +671,12 @@ class AliasedInsp( adapt_on_names=adapt_on_names, anonymize_labels=True, ) + if inspected.is_aliased_class: + self._adapter = inspected._adapter.wrap(self._adapter) self._adapt_on_names = adapt_on_names - self._target = mapper.class_ + self._target = mapped_class_or_ac + # self._target = mapper.class_ # mapped_class_or_ac @property def entity(self): @@ -795,6 +804,21 @@ class AliasedInsp( def _memoized_values(self): return {} + @util.memoized_property + def columns(self): + if self._is_with_polymorphic: + cols_plus_keys = self.mapper._columns_plus_keys( + [ent.mapper for ent in self._with_polymorphic_entities] + ) + else: + cols_plus_keys = self.mapper._columns_plus_keys() + + cols_plus_keys = [ + (key, self._adapt_element(col)) for key, col in cols_plus_keys + ] + + return ColumnCollection(cols_plus_keys) + def _memo(self, key, callable_, *args, **kw): if key in self._memoized_values: return self._memoized_values[key] @@ -1290,8 +1314,7 @@ class _ORMJoin(expression.Join): source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, - dest_polymorphic=True, - of_type_mapper=right_info.mapper, + of_type_entity=right_info, alias_secondary=True, ) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 6415d4b370..f143190890 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -522,7 +522,12 @@ class _MetaOptions(type): def __init__(cls, classname, bases, dict_): cls._cache_attrs = tuple( - sorted(d for d in dict_ if not d.startswith("__")) + sorted( + d + for d in dict_ + if not d.startswith("__") + and d not in ("_cache_key_traversal",) + ) ) type.__init__(cls, classname, bases, dict_) @@ -561,6 +566,31 @@ class Options(util.with_metaclass(_MetaOptions)): def _state_dict(cls): return cls._state_dict_const + @classmethod + def safe_merge(cls, other): + d = other._state_dict() + + # only support a merge with another object of our class + # and which does not have attrs that we dont. otherwise + # we risk having state that might not be part of our cache + # key strategy + + if ( + cls is not other.__class__ + and other._cache_attrs + and set(other._cache_attrs).difference(cls._cache_attrs) + ): + raise TypeError( + "other element %r is not empty, is not of type %s, " + "and contains attributes not covered here %r" + % ( + other, + cls, + set(other._cache_attrs).difference(cls._cache_attrs), + ) + ) + return cls + d + class CacheableOptions(Options, HasCacheKey): @hybridmethod diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 287e537242..fa2888a23e 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -878,6 +878,7 @@ class ColumnElement( key = self._proxy_key else: key = name + co = ColumnClause( coercions.expect(roles.TruncatedLabelRole, name) if name_is_truncatable @@ -885,6 +886,7 @@ class ColumnElement( type_=getattr(self, "type", None), _selectable=selectable, ) + co._propagate_attrs = selectable._propagate_attrs co._proxies = [self] if selectable._is_clone_of is not None: @@ -1284,6 +1286,7 @@ class BindParameter(roles.InElementRole, ColumnElement): """ + if required is NO_ARG: required = value is NO_ARG and callable_ is None if value is NO_ARG: @@ -1302,6 +1305,7 @@ class BindParameter(roles.InElementRole, ColumnElement): id(self), re.sub(r"[%\(\) \$]+", "_", key).strip("_") if key is not None + and not isinstance(key, _anonymous_label) else "param", ) ) @@ -4182,16 +4186,27 @@ class Label(roles.LabeledColumnExprRole, ColumnElement): return self.element._from_objects def _make_proxy(self, selectable, name=None, **kw): + name = self.name if not name else name + key, e = self.element._make_proxy( selectable, - name=name if name else self.name, + name=name, disallow_is_literal=True, + name_is_truncatable=isinstance(name, _truncated_label), ) + # TODO: want to remove this assertion at some point. all + # _make_proxy() implementations will give us back the key that + # is our "name" in the first place. based on this we can + # safely return our "self.key" as the key here, to support a new + # case where the key and name are separate. + assert key == self.name + e._propagate_attrs = selectable._propagate_attrs e._proxies.append(self) if self._type is not None: e.type = self._type - return key, e + + return self.key, e class ColumnClause( @@ -4240,7 +4255,7 @@ class ColumnClause( __visit_name__ = "column" _traverse_internals = [ - ("name", InternalTraversal.dp_string), + ("name", InternalTraversal.dp_anon_name), ("type", InternalTraversal.dp_type), ("table", InternalTraversal.dp_clauseelement), ("is_literal", InternalTraversal.dp_boolean), @@ -4410,10 +4425,8 @@ class ColumnClause( def _gen_label(self, name, dedupe_on_key=True): t = self.table - if self.is_literal: return None - elif t is not None and t.named_with_column: if getattr(t, "schema", None): label = t.schema.replace(".", "_") + "_" + t.name + "_" + name diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 170e016a56..d6845e05f7 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3451,8 +3451,8 @@ class SelectState(util.MemoizedSlots, CompileState): self.columns_plus_names = statement._generate_columns_plus_names(True) def _get_froms(self, statement): - froms = [] seen = set() + froms = [] for item in itertools.chain( itertools.chain.from_iterable( @@ -3474,6 +3474,16 @@ class SelectState(util.MemoizedSlots, CompileState): froms.append(item) seen.update(item._cloned_set) + toremove = set( + itertools.chain.from_iterable( + [_expand_cloned(f._hide_froms) for f in froms] + ) + ) + if toremove: + # filter out to FROM clauses not in the list, + # using a list to maintain ordering + froms = [f for f in froms if f not in toremove] + return froms def _get_display_froms( @@ -3490,16 +3500,6 @@ class SelectState(util.MemoizedSlots, CompileState): froms = self.froms - toremove = set( - itertools.chain.from_iterable( - [_expand_cloned(f._hide_froms) for f in froms] - ) - ) - if toremove: - # filter out to FROM clauses not in the list, - # using a list to maintain ordering - froms = [f for f in froms if f not in toremove] - if self.statement._correlate: to_correlate = self.statement._correlate if to_correlate: @@ -3557,7 +3557,7 @@ class SelectState(util.MemoizedSlots, CompileState): def _memoized_attr__label_resolve_dict(self): with_cols = dict( (c._resolve_label or c._label or c.key, c) - for c in _select_iterables(self.statement._raw_columns) + for c in self.statement._exported_columns_iterator() if c._allow_label_resolve ) only_froms = dict( @@ -3578,6 +3578,10 @@ class SelectState(util.MemoizedSlots, CompileState): else: return None + @classmethod + def exported_columns_iterator(cls, statement): + return _select_iterables(statement._raw_columns) + def _setup_joins(self, args): for (right, onclause, left, flags) in args: isouter = flags["isouter"] @@ -4599,7 +4603,7 @@ class Select( pa = None collection = [] - for c in _select_iterables(self._raw_columns): + for c in self._exported_columns_iterator(): # we use key_label since this name is intended for targeting # within the ColumnCollection only, it's not related to SQL # rendering which always uses column name for SQL label names @@ -4630,7 +4634,7 @@ class Select( return self def _generate_columns_plus_names(self, anon_for_dupe_key): - cols = _select_iterables(self._raw_columns) + cols = self._exported_columns_iterator() # when use_labels is on: # in all cases == if we see the same label name, use _label_anon_label diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index a38088a27b..388097e45a 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -18,6 +18,7 @@ NO_CACHE = util.symbol("no_cache") CACHE_IN_PLACE = util.symbol("cache_in_place") CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key") STATIC_CACHE_KEY = util.symbol("static_cache_key") +ANON_NAME = util.symbol("anon_name") def compare(obj1, obj2, **kw): @@ -33,6 +34,7 @@ class HasCacheKey(object): _cache_key_traversal = NO_CACHE __slots__ = () + @util.preload_module("sqlalchemy.sql.elements") def _gen_cache_key(self, anon_map, bindparams): """return an optional cache key. @@ -54,6 +56,8 @@ class HasCacheKey(object): """ + elements = util.preloaded.sql_elements + idself = id(self) if anon_map is not None: @@ -102,6 +106,10 @@ class HasCacheKey(object): result += (attrname, obj) elif meth is STATIC_CACHE_KEY: result += (attrname, obj._static_cache_key) + elif meth is ANON_NAME: + if elements._anonymous_label in obj.__class__.__mro__: + obj = obj.apply_map(anon_map) + result += (attrname, obj) elif meth is CALL_GEN_CACHE_KEY: result += ( attrname, @@ -321,6 +329,7 @@ class _CacheKey(ExtendedInternalTraversal): ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE visit_statement_hint_list = CACHE_IN_PLACE visit_type = STATIC_CACHE_KEY + visit_anon_name = ANON_NAME def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams): return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) @@ -387,15 +396,6 @@ class _CacheKey(ExtendedInternalTraversal): attrname, obj, parent, anon_map, bindparams ) - def visit_anon_name(self, attrname, obj, parent, anon_map, bindparams): - from . import elements - - name = obj - if isinstance(name, elements._anonymous_label): - name = name.apply_map(anon_map) - - return (attrname, name) - def visit_fromclause_ordered_set( self, attrname, obj, parent, anon_map, bindparams ): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 377aa4fe01..e8726000b8 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -822,9 +822,14 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): # is another join or selectable that contains a table which our # selectable derives from, that we want to process return None + elif not isinstance(col, ColumnElement): return None - elif self.include_fn and not self.include_fn(col): + + if "adapt_column" in col._annotations: + col = col._annotations["adapt_column"] + + if self.include_fn and not self.include_fn(col): return None elif self.exclude_fn and self.exclude_fn(col): return None diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 683f545dd0..5de68f504f 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -50,6 +50,13 @@ def _generate_compiler_dispatch(cls): """ visit_name = cls.__visit_name__ + if "_compiler_dispatch" in cls.__dict__: + # class has a fixed _compiler_dispatch() method. + # copy it to "original" so that we can get it back if + # sqlalchemy.ext.compiles overrides it. + cls._original_compiler_dispatch = cls._compiler_dispatch + return + if not isinstance(visit_name, util.compat.string_types): raise exc.InvalidRequestError( "__visit_name__ on class %s must be a string at the class level" @@ -76,7 +83,9 @@ def _generate_compiler_dispatch(cls): + self.__visit_name__ on the visitor, and call it with the same kw params. """ - cls._compiler_dispatch = _compiler_dispatch + cls._compiler_dispatch = ( + cls._original_compiler_dispatch + ) = _compiler_dispatch class TraversibleType(type): diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 79de3c9783..247dbc13c3 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -399,7 +399,7 @@ def reraise(tp, value, tb=None, cause=None): raise_(value, with_traceback=tb, from_=cause) -def with_metaclass(meta, *bases): +def with_metaclass(meta, *bases, **kw): """Create a base class with a metaclass. Drops the middle class upon creation. @@ -414,8 +414,15 @@ def with_metaclass(meta, *bases): def __new__(cls, name, this_bases, d): if this_bases is None: - return type.__new__(cls, name, (), d) - return meta(name, bases, d) + cls = type.__new__(cls, name, (), d) + else: + cls = meta(name, bases, d) + + if hasattr(cls, "__init_subclass__") and hasattr( + cls.__init_subclass__, "__func__" + ): + cls.__init_subclass__.__func__(cls, **kw) + return cls return metaclass("temporary_class", None, {}) diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index f261bc8119..5dbbc2f5c1 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -857,7 +857,10 @@ class JoinedEagerLoadTest(fixtures.MappedTest): exec_opts = {} bind_arguments = {} ORMCompileState.orm_pre_session_exec( - sess, compile_state.query, exec_opts, bind_arguments + sess, + compile_state.select_statement, + exec_opts, + bind_arguments, ) r = sess.connection().execute( diff --git a/test/orm/inheritance/_poly_fixtures.py b/test/orm/inheritance/_poly_fixtures.py index 5d23e7801c..da2ad4cdf8 100644 --- a/test/orm/inheritance/_poly_fixtures.py +++ b/test/orm/inheritance/_poly_fixtures.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import polymorphic_union from sqlalchemy.orm import relationship +from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import config from sqlalchemy.testing import fixtures @@ -54,6 +55,8 @@ class _PolymorphicFixtureBase(fixtures.MappedTest, AssertsCompiledSQL): run_setup_mappers = "once" run_deletes = None + label_style = LABEL_STYLE_TABLENAME_PLUS_COL + @classmethod def define_tables(cls, metadata): global people, engineers, managers, boss @@ -427,14 +430,16 @@ class _PolymorphicAliasedJoins(_PolymorphicFixtureBase): person_join = ( people.outerjoin(engineers) .outerjoin(managers) - .select(use_labels=True) - .alias("pjoin") + .select() + ._set_label_style(cls.label_style) + .subquery("pjoin") ) manager_join = ( people.join(managers) .outerjoin(boss) - .select(use_labels=True) - .alias("mjoin") + .select() + ._set_label_style(cls.label_style) + .subquery("mjoin") ) person_with_polymorphic = ([Person, Manager, Engineer], person_join) manager_with_polymorphic = ("*", manager_join) diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index 62f2097d39..514f4ba76f 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -1342,6 +1342,7 @@ class GenerativeTest(fixtures.MappedTest, AssertsExecutionResults): ) .order_by(Person.name) ) + eq_( list(r), [ diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 5e832e9345..e38758ee2d 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -1357,6 +1357,7 @@ class EagerTargetingTest(fixtures.MappedTest): bid = b1.id sess.expunge_all() + node = sess.query(B).filter(B.id == bid).all()[0] eq_(node, B(id=1, name="b1", b_data="i")) eq_(node.children[0], B(id=2, name="b2", b_data="l")) diff --git a/test/orm/inheritance/test_polymorphic_rel.py b/test/orm/inheritance/test_polymorphic_rel.py index 5494145078..e7e2530b29 100644 --- a/test/orm/inheritance/test_polymorphic_rel.py +++ b/test/orm/inheritance/test_polymorphic_rel.py @@ -1522,6 +1522,32 @@ class _PolymorphicTestBase(object): expected, ) + def test_self_referential_two_newstyle(self): + # TODO: this is the first test *EVER* of an aliased class of + # an aliased class. we should add many more tests for this. + # new case added in Id810f485c5f7ed971529489b84694e02a3356d6d + sess = create_session() + expected = [(m1, e1), (m1, e2), (m1, b1)] + + p1 = aliased(Person) + p2 = aliased(Person) + stmt = ( + future_select(p1, p2) + .filter(p1.company_id == p2.company_id) + .filter(p1.name == "dogbert") + .filter(p1.person_id > p2.person_id) + ) + subq = stmt.subquery() + + pa1 = aliased(p1, subq) + pa2 = aliased(p2, subq) + + stmt = future_select(pa1, pa2).order_by(pa1.person_id, pa2.person_id) + + eq_( + sess.execute(stmt).unique().all(), expected, + ) + def test_nesting_queries(self): # query.statement places a flag "no_adapt" on the returned # statement. This prevents the polymorphic adaptation in the diff --git a/test/orm/inheritance/test_relationship.py b/test/orm/inheritance/test_relationship.py index ea5b9f96bf..5ba482649f 100644 --- a/test/orm/inheritance/test_relationship.py +++ b/test/orm/inheritance/test_relationship.py @@ -6,6 +6,7 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.orm import aliased from sqlalchemy.orm import backref +from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import contains_eager from sqlalchemy.orm import create_session from sqlalchemy.orm import joinedload @@ -765,6 +766,48 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): "secondary_1.left_id", ) + def test_query_crit_core_workaround(self): + # do a test in the style of orm/test_core_compilation.py + + Child1, Child2 = self.classes.Child1, self.classes.Child2 + secondary = self.tables.secondary + + configure_mappers() + + from sqlalchemy.sql import join + + C1 = aliased(Child1, flat=True) + + # figure out all the things we need to do in Core to make + # the identical query that the ORM renders. + + salias = secondary.alias() + stmt = ( + select([Child2]) + .select_from( + join( + Child2, + salias, + Child2.id.expressions[1] == salias.c.left_id, + ).join(C1, salias.c.right_id == C1.id.expressions[1]) + ) + .where(C1.left_child2 == Child2(id=1)) + ) + + self.assert_compile( + stmt.apply_labels(), + "SELECT parent.id AS parent_id, " + "parent.cls AS parent_cls, child2.id AS child2_id " + "FROM secondary AS secondary_1, " + "parent JOIN child2 ON parent.id = child2.id JOIN secondary AS " + "secondary_2 ON parent.id = secondary_2.left_id JOIN " + "(parent AS parent_1 JOIN child1 AS child1_1 " + "ON parent_1.id = child1_1.id) " + "ON parent_1.id = secondary_2.right_id WHERE " + "parent_1.id = secondary_1.right_id AND :param_1 = " + "secondary_1.left_id", + ) + def test_eager_join(self): Child1, Child2 = self.classes.Child1, self.classes.Child2 sess = create_session() diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 61df1d277e..a26d0ae267 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -2,21 +2,30 @@ from sqlalchemy import exc from sqlalchemy import func from sqlalchemy import insert from sqlalchemy import literal_column +from sqlalchemy import or_ from sqlalchemy import testing +from sqlalchemy import util from sqlalchemy.future import select from sqlalchemy.orm import aliased from sqlalchemy.orm import column_property +from sqlalchemy.orm import contains_eager from sqlalchemy.orm import join as orm_join +from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper +from sqlalchemy.orm import query_expression +from sqlalchemy.orm import relationship from sqlalchemy.orm import Session +from sqlalchemy.orm import with_expression from sqlalchemy.orm import with_polymorphic from sqlalchemy.sql.selectable import Join as core_join +from sqlalchemy.sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY +from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import eq_ from .inheritance import _poly_fixtures from .test_query import QueryTest - # TODO: # composites / unions, etc. @@ -178,6 +187,344 @@ class JoinTest(QueryTest, AssertsCompiledSQL): ) +class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL): + """The Query object calls eanble_eagerloads(False) when you call + .subquery(). With Core select, we don't have that information, we instead + have to look at the "toplevel" flag to know where we are. make sure + the many different combinations that these two objects and still + too many flags at the moment work as expected on the outside. + + """ + + __dialect__ = "default" + + run_setup_mappers = None + + @testing.fixture + def joinedload_fixture(self): + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={"addresses": relationship(Address, lazy="joined")}, + ) + + mapper(Address, addresses) + + return User, Address + + def test_no_joinedload_in_subquery_select_rows(self, joinedload_fixture): + User, Address = joinedload_fixture + + sess = Session() + stmt1 = sess.query(User).subquery() + stmt1 = sess.query(stmt1) + + stmt2 = select(User).subquery() + + stmt2 = select(stmt2) + + expected = ( + "SELECT anon_1.id, anon_1.name FROM " + "(SELECT users.id AS id, users.name AS name " + "FROM users) AS anon_1" + ) + self.assert_compile( + stmt1._final_statement(legacy_query_style=False), expected, + ) + + self.assert_compile(stmt2, expected) + + def test_no_joinedload_in_subquery_select_entity(self, joinedload_fixture): + User, Address = joinedload_fixture + + sess = Session() + stmt1 = sess.query(User).subquery() + ua = aliased(User, stmt1) + stmt1 = sess.query(ua) + + stmt2 = select(User).subquery() + + ua = aliased(User, stmt2) + stmt2 = select(ua) + + expected = ( + "SELECT anon_1.id, anon_1.name, addresses_1.id AS id_1, " + "addresses_1.user_id, addresses_1.email_address FROM " + "(SELECT users.id AS id, users.name AS name FROM users) AS anon_1 " + "LEFT OUTER JOIN addresses AS addresses_1 " + "ON anon_1.id = addresses_1.user_id" + ) + + self.assert_compile( + stmt1._final_statement(legacy_query_style=False), expected, + ) + + self.assert_compile(stmt2, expected) + + # TODO: need to test joinedload options, deferred mappings, deferred + # options. these are all loader options that should *only* have an + # effect on the outermost statement, never a subquery. + + +class ExtraColsTest(QueryTest, AssertsCompiledSQL): + __dialect__ = "default" + + run_setup_mappers = None + + @testing.fixture + def query_expression_fixture(self): + users, User = ( + self.tables.users, + self.classes.User, + ) + + mapper( + User, + users, + properties=util.OrderedDict([("value", query_expression())]), + ) + return User + + @testing.fixture + def column_property_fixture(self): + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties=util.OrderedDict( + [ + ("concat", column_property((users.c.id * 2))), + ( + "count", + column_property( + select(func.count(addresses.c.id)) + .where(users.c.id == addresses.c.user_id,) + .correlate(users) + .scalar_subquery() + ), + ), + ] + ), + ) + + mapper(Address, addresses, properties={"user": relationship(User,)}) + + return User, Address + + @testing.fixture + def plain_fixture(self): + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, users, + ) + + mapper(Address, addresses, properties={"user": relationship(User,)}) + + return User, Address + + def test_no_joinedload_embedded(self, plain_fixture): + User, Address = plain_fixture + + stmt = select(Address).options(joinedload(Address.user)) + + subq = stmt.subquery() + + s2 = select(subq) + + self.assert_compile( + s2, + "SELECT anon_1.id, anon_1.user_id, anon_1.email_address " + "FROM (SELECT addresses.id AS id, addresses.user_id AS " + "user_id, addresses.email_address AS email_address " + "FROM addresses) AS anon_1", + ) + + def test_with_expr_one(self, query_expression_fixture): + User = query_expression_fixture + + stmt = select(User).options( + with_expression(User.value, User.name + "foo") + ) + + self.assert_compile( + stmt, + "SELECT users.name || :name_1 AS anon_1, users.id, " + "users.name FROM users", + ) + + def test_with_expr_two(self, query_expression_fixture): + User = query_expression_fixture + + stmt = select(User.id, User.name, (User.name + "foo").label("foo")) + + subq = stmt.subquery() + u1 = aliased(User, subq) + + stmt = select(u1).options(with_expression(u1.value, subq.c.foo)) + + self.assert_compile( + stmt, + "SELECT anon_1.foo, anon_1.id, anon_1.name FROM " + "(SELECT users.id AS id, users.name AS name, " + "users.name || :name_1 AS foo FROM users) AS anon_1", + ) + + def test_joinedload_outermost(self, plain_fixture): + User, Address = plain_fixture + + stmt = select(Address).options(joinedload(Address.user)) + + # render joined eager loads with stringify + self.assert_compile( + stmt, + "SELECT addresses.id, addresses.user_id, addresses.email_address, " + "users_1.id AS id_1, users_1.name FROM addresses " + "LEFT OUTER JOIN users AS users_1 " + "ON users_1.id = addresses.user_id", + ) + + def test_contains_eager_outermost(self, plain_fixture): + User, Address = plain_fixture + + stmt = ( + select(Address) + .join(Address.user) + .options(contains_eager(Address.user)) + ) + + # render joined eager loads with stringify + self.assert_compile( + stmt, + "SELECT users.id, users.name, addresses.id AS id_1, " + "addresses.user_id, " + "addresses.email_address " + "FROM addresses JOIN users ON users.id = addresses.user_id", + ) + + def test_column_properties(self, column_property_fixture): + """test querying mappings that reference external columns or + selectables.""" + + User, Address = column_property_fixture + + stmt = select(User) + + self.assert_compile( + stmt, + "SELECT users.id * :id_1 AS anon_1, " + "(SELECT count(addresses.id) AS count_1 FROM addresses " + "WHERE users.id = addresses.user_id) AS anon_2, users.id, " + "users.name FROM users", + checkparams={"id_1": 2}, + ) + + def test_column_properties_can_we_use(self, column_property_fixture): + """test querying mappings that reference external columns or + selectables. """ + + # User, Address = column_property_fixture + + # stmt = select(User) + + # TODO: shouldn't we be able to get at count ? + + # stmt = stmt.where(stmt.selected_columns.count > 5) + + # self.assert_compile(stmt, "") + + def test_column_properties_subquery(self, column_property_fixture): + """test querying mappings that reference external columns or + selectables.""" + + User, Address = column_property_fixture + + stmt = select(User) + + # here, the subquery needs to export the columns that include + # the column properties + stmt = select(stmt.subquery()) + + # TODO: shouldnt we be able to get to stmt.subquery().c.count ? + self.assert_compile( + stmt, + "SELECT anon_2.anon_1, anon_2.anon_3, anon_2.id, anon_2.name " + "FROM (SELECT users.id * :id_1 AS anon_1, " + "(SELECT count(addresses.id) AS count_1 FROM addresses " + "WHERE users.id = addresses.user_id) AS anon_3, users.id AS id, " + "users.name AS name FROM users) AS anon_2", + checkparams={"id_1": 2}, + ) + + def test_column_properties_subquery_two(self, column_property_fixture): + """test querying mappings that reference external columns or + selectables.""" + + User, Address = column_property_fixture + + # col properties will retain anonymous labels, however will + # adopt the .key within the subquery collection so they can + # be addressed. + stmt = select(User.id, User.name, User.concat, User.count,) + + subq = stmt.subquery() + # here, the subquery needs to export the columns that include + # the column properties + stmt = select(subq).where(subq.c.concat == "foo") + + self.assert_compile( + stmt, + "SELECT anon_1.id, anon_1.name, anon_1.anon_2, anon_1.anon_3 " + "FROM (SELECT users.id AS id, users.name AS name, " + "users.id * :id_1 AS anon_2, " + "(SELECT count(addresses.id) AS count_1 " + "FROM addresses WHERE users.id = addresses.user_id) AS anon_3 " + "FROM users) AS anon_1 WHERE anon_1.anon_2 = :param_1", + checkparams={"id_1": 2, "param_1": "foo"}, + ) + + def test_column_properties_aliased_subquery(self, column_property_fixture): + """test querying mappings that reference external columns or + selectables.""" + + User, Address = column_property_fixture + + u1 = aliased(User) + stmt = select(u1) + + # here, the subquery needs to export the columns that include + # the column properties + stmt = select(stmt.subquery()) + self.assert_compile( + stmt, + "SELECT anon_2.anon_1, anon_2.anon_3, anon_2.id, anon_2.name " + "FROM (SELECT users_1.id * :id_1 AS anon_1, " + "(SELECT count(addresses.id) AS count_1 FROM addresses " + "WHERE users_1.id = addresses.user_id) AS anon_3, " + "users_1.id AS id, users_1.name AS name " + "FROM users AS users_1) AS anon_2", + checkparams={"id_1": 2}, + ) + + class RelationshipNaturalCompileTest(QueryTest, AssertsCompiledSQL): """test using core join() with relationship attributes. @@ -193,7 +540,6 @@ class RelationshipNaturalCompileTest(QueryTest, AssertsCompiledSQL): __dialect__ = "default" - @testing.fails("need to have of_type() expressions render directly") def test_of_type_implicit_join(self): User, Address = self.classes("User", "Address") @@ -201,7 +547,12 @@ class RelationshipNaturalCompileTest(QueryTest, AssertsCompiledSQL): a1 = aliased(Address) stmt1 = select(u1).where(u1.addresses.of_type(a1)) - stmt2 = Session().query(u1).filter(u1.addresses.of_type(a1)) + stmt2 = ( + Session() + .query(u1) + .filter(u1.addresses.of_type(a1)) + ._final_statement(legacy_query_style=False) + ) expected = ( "SELECT users_1.id, users_1.name FROM users AS users_1, " @@ -260,6 +611,118 @@ class InheritedTest(_poly_fixtures._Polymorphic): run_setup_mappers = "once" +class ExplicitWithPolymorhpicTest( + _poly_fixtures._PolymorphicUnions, AssertsCompiledSQL +): + + __dialect__ = "default" + + default_punion = ( + "(SELECT pjoin.person_id AS person_id, " + "pjoin.company_id AS company_id, " + "pjoin.name AS name, pjoin.type AS type, " + "pjoin.status AS status, pjoin.engineer_name AS engineer_name, " + "pjoin.primary_language AS primary_language, " + "pjoin.manager_name AS manager_name " + "FROM (SELECT engineers.person_id AS person_id, " + "people.company_id AS company_id, people.name AS name, " + "people.type AS type, engineers.status AS status, " + "engineers.engineer_name AS engineer_name, " + "engineers.primary_language AS primary_language, " + "CAST(NULL AS VARCHAR(50)) AS manager_name " + "FROM people JOIN engineers ON people.person_id = engineers.person_id " + "UNION ALL SELECT managers.person_id AS person_id, " + "people.company_id AS company_id, people.name AS name, " + "people.type AS type, managers.status AS status, " + "CAST(NULL AS VARCHAR(50)) AS engineer_name, " + "CAST(NULL AS VARCHAR(50)) AS primary_language, " + "managers.manager_name AS manager_name FROM people " + "JOIN managers ON people.person_id = managers.person_id) AS pjoin) " + "AS anon_1" + ) + + def test_subquery_col_expressions_wpoly_one(self): + Person, Manager, Engineer = self.classes( + "Person", "Manager", "Engineer" + ) + + wp1 = with_polymorphic(Person, [Manager, Engineer]) + + subq1 = select(wp1).subquery() + + wp2 = with_polymorphic(Person, [Engineer, Manager]) + subq2 = select(wp2).subquery() + + # first thing we see, is that when we go through with_polymorphic, + # the entities that get placed into the aliased class go through + # Mapper._mappers_from_spec(), which matches them up to the + # existing Mapper.self_and_descendants collection, meaning, + # the order is the same every time. Assert here that's still + # happening. If a future internal change modifies this assumption, + # that's not necessarily bad, but it would change things. + + eq_( + subq1.c.keys(), + [ + "person_id", + "company_id", + "name", + "type", + "person_id_1", + "status", + "engineer_name", + "primary_language", + "person_id_1", + "status_1", + "manager_name", + ], + ) + eq_( + subq2.c.keys(), + [ + "person_id", + "company_id", + "name", + "type", + "person_id_1", + "status", + "engineer_name", + "primary_language", + "person_id_1", + "status_1", + "manager_name", + ], + ) + + def test_subquery_col_expressions_wpoly_two(self): + Person, Manager, Engineer = self.classes( + "Person", "Manager", "Engineer" + ) + + wp1 = with_polymorphic(Person, [Manager, Engineer]) + + subq1 = select(wp1).subquery() + + stmt = select(subq1).where( + or_( + subq1.c.engineer_name == "dilbert", + subq1.c.manager_name == "dogbert", + ) + ) + + self.assert_compile( + stmt, + "SELECT anon_1.person_id, anon_1.company_id, anon_1.name, " + "anon_1.type, anon_1.person_id AS person_id_1, anon_1.status, " + "anon_1.engineer_name, anon_1.primary_language, " + "anon_1.person_id AS person_id_2, anon_1.status AS status_1, " + "anon_1.manager_name FROM " + "%s WHERE " + "anon_1.engineer_name = :engineer_name_1 " + "OR anon_1.manager_name = :manager_name_1" % (self.default_punion), + ) + + class ImplicitWithPolymorphicTest( _poly_fixtures._PolymorphicUnions, AssertsCompiledSQL ): @@ -310,7 +773,9 @@ class ImplicitWithPolymorphicTest( ) self.assert_compile(stmt, expected) - self.assert_compile(q.statement, expected) + self.assert_compile( + q._final_statement(legacy_query_style=False), expected, + ) def test_select_where_baseclass(self): Person = self.classes.Person @@ -349,7 +814,9 @@ class ImplicitWithPolymorphicTest( ) self.assert_compile(stmt, expected) - self.assert_compile(q.statement, expected) + self.assert_compile( + q._final_statement(legacy_query_style=False), expected, + ) def test_select_where_subclass(self): @@ -397,7 +864,10 @@ class ImplicitWithPolymorphicTest( # in context.py self.assert_compile(stmt, disambiguate_expected) - self.assert_compile(q.statement, disambiguate_expected) + self.assert_compile( + q._final_statement(legacy_query_style=False), + disambiguate_expected, + ) def test_select_where_columns_subclass(self): @@ -436,7 +906,9 @@ class ImplicitWithPolymorphicTest( ) self.assert_compile(stmt, expected) - self.assert_compile(q.statement, expected) + self.assert_compile( + q._final_statement(legacy_query_style=False), expected, + ) class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): @@ -506,15 +978,14 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): orm_join(Company, Person, Company.employees) ) stmt2 = select(Company).join(Company.employees) - stmt3 = Session().query(Company).join(Company.employees).statement - - # TODO: can't get aliasing to not happen for .join() verion - self.assert_compile( - stmt1, - self.straight_company_to_person_expected.replace( - "pjoin_1", "pjoin" - ), + stmt3 = ( + Session() + .query(Company) + .join(Company.employees) + ._final_statement(legacy_query_style=False) ) + + self.assert_compile(stmt1, self.straight_company_to_person_expected) self.assert_compile(stmt2, self.straight_company_to_person_expected) self.assert_compile(stmt3, self.straight_company_to_person_expected) @@ -532,12 +1003,11 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): "Company", "Person", "Manager", "Engineer" ) - # TODO: fails - # stmt1 = ( - # select(Company) - # .select_from(orm_join(Company, Person, Company.employees)) - # .where(Person.name == "ed") - # ) + stmt1 = ( + select(Company) + .select_from(orm_join(Company, Person, Company.employees)) + .where(Person.name == "ed") + ) stmt2 = ( select(Company).join(Company.employees).where(Person.name == "ed") @@ -547,20 +1017,10 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): .query(Company) .join(Company.employees) .filter(Person.name == "ed") - .statement + ._final_statement(legacy_query_style=False) ) - # TODO: more inheriance woes, the first statement doesn't know that - # it loads polymorphically with Person. should we have mappers and - # ORM attributes return their polymorphic entity for - # __clause_element__() ? or should we know to look inside the - # orm_join and find all the entities that are important? it is - # looking like having ORM expressions use their polymoprhic selectable - # will solve a lot but not all of these problems. - - # self.assert_compile(stmt1, self.c_to_p_whereclause) - - # self.assert_compile(stmt1, self.c_to_p_whereclause) + self.assert_compile(stmt1, self.c_to_p_whereclause) self.assert_compile(stmt2, self.c_to_p_whereclause) self.assert_compile(stmt3, self.c_to_p_whereclause) @@ -581,16 +1041,12 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): .query(Company) .join(Company.employees) .join(Person.paperwork) - .statement + ._final_statement(legacy_query_style=False) ) self.assert_compile(stmt1, self.person_paperwork_expected) - self.assert_compile( - stmt2, self.person_paperwork_expected.replace("pjoin", "pjoin_1") - ) - self.assert_compile( - stmt3, self.person_paperwork_expected.replace("pjoin", "pjoin_1") - ) + self.assert_compile(stmt2, self.person_paperwork_expected) + self.assert_compile(stmt3, self.person_paperwork_expected) def test_wpoly_of_type(self): Company, Person, Manager, Engineer = self.classes( @@ -608,7 +1064,7 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): Session() .query(Company) .join(Company.employees.of_type(p1)) - .statement + ._final_statement(legacy_query_style=False) ) expected = ( "SELECT companies.company_id, companies.name " @@ -633,7 +1089,11 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): stmt2 = select(Company).join(p1, Company.employees.of_type(p1)) - stmt3 = s.query(Company).join(Company.employees.of_type(p1)).statement + stmt3 = ( + s.query(Company) + .join(Company.employees.of_type(p1)) + ._final_statement(legacy_query_style=False) + ) expected = ( "SELECT companies.company_id, companies.name FROM companies " @@ -661,7 +1121,7 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): Session() .query(Company) .join(Company.employees.of_type(p1)) - .statement + ._final_statement(legacy_query_style=False) ) expected = ( @@ -677,9 +1137,12 @@ class RelationshipNaturalInheritedTest(InheritedTest, AssertsCompiledSQL): class RelNaturalAliasedJoinsTest( _poly_fixtures._PolymorphicAliasedJoins, RelationshipNaturalInheritedTest ): + + # this is the label style for the polymorphic selectable, not the + # outside query + label_style = LABEL_STYLE_TABLENAME_PLUS_COL + straight_company_to_person_expected = ( - # TODO: would rather not have the aliasing here but can't fix - # that right now "SELECT companies.company_id, companies.name FROM companies " "JOIN (SELECT people.person_id AS people_person_id, people.company_id " "AS people_company_id, people.name AS people_name, people.type " @@ -691,8 +1154,8 @@ class RelNaturalAliasedJoinsTest( "managers.manager_name AS managers_manager_name FROM people " "LEFT OUTER JOIN engineers ON people.person_id = " "engineers.person_id LEFT OUTER JOIN managers ON people.person_id = " - "managers.person_id) AS pjoin_1 ON companies.company_id = " - "pjoin_1.people_company_id" + "managers.person_id) AS pjoin ON companies.company_id = " + "pjoin.people_company_id" ) person_paperwork_expected = ( @@ -768,8 +1231,8 @@ class RelNaturalAliasedJoinsTest( "FROM people LEFT OUTER JOIN engineers " "ON people.person_id = engineers.person_id " "LEFT OUTER JOIN managers ON people.person_id = managers.person_id) " - "AS pjoin_1 ON companies.company_id = pjoin_1.people_company_id " - "WHERE pjoin_1.people_name = :name_1" + "AS pjoin ON companies.company_id = pjoin.people_company_id " + "WHERE pjoin.people_name = :people_name_1" ) poly_columns = ( @@ -788,6 +1251,113 @@ class RelNaturalAliasedJoinsTest( ) +class RelNaturalAliasedJoinsDisamTest( + _poly_fixtures._PolymorphicAliasedJoins, RelationshipNaturalInheritedTest +): + # this is the label style for the polymorphic selectable, not the + # outside query + label_style = LABEL_STYLE_DISAMBIGUATE_ONLY + + straight_company_to_person_expected = ( + "SELECT companies.company_id, companies.name FROM companies JOIN " + "(SELECT people.person_id AS person_id, " + "people.company_id AS company_id, people.name AS name, " + "people.type AS type, engineers.person_id AS person_id_1, " + "engineers.status AS status, " + "engineers.engineer_name AS engineer_name, " + "engineers.primary_language AS primary_language, " + "managers.person_id AS person_id_2, managers.status AS status_1, " + "managers.manager_name AS manager_name FROM people " + "LEFT OUTER JOIN engineers ON people.person_id = engineers.person_id " + "LEFT OUTER JOIN managers ON people.person_id = managers.person_id) " + "AS pjoin ON companies.company_id = pjoin.company_id" + ) + + person_paperwork_expected = ( + "SELECT companies.company_id, companies.name FROM companies " + "JOIN (SELECT people.person_id AS person_id, people.company_id " + "AS company_id, people.name AS name, people.type AS type, " + "engineers.person_id AS person_id_1, engineers.status AS status, " + "engineers.engineer_name AS engineer_name, " + "engineers.primary_language AS primary_language, managers.person_id " + "AS person_id_2, managers.status AS status_1, managers.manager_name " + "AS manager_name FROM people LEFT OUTER JOIN engineers " + "ON people.person_id = engineers.person_id " + "LEFT OUTER JOIN managers ON people.person_id = managers.person_id) " + "AS pjoin ON companies.company_id = pjoin.company_id " + "JOIN paperwork ON pjoin.person_id = paperwork.person_id" + ) + + default_pjoin = ( + "(SELECT people.person_id AS person_id, people.company_id AS " + "company_id, people.name AS name, people.type AS type, " + "engineers.person_id AS person_id_1, engineers.status AS status, " + "engineers.engineer_name AS engineer_name, engineers.primary_language " + "AS primary_language, managers.person_id AS person_id_2, " + "managers.status AS status_1, managers.manager_name AS manager_name " + "FROM people LEFT OUTER JOIN engineers ON people.person_id = " + "engineers.person_id LEFT OUTER JOIN managers ON people.person_id = " + "managers.person_id) AS pjoin " + "ON companies.company_id = pjoin.company_id" + ) + flat_aliased_pjoin = ( + "(SELECT people.person_id AS person_id, people.company_id AS " + "company_id, people.name AS name, people.type AS type, " + "engineers.person_id AS person_id_1, engineers.status AS status, " + "engineers.engineer_name AS engineer_name, " + "engineers.primary_language AS primary_language, " + "managers.person_id AS person_id_2, managers.status AS status_1, " + "managers.manager_name AS manager_name FROM people " + "LEFT OUTER JOIN engineers ON people.person_id = engineers.person_id " + "LEFT OUTER JOIN managers ON people.person_id = managers.person_id) " + "AS pjoin_1 ON companies.company_id = pjoin_1.company_id" + ) + + aliased_pjoin = ( + "(SELECT people.person_id AS person_id, people.company_id AS " + "company_id, people.name AS name, people.type AS type, " + "engineers.person_id AS person_id_1, engineers.status AS status, " + "engineers.engineer_name AS engineer_name, engineers.primary_language " + "AS primary_language, managers.person_id AS person_id_2, " + "managers.status AS status_1, managers.manager_name AS manager_name " + "FROM people LEFT OUTER JOIN engineers ON people.person_id = " + "engineers.person_id LEFT OUTER JOIN managers ON people.person_id = " + "managers.person_id) AS pjoin_1 " + "ON companies.company_id = pjoin_1.company_id" + ) + + c_to_p_whereclause = ( + "SELECT companies.company_id, companies.name FROM companies JOIN " + "(SELECT people.person_id AS person_id, " + "people.company_id AS company_id, people.name AS name, " + "people.type AS type, engineers.person_id AS person_id_1, " + "engineers.status AS status, " + "engineers.engineer_name AS engineer_name, " + "engineers.primary_language AS primary_language, " + "managers.person_id AS person_id_2, managers.status AS status_1, " + "managers.manager_name AS manager_name FROM people " + "LEFT OUTER JOIN engineers ON people.person_id = engineers.person_id " + "LEFT OUTER JOIN managers ON people.person_id = managers.person_id) " + "AS pjoin ON companies.company_id = pjoin.company_id " + "WHERE pjoin.name = :name_1" + ) + + poly_columns = ( + "SELECT pjoin.person_id FROM (SELECT people.person_id AS " + "person_id, people.company_id AS company_id, people.name AS name, " + "people.type AS type, engineers.person_id AS person_id_1, " + "engineers.status AS status, " + "engineers.engineer_name AS engineer_name, " + "engineers.primary_language AS primary_language, " + "managers.person_id AS person_id_2, " + "managers.status AS status_1, managers.manager_name AS manager_name " + "FROM people LEFT OUTER JOIN engineers " + "ON people.person_id = engineers.person_id " + "LEFT OUTER JOIN managers " + "ON people.person_id = managers.person_id) AS pjoin" + ) + + class RawSelectTest(QueryTest, AssertsCompiledSQL): """older tests from test_query. Here, they are converted to use future selects with ORM compilation. @@ -808,7 +1378,12 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): User = self.classes.User stmt1 = select(User).where(User.addresses) - stmt2 = Session().query(User).filter(User.addresses).statement + stmt2 = ( + Session() + .query(User) + .filter(User.addresses) + ._final_statement(legacy_query_style=False) + ) expected = ( "SELECT users.id, users.name FROM users, addresses " @@ -829,7 +1404,12 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ) stmt1 = select(Item).where(Item.keywords) - stmt2 = Session().query(Item).filter(Item.keywords).statement + stmt2 = ( + Session() + .query(Item) + .filter(Item.keywords) + ._final_statement(legacy_query_style=False) + ) self.assert_compile(stmt1, expected) self.assert_compile(stmt2, expected) @@ -839,7 +1419,10 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): expected = "SELECT * FROM users" stmt1 = select(literal_column("*")).select_from(User) stmt2 = ( - Session().query(literal_column("*")).select_from(User).statement + Session() + .query(literal_column("*")) + .select_from(User) + ._final_statement(legacy_query_style=False) ) self.assert_compile(stmt1, expected) @@ -850,7 +1433,12 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ua = aliased(User, name="ua") stmt1 = select(literal_column("*")).select_from(ua) - stmt2 = Session().query(literal_column("*")).select_from(ua) + stmt2 = ( + Session() + .query(literal_column("*")) + .select_from(ua) + ._final_statement(legacy_query_style=False) + ) expected = "SELECT * FROM users AS ua" @@ -886,7 +1474,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): .correlate(User) .scalar_subquery(), ) - .statement + ._final_statement(legacy_query_style=False) ) self.assert_compile(stmt1, expected) @@ -916,7 +1504,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): .correlate(uu) .scalar_subquery(), ) - .statement + ._final_statement(legacy_query_style=False) ) expected = ( @@ -935,7 +1523,9 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): expected = "SELECT users.id, users.name FROM users" stmt1 = select(User) - stmt2 = Session().query(User).statement + stmt2 = ( + Session().query(User)._final_statement(legacy_query_style=False) + ) self.assert_compile(stmt1, expected) self.assert_compile(stmt2, expected) @@ -946,7 +1536,11 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): expected = "SELECT users.id, users.name FROM users" stmt1 = select(User.id, User.name) - stmt2 = Session().query(User.id, User.name).statement + stmt2 = ( + Session() + .query(User.id, User.name) + ._final_statement(legacy_query_style=False) + ) self.assert_compile(stmt1, expected) self.assert_compile(stmt2, expected) @@ -956,7 +1550,11 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ua = aliased(User, name="ua") stmt1 = select(ua.id, ua.name) - stmt2 = Session().query(ua.id, ua.name).statement + stmt2 = ( + Session() + .query(ua.id, ua.name) + ._final_statement(legacy_query_style=False) + ) expected = "SELECT ua.id, ua.name FROM users AS ua" self.assert_compile(stmt1, expected) @@ -967,7 +1565,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): ua = aliased(User, name="ua") stmt1 = select(ua) - stmt2 = Session().query(ua).statement + stmt2 = Session().query(ua)._final_statement(legacy_query_style=False) expected = "SELECT ua.id, ua.name FROM users AS ua" self.assert_compile(stmt1, expected) @@ -1081,7 +1679,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): .query(Foo) .filter(Foo.foob == "somename") .order_by(Foo.foob) - .statement + ._final_statement(legacy_query_style=False) ) expected = ( diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index ce687fdeec..1fde343d8c 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -2478,6 +2478,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): "email_address" ), ).group_by(Address.user_id) + ag1 = aliased(Address, agg_address.subquery()) ag2 = aliased(Address, agg_address.subquery(), adapt_on_names=True) diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py index 300670a701..fe3f2a7214 100644 --- a/test/orm/test_joins.py +++ b/test/orm/test_joins.py @@ -17,7 +17,6 @@ from sqlalchemy import true from sqlalchemy.engine import default from sqlalchemy.orm import aliased from sqlalchemy.orm import backref -from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import create_session from sqlalchemy.orm import join from sqlalchemy.orm import joinedload @@ -33,283 +32,15 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.schema import Column from test.orm import _fixtures +from .inheritance import _poly_fixtures +from .test_query import QueryTest -class QueryTest(_fixtures.FixtureTest): +class InheritedTest(_poly_fixtures._Polymorphic): run_setup_mappers = "once" - run_inserts = "once" - run_deletes = None - - @classmethod - def setup_mappers(cls): - ( - Node, - composite_pk_table, - users, - Keyword, - items, - Dingaling, - order_items, - item_keywords, - Item, - User, - dingalings, - Address, - keywords, - CompositePk, - nodes, - Order, - orders, - addresses, - ) = ( - cls.classes.Node, - cls.tables.composite_pk_table, - cls.tables.users, - cls.classes.Keyword, - cls.tables.items, - cls.classes.Dingaling, - cls.tables.order_items, - cls.tables.item_keywords, - cls.classes.Item, - cls.classes.User, - cls.tables.dingalings, - cls.classes.Address, - cls.tables.keywords, - cls.classes.CompositePk, - cls.tables.nodes, - cls.classes.Order, - cls.tables.orders, - cls.tables.addresses, - ) - - mapper( - User, - users, - properties={ - "addresses": relationship( - Address, backref="user", order_by=addresses.c.id - ), - # o2m, m2o - "orders": relationship( - Order, backref="user", order_by=orders.c.id - ), - }, - ) - mapper( - Address, - addresses, - properties={ - # o2o - "dingaling": relationship( - Dingaling, uselist=False, backref="address" - ) - }, - ) - mapper(Dingaling, dingalings) - mapper( - Order, - orders, - properties={ - # m2m - "items": relationship( - Item, secondary=order_items, order_by=items.c.id - ), - "address": relationship(Address), # m2o - }, - ) - mapper( - Item, - items, - properties={ - "keywords": relationship( - Keyword, secondary=item_keywords - ) # m2m - }, - ) - mapper(Keyword, keywords) - - mapper( - Node, - nodes, - properties={ - "children": relationship( - Node, backref=backref("parent", remote_side=[nodes.c.id]) - ) - }, - ) - - mapper(CompositePk, composite_pk_table) - - configure_mappers() - - -class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): - run_setup_mappers = "once" - - @classmethod - def define_tables(cls, metadata): - Table( - "companies", - metadata, - Column( - "company_id", - Integer, - primary_key=True, - test_needs_autoincrement=True, - ), - Column("name", String(50)), - ) - - Table( - "people", - metadata, - Column( - "person_id", - Integer, - primary_key=True, - test_needs_autoincrement=True, - ), - Column("company_id", Integer, ForeignKey("companies.company_id")), - Column("name", String(50)), - Column("type", String(30)), - ) - - Table( - "engineers", - metadata, - Column( - "person_id", - Integer, - ForeignKey("people.person_id"), - primary_key=True, - ), - Column("status", String(30)), - Column("engineer_name", String(50)), - Column("primary_language", String(50)), - ) - Table( - "machines", - metadata, - Column( - "machine_id", - Integer, - primary_key=True, - test_needs_autoincrement=True, - ), - Column("name", String(50)), - Column("engineer_id", Integer, ForeignKey("engineers.person_id")), - ) - - Table( - "managers", - metadata, - Column( - "person_id", - Integer, - ForeignKey("people.person_id"), - primary_key=True, - ), - Column("status", String(30)), - Column("manager_name", String(50)), - ) - - Table( - "boss", - metadata, - Column( - "boss_id", - Integer, - ForeignKey("managers.person_id"), - primary_key=True, - ), - Column("golf_swing", String(30)), - ) - - Table( - "paperwork", - metadata, - Column( - "paperwork_id", - Integer, - primary_key=True, - test_needs_autoincrement=True, - ), - Column("description", String(50)), - Column("person_id", Integer, ForeignKey("people.person_id")), - ) - - @classmethod - def setup_classes(cls): - paperwork, people, companies, boss, managers, machines, engineers = ( - cls.tables.paperwork, - cls.tables.people, - cls.tables.companies, - cls.tables.boss, - cls.tables.managers, - cls.tables.machines, - cls.tables.engineers, - ) - - class Company(cls.Comparable): - pass - - class Person(cls.Comparable): - pass - - class Engineer(Person): - pass - - class Manager(Person): - pass - - class Boss(Manager): - pass - - class Machine(cls.Comparable): - pass - - class Paperwork(cls.Comparable): - pass - - mapper( - Company, - companies, - properties={ - "employees": relationship(Person, order_by=people.c.person_id) - }, - ) - - mapper(Machine, machines) - - mapper( - Person, - people, - polymorphic_on=people.c.type, - polymorphic_identity="person", - properties={ - "paperwork": relationship( - Paperwork, order_by=paperwork.c.paperwork_id - ) - }, - ) - mapper( - Engineer, - engineers, - inherits=Person, - polymorphic_identity="engineer", - properties={ - "machines": relationship( - Machine, order_by=machines.c.machine_id - ) - }, - ) - mapper( - Manager, managers, inherits=Person, polymorphic_identity="manager" - ) - mapper(Boss, boss, inherits=Manager, polymorphic_identity="boss") - mapper(Paperwork, paperwork) +class InheritedJoinTest(InheritedTest, AssertsCompiledSQL): def test_single_prop(self): Company = self.classes.Company diff --git a/test/orm/test_of_type.py b/test/orm/test_of_type.py index 82930f754f..daac38dc23 100644 --- a/test/orm/test_of_type.py +++ b/test/orm/test_of_type.py @@ -54,7 +54,7 @@ class _PolymorphicTestBase(object): def test_any_four(self): sess = Session() - any_ = Company.employees.of_type(Boss).any( + any_ = Company.employees.of_type(Manager).any( Manager.manager_name == "pointy" ) eq_(sess.query(Company).filter(any_).one(), self.c1) diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 76706b37b8..478fc71476 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -16,7 +16,6 @@ from sqlalchemy import exc as sa_exc from sqlalchemy import exists from sqlalchemy import ForeignKey from sqlalchemy import func -from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import literal @@ -553,227 +552,6 @@ class BindSensitiveStringifyTest(fixtures.TestBase): self._test(True, True, True, True) -class RawSelectTest(QueryTest, AssertsCompiledSQL): - __dialect__ = "default" - - def test_select_from_entity(self): - User = self.classes.User - - self.assert_compile( - select(["*"]).select_from(User), "SELECT * FROM users" - ) - - def test_where_relationship(self): - User = self.classes.User - - self.assert_compile( - select([User]).where(User.addresses), - "SELECT users.id, users.name FROM users, addresses " - "WHERE users.id = addresses.user_id", - ) - - def test_where_m2m_relationship(self): - Item = self.classes.Item - - self.assert_compile( - select([Item]).where(Item.keywords), - "SELECT items.id, items.description FROM items, " - "item_keywords AS item_keywords_1, keywords " - "WHERE items.id = item_keywords_1.item_id " - "AND keywords.id = item_keywords_1.keyword_id", - ) - - def test_inline_select_from_entity(self): - User = self.classes.User - - self.assert_compile( - select(["*"], from_obj=User), "SELECT * FROM users" - ) - - def test_select_from_aliased_entity(self): - User = self.classes.User - ua = aliased(User, name="ua") - self.assert_compile( - select(["*"]).select_from(ua), "SELECT * FROM users AS ua" - ) - - def test_correlate_entity(self): - User = self.classes.User - Address = self.classes.Address - - self.assert_compile( - select( - [ - User.name, - Address.id, - select([func.count(Address.id)]) - .where(User.id == Address.user_id) - .correlate(User) - .scalar_subquery(), - ] - ), - "SELECT users.name, addresses.id, " - "(SELECT count(addresses.id) AS count_1 " - "FROM addresses WHERE users.id = addresses.user_id) AS anon_1 " - "FROM users, addresses", - ) - - def test_correlate_aliased_entity(self): - User = self.classes.User - Address = self.classes.Address - uu = aliased(User, name="uu") - - self.assert_compile( - select( - [ - uu.name, - Address.id, - select([func.count(Address.id)]) - .where(uu.id == Address.user_id) - .correlate(uu) - .scalar_subquery(), - ] - ), - # for a long time, "uu.id = address.user_id" was reversed; - # this was resolved as of #2872 and had to do with - # InstrumentedAttribute.__eq__() taking precedence over - # QueryableAttribute.__eq__() - "SELECT uu.name, addresses.id, " - "(SELECT count(addresses.id) AS count_1 " - "FROM addresses WHERE uu.id = addresses.user_id) AS anon_1 " - "FROM users AS uu, addresses", - ) - - def test_columns_clause_entity(self): - User = self.classes.User - - self.assert_compile( - select([User]), "SELECT users.id, users.name FROM users" - ) - - def test_columns_clause_columns(self): - User = self.classes.User - - self.assert_compile( - select([User.id, User.name]), - "SELECT users.id, users.name FROM users", - ) - - def test_columns_clause_aliased_columns(self): - User = self.classes.User - ua = aliased(User, name="ua") - self.assert_compile( - select([ua.id, ua.name]), "SELECT ua.id, ua.name FROM users AS ua" - ) - - def test_columns_clause_aliased_entity(self): - User = self.classes.User - ua = aliased(User, name="ua") - self.assert_compile( - select([ua]), "SELECT ua.id, ua.name FROM users AS ua" - ) - - def test_core_join(self): - User = self.classes.User - Address = self.classes.Address - from sqlalchemy.sql import join - - self.assert_compile( - select([User]).select_from(join(User, Address)), - "SELECT users.id, users.name FROM users " - "JOIN addresses ON users.id = addresses.user_id", - ) - - def test_insert_from_query(self): - User = self.classes.User - Address = self.classes.Address - - s = Session() - q = s.query(User.id, User.name).filter_by(name="ed") - self.assert_compile( - insert(Address).from_select(("id", "email_address"), q), - "INSERT INTO addresses (id, email_address) " - "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.name = :name_1", - ) - - def test_insert_from_query_col_attr(self): - User = self.classes.User - Address = self.classes.Address - - s = Session() - q = s.query(User.id, User.name).filter_by(name="ed") - self.assert_compile( - insert(Address).from_select( - (Address.id, Address.email_address), q - ), - "INSERT INTO addresses (id, email_address) " - "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.name = :name_1", - ) - - def test_update_from_entity(self): - from sqlalchemy.sql import update - - User = self.classes.User - self.assert_compile( - update(User), "UPDATE users SET id=:id, name=:name" - ) - - self.assert_compile( - update(User).values(name="ed").where(User.id == 5), - "UPDATE users SET name=:name WHERE users.id = :id_1", - checkparams={"id_1": 5, "name": "ed"}, - ) - - def test_delete_from_entity(self): - from sqlalchemy.sql import delete - - User = self.classes.User - self.assert_compile(delete(User), "DELETE FROM users") - - self.assert_compile( - delete(User).where(User.id == 5), - "DELETE FROM users WHERE users.id = :id_1", - checkparams={"id_1": 5}, - ) - - def test_insert_from_entity(self): - from sqlalchemy.sql import insert - - User = self.classes.User - self.assert_compile( - insert(User), "INSERT INTO users (id, name) VALUES (:id, :name)" - ) - - self.assert_compile( - insert(User).values(name="ed"), - "INSERT INTO users (name) VALUES (:name)", - checkparams={"name": "ed"}, - ) - - def test_col_prop_builtin_function(self): - class Foo(object): - pass - - mapper( - Foo, - self.tables.users, - properties={ - "foob": column_property( - func.coalesce(self.tables.users.c.name) - ) - }, - ) - - self.assert_compile( - select([Foo]).where(Foo.foob == "somename").order_by(Foo.foob), - "SELECT users.id, users.name FROM users " - "WHERE coalesce(users.name) = :param_1 " - "ORDER BY coalesce(users.name)", - ) - - class GetTest(QueryTest): def test_get_composite_pk_keyword_based_no_result(self): CompositePk = self.classes.CompositePk diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py index 5e3f516061..1d98826782 100644 --- a/test/orm/test_utils.py +++ b/test/orm/test_utils.py @@ -220,7 +220,6 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): "parententity": point_mapper, "parentmapper": point_mapper, "orm_key": "x_alone", - "compile_state_plugin": "orm", }, ) eq_( @@ -230,7 +229,6 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): "parententity": point_mapper, "parentmapper": point_mapper, "orm_key": "x", - "compile_state_plugin": "orm", }, ) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index d3d21cb0e3..2d84ab6764 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -661,6 +661,56 @@ class CoreFixtures(object): fixtures.append(_statements_w_context_options_fixtures) + def _statements_w_anonymous_col_names(): + def one(): + c = column("q") + + l = c.label(None) + + # new case as of Id810f485c5f7ed971529489b84694e02a3356d6d + subq = select([l]).subquery() + + # this creates a ColumnClause as a proxy to the Label() that has + # an anoymous name, so the column has one too. + anon_col = subq.c[0] + + # then when BindParameter is created, it checks the label + # and doesn't double up on the anonymous name which is uncachable + return anon_col > 5 + + def two(): + c = column("p") + + l = c.label(None) + + # new case as of Id810f485c5f7ed971529489b84694e02a3356d6d + subq = select([l]).subquery() + + # this creates a ColumnClause as a proxy to the Label() that has + # an anoymous name, so the column has one too. + anon_col = subq.c[0] + + # then when BindParameter is created, it checks the label + # and doesn't double up on the anonymous name which is uncachable + return anon_col > 5 + + def three(): + + l1, l2 = table_a.c.a.label(None), table_a.c.b.label(None) + + stmt = select([table_a.c.a, table_a.c.b, l1, l2]) + + subq = stmt.subquery() + return select([subq]).where(subq.c[2] == 10) + + return ( + one(), + two(), + three(), + ) + + fixtures.append(_statements_w_anonymous_col_names) + class CacheKeyFixture(object): def _run_cache_key_fixture(self, fixture, compare_values): diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 20f31ba1e2..4c87c0a469 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -71,6 +71,7 @@ from sqlalchemy.engine import default from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import column from sqlalchemy.sql import compiler +from sqlalchemy.sql import elements from sqlalchemy.sql import label from sqlalchemy.sql import operators from sqlalchemy.sql import table @@ -3294,6 +3295,29 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): checkparams={"3foo_1": "foo", "4_foo_1": "bar"}, ) + def test_bind_given_anon_name_dont_double(self): + c = column("id") + l = c.label(None) + + # new case as of Id810f485c5f7ed971529489b84694e02a3356d6d + subq = select([l]).subquery() + + # this creates a ColumnClause as a proxy to the Label() that has + # an anoymous name, so the column has one too. + anon_col = subq.c[0] + assert isinstance(anon_col.name, elements._anonymous_label) + + # then when BindParameter is created, it checks the label + # and doesn't double up on the anonymous name which is uncachable + expr = anon_col > 5 + + self.assert_compile( + expr, "anon_1.id_1 > :param_1", checkparams={"param_1": 5} + ) + + # see also test_compare.py -> _statements_w_anonymous_col_names + # fixture for cache key + def test_bind_as_col(self): t = table("foo", column("id")) diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index e509c9f95e..d53ee33853 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -147,6 +147,66 @@ class SelectableTest( assert s1.corresponding_column(scalar_select) is s1.c.foo assert s2.corresponding_column(scalar_select) is s2.c.foo + def test_labels_name_w_separate_key(self): + label = select([table1.c.col1]).label("foo") + label.key = "bar" + + s1 = select([label]) + assert s1.corresponding_column(label) is s1.selected_columns.bar + + # renders as foo + self.assert_compile( + s1, "SELECT (SELECT table1.col1 FROM table1) AS foo" + ) + + def test_labels_anon_w_separate_key(self): + label = select([table1.c.col1]).label(None) + label.key = "bar" + + s1 = select([label]) + + # .bar is there + assert s1.corresponding_column(label) is s1.selected_columns.bar + + # renders as anon_1 + self.assert_compile( + s1, "SELECT (SELECT table1.col1 FROM table1) AS anon_1" + ) + + def test_labels_anon_w_separate_key_subquery(self): + label = select([table1.c.col1]).label(None) + label.key = label._key_label = "bar" + + s1 = select([label]) + + subq = s1.subquery() + + s2 = select([subq]).where(subq.c.bar > 5) + self.assert_compile( + s2, + "SELECT anon_2.anon_1 FROM (SELECT (SELECT table1.col1 " + "FROM table1) AS anon_1) AS anon_2 " + "WHERE anon_2.anon_1 > :param_1", + checkparams={"param_1": 5}, + ) + + def test_labels_anon_generate_binds_subquery(self): + label = select([table1.c.col1]).label(None) + label.key = label._key_label = "bar" + + s1 = select([label]) + + subq = s1.subquery() + + s2 = select([subq]).where(subq.c[0] > 5) + self.assert_compile( + s2, + "SELECT anon_2.anon_1 FROM (SELECT (SELECT table1.col1 " + "FROM table1) AS anon_1) AS anon_2 " + "WHERE anon_2.anon_1 > :param_1", + checkparams={"param_1": 5}, + ) + def test_select_label_grouped_still_corresponds(self): label = select([table1.c.col1]).label("foo") label2 = label.self_group() diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py index 4e713dd286..d68a744753 100644 --- a/test/sql/test_utils.py +++ b/test/sql/test_utils.py @@ -4,8 +4,11 @@ from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy.sql import base as sql_base from sqlalchemy.sql import util as sql_util from sqlalchemy.sql.elements import ColumnElement +from sqlalchemy.testing import assert_raises +from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures @@ -48,3 +51,41 @@ class MiscTest(fixtures.TestBase): set(sql_util.find_tables(subset_select, include_aliases=True)), {common, calias, subset_select}, ) + + def test_options_merge(self): + class opt1(sql_base.CacheableOptions): + _cache_key_traversal = [] + + class opt2(sql_base.CacheableOptions): + _cache_key_traversal = [] + + foo = "bar" + + class opt3(sql_base.CacheableOptions): + _cache_key_traversal = [] + + foo = "bar" + bat = "hi" + + o2 = opt2.safe_merge(opt1) + eq_(o2.__dict__, {}) + eq_(o2.foo, "bar") + + assert_raises_message( + TypeError, + r"other element .*opt2.* is not empty, is not of type .*opt1.*, " + r"and contains attributes not covered here .*'foo'.*", + opt1.safe_merge, + opt2, + ) + + o2 = opt2 + {"foo": "bat"} + o3 = opt2.safe_merge(o2) + + eq_(o3.foo, "bat") + + o4 = opt3.safe_merge(o2) + eq_(o4.foo, "bat") + eq_(o4.bat, "hi") + + assert_raises(TypeError, opt2.safe_merge, o4)