]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Simplify _ColumnEntity, related
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Sep 2019 21:32:10 +0000 (17:32 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 30 Sep 2019 14:10:58 +0000 (10:10 -0400)
In the interests of making Query much more lightweight up front,
rework the calculations done at the top when the entities
are constructed to be much less inolved.  Use the new
coercion system for _ColumnEntity and stop accepting
plain strings, this will need to emit a deprecation warning
in 1.3.x.     Use annotations and other techniques to reduce
the decisionmaking and complexity of Query.

For the use case of subquery(), .statement, etc. we would like
to do minimal work in order to get the columns clause.

Change-Id: I7e459bbd3bb10ec71235f75ef4f3b0a969bec590

17 files changed:
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/util.py
test/aaa_profiling/test_orm.py
test/orm/inheritance/test_single.py
test/orm/test_composites.py
test/orm/test_froms.py
test/orm/test_query.py
test/orm/test_subquery_relations.py
test/orm/test_utils.py
test/profiles.txt

index 9d404e00d43d057b2dd9cddbcbbea52872902b58..117dd4cea43d774bedb3fe8f9dbddc5b787020a0 100644 (file)
@@ -168,18 +168,18 @@ class QueryableAttribute(
         """
         return inspection.inspect(self._parententity)
 
-    @property
+    @util.memoized_property
     def expression(self):
-        return self.comparator.__clause_element__()
-
-    def __clause_element__(self):
-        return self.comparator.__clause_element__()
+        return self.comparator.__clause_element__()._annotate(
+            {"orm_key": self.key}
+        )
 
-    def _query_clause_element(self):
-        """like __clause_element__(), but called specifically
-        by :class:`.Query` to allow special behavior."""
+    @property
+    def _annotations(self):
+        return self.__clause_element__()._annotations
 
-        return self.comparator._query_clause_element()
+    def __clause_element__(self):
+        return self.expression
 
     def _bulk_update_tuples(self, value):
         """Return setter tuples for a bulk UPDATE."""
@@ -207,7 +207,7 @@ class QueryableAttribute(
         )
 
     def label(self, name):
-        return self._query_clause_element().label(name)
+        return self.__clause_element__().label(name)
 
     def operate(self, op, *other, **kwargs):
         return op(self.comparator, *other, **kwargs)
index 28b3bc5db3b959cceb9550ac0ea98bb4ec8fd3a5..075638fed9a68731fd7f382b512269c5a9d69d29 100644 (file)
@@ -413,19 +413,26 @@ class CompositeProperty(DescriptorProperty):
 
         __hash__ = None
 
-        @property
+        @util.memoized_property
         def clauses(self):
-            return self.__clause_element__()
-
-        def __clause_element__(self):
             return expression.ClauseList(
                 group=False, *self._comparable_elements
             )
 
-        def _query_clause_element(self):
-            return CompositeProperty.CompositeBundle(
-                self.prop, self.__clause_element__()
+        def __clause_element__(self):
+            return self.expression
+
+        @util.memoized_property
+        def expression(self):
+            clauses = self.clauses._annotate(
+                {
+                    "bundle": True,
+                    "parententity": self._parententity,
+                    "parentmapper": self._parententity,
+                    "orm_key": self.prop.key,
+                }
             )
+            return CompositeProperty.CompositeBundle(self.prop, clauses)
 
         def _bulk_update_tuples(self, value):
             if value is None:
index 5098a55ce3793065a28bfa1ebd2a31948dcac622..d6bdfb924789ca554d3d6263012f3f30b0266f94 100644 (file)
@@ -363,6 +363,7 @@ class PropComparator(operators.ColumnOperators):
     __slots__ = "prop", "property", "_parententity", "_adapt_to_entity"
 
     def __init__(self, prop, parentmapper, adapt_to_entity=None):
+        # type: (MapperProperty, Mapper, Optional(AliasedInsp))
         self.prop = self.property = prop
         self._parententity = adapt_to_entity or parentmapper
         self._adapt_to_entity = adapt_to_entity
@@ -370,10 +371,15 @@ class PropComparator(operators.ColumnOperators):
     def __clause_element__(self):
         raise NotImplementedError("%r" % self)
 
-    def _query_clause_element(self):
-        return self.__clause_element__()
-
     def _bulk_update_tuples(self, value):
+        # type: (ColumnOperators) -> List[tuple[ColumnOperators, Any]]
+        """Receive a SQL expression that represents a value in the SET
+        clause of an UPDATE statement.
+
+        Return a tuple that can be passed to a :class:`.Update` construct.
+
+        """
+
         return [(self.__clause_element__(), value)]
 
     def adapt_to_entity(self, adapt_to_entity):
index e2c10e50aa964e985a41a1cfcea3054522688e26..f804d6eed4831751877acb484eb2e2f553979cf7 100644 (file)
@@ -292,7 +292,7 @@ class ColumnProperty(StrategizedProperty):
 
         def _memoized_method___clause_element__(self):
             if self.adapter:
-                return self.adapter(self.prop.columns[0])
+                return self.adapter(self.prop.columns[0], self.prop.key)
             else:
                 # no adapter, so we aren't aliased
                 # assert self._parententity is self._parentmapper
@@ -300,6 +300,7 @@ class ColumnProperty(StrategizedProperty):
                     {
                         "parententity": self._parententity,
                         "parentmapper": self._parententity,
+                        "orm_key": self.prop.key,
                     }
                 )
 
index 37bd77f6361cfe2f3726f991d4bf9a2409262ca9..3d08dce22c4a419f47e106acbaf1dc021317b37b 100644 (file)
@@ -114,7 +114,6 @@ class Query(Generative):
     _from_obj = ()
     _join_entities = ()
     _select_from_entity = None
-    _mapper_adapter_map = {}
     _filter_aliases = ()
     _from_obj_alias = None
     _joinpath = _joinpoint = util.immutabledict()
@@ -177,61 +176,23 @@ class Query(Generative):
         self._primary_entity = None
         self._has_mapper_entities = False
 
-        # 1. don't run util.to_list() or _set_entity_selectables
-        #    if no entities were passed - major performance bottleneck
-        #    from lazy loader implementation when it seeks to use Query
-        #    class for an identity lookup, causes test_orm.py to fail
-        #    with thousands of extra function calls, see issue #4228
-        #    for why this use had to be added
-        # 2. can't use classmethod on Query because session.query_cls
-        #    is an arbitrary callable in some user recipes, not
-        #    necessarily a class, so we don't have the class available.
-        #    see issue #4256
-        # 3. can't do "if entities is not None" because we usually get here
-        #    from session.query() which takes in *entities.
-        # 4. can't do "if entities" because users make use of undocumented
-        #    to_list() behavior here and they pass clause expressions that
-        #    can't be evaluated as boolean.  See issue #4269.
-        # 5. the empty tuple is a singleton in cPython, take advantage of this
-        #    so that we can skip for the empty "*entities" case without using
-        #    any Python overloadable operators.
-        #
         if entities is not ():
             for ent in util.to_list(entities):
                 entity_wrapper(self, ent)
 
-            self._set_entity_selectables(self._entities)
-
-    def _set_entity_selectables(self, entities):
-        self._mapper_adapter_map = d = self._mapper_adapter_map.copy()
-
-        for ent in entities:
-            for entity in ent.entities:
-                if entity not in d:
-                    ext_info = inspect(entity)
-                    if (
-                        not ext_info.is_aliased_class
-                        and ext_info.mapper.with_polymorphic
-                    ):
-                        if (
-                            ext_info.mapper.persist_selectable
-                            not in self._polymorphic_adapters
-                        ):
-                            self._mapper_loads_polymorphically_with(
-                                ext_info.mapper,
-                                sql_util.ColumnAdapter(
-                                    ext_info.selectable,
-                                    ext_info.mapper._equivalent_columns,
-                                ),
-                            )
-                        aliased_adapter = None
-                    elif ext_info.is_aliased_class:
-                        aliased_adapter = ext_info._adapter
-                    else:
-                        aliased_adapter = None
-
-                    d[entity] = (ext_info, aliased_adapter)
-                ent.setup_entity(*d[entity])
+    def _setup_query_adapters(self, entity, ext_info):
+        if not ext_info.is_aliased_class and ext_info.mapper.with_polymorphic:
+            if (
+                ext_info.mapper.persist_selectable
+                not in self._polymorphic_adapters
+            ):
+                self._mapper_loads_polymorphically_with(
+                    ext_info.mapper,
+                    sql_util.ColumnAdapter(
+                        ext_info.selectable,
+                        ext_info.mapper._equivalent_columns,
+                    ),
+                )
 
     def _mapper_loads_polymorphically_with(self, mapper, adapter):
         for m2 in mapper._with_polymorphic_mappers or [mapper]:
@@ -1162,8 +1123,7 @@ class Query(Generative):
             entity = aliased(entity, alias)
 
         self._entities = list(self._entities)
-        m = _MapperEntity(self, entity)
-        self._set_entity_selectables([m])
+        _MapperEntity(self, entity)
 
     @_generative
     def with_session(self, session):
@@ -1455,12 +1415,9 @@ class Query(Generative):
         of result columns to be returned."""
 
         self._entities = list(self._entities)
-        l = len(self._entities)
+
         for c in column:
             _ColumnEntity(self, c)
-        # _ColumnEntity may add many entities if the
-        # given arg is a FROM clause
-        self._set_entity_selectables(self._entities[l:])
 
     @util.pending_deprecation(
         "0.7",
@@ -2464,9 +2421,13 @@ class Query(Generative):
             )
         else:
             # add a new element to the self._from_obj list
-
             if use_entity_index is not None:
-                # why doesn't this work as .entity_zero_or_selectable?
+                # make use of _MapperEntity selectable, which is usually
+                # entity_zero.selectable, but if with_polymorphic() were used
+                # might be distinct
+                assert isinstance(
+                    self._entities[use_entity_index], _MapperEntity
+                )
                 left_clause = self._entities[use_entity_index].selectable
             else:
                 left_clause = left
@@ -3529,7 +3490,7 @@ class Query(Generative):
         # we get just "SELECT 1" without any entities.
         return sql.exists(
             self.enable_eagerloads(False)
-            .add_columns("1")
+            .add_columns(sql.literal_column("1"))
             .with_labels()
             .statement.with_only_columns([1])
         )
@@ -4029,10 +3990,10 @@ class Query(Generative):
 
         """
 
-        search = set(self._mapper_adapter_map.values())
+        search = set(context.single_inh_entities.values())
         if (
             self._select_from_entity
-            and self._select_from_entity not in self._mapper_adapter_map
+            and self._select_from_entity not in context.single_inh_entities
         ):
             insp = inspect(self._select_from_entity)
             if insp.is_aliased_class:
@@ -4110,23 +4071,27 @@ class _MapperEntity(_QueryEntity):
         self.entities = [entity]
         self.expr = entity
 
-    supports_single_entity = True
-
-    use_id_for_hash = True
+        ext_info = self.entity_zero = inspect(entity)
 
-    def setup_entity(self, ext_info, aliased_adapter):
         self.mapper = ext_info.mapper
-        self.aliased_adapter = aliased_adapter
+
+        if ext_info.is_aliased_class:
+            self._label_name = ext_info.name
+        else:
+            self._label_name = self.mapper.class_.__name__
+
         self.selectable = ext_info.selectable
         self.is_aliased_class = ext_info.is_aliased_class
         self._with_polymorphic = ext_info.with_polymorphic_mappers
         self._polymorphic_discriminator = ext_info.polymorphic_on
-        self.entity_zero = ext_info
-        if ext_info.is_aliased_class:
-            self._label_name = self.entity_zero.name
-        else:
-            self._label_name = self.mapper.class_.__name__
-        self.path = self.entity_zero._path_registry
+        self.path = ext_info._path_registry
+
+        if ext_info.mapper.with_polymorphic:
+            query._setup_query_adapters(entity, ext_info)
+
+    supports_single_entity = True
+
+    use_id_for_hash = True
 
     def set_with_polymorphic(
         self, query, cls_or_mappers, selectable, polymorphic_on
@@ -4185,7 +4150,7 @@ class _MapperEntity(_QueryEntity):
             if query._polymorphic_adapters:
                 adapter = query._polymorphic_adapters.get(self.mapper, None)
         else:
-            adapter = self.aliased_adapter
+            adapter = self.entity_zero._adapter
 
         if adapter:
             if query._from_obj_alias:
@@ -4235,6 +4200,14 @@ class _MapperEntity(_QueryEntity):
     def setup_context(self, query, context):
         adapter = self._get_entity_clauses(query, context)
 
+        single_table_crit = self.mapper._single_table_criterion
+        if single_table_crit is not None:
+            ext_info = self.entity_zero
+            context.single_inh_entities[ext_info] = (
+                ext_info,
+                ext_info._adapter if ext_info.is_aliased_class else None,
+            )
+
         # if self._adapted_selectable is None:
         context.froms += (self.selectable,)
 
@@ -4352,7 +4325,9 @@ class Bundle(InspectionAttr):
         return cloned
 
     def __clause_element__(self):
-        return expression.ClauseList(group=False, *self.exprs)
+        return expression.ClauseList(group=False, *self.exprs)._annotate(
+            {"bundle": True}
+        )
 
     @property
     def clauses(self):
@@ -4386,8 +4361,19 @@ class Bundle(InspectionAttr):
 class _BundleEntity(_QueryEntity):
     use_id_for_hash = False
 
-    def __init__(self, query, bundle, setup_entities=True):
-        query._entities.append(self)
+    def __init__(self, query, expr, setup_entities=True, parent_bundle=None):
+        if parent_bundle:
+            parent_bundle._entities.append(self)
+        else:
+            query._entities.append(self)
+
+        if isinstance(
+            expr, (attributes.QueryableAttribute, interfaces.PropComparator)
+        ):
+            bundle = expr.__clause_element__()
+        else:
+            bundle = expr
+
         self.bundle = self.expr = bundle
         self.type = type(bundle)
         self._label_name = bundle.name
@@ -4396,9 +4382,9 @@ class _BundleEntity(_QueryEntity):
         if setup_entities:
             for expr in bundle.exprs:
                 if isinstance(expr, Bundle):
-                    _BundleEntity(self, expr)
+                    _BundleEntity(query, expr, parent_bundle=self)
                 else:
-                    _ColumnEntity(self, expr)
+                    _ColumnEntity(query, expr, parent_bundle=self)
 
         self.supports_single_entity = self.bundle.single_entity
 
@@ -4448,18 +4434,19 @@ class _BundleEntity(_QueryEntity):
         else:
             return None
 
-    def adapt_to_selectable(self, query, sel):
-        c = _BundleEntity(query, self.bundle, setup_entities=False)
+    def adapt_to_selectable(self, query, sel, parent_bundle=None):
+        c = _BundleEntity(
+            query,
+            self.bundle,
+            setup_entities=False,
+            parent_bundle=parent_bundle,
+        )
         # c._label_name = self._label_name
         # c.entity_zero = self.entity_zero
         # c.entities = self.entities
 
         for ent in self._entities:
-            ent.adapt_to_selectable(c, sel)
-
-    def setup_entity(self, ext_info, aliased_adapter):
-        for ent in self._entities:
-            ent.setup_entity(ext_info, aliased_adapter)
+            ent.adapt_to_selectable(query, sel, parent_bundle=c)
 
     def setup_context(self, query, context):
         for ent in self._entities:
@@ -4481,76 +4468,52 @@ class _BundleEntity(_QueryEntity):
 class _ColumnEntity(_QueryEntity):
     """Column/expression based entity."""
 
-    def __init__(self, query, column, namespace=None):
-        self.expr = column
+    froms = frozenset()
+
+    def __init__(self, query, column, namespace=None, parent_bundle=None):
+        self.expr = expr = column
         self.namespace = namespace
-        search_entities = True
-        check_column = False
-
-        if isinstance(column, util.string_types):
-            column = sql.literal_column(column)
-            self._label_name = column.name
-            search_entities = False
-            check_column = True
-            _entity = None
-        elif isinstance(
-            column, (attributes.QueryableAttribute, interfaces.PropComparator)
-        ):
-            _entity = getattr(column, "_parententity", None)
-            if _entity is not None:
-                search_entities = False
-            self._label_name = column.key
-            column = column._query_clause_element()
-            check_column = True
-            if isinstance(column, Bundle):
-                _BundleEntity(query, column)
-                return
+        _label_name = None
 
-        if not isinstance(column, sql.ColumnElement):
-            if hasattr(column, "_select_iterable"):
-                # break out an object like Table into
-                # individual columns
-                for c in column._select_iterable:
-                    if c is column:
-                        break
-                    _ColumnEntity(query, c, namespace=column)
-                else:
-                    return
+        column = coercions.expect(roles.ColumnsClauseRole, column)
 
-            raise sa_exc.InvalidRequestError(
-                "SQL expression, column, or mapped entity "
-                "expected - got '%r'" % (column,)
-            )
-        elif not check_column:
+        annotations = column._annotations
+
+        if annotations.get("bundle", False):
+            _BundleEntity(query, expr, parent_bundle=parent_bundle)
+            return
+
+        orm_expr = False
+
+        if "parententity" in annotations:
+            _entity = annotations["parententity"]
+            self._label_name = _label_name = annotations.get("orm_key", None)
+            orm_expr = True
+
+        if hasattr(column, "_select_iterable"):
+            # break out an object like Table into
+            # individual columns
+            for c in column._select_iterable:
+                if c is column:
+                    break
+                _ColumnEntity(query, c, namespace=column)
+            else:
+                return
+
+        if _label_name is None:
             self._label_name = getattr(column, "key", None)
-            search_entities = True
 
         self.type = type_ = column.type
         self.use_id_for_hash = not type_.hashable
 
-        # If the Column is unnamed, give it a
-        # label() so that mutable column expressions
-        # can be located in the result even
-        # if the expression's identity has been changed
-        # due to adaption.
-
-        if not column._label and not getattr(column, "is_literal", False):
-            column = column.label(self._label_name)
-
-        query._entities.append(self)
+        if parent_bundle:
+            parent_bundle._entities.append(self)
+        else:
+            query._entities.append(self)
 
         self.column = column
-        self.froms = set()
-
-        # look for ORM entities represented within the
-        # given expression.  Try to count only entities
-        # for columns whose FROM object is in the actual list
-        # of FROMs for the overall expression - this helps
-        # subqueries which were built from ORM constructs from
-        # leaking out their entities into the main select construct
-        self.actual_froms = set(column._from_objects)
 
-        if not search_entities:
+        if orm_expr:
             self.entity_zero = _entity
             if _entity:
                 self.entities = [_entity]
@@ -4559,21 +4522,20 @@ class _ColumnEntity(_QueryEntity):
                 self.entities = []
                 self.mapper = None
         else:
-            all_elements = [
-                elem
-                for elem in sql_util.surface_column_elements(
-                    column, include_scalar_selects=False
-                )
-                if "parententity" in elem._annotations
-            ]
 
-            self.entities = util.unique_list(
-                [elem._annotations["parententity"] for elem in all_elements]
+            entity = sql_util.extract_first_column_annotation(
+                column, "parententity"
             )
 
+            if entity:
+                self.entities = [entity]
+            else:
+                self.entities = []
+
             if self.entities:
                 self.entity_zero = self.entities[0]
                 self.mapper = self.entity_zero.mapper
+
             elif self.namespace is not None:
                 self.entity_zero = self.namespace
                 self.mapper = None
@@ -4581,6 +4543,9 @@ class _ColumnEntity(_QueryEntity):
                 self.entity_zero = None
                 self.mapper = None
 
+        if self.entities and self.entity_zero.mapper.with_polymorphic:
+            query._setup_query_adapters(self.entity_zero, self.entity_zero)
+
     supports_single_entity = False
 
     def _deep_entity_zero(self):
@@ -4603,24 +4568,21 @@ class _ColumnEntity(_QueryEntity):
     def entity_zero_or_selectable(self):
         if self.entity_zero is not None:
             return self.entity_zero
-        elif self.actual_froms:
-            return list(self.actual_froms)[0]
+        elif self.column._from_objects:
+            return self.column._from_objects[0]
         else:
             return None
 
-    def adapt_to_selectable(self, query, sel):
-        c = _ColumnEntity(query, sel.corresponding_column(self.column))
+    def adapt_to_selectable(self, query, sel, parent_bundle=None):
+        c = _ColumnEntity(
+            query,
+            sel.corresponding_column(self.column),
+            parent_bundle=parent_bundle,
+        )
         c._label_name = self._label_name
         c.entity_zero = self.entity_zero
         c.entities = self.entities
 
-    def setup_entity(self, ext_info, aliased_adapter):
-        if "selectable" not in self.__dict__:
-            self.selectable = ext_info.selectable
-
-        if self.actual_froms.intersection(ext_info.selectable._from_objects):
-            self.froms.add(ext_info.selectable)
-
     def corresponds_to(self, entity):
         if self.entity_zero is None:
             return False
@@ -4651,13 +4613,32 @@ class _ColumnEntity(_QueryEntity):
 
     def setup_context(self, query, context):
         column = query._adapt_clause(self.column, False, True)
+        ezero = self.entity_zero
+
+        if self.mapper:
+            single_table_crit = self.mapper._single_table_criterion
+            if single_table_crit is not None:
+                context.single_inh_entities[ezero] = (
+                    ezero,
+                    ezero._adapter if ezero.is_aliased_class else None,
+                )
 
         if column._annotations:
             # annotated columns perform more slowly in compiler and
             # result due to the __eq__() method, so use deannotated
             column = column._deannotate()
 
-        context.froms += tuple(self.froms)
+        if ezero is not None:
+            # use entity_zero as the from if we have it. this is necessary
+            # for polymorpic scenarios where our FROM is based on ORM entity,
+            # not the FROM of the column.  but also, don't use it if our column
+            # doesn't actually have any FROMs that line up, such as when its
+            # a scalar subquery.
+            if set(self.column._from_objects).intersection(
+                ezero.selectable._from_objects
+            ):
+                context.froms += (ezero.selectable,)
+
         context.primary_columns.append(column)
 
         context.attributes[("fetch_column", self)] = column
@@ -4697,6 +4678,7 @@ class QueryContext(object):
         "partials",
         "post_load_paths",
         "identity_token",
+        "single_inh_entities",
     )
 
     def __init__(self, query):
@@ -4731,6 +4713,7 @@ class QueryContext(object):
         self.secondary_columns = []
         self.eager_order_by = []
         self.eager_joins = {}
+        self.single_inh_entities = {}
         self.create_eager_joins = []
         self.propagate_options = set(
             o for o in query._with_options if o.propagate_to_loaders
index 63ec21099a328fdf0484fbc00f9e8238631826d9..731947cbaf08df0f5bcf2fc2f15ce7d2f79de7d7 100644 (file)
@@ -2319,11 +2319,11 @@ class JoinCondition(object):
         """
 
         self.primaryjoin = _deep_deannotate(
-            self.primaryjoin, values=("parententity",)
+            self.primaryjoin, values=("parententity", "orm_key")
         )
         if self.secondaryjoin is not None:
             self.secondaryjoin = _deep_deannotate(
-                self.secondaryjoin, values=("parententity",)
+                self.secondaryjoin, values=("parententity", "orm_key")
             )
 
     def _determine_joins(self):
index 4b4fa405238570f4242429bc4ec0dacabf1771ee..747ec7e658f0b1a5f06d4519956a9143483c5ef5 100644 (file)
@@ -668,10 +668,11 @@ class AliasedInsp(InspectionAttr):
             state["represents_outer_join"],
         )
 
-    def _adapt_element(self, elem):
-        return self._adapter.traverse(elem)._annotate(
-            {"parententity": self, "parentmapper": self.mapper}
-        )
+    def _adapt_element(self, elem, key=None):
+        d = {"parententity": self, "parentmapper": self.mapper}
+        if key:
+            d["orm_key"] = key
+        return self._adapter.traverse(elem)._annotate(d)
 
     def _entity_for_mapper(self, mapper):
         self_poly = self.with_polymorphic_mappers
index a7a856bba9e78401dac4e9e80153bb2424c57d10..95aee0468fed93cd94f2dc6c6104c1930c3ca4e8 100644 (file)
@@ -57,7 +57,7 @@ def expect(role, element, **kw):
     else:
         resolved = element
 
-    if issubclass(resolved.__class__, impl._role_class):
+    if impl._role_class in resolved.__class__.__mro__:
         if impl._post_coercion:
             resolved = impl._post_coercion(resolved, **kw)
         return resolved
@@ -102,13 +102,16 @@ class RoleImpl(object):
 
     def _resolve_for_clause_element(self, element, argname=None, **kw):
         original_element = element
-        is_clause_element = False
+        is_clause_element = hasattr(element, "__clause_element__")
 
-        while hasattr(element, "__clause_element__") and not isinstance(
-            element, (elements.ClauseElement, schema.SchemaItem)
-        ):
-            element = element.__clause_element__()
-            is_clause_element = True
+        if is_clause_element:
+            while not isinstance(
+                element, (elements.ClauseElement, schema.SchemaItem)
+            ):
+                try:
+                    element = element.__clause_element__()
+                except AttributeError:
+                    break
 
         if not is_clause_element:
             if self._use_inspection:
index fe83b163cd24f505b41ef4022a7e40c0b8d96a40..3c7f904deaf6e2874462a0e11de1201ff78388eb 100644 (file)
@@ -364,23 +364,19 @@ def surface_selectables_only(clause):
             stack.append(elem.table)
 
 
-def surface_column_elements(clause, include_scalar_selects=True):
-    """traverse and yield only outer-exposed column elements, such as would
-    be addressable in the WHERE clause of a SELECT if this element were
-    in the columns clause."""
+def extract_first_column_annotation(column, annotation_name):
+    filter_ = (FromGrouping, SelectBase)
 
-    filter_ = (FromGrouping,)
-    if not include_scalar_selects:
-        filter_ += (SelectBase,)
-
-    stack = deque([clause])
+    stack = deque([column])
     while stack:
         elem = stack.popleft()
-        yield elem
+        if annotation_name in elem._annotations:
+            return elem._annotations[annotation_name]
         for sub in elem.get_children():
             if isinstance(sub, filter_):
                 continue
             stack.append(sub)
+    return None
 
 
 def selectables_overlap(left, right):
index 4e52a777895c505a1b1c48b29cd85c01deb4d20a..632f559373d9bb0c3162946cc365d71f9a62bf57 100644 (file)
@@ -560,6 +560,10 @@ class QueryTest(fixtures.MappedTest):
         self._fixture()
         sess = Session()
 
+        # warm up cache
+        for attr in [Parent.data1, Parent.data2, Parent.data3, Parent.data4]:
+            attr.__clause_element__()
+
         @profiling.function_call_count()
         def go():
             for i in range(10):
index 7b8d413a46d05532b06cdb510007f31e0a88862e..d0db76b215df0201fecf2b59624bd6e1f5fae3fe 100644 (file)
@@ -287,7 +287,7 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest):
 
         self.assert_compile(
             sess.query(literal("1")).select_from(a1),
-            "SELECT :param_1 AS param_1 FROM employees AS employees_1 "
+            "SELECT :param_1 AS anon_1 FROM employees AS employees_1 "
             "WHERE employees_1.type IN (:type_1, :type_2)",
         )
 
index 7247c859a6b4e5d8879750d91b78309824bbb650..0d679e6db52e9240076fa78423b69aed1a1c57e3 100644 (file)
@@ -20,6 +20,8 @@ from sqlalchemy.testing.schema import Table
 
 
 class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
+    __dialect__ = "default"
+
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -311,6 +313,20 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             [(Point(3, 4), Point(5, 6))],
         )
 
+    def test_cols_as_core_clauseelement(self):
+        Edge = self.classes.Edge
+        Point = self.classes.Point
+
+        start, end = Edge.start, Edge.end
+
+        stmt = select([start, end]).where(start == Point(3, 4))
+        self.assert_compile(
+            stmt,
+            "SELECT edges.x1, edges.y1, edges.x2, edges.y2 "
+            "FROM edges WHERE edges.x1 = :x1_1 AND edges.y1 = :y1_1",
+            checkparams={"x1_1": 3, "y1_1": 4},
+        )
+
     def test_query_cols_labeled(self):
         Edge = self.classes.Edge
         Point = self.classes.Point
index 498b68057759f88b2c23ff566399cfcd6b4ee863..efa45affa0bf77adc96454adecff127e11ce139b 100644 (file)
@@ -7,7 +7,6 @@ from sqlalchemy import exc as sa_exc
 from sqlalchemy import exists
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
-from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import literal_column
 from sqlalchemy import select
@@ -2211,7 +2210,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
             sess.expunge_all()
 
         assert_raises(
-            sa_exc.InvalidRequestError, sess.query(User).add_column, object()
+            sa_exc.ArgumentError, sess.query(User).add_column, object()
         )
 
     def test_add_multi_columns(self):
@@ -2270,7 +2269,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
             .order_by(User.id)
         )
         q = sess.query(User)
-        result = q.add_column("count").from_statement(s).all()
+        result = q.add_column(s.selected_columns.count).from_statement(s).all()
         assert result == expected
 
     def test_raw_columns(self):
@@ -2315,7 +2314,10 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
         )
         q = create_session().query(User)
         result = (
-            q.add_column("count").add_column("concat").from_statement(s).all()
+            q.add_column(s.selected_columns.count)
+            .add_column(s.selected_columns.concat)
+            .from_statement(s)
+            .all()
         )
         assert result == expected
 
@@ -2399,7 +2401,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
         ]:
             q = s.query(crit)
             mzero = q._entity_zero()
-            is_(mzero.persist_selectable, q._query_entity_zero().selectable)
+            is_(mzero, q._query_entity_zero().entity_zero)
             q = q.join(j)
             self.assert_compile(q, exp)
 
@@ -2429,7 +2431,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
         ]:
             q = s.query(crit)
             mzero = q._entity_zero()
-            is_(inspect(mzero).selectable, q._query_entity_zero().selectable)
+            is_(mzero, q._query_entity_zero().entity_zero)
             q = q.join(j)
             self.assert_compile(q, exp)
 
index 4dff6fe56d96efca52a6a7c444c7ff4e148a6334..bcd13e6e2a6ad226fdab9a0a092d1c34eaa68997 100644 (file)
@@ -1005,14 +1005,14 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL):
 
         s = create_session()
         q = s.query(User)
-        assert_raises(sa_exc.InvalidRequestError, q.add_column, object())
+        assert_raises(sa_exc.ArgumentError, q.add_column, object())
 
     def test_invalid_column_tuple(self):
         User = self.classes.User
 
         s = create_session()
         q = s.query(User)
-        assert_raises(sa_exc.InvalidRequestError, q.add_column, (1, 1))
+        assert_raises(sa_exc.ArgumentError, q.add_column, (1, 1))
 
     def test_distinct(self):
         """test that a distinct() call is not valid before 'clauseelement'
@@ -2449,6 +2449,9 @@ class ComparatorTest(QueryTest):
             def __clause_element__(self):
                 return self.expr
 
+        # this use case isn't exactly needed in this form, however it tests
+        # that we resolve for multiple __clause_element__() calls as is needed
+        # by systems like composites
         sess = Session()
         eq_(
             sess.query(Comparator(User.id))
@@ -3398,11 +3401,11 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL):
             q3,
             "SELECT anon_1.users_id AS anon_1_users_id, "
             "anon_1.users_name AS anon_1_users_name, "
-            "anon_1.param_1 AS anon_1_param_1 "
-            "FROM (SELECT users.id AS users_id, users.name AS "
-            "users_name, :param_1 AS param_1 "
-            "FROM users UNION SELECT users.id AS users_id, "
-            "users.name AS users_name, 'y' FROM users) AS anon_1",
+            "anon_1.anon_2 AS anon_1_anon_2 FROM "
+            "(SELECT users.id AS users_id, users.name AS users_name, "
+            ":param_1 AS anon_2 FROM users "
+            "UNION SELECT users.id AS users_id, users.name AS users_name, "
+            "'y' FROM users) AS anon_1",
         )
 
     def test_union_literal_expressions_results(self):
@@ -3410,7 +3413,8 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL):
 
         s = Session()
 
-        q1 = s.query(User, literal("x"))
+        x_literal = literal("x")
+        q1 = s.query(User, x_literal)
         q2 = s.query(User, literal_column("'y'"))
         q3 = q1.union(q2)
 
@@ -3421,7 +3425,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL):
         eq_([x["name"] for x in q6.column_descriptions], ["User", "foo"])
 
         for q in (
-            q3.order_by(User.id, text("anon_1_param_1")),
+            q3.order_by(User.id, x_literal),
             q6.order_by(User.id, "foo"),
         ):
             eq_(
@@ -4231,12 +4235,14 @@ class TextTest(QueryTest, AssertsCompiledSQL):
         User = self.classes.User
 
         s = create_session()
-        assert_raises(
-            sa_exc.InvalidRequestError, s.query, User.id, text("users.name")
+
+        self.assert_compile(
+            s.query(User.id, text("users.name")),
+            "SELECT users.id AS users_id, users.name FROM users",
         )
 
         eq_(
-            s.query(User.id, "name").order_by(User.id).all(),
+            s.query(User.id, literal_column("name")).order_by(User.id).all(),
             [(7, "jack"), (8, "ed"), (9, "fred"), (10, "chuck")],
         )
 
index b32b6547fa363149c704a27251fcb7f09298bf37..03e17d291764fc89d4ec7a55abb7d12fed0057ce 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy import bindparam
 from sqlalchemy import ForeignKey
 from sqlalchemy import inspect
 from sqlalchemy import Integer
+from sqlalchemy import literal_column
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
@@ -792,7 +793,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         sess = create_session()
 
         self.assert_compile(
-            sess.query(User, "1"),
+            sess.query(User, literal_column("1")),
             "SELECT users.id AS users_id, users.name AS users_name, "
             "1 FROM users",
         )
index e47fc3f267a2e426e15eb4a458afdf6ea8e5d4bc..4bc2a5c88018a8db45493a045d48832fb7217d71 100644 (file)
@@ -210,7 +210,24 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL):
         eq_(str(alias.x + 1), "point_1.x + :x_1")
         eq_(str(alias.x_alone + 1), "point_1.x + :x_1")
 
-        is_(Point.x_alone.__clause_element__(), Point.x.__clause_element__())
+        point_mapper = inspect(Point)
+
+        eq_(
+            Point.x_alone._annotations,
+            {
+                "parententity": point_mapper,
+                "parentmapper": point_mapper,
+                "orm_key": "x_alone",
+            },
+        )
+        eq_(
+            Point.x._annotations,
+            {
+                "parententity": point_mapper,
+                "parentmapper": point_mapper,
+                "orm_key": "x",
+            },
+        )
 
         eq_(str(alias.x_alone == alias.x), "point_1.x = point_1.x")
 
index 0750ab767991ee712111ba12eda2efe78beb0a12..e4b99ba6806f859fbee2955558ff26c0f0418d41 100644 (file)
 
 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert
 
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mssql_pyodbc_dbapiunicode_cextensions 70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mssql_pyodbc_dbapiunicode_nocextensions 70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_mysqldb_dbapiunicode_cextensions 70,70,70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_mysqldb_dbapiunicode_nocextensions 70,70,70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_pymysql_dbapiunicode_cextensions 70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_pymysql_dbapiunicode_nocextensions 70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_oracle_cx_oracle_dbapiunicode_cextensions 70,70
-test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_oracle_cx_oracle_dbapiunicode_nocextensions 70,70
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mssql_pyodbc_dbapiunicode_cextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mssql_pyodbc_dbapiunicode_nocextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_mysqldb_dbapiunicode_cextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_mysqldb_dbapiunicode_nocextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_pymysql_dbapiunicode_cextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_mysql_pymysql_dbapiunicode_nocextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_oracle_cx_oracle_dbapiunicode_cextensions 66
+test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_oracle_cx_oracle_dbapiunicode_nocextensions 66
 test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_postgresql_psycopg2_dbapiunicode_cextensions 67
 test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 67
 test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_sqlite_pysqlite_dbapiunicode_cextensions 67
 test.aaa_profiling.test_compiler.CompileTest.test_insert 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 67
-test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_mysqldb_dbapiunicode_cextensions 73,73,73,73
-test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_mysqldb_dbapiunicode_nocextensions 73,73,73,73
-test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_pymysql_dbapiunicode_cextensions 73,73
-test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_pymysql_dbapiunicode_nocextensions 73,73
+test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_mysqldb_dbapiunicode_cextensions 73
+test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_mysqldb_dbapiunicode_nocextensions 73
+test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_pymysql_dbapiunicode_cextensions 73
+test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_mysql_pymysql_dbapiunicode_nocextensions 73
 test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_oracle_cx_oracle_dbapiunicode_cextensions 73
 test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_oracle_cx_oracle_dbapiunicode_nocextensions 73
 test.aaa_profiling.test_compiler.CompileTest.test_insert 3.7_postgresql_psycopg2_dbapiunicode_cextensions 72
@@ -523,24 +523,14 @@ test.aaa_profiling.test_orm.MergeTest.test_merge_no_load 3.7_sqlite_pysqlite_dba
 
 # TEST: test.aaa_profiling.test_orm.QueryTest.test_query_cols
 
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_mssql_pyodbc_dbapiunicode_cextensions 6230
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_mssql_pyodbc_dbapiunicode_nocextensions 6780
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_mysql_mysqldb_dbapiunicode_cextensions 6290
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_mysql_mysqldb_dbapiunicode_nocextensions 6840
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_oracle_cx_oracle_dbapiunicode_cextensions 6360
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_oracle_cx_oracle_dbapiunicode_nocextensions 8190
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6100
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6641
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_cextensions 6035
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_postgresql_psycopg2_dbapiunicode_cextensions 5900
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6441
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_cextensions 5900
 test.aaa_profiling.test_orm.QueryTest.test_query_cols 2.7_sqlite_pysqlite_dbapiunicode_nocextensions 6585
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_mysql_mysqldb_dbapiunicode_cextensions 6483
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_mysql_mysqldb_dbapiunicode_nocextensions 7143
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_oracle_cx_oracle_dbapiunicode_cextensions 6473
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_oracle_cx_oracle_dbapiunicode_nocextensions 7043
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6464
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 7034
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_sqlite_pysqlite_dbapiunicode_cextensions 6326
-test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_sqlite_pysqlite_dbapiunicode_nocextensions 6806
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6138
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 6800
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_sqlite_pysqlite_dbapiunicode_cextensions 6226
+test.aaa_profiling.test_orm.QueryTest.test_query_cols 3.7_sqlite_pysqlite_dbapiunicode_nocextensions 6506
 
 # TEST: test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results