From: Mike Bayer Date: Fri, 27 Sep 2019 21:32:10 +0000 (-0400) Subject: Simplify _ColumnEntity, related X-Git-Tag: rel_1_4_0b1~713^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6ddb62a8ba66b19afd41b967911ce5982250856e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Simplify _ColumnEntity, related In the interests of making Query much more lightweight up front, rework the calculations done at the top when the entities are constructed to be much less inolved. Use the new coercion system for _ColumnEntity and stop accepting plain strings, this will need to emit a deprecation warning in 1.3.x. Use annotations and other techniques to reduce the decisionmaking and complexity of Query. For the use case of subquery(), .statement, etc. we would like to do minimal work in order to get the columns clause. Change-Id: I7e459bbd3bb10ec71235f75ef4f3b0a969bec590 --- diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 9d404e00d4..117dd4cea4 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -168,18 +168,18 @@ class QueryableAttribute( """ return inspection.inspect(self._parententity) - @property + @util.memoized_property def expression(self): - return self.comparator.__clause_element__() - - def __clause_element__(self): - return self.comparator.__clause_element__() + return self.comparator.__clause_element__()._annotate( + {"orm_key": self.key} + ) - def _query_clause_element(self): - """like __clause_element__(), but called specifically - by :class:`.Query` to allow special behavior.""" + @property + def _annotations(self): + return self.__clause_element__()._annotations - return self.comparator._query_clause_element() + def __clause_element__(self): + return self.expression def _bulk_update_tuples(self, value): """Return setter tuples for a bulk UPDATE.""" @@ -207,7 +207,7 @@ class QueryableAttribute( ) def label(self, name): - return self._query_clause_element().label(name) + return self.__clause_element__().label(name) def operate(self, op, *other, **kwargs): return op(self.comparator, *other, **kwargs) diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 28b3bc5db3..075638fed9 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -413,19 +413,26 @@ class CompositeProperty(DescriptorProperty): __hash__ = None - @property + @util.memoized_property def clauses(self): - return self.__clause_element__() - - def __clause_element__(self): return expression.ClauseList( group=False, *self._comparable_elements ) - def _query_clause_element(self): - return CompositeProperty.CompositeBundle( - self.prop, self.__clause_element__() + def __clause_element__(self): + return self.expression + + @util.memoized_property + def expression(self): + clauses = self.clauses._annotate( + { + "bundle": True, + "parententity": self._parententity, + "parentmapper": self._parententity, + "orm_key": self.prop.key, + } ) + return CompositeProperty.CompositeBundle(self.prop, clauses) def _bulk_update_tuples(self, value): if value is None: diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 5098a55ce3..d6bdfb9247 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -363,6 +363,7 @@ class PropComparator(operators.ColumnOperators): __slots__ = "prop", "property", "_parententity", "_adapt_to_entity" def __init__(self, prop, parentmapper, adapt_to_entity=None): + # type: (MapperProperty, Mapper, Optional(AliasedInsp)) self.prop = self.property = prop self._parententity = adapt_to_entity or parentmapper self._adapt_to_entity = adapt_to_entity @@ -370,10 +371,15 @@ class PropComparator(operators.ColumnOperators): def __clause_element__(self): raise NotImplementedError("%r" % self) - def _query_clause_element(self): - return self.__clause_element__() - def _bulk_update_tuples(self, value): + # type: (ColumnOperators) -> List[tuple[ColumnOperators, Any]] + """Receive a SQL expression that represents a value in the SET + clause of an UPDATE statement. + + Return a tuple that can be passed to a :class:`.Update` construct. + + """ + return [(self.__clause_element__(), value)] def adapt_to_entity(self, adapt_to_entity): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index e2c10e50aa..f804d6eed4 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -292,7 +292,7 @@ class ColumnProperty(StrategizedProperty): def _memoized_method___clause_element__(self): if self.adapter: - return self.adapter(self.prop.columns[0]) + return self.adapter(self.prop.columns[0], self.prop.key) else: # no adapter, so we aren't aliased # assert self._parententity is self._parentmapper @@ -300,6 +300,7 @@ class ColumnProperty(StrategizedProperty): { "parententity": self._parententity, "parentmapper": self._parententity, + "orm_key": self.prop.key, } ) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 37bd77f636..3d08dce22c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -114,7 +114,6 @@ class Query(Generative): _from_obj = () _join_entities = () _select_from_entity = None - _mapper_adapter_map = {} _filter_aliases = () _from_obj_alias = None _joinpath = _joinpoint = util.immutabledict() @@ -177,61 +176,23 @@ class Query(Generative): self._primary_entity = None self._has_mapper_entities = False - # 1. don't run util.to_list() or _set_entity_selectables - # if no entities were passed - major performance bottleneck - # from lazy loader implementation when it seeks to use Query - # class for an identity lookup, causes test_orm.py to fail - # with thousands of extra function calls, see issue #4228 - # for why this use had to be added - # 2. can't use classmethod on Query because session.query_cls - # is an arbitrary callable in some user recipes, not - # necessarily a class, so we don't have the class available. - # see issue #4256 - # 3. can't do "if entities is not None" because we usually get here - # from session.query() which takes in *entities. - # 4. can't do "if entities" because users make use of undocumented - # to_list() behavior here and they pass clause expressions that - # can't be evaluated as boolean. See issue #4269. - # 5. the empty tuple is a singleton in cPython, take advantage of this - # so that we can skip for the empty "*entities" case without using - # any Python overloadable operators. - # if entities is not (): for ent in util.to_list(entities): entity_wrapper(self, ent) - self._set_entity_selectables(self._entities) - - def _set_entity_selectables(self, entities): - self._mapper_adapter_map = d = self._mapper_adapter_map.copy() - - for ent in entities: - for entity in ent.entities: - if entity not in d: - ext_info = inspect(entity) - if ( - not ext_info.is_aliased_class - and ext_info.mapper.with_polymorphic - ): - if ( - ext_info.mapper.persist_selectable - not in self._polymorphic_adapters - ): - self._mapper_loads_polymorphically_with( - ext_info.mapper, - sql_util.ColumnAdapter( - ext_info.selectable, - ext_info.mapper._equivalent_columns, - ), - ) - aliased_adapter = None - elif ext_info.is_aliased_class: - aliased_adapter = ext_info._adapter - else: - aliased_adapter = None - - d[entity] = (ext_info, aliased_adapter) - ent.setup_entity(*d[entity]) + def _setup_query_adapters(self, entity, ext_info): + if not ext_info.is_aliased_class and ext_info.mapper.with_polymorphic: + if ( + ext_info.mapper.persist_selectable + not in self._polymorphic_adapters + ): + self._mapper_loads_polymorphically_with( + ext_info.mapper, + sql_util.ColumnAdapter( + ext_info.selectable, + ext_info.mapper._equivalent_columns, + ), + ) def _mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers or [mapper]: @@ -1162,8 +1123,7 @@ class Query(Generative): entity = aliased(entity, alias) self._entities = list(self._entities) - m = _MapperEntity(self, entity) - self._set_entity_selectables([m]) + _MapperEntity(self, entity) @_generative def with_session(self, session): @@ -1455,12 +1415,9 @@ class Query(Generative): of result columns to be returned.""" self._entities = list(self._entities) - l = len(self._entities) + for c in column: _ColumnEntity(self, c) - # _ColumnEntity may add many entities if the - # given arg is a FROM clause - self._set_entity_selectables(self._entities[l:]) @util.pending_deprecation( "0.7", @@ -2464,9 +2421,13 @@ class Query(Generative): ) else: # add a new element to the self._from_obj list - if use_entity_index is not None: - # why doesn't this work as .entity_zero_or_selectable? + # make use of _MapperEntity selectable, which is usually + # entity_zero.selectable, but if with_polymorphic() were used + # might be distinct + assert isinstance( + self._entities[use_entity_index], _MapperEntity + ) left_clause = self._entities[use_entity_index].selectable else: left_clause = left @@ -3529,7 +3490,7 @@ class Query(Generative): # we get just "SELECT 1" without any entities. return sql.exists( self.enable_eagerloads(False) - .add_columns("1") + .add_columns(sql.literal_column("1")) .with_labels() .statement.with_only_columns([1]) ) @@ -4029,10 +3990,10 @@ class Query(Generative): """ - search = set(self._mapper_adapter_map.values()) + search = set(context.single_inh_entities.values()) if ( self._select_from_entity - and self._select_from_entity not in self._mapper_adapter_map + and self._select_from_entity not in context.single_inh_entities ): insp = inspect(self._select_from_entity) if insp.is_aliased_class: @@ -4110,23 +4071,27 @@ class _MapperEntity(_QueryEntity): self.entities = [entity] self.expr = entity - supports_single_entity = True - - use_id_for_hash = True + ext_info = self.entity_zero = inspect(entity) - def setup_entity(self, ext_info, aliased_adapter): self.mapper = ext_info.mapper - self.aliased_adapter = aliased_adapter + + if ext_info.is_aliased_class: + self._label_name = ext_info.name + else: + self._label_name = self.mapper.class_.__name__ + self.selectable = ext_info.selectable self.is_aliased_class = ext_info.is_aliased_class self._with_polymorphic = ext_info.with_polymorphic_mappers self._polymorphic_discriminator = ext_info.polymorphic_on - self.entity_zero = ext_info - if ext_info.is_aliased_class: - self._label_name = self.entity_zero.name - else: - self._label_name = self.mapper.class_.__name__ - self.path = self.entity_zero._path_registry + self.path = ext_info._path_registry + + if ext_info.mapper.with_polymorphic: + query._setup_query_adapters(entity, ext_info) + + supports_single_entity = True + + use_id_for_hash = True def set_with_polymorphic( self, query, cls_or_mappers, selectable, polymorphic_on @@ -4185,7 +4150,7 @@ class _MapperEntity(_QueryEntity): if query._polymorphic_adapters: adapter = query._polymorphic_adapters.get(self.mapper, None) else: - adapter = self.aliased_adapter + adapter = self.entity_zero._adapter if adapter: if query._from_obj_alias: @@ -4235,6 +4200,14 @@ class _MapperEntity(_QueryEntity): def setup_context(self, query, context): adapter = self._get_entity_clauses(query, context) + single_table_crit = self.mapper._single_table_criterion + if single_table_crit is not None: + ext_info = self.entity_zero + context.single_inh_entities[ext_info] = ( + ext_info, + ext_info._adapter if ext_info.is_aliased_class else None, + ) + # if self._adapted_selectable is None: context.froms += (self.selectable,) @@ -4352,7 +4325,9 @@ class Bundle(InspectionAttr): return cloned def __clause_element__(self): - return expression.ClauseList(group=False, *self.exprs) + return expression.ClauseList(group=False, *self.exprs)._annotate( + {"bundle": True} + ) @property def clauses(self): @@ -4386,8 +4361,19 @@ class Bundle(InspectionAttr): class _BundleEntity(_QueryEntity): use_id_for_hash = False - def __init__(self, query, bundle, setup_entities=True): - query._entities.append(self) + def __init__(self, query, expr, setup_entities=True, parent_bundle=None): + if parent_bundle: + parent_bundle._entities.append(self) + else: + query._entities.append(self) + + if isinstance( + expr, (attributes.QueryableAttribute, interfaces.PropComparator) + ): + bundle = expr.__clause_element__() + else: + bundle = expr + self.bundle = self.expr = bundle self.type = type(bundle) self._label_name = bundle.name @@ -4396,9 +4382,9 @@ class _BundleEntity(_QueryEntity): if setup_entities: for expr in bundle.exprs: if isinstance(expr, Bundle): - _BundleEntity(self, expr) + _BundleEntity(query, expr, parent_bundle=self) else: - _ColumnEntity(self, expr) + _ColumnEntity(query, expr, parent_bundle=self) self.supports_single_entity = self.bundle.single_entity @@ -4448,18 +4434,19 @@ class _BundleEntity(_QueryEntity): else: return None - def adapt_to_selectable(self, query, sel): - c = _BundleEntity(query, self.bundle, setup_entities=False) + def adapt_to_selectable(self, query, sel, parent_bundle=None): + c = _BundleEntity( + query, + self.bundle, + setup_entities=False, + parent_bundle=parent_bundle, + ) # c._label_name = self._label_name # c.entity_zero = self.entity_zero # c.entities = self.entities for ent in self._entities: - ent.adapt_to_selectable(c, sel) - - def setup_entity(self, ext_info, aliased_adapter): - for ent in self._entities: - ent.setup_entity(ext_info, aliased_adapter) + ent.adapt_to_selectable(query, sel, parent_bundle=c) def setup_context(self, query, context): for ent in self._entities: @@ -4481,76 +4468,52 @@ class _BundleEntity(_QueryEntity): class _ColumnEntity(_QueryEntity): """Column/expression based entity.""" - def __init__(self, query, column, namespace=None): - self.expr = column + froms = frozenset() + + def __init__(self, query, column, namespace=None, parent_bundle=None): + self.expr = expr = column self.namespace = namespace - search_entities = True - check_column = False - - if isinstance(column, util.string_types): - column = sql.literal_column(column) - self._label_name = column.name - search_entities = False - check_column = True - _entity = None - elif isinstance( - column, (attributes.QueryableAttribute, interfaces.PropComparator) - ): - _entity = getattr(column, "_parententity", None) - if _entity is not None: - search_entities = False - self._label_name = column.key - column = column._query_clause_element() - check_column = True - if isinstance(column, Bundle): - _BundleEntity(query, column) - return + _label_name = None - if not isinstance(column, sql.ColumnElement): - if hasattr(column, "_select_iterable"): - # break out an object like Table into - # individual columns - for c in column._select_iterable: - if c is column: - break - _ColumnEntity(query, c, namespace=column) - else: - return + column = coercions.expect(roles.ColumnsClauseRole, column) - raise sa_exc.InvalidRequestError( - "SQL expression, column, or mapped entity " - "expected - got '%r'" % (column,) - ) - elif not check_column: + annotations = column._annotations + + if annotations.get("bundle", False): + _BundleEntity(query, expr, parent_bundle=parent_bundle) + return + + orm_expr = False + + if "parententity" in annotations: + _entity = annotations["parententity"] + self._label_name = _label_name = annotations.get("orm_key", None) + orm_expr = True + + if hasattr(column, "_select_iterable"): + # break out an object like Table into + # individual columns + for c in column._select_iterable: + if c is column: + break + _ColumnEntity(query, c, namespace=column) + else: + return + + if _label_name is None: self._label_name = getattr(column, "key", None) - search_entities = True self.type = type_ = column.type self.use_id_for_hash = not type_.hashable - # If the Column is unnamed, give it a - # label() so that mutable column expressions - # can be located in the result even - # if the expression's identity has been changed - # due to adaption. - - if not column._label and not getattr(column, "is_literal", False): - column = column.label(self._label_name) - - query._entities.append(self) + if parent_bundle: + parent_bundle._entities.append(self) + else: + query._entities.append(self) self.column = column - self.froms = set() - - # look for ORM entities represented within the - # given expression. Try to count only entities - # for columns whose FROM object is in the actual list - # of FROMs for the overall expression - this helps - # subqueries which were built from ORM constructs from - # leaking out their entities into the main select construct - self.actual_froms = set(column._from_objects) - if not search_entities: + if orm_expr: self.entity_zero = _entity if _entity: self.entities = [_entity] @@ -4559,21 +4522,20 @@ class _ColumnEntity(_QueryEntity): self.entities = [] self.mapper = None else: - all_elements = [ - elem - for elem in sql_util.surface_column_elements( - column, include_scalar_selects=False - ) - if "parententity" in elem._annotations - ] - self.entities = util.unique_list( - [elem._annotations["parententity"] for elem in all_elements] + entity = sql_util.extract_first_column_annotation( + column, "parententity" ) + if entity: + self.entities = [entity] + else: + self.entities = [] + if self.entities: self.entity_zero = self.entities[0] self.mapper = self.entity_zero.mapper + elif self.namespace is not None: self.entity_zero = self.namespace self.mapper = None @@ -4581,6 +4543,9 @@ class _ColumnEntity(_QueryEntity): self.entity_zero = None self.mapper = None + if self.entities and self.entity_zero.mapper.with_polymorphic: + query._setup_query_adapters(self.entity_zero, self.entity_zero) + supports_single_entity = False def _deep_entity_zero(self): @@ -4603,24 +4568,21 @@ class _ColumnEntity(_QueryEntity): def entity_zero_or_selectable(self): if self.entity_zero is not None: return self.entity_zero - elif self.actual_froms: - return list(self.actual_froms)[0] + elif self.column._from_objects: + return self.column._from_objects[0] else: return None - def adapt_to_selectable(self, query, sel): - c = _ColumnEntity(query, sel.corresponding_column(self.column)) + def adapt_to_selectable(self, query, sel, parent_bundle=None): + c = _ColumnEntity( + query, + sel.corresponding_column(self.column), + parent_bundle=parent_bundle, + ) c._label_name = self._label_name c.entity_zero = self.entity_zero c.entities = self.entities - def setup_entity(self, ext_info, aliased_adapter): - if "selectable" not in self.__dict__: - self.selectable = ext_info.selectable - - if self.actual_froms.intersection(ext_info.selectable._from_objects): - self.froms.add(ext_info.selectable) - def corresponds_to(self, entity): if self.entity_zero is None: return False @@ -4651,13 +4613,32 @@ class _ColumnEntity(_QueryEntity): def setup_context(self, query, context): column = query._adapt_clause(self.column, False, True) + ezero = self.entity_zero + + if self.mapper: + single_table_crit = self.mapper._single_table_criterion + if single_table_crit is not None: + context.single_inh_entities[ezero] = ( + ezero, + ezero._adapter if ezero.is_aliased_class else None, + ) if column._annotations: # annotated columns perform more slowly in compiler and # result due to the __eq__() method, so use deannotated column = column._deannotate() - context.froms += tuple(self.froms) + if ezero is not None: + # use entity_zero as the from if we have it. this is necessary + # for polymorpic scenarios where our FROM is based on ORM entity, + # not the FROM of the column. but also, don't use it if our column + # doesn't actually have any FROMs that line up, such as when its + # a scalar subquery. + if set(self.column._from_objects).intersection( + ezero.selectable._from_objects + ): + context.froms += (ezero.selectable,) + context.primary_columns.append(column) context.attributes[("fetch_column", self)] = column @@ -4697,6 +4678,7 @@ class QueryContext(object): "partials", "post_load_paths", "identity_token", + "single_inh_entities", ) def __init__(self, query): @@ -4731,6 +4713,7 @@ class QueryContext(object): self.secondary_columns = [] self.eager_order_by = [] self.eager_joins = {} + self.single_inh_entities = {} self.create_eager_joins = [] self.propagate_options = set( o for o in query._with_options if o.propagate_to_loaders diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 63ec21099a..731947cbaf 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -2319,11 +2319,11 @@ class JoinCondition(object): """ self.primaryjoin = _deep_deannotate( - self.primaryjoin, values=("parententity",) + self.primaryjoin, values=("parententity", "orm_key") ) if self.secondaryjoin is not None: self.secondaryjoin = _deep_deannotate( - self.secondaryjoin, values=("parententity",) + self.secondaryjoin, values=("parententity", "orm_key") ) def _determine_joins(self): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 4b4fa40523..747ec7e658 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -668,10 +668,11 @@ class AliasedInsp(InspectionAttr): state["represents_outer_join"], ) - def _adapt_element(self, elem): - return self._adapter.traverse(elem)._annotate( - {"parententity": self, "parentmapper": self.mapper} - ) + def _adapt_element(self, elem, key=None): + d = {"parententity": self, "parentmapper": self.mapper} + if key: + d["orm_key"] = key + return self._adapter.traverse(elem)._annotate(d) def _entity_for_mapper(self, mapper): self_poly = self.with_polymorphic_mappers diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index a7a856bba9..95aee0468f 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -57,7 +57,7 @@ def expect(role, element, **kw): else: resolved = element - if issubclass(resolved.__class__, impl._role_class): + if impl._role_class in resolved.__class__.__mro__: if impl._post_coercion: resolved = impl._post_coercion(resolved, **kw) return resolved @@ -102,13 +102,16 @@ class RoleImpl(object): def _resolve_for_clause_element(self, element, argname=None, **kw): original_element = element - is_clause_element = False + is_clause_element = hasattr(element, "__clause_element__") - while hasattr(element, "__clause_element__") and not isinstance( - element, (elements.ClauseElement, schema.SchemaItem) - ): - element = element.__clause_element__() - is_clause_element = True + if is_clause_element: + while not isinstance( + element, (elements.ClauseElement, schema.SchemaItem) + ): + try: + element = element.__clause_element__() + except AttributeError: + break if not is_clause_element: if self._use_inspection: diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index fe83b163cd..3c7f904dea 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -364,23 +364,19 @@ def surface_selectables_only(clause): stack.append(elem.table) -def surface_column_elements(clause, include_scalar_selects=True): - """traverse and yield only outer-exposed column elements, such as would - be addressable in the WHERE clause of a SELECT if this element were - in the columns clause.""" +def extract_first_column_annotation(column, annotation_name): + filter_ = (FromGrouping, SelectBase) - filter_ = (FromGrouping,) - if not include_scalar_selects: - filter_ += (SelectBase,) - - stack = deque([clause]) + stack = deque([column]) while stack: elem = stack.popleft() - yield elem + if annotation_name in elem._annotations: + return elem._annotations[annotation_name] for sub in elem.get_children(): if isinstance(sub, filter_): continue stack.append(sub) + return None def selectables_overlap(left, right): diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 4e52a77789..632f559373 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -560,6 +560,10 @@ class QueryTest(fixtures.MappedTest): self._fixture() sess = Session() + # warm up cache + for attr in [Parent.data1, Parent.data2, Parent.data3, Parent.data4]: + attr.__clause_element__() + @profiling.function_call_count() def go(): for i in range(10): diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py index 7b8d413a46..d0db76b215 100644 --- a/test/orm/inheritance/test_single.py +++ b/test/orm/inheritance/test_single.py @@ -287,7 +287,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): self.assert_compile( sess.query(literal("1")).select_from(a1), - "SELECT :param_1 AS param_1 FROM employees AS employees_1 " + "SELECT :param_1 AS anon_1 FROM employees AS employees_1 " "WHERE employees_1.type IN (:type_1, :type_2)", ) diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index 7247c859a6..0d679e6db5 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -20,6 +20,8 @@ from sqlalchemy.testing.schema import Table class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): + __dialect__ = "default" + @classmethod def define_tables(cls, metadata): Table( @@ -311,6 +313,20 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): [(Point(3, 4), Point(5, 6))], ) + def test_cols_as_core_clauseelement(self): + Edge = self.classes.Edge + Point = self.classes.Point + + start, end = Edge.start, Edge.end + + stmt = select([start, end]).where(start == Point(3, 4)) + self.assert_compile( + stmt, + "SELECT edges.x1, edges.y1, edges.x2, edges.y2 " + "FROM edges WHERE edges.x1 = :x1_1 AND edges.y1 = :y1_1", + checkparams={"x1_1": 3, "y1_1": 4}, + ) + def test_query_cols_labeled(self): Edge = self.classes.Edge Point = self.classes.Point diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 498b680577..efa45affa0 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -7,7 +7,6 @@ from sqlalchemy import exc as sa_exc from sqlalchemy import exists from sqlalchemy import ForeignKey from sqlalchemy import func -from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import literal_column from sqlalchemy import select @@ -2211,7 +2210,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): sess.expunge_all() assert_raises( - sa_exc.InvalidRequestError, sess.query(User).add_column, object() + sa_exc.ArgumentError, sess.query(User).add_column, object() ) def test_add_multi_columns(self): @@ -2270,7 +2269,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): .order_by(User.id) ) q = sess.query(User) - result = q.add_column("count").from_statement(s).all() + result = q.add_column(s.selected_columns.count).from_statement(s).all() assert result == expected def test_raw_columns(self): @@ -2315,7 +2314,10 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): ) q = create_session().query(User) result = ( - q.add_column("count").add_column("concat").from_statement(s).all() + q.add_column(s.selected_columns.count) + .add_column(s.selected_columns.concat) + .from_statement(s) + .all() ) assert result == expected @@ -2399,7 +2401,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): ]: q = s.query(crit) mzero = q._entity_zero() - is_(mzero.persist_selectable, q._query_entity_zero().selectable) + is_(mzero, q._query_entity_zero().entity_zero) q = q.join(j) self.assert_compile(q, exp) @@ -2429,7 +2431,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): ]: q = s.query(crit) mzero = q._entity_zero() - is_(inspect(mzero).selectable, q._query_entity_zero().selectable) + is_(mzero, q._query_entity_zero().entity_zero) q = q.join(j) self.assert_compile(q, exp) diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 4dff6fe56d..bcd13e6e2a 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -1005,14 +1005,14 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): s = create_session() q = s.query(User) - assert_raises(sa_exc.InvalidRequestError, q.add_column, object()) + assert_raises(sa_exc.ArgumentError, q.add_column, object()) def test_invalid_column_tuple(self): User = self.classes.User s = create_session() q = s.query(User) - assert_raises(sa_exc.InvalidRequestError, q.add_column, (1, 1)) + assert_raises(sa_exc.ArgumentError, q.add_column, (1, 1)) def test_distinct(self): """test that a distinct() call is not valid before 'clauseelement' @@ -2449,6 +2449,9 @@ class ComparatorTest(QueryTest): def __clause_element__(self): return self.expr + # this use case isn't exactly needed in this form, however it tests + # that we resolve for multiple __clause_element__() calls as is needed + # by systems like composites sess = Session() eq_( sess.query(Comparator(User.id)) @@ -3398,11 +3401,11 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): q3, "SELECT anon_1.users_id AS anon_1_users_id, " "anon_1.users_name AS anon_1_users_name, " - "anon_1.param_1 AS anon_1_param_1 " - "FROM (SELECT users.id AS users_id, users.name AS " - "users_name, :param_1 AS param_1 " - "FROM users UNION SELECT users.id AS users_id, " - "users.name AS users_name, 'y' FROM users) AS anon_1", + "anon_1.anon_2 AS anon_1_anon_2 FROM " + "(SELECT users.id AS users_id, users.name AS users_name, " + ":param_1 AS anon_2 FROM users " + "UNION SELECT users.id AS users_id, users.name AS users_name, " + "'y' FROM users) AS anon_1", ) def test_union_literal_expressions_results(self): @@ -3410,7 +3413,8 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): s = Session() - q1 = s.query(User, literal("x")) + x_literal = literal("x") + q1 = s.query(User, x_literal) q2 = s.query(User, literal_column("'y'")) q3 = q1.union(q2) @@ -3421,7 +3425,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): eq_([x["name"] for x in q6.column_descriptions], ["User", "foo"]) for q in ( - q3.order_by(User.id, text("anon_1_param_1")), + q3.order_by(User.id, x_literal), q6.order_by(User.id, "foo"), ): eq_( @@ -4231,12 +4235,14 @@ class TextTest(QueryTest, AssertsCompiledSQL): User = self.classes.User s = create_session() - assert_raises( - sa_exc.InvalidRequestError, s.query, User.id, text("users.name") + + self.assert_compile( + s.query(User.id, text("users.name")), + "SELECT users.id AS users_id, users.name FROM users", ) eq_( - s.query(User.id, "name").order_by(User.id).all(), + s.query(User.id, literal_column("name")).order_by(User.id).all(), [(7, "jack"), (8, "ed"), (9, "fred"), (10, "chuck")], ) diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py index b32b6547fa..03e17d2917 100644 --- a/test/orm/test_subquery_relations.py +++ b/test/orm/test_subquery_relations.py @@ -3,6 +3,7 @@ from sqlalchemy import bindparam from sqlalchemy import ForeignKey from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import literal_column from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing @@ -792,7 +793,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): sess = create_session() self.assert_compile( - sess.query(User, "1"), + sess.query(User, literal_column("1")), "SELECT users.id AS users_id, users.name AS users_name, " "1 FROM users", ) diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py index e47fc3f267..4bc2a5c880 100644 --- a/test/orm/test_utils.py +++ b/test/orm/test_utils.py @@ -210,7 +210,24 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL): eq_(str(alias.x + 1), "point_1.x + :x_1") eq_(str(alias.x_alone + 1), "point_1.x + :x_1") - is_(Point.x_alone.__clause_element__(), Point.x.__clause_element__()) + point_mapper = inspect(Point) + + eq_( + Point.x_alone._annotations, + { + "parententity": point_mapper, + "parentmapper": point_mapper, + "orm_key": "x_alone", + }, + ) + eq_( + Point.x._annotations, + { + "parententity": point_mapper, + "parentmapper": point_mapper, + "orm_key": "x", + }, + ) eq_(str(alias.x_alone == alias.x), "point_1.x = point_1.x") diff --git a/test/profiles.txt b/test/profiles.txt index 0750ab7679..e4b99ba680 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -13,22 +13,22 @@ # TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert -test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mssql_pyodbc_dbapiunicode_cextensions 70,70 -test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mssql_pyodbc_dbapiunicode_nocextensions 70,70 -test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_mysqldb_dbapiunicode_cextensions 70,70,70,70 -test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_mysqldb_dbapiunicode_nocextensions 70,70,70,70 -test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_pymysql_dbapiunicode_cextensions 70,70 -test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_pymysql_dbapiunicode_nocextensions 70,70 -test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_oracle_cx_oracle_dbapiunicode_cextensions 70,70 -test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_oracle_cx_oracle_dbapiunicode_nocextensions 70,70 +test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mssql_pyodbc_dbapiunicode_cextensions 66 +test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mssql_pyodbc_dbapiunicode_nocextensions 66 +test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_mysqldb_dbapiunicode_cextensions 66 +test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_mysqldb_dbapiunicode_nocextensions 66 +test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_pymysql_dbapiunicode_cextensions 66 +test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_pymysql_dbapiunicode_nocextensions 66 +test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_oracle_cx_oracle_dbapiunicode_cextensions 66 +test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_oracle_cx_oracle_dbapiunicode_nocextensions 66 test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_postgresql_psycopg2_dbapiunicode_cextensions 67 test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 67 test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_sqlite_pysqlite_dbapiunicode_cextensions 67 test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 67 -test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_mysqldb_dbapiunicode_cextensions 73,73,73,73 -test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_mysqldb_dbapiunicode_nocextensions 73,73,73,73 -test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_pymysql_dbapiunicode_cextensions 73,73 -test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_pymysql_dbapiunicode_nocextensions 73,73 +test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_mysqldb_dbapiunicode_cextensions 73 +test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_mysqldb_dbapiunicode_nocextensions 73 +test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_pymysql_dbapiunicode_cextensions 73 +test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_pymysql_dbapiunicode_nocextensions 73 test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_oracle_cx_oracle_dbapiunicode_cextensions 73 test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_oracle_cx_oracle_dbapiunicode_nocextensions 73 test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_postgresql_psycopg2_dbapiunicode_cextensions 72 @@ -523,24 +523,14 @@ test.aaa_profiling.test_orm.MergeTest.test_merge_no_load 3.7_sqlite_pysqlite_dba # TEST: test.aaa_profiling.test_orm.QueryTest.test_query_cols -test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_mssql_pyodbc_dbapiunicode_cextensions 6230 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_mssql_pyodbc_dbapiunicode_nocextensions 6780 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_mysql_mysqldb_dbapiunicode_cextensions 6290 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_mysql_mysqldb_dbapiunicode_nocextensions 6840 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_oracle_cx_oracle_dbapiunicode_cextensions 6360 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_oracle_cx_oracle_dbapiunicode_nocextensions 8190 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6100 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6641 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_cextensions 6035 +test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_postgresql_psycopg2_dbapiunicode_cextensions 5900 +test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6441 +test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_cextensions 5900 test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 6585 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_mysql_mysqldb_dbapiunicode_cextensions 6483 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_mysql_mysqldb_dbapiunicode_nocextensions 7143 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_oracle_cx_oracle_dbapiunicode_cextensions 6473 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_oracle_cx_oracle_dbapiunicode_nocextensions 7043 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6464 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 7034 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_sqlite_pysqlite_dbapiunicode_cextensions 6326 -test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_sqlite_pysqlite_dbapiunicode_nocextensions 6806 +test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6138 +test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 6800 +test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_sqlite_pysqlite_dbapiunicode_cextensions 6226 +test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_sqlite_pysqlite_dbapiunicode_nocextensions 6506 # TEST: test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results