]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Annotate session-bind-lookup entity in Query-produced selectables
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Aug 2019 16:09:17 +0000 (12:09 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Aug 2019 21:57:38 +0000 (17:57 -0400)
Added new entity-targeting capabilities to the :class:`.Query` object to
help with the case where the :class:`.Session` is using a bind dictionary
against mapped classes, rather than a single bind, and the :class:`.Query`
is against a Core statement that was ultimately generated from a method
such as :meth:`.Query.subquery`; a deep search is performed to locate
any ORM entity related to the query in order to locate a mapper if
one is not otherwise present.

Fixes: #4829
Change-Id: I95cf325a5aba21baec4b313246c6f4d692284820

doc/build/changelog/unreleased_14/4829.rst [new file with mode: 0644]
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/annotation.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
test/orm/test_query.py
test/sql/test_selectable.py

diff --git a/doc/build/changelog/unreleased_14/4829.rst b/doc/build/changelog/unreleased_14/4829.rst
new file mode 100644 (file)
index 0000000..93c582f
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 4829
+
+    Added new entity-targeting capabilities to the :class:`.Query` object to
+    help with the case where the :class:`.Session` is using a bind dictionary
+    against mapped classes, rather than a single bind, and the :class:`.Query`
+    is against a Core statement that was ultimately generated from a method
+    such as :meth:`.Query.subquery`; a deep search is performed to locate
+    any ORM entity related to the query in order to locate a mapper if
+    one is not otherwise present.
+
index 936929703303516515deda696bc3026472ff3ada..d4ff35d2e59f922ecd7b496c324d345c3f47968b 100644 (file)
@@ -384,6 +384,25 @@ class Query(object):
             else self._query_entity_zero().entity_zero
         )
 
+    def _deep_entity_zero(self):
+        """Return a 'deep' entity; this is any entity we can find associated
+        with the first entity / column experssion.   this is used only for
+        session.get_bind().
+
+        """
+
+        if (
+            self._select_from_entity is not None
+            and not self._select_from_entity.is_clause_element
+        ):
+            return self._select_from_entity.mapper
+        for ent in self._entities:
+            ezero = ent._deep_entity_zero()
+            if ezero is not None:
+                return ezero.mapper
+        else:
+            return None
+
     @property
     def _mapper_entities(self):
         for ent in self._entities:
@@ -394,13 +413,7 @@ class Query(object):
         return self._joinpoint.get("_joinpoint_entity", self._entity_zero())
 
     def _bind_mapper(self):
-        ezero = self._entity_zero()
-        if ezero is not None:
-            insp = inspect(ezero)
-            if not insp.is_clause_element:
-                return insp.mapper
-
-        return None
+        return self._deep_entity_zero()
 
     def _only_full_mapper_zero(self, methname):
         if self._entities != [self._primary_entity]:
@@ -3900,6 +3913,12 @@ class Query(object):
         else:
             context.statement = self._simple_statement(context)
 
+        if for_statement:
+            ezero = self._mapper_zero()
+            if ezero is not None:
+                context.statement = context.statement._annotate(
+                    {"deepentity": ezero}
+                )
         return context
 
     def _compound_eager_statement(self, context):
@@ -4161,6 +4180,9 @@ class _MapperEntity(_QueryEntity):
     def entity_zero_or_selectable(self):
         return self.entity_zero
 
+    def _deep_entity_zero(self):
+        return self.entity_zero
+
     def corresponds_to(self, entity):
         return _entity_corresponds_to(self.entity_zero, entity)
 
@@ -4430,6 +4452,14 @@ class _BundleEntity(_QueryEntity):
         else:
             return None
 
+    def _deep_entity_zero(self):
+        for ent in self._entities:
+            ezero = ent._deep_entity_zero()
+            if ezero is not None:
+                return ezero
+        else:
+            return None
+
     def adapt_to_selectable(self, query, sel):
         c = _BundleEntity(query, self.bundle, setup_entities=False)
         # c._label_name = self._label_name
@@ -4530,7 +4560,7 @@ class _ColumnEntity(_QueryEntity):
         # of FROMs for the overall expression - this helps
         # subqueries which were built from ORM constructs from
         # leaking out their entities into the main select construct
-        self.actual_froms = actual_froms = set(column._from_objects)
+        self.actual_froms = set(column._from_objects)
 
         if not search_entities:
             self.entity_zero = _entity
@@ -4540,7 +4570,6 @@ class _ColumnEntity(_QueryEntity):
             else:
                 self.entities = []
                 self.mapper = None
-            self._from_entities = set(self.entities)
         else:
             all_elements = [
                 elem
@@ -4551,21 +4580,9 @@ class _ColumnEntity(_QueryEntity):
             ]
 
             self.entities = util.unique_list(
-                [
-                    elem._annotations["parententity"]
-                    for elem in all_elements
-                    if "parententity" in elem._annotations
-                ]
+                [elem._annotations["parententity"] for elem in all_elements]
             )
 
-            self._from_entities = set(
-                [
-                    elem._annotations["parententity"]
-                    for elem in all_elements
-                    if "parententity" in elem._annotations
-                    and actual_froms.intersection(elem._from_objects)
-                ]
-            )
             if self.entities:
                 self.entity_zero = self.entities[0]
                 self.mapper = self.entity_zero.mapper
@@ -4578,6 +4595,22 @@ class _ColumnEntity(_QueryEntity):
 
     supports_single_entity = False
 
+    def _deep_entity_zero(self):
+        if self.mapper is not None:
+            return self.mapper
+
+        else:
+            for obj in visitors.iterate(
+                self.column,
+                {"column_tables": True, "column_collections": False},
+            ):
+                if "parententity" in obj._annotations:
+                    return obj._annotations["parententity"]
+                elif "deepentity" in obj._annotations:
+                    return obj._annotations["deepentity"]
+            else:
+                return None
+
     @property
     def entity_zero_or_selectable(self):
         if self.entity_zero is not None:
index 7fc9245ab5138d52add3ec988c78c278661544d8..a0264845e382cdabb77d2f9bb77b7010dd2cfd19 100644 (file)
@@ -15,8 +15,80 @@ from . import operators
 from .. import util
 
 
+class SupportsCloneAnnotations(object):
+    _annotations = util.immutabledict()
+
+    def _annotate(self, values):
+        """return a copy of this ClauseElement with annotations
+        updated by the given dictionary.
+
+        """
+        new = self._clone()
+        new._annotations = new._annotations.union(values)
+        return new
+
+    def _with_annotations(self, values):
+        """return a copy of this ClauseElement with annotations
+        replaced by the given dictionary.
+
+        """
+        new = self._clone()
+        new._annotations = util.immutabledict(values)
+        return new
+
+    def _deannotate(self, values=None, clone=False):
+        """return a copy of this :class:`.ClauseElement` with annotations
+        removed.
+
+        :param values: optional tuple of individual values
+         to remove.
+
+        """
+        if clone or self._annotations:
+            # clone is used when we are also copying
+            # the expression for a deep deannotation
+            new = self._clone()
+            new._annotations = {}
+            return new
+        else:
+            return self
+
+
+class SupportsWrappingAnnotations(object):
+    def _annotate(self, values):
+        """return a copy of this ClauseElement with annotations
+        updated by the given dictionary.
+
+        """
+        return Annotated(self, values)
+
+    def _with_annotations(self, values):
+        """return a copy of this ClauseElement with annotations
+        replaced by the given dictionary.
+
+        """
+        return Annotated(self, values)
+
+    def _deannotate(self, values=None, clone=False):
+        """return a copy of this :class:`.ClauseElement` with annotations
+        removed.
+
+        :param values: optional tuple of individual values
+         to remove.
+
+        """
+        if clone:
+            # clone is used when we are also copying
+            # the expression for a deep deannotation
+            return self._clone()
+        else:
+            # if no clone, since we have no annotations we return
+            # self
+            return self
+
+
 class Annotated(object):
-    """clones a ClauseElement and applies an 'annotations' dictionary.
+    """clones a SupportsAnnotated and applies an 'annotations' dictionary.
 
     Unlike regular clones, this clone also mimics __hash__() and
     __cmp__() of the original element so that it takes its place
index e2df1adc2d5843463a450052ea8f9bc6bdd7e19b..19d26f138282d41c5ae4ef2eebde77291d8294b0 100644 (file)
@@ -22,6 +22,7 @@ from . import operators
 from . import roles
 from . import type_api
 from .annotation import Annotated
+from .annotation import SupportsWrappingAnnotations
 from .base import _clone
 from .base import _generative
 from .base import Executable
@@ -161,7 +162,7 @@ def not_(clause):
 
 
 @inspection._self_inspects
-class ClauseElement(roles.SQLRole, Visitable):
+class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
     """Base class for elements of a programmatically constructed SQL
     expression.
 
@@ -267,37 +268,6 @@ class ClauseElement(roles.SQLRole, Visitable):
         d.pop("_is_clone_of", None)
         return d
 
-    def _annotate(self, values):
-        """return a copy of this ClauseElement with annotations
-        updated by the given dictionary.
-
-        """
-        return Annotated(self, values)
-
-    def _with_annotations(self, values):
-        """return a copy of this ClauseElement with annotations
-        replaced by the given dictionary.
-
-        """
-        return Annotated(self, values)
-
-    def _deannotate(self, values=None, clone=False):
-        """return a copy of this :class:`.ClauseElement` with annotations
-        removed.
-
-        :param values: optional tuple of individual values
-         to remove.
-
-        """
-        if clone:
-            # clone is used when we are also copying
-            # the expression for a deep deannotation
-            return self._clone()
-        else:
-            # if no clone, since we have no annotations we return
-            # self
-            return self
-
     def _execute_on_connection(self, connection, multiparams, params):
         if self.supports_execution:
             return connection._execute_clauseelement(self, multiparams, params)
@@ -4136,6 +4106,12 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
         self._memoized_property.expire_instance(self)
         self.__dict__["table"] = table
 
+    def get_children(self, column_tables=False, **kw):
+        if column_tables and self.table is not None:
+            return [self.table]
+        else:
+            return []
+
     table = property(_get_table, _set_table)
 
     def _cache_key(self, **kw):
index 03dbcd449acb9a396c50b9e2cae6fae6dd893779..97c49f8fcc98e6b888ab80b5122172ac90519f30 100644 (file)
@@ -19,6 +19,7 @@ from . import operators
 from . import roles
 from . import type_api
 from .annotation import Annotated
+from .annotation import SupportsCloneAnnotations
 from .base import _clone
 from .base import _cloned_difference
 from .base import _cloned_intersection
@@ -2068,6 +2069,7 @@ class SelectBase(
     roles.InElementRole,
     HasCTE,
     Executable,
+    SupportsCloneAnnotations,
     Selectable,
 ):
     """Base class for SELECT statements.
index f5283ff443ae24b15b78f48837c90ab10b99113f..4dff6fe56d96efca52a6a7c444c7ff4e148a6334 100644 (file)
@@ -139,6 +139,23 @@ class RowTupleTest(QueryTest):
         assert row.id == 7
         assert row.uname == "jack"
 
+    def test_deep_entity(self):
+        users, User = (self.tables.users, self.classes.User)
+
+        mapper(User, users)
+
+        sess = create_session()
+        bundle = Bundle("b1", User.id, User.name)
+        subq1 = sess.query(User.id).subquery()
+        subq2 = sess.query(bundle).subquery()
+        cte = sess.query(User.id).cte()
+        ex = sess.query(User).exists()
+
+        is_(sess.query(subq1)._deep_entity_zero(), inspect(User))
+        is_(sess.query(subq2)._deep_entity_zero(), inspect(User))
+        is_(sess.query(cte)._deep_entity_zero(), inspect(User))
+        is_(sess.query(ex)._deep_entity_zero(), inspect(User))
+
     def test_column_metadata(self):
         users, Address, addresses, User = (
             self.tables.users,
@@ -156,6 +173,8 @@ class RowTupleTest(QueryTest):
         fn = func.count(User.id)
         name_label = User.name.label("uname")
         bundle = Bundle("b1", User.id, User.name)
+        subq1 = sess.query(User.id).subquery()
+        subq2 = sess.query(bundle).subquery()
         cte = sess.query(User.id).cte()
         for q, asserted in [
             (
@@ -275,6 +294,30 @@ class RowTupleTest(QueryTest):
                     }
                 ],
             ),
+            (
+                sess.query(subq1.c.id),
+                [
+                    {
+                        "aliased": False,
+                        "expr": subq1.c.id,
+                        "type": subq1.c.id.type,
+                        "name": "id",
+                        "entity": None,
+                    }
+                ],
+            ),
+            (
+                sess.query(subq2.c.id),
+                [
+                    {
+                        "aliased": False,
+                        "expr": subq2.c.id,
+                        "type": subq2.c.id.type,
+                        "name": "id",
+                        "entity": None,
+                    }
+                ],
+            ),
             (
                 sess.query(users),
                 [
@@ -5518,12 +5561,15 @@ class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
 class SessionBindTest(QueryTest):
     @contextlib.contextmanager
-    def _assert_bind_args(self, session):
+    def _assert_bind_args(self, session, expect_mapped_bind=True):
         get_bind = mock.Mock(side_effect=session.get_bind)
         with mock.patch.object(session, "get_bind", get_bind):
             yield
         for call_ in get_bind.mock_calls:
-            is_(call_[1][0], inspect(self.classes.User))
+            if expect_mapped_bind:
+                is_(call_[1][0], inspect(self.classes.User))
+            else:
+                is_(call_[1][0], None)
             is_not_(call_[2]["clause"], None)
 
     def test_single_entity_q(self):
@@ -5532,12 +5578,43 @@ class SessionBindTest(QueryTest):
         with self._assert_bind_args(session):
             session.query(User).all()
 
+    def test_aliased_entity_q(self):
+        User = self.classes.User
+        u = aliased(User)
+        session = Session()
+        with self._assert_bind_args(session):
+            session.query(u).all()
+
     def test_sql_expr_entity_q(self):
         User = self.classes.User
         session = Session()
         with self._assert_bind_args(session):
             session.query(User.id).all()
 
+    def test_sql_expr_subquery_from_entity(self):
+        User = self.classes.User
+        session = Session()
+        with self._assert_bind_args(session):
+            subq = session.query(User.id).subquery()
+            session.query(subq).all()
+
+    def test_sql_expr_cte_from_entity(self):
+        User = self.classes.User
+        session = Session()
+        with self._assert_bind_args(session):
+            cte = session.query(User.id).cte()
+            subq = session.query(cte).subquery()
+            session.query(subq).all()
+
+    def test_sql_expr_bundle_cte_from_entity(self):
+        User = self.classes.User
+        session = Session()
+        with self._assert_bind_args(session):
+            cte = session.query(User.id, User.name).cte()
+            subq = session.query(cte).subquery()
+            bundle = Bundle(subq.c.id, subq.c.name)
+            session.query(bundle).all()
+
     def test_count(self):
         User = self.classes.User
         session = Session()
@@ -5594,6 +5671,35 @@ class SessionBindTest(QueryTest):
         with self._assert_bind_args(session):
             session.query(func.max(User.score)).scalar()
 
+    def test_plain_table(self):
+        User = self.classes.User
+
+        session = Session()
+        with self._assert_bind_args(session, expect_mapped_bind=False):
+            session.query(inspect(User).local_table).all()
+
+    def test_plain_table_from_self(self):
+        User = self.classes.User
+
+        session = Session()
+        with self._assert_bind_args(session, expect_mapped_bind=False):
+            session.query(inspect(User).local_table).from_self().all()
+
+    def test_plain_table_count(self):
+        User = self.classes.User
+
+        session = Session()
+        with self._assert_bind_args(session, expect_mapped_bind=False):
+            session.query(inspect(User).local_table).count()
+
+    def test_plain_table_select_from(self):
+        User = self.classes.User
+
+        table = inspect(User).local_table
+        session = Session()
+        with self._assert_bind_args(session, expect_mapped_bind=False):
+            session.query(table).select_from(table).all()
+
     @testing.requires.nested_aggregates
     def test_column_property_select(self):
         User = self.classes.User
index 189436192db53e6277dc858581e7142f50b9dd4f..c54f27c23b3629c0d9935da73c5ef577aba6aab8 100644 (file)
@@ -41,7 +41,10 @@ from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import AssertsExecutionResults
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import in_
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_not_
+from sqlalchemy.testing import ne_
 
 
 metadata = MetaData()
@@ -2196,12 +2199,21 @@ class AnnotationsTest(fixtures.TestBase):
         t = table("t", column("x"))
 
         a = t.alias()
+
+        for obj in [t, t.c.x, a, t.c.x > 1, (t.c.x > 1).label(None)]:
+            annot = obj._annotate({})
+            eq_(set([obj]), set([annot]))
+
+    def test_clone_annotations_dont_hash(self):
+        t = table("t", column("x"))
+
         s = t.select()
+        a = t.alias()
         s2 = a.select()
 
-        for obj in [t, t.c.x, a, s, s2, t.c.x > 1, (t.c.x > 1).label(None)]:
+        for obj in [s, s2]:
             annot = obj._annotate({})
-            eq_(set([obj]), set([annot]))
+            ne_(set([obj]), set([annot]))
 
     def test_compare(self):
         t = table("t", column("x"), column("y"))
@@ -2423,7 +2435,7 @@ class AnnotationsTest(fixtures.TestBase):
                 expected,
             )
 
-    def test_deannotate(self):
+    def test_deannotate_wrapping(self):
         table1 = table("table1", column("col1"), column("col2"))
 
         bin_ = table1.c.col1 == bindparam("foo", value=None)
@@ -2433,7 +2445,7 @@ class AnnotationsTest(fixtures.TestBase):
         b4 = sql_util._deep_deannotate(bin_)
 
         for elem in (b2._annotations, b2.left._annotations):
-            assert "_orm_adapt" in elem
+            in_("_orm_adapt", elem)
 
         for elem in (
             b3._annotations,
@@ -2441,17 +2453,47 @@ class AnnotationsTest(fixtures.TestBase):
             b4._annotations,
             b4.left._annotations,
         ):
-            assert elem == {}
+            eq_(elem, {})
 
-        assert b2.left is not bin_.left
-        assert b3.left is not b2.left and b2.left is not bin_.left
-        assert b4.left is bin_.left  # since column is immutable
+        is_not_(b2.left, bin_.left)
+        is_not_(b3.left, b2.left)
+        is_not_(b2.left, bin_.left)
+        is_(b4.left, bin_.left)  # since column is immutable
         # deannotate copies the element
-        assert (
-            bin_.right is not b2.right
-            and b2.right is not b3.right
-            and b3.right is not b4.right
+        is_not_(bin_.right, b2.right)
+        is_not_(b2.right, b3.right)
+        is_not_(b3.right, b4.right)
+
+    def test_deannotate_clone(self):
+        table1 = table("table1", column("col1"), column("col2"))
+
+        subq = (
+            select([table1])
+            .where(table1.c.col1 == bindparam("foo"))
+            .subquery()
         )
+        stmt = select([subq])
+
+        s2 = sql_util._deep_annotate(stmt, {"_orm_adapt": True})
+        s3 = sql_util._deep_deannotate(s2)
+        s4 = sql_util._deep_deannotate(s3)
+
+        eq_(stmt._annotations, {})
+        eq_(subq._annotations, {})
+
+        eq_(s2._annotations, {"_orm_adapt": True})
+        eq_(s3._annotations, {})
+        eq_(s4._annotations, {})
+
+        # select._raw_columns[0] is the subq object
+        eq_(s2._raw_columns[0]._annotations, {"_orm_adapt": True})
+        eq_(s3._raw_columns[0]._annotations, {})
+        eq_(s4._raw_columns[0]._annotations, {})
+
+        is_not_(s3, s2)
+        is_not_(s4, s3)  # deep deannotate makes a clone unconditionally
+
+        is_(s3._deannotate(), s3)  # regular deannotate returns same object
 
     def test_annotate_unique_traversal(self):
         """test that items are copied only once during