]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- restore r611883ffb35ca6664649f6328ae8 with additional fixes and an additional test
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Jan 2015 23:31:10 +0000 (18:31 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Jan 2015 23:31:10 +0000 (18:31 -0500)
that is much more specific to #1326

doc/build/changelog/changelog_10.rst
doc/build/changelog/migration_10.rst
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
test/orm/test_query.py

index 089c9fafb356a3abd00133a4867316268eb289fa..79e43e6a37612470948869ea85ca8806a185fedf 100644 (file)
     series as well.  For changes that are specific to 1.0 with an emphasis
     on compatibility concerns, see :doc:`/changelog/migration_10`.
 
+    .. change::
+        :tags: bug, orm
+        :tickets: 3227, 3242, 1326
+
+        The primary :class:`.Mapper` of a :class:`.Query` is now passed to the
+        :meth:`.Session.get_bind` method when calling upon
+        :meth:`.Query.count`, :meth:`.Query.update`, :meth:`.Query.delete`,
+        as well as queries against mapped columns,
+        :obj:`.column_property` objects, and SQL functions and expressions
+        derived from mapped columns.   This allows sessions that rely upon
+        either customized :meth:`.Session.get_bind` schemes or "bound" metadata
+        to work in all relevant cases.
+
+        .. seealso::
+
+            :ref:`bug_3227`
+
     .. change::
         :tags: enhancement, sql
         :tickets: 3074
index bd878f4cb2bcb1b43c15f59c6388df363a62dc22..c0369d8b868b202dfb4d136bd9f7d09f7029ccfb 100644 (file)
@@ -381,6 +381,59 @@ of inheritance-oriented scenarios, including:
 
 :ticket:`3035`
 
+
+.. _bug_3227:
+
+Session.get_bind() will receive the Mapper in all relevant Query cases
+-----------------------------------------------------------------------
+
+A series of issues were repaired where the :meth:`.Session.get_bind`
+would not receive the primary :class:`.Mapper` of the :class:`.Query`,
+even though this mapper was readily available (the primary mapper is the
+single mapper, or alternatively the first mapper, that is associated with
+a :class:`.Query` object).
+
+The :class:`.Mapper` object, when passed to :meth:`.Session.get_bind`,
+is typically used by sessions that make use of the
+:paramref:`.Session.binds` parameter to associate mappers with a
+series of engines (although in this use case, things frequently
+"worked" in most cases anyway as the bind would be located via the
+mapped table object), or more specifically implement a user-defined
+:meth:`.Session.get_bind` method that provies some pattern of
+selecting engines based on mappers, such as horizontal sharding or a
+so-called "routing" session that routes queries to different backends.
+
+These scenarios include:
+
+* :meth:`.Query.count`::
+
+        session.query(User).count()
+
+* :meth:`.Query.update` and :meth:`.Query.delete`, both for the UPDATE/DELETE
+  statement as well as for the SELECT used by the "fetch" strategy::
+
+        session.query(User).filter(User.id == 15).update(
+                {"name": "foob"}, synchronize_session='fetch')
+
+        session.query(User).filter(User.id == 15).delete(
+                synchronize_session='fetch')
+
+* Queries against individual columns::
+
+        session.query(User.id, User.name).all()
+
+* SQL functions and other expressions against indirect mappings such as
+  :obj:`.column_property`::
+
+        class User(Base):
+            # ...
+
+            score = column_property(func.coalesce(self.tables.users.c.name, None)))
+
+        session.query(func.max(User.score)).scalar()
+
+:ticket:`3227` :ticket:`3242` :ticket:`1326`
+
 .. _feature_2963:
 
 .info dictionary improvements
index e553f399de6ee73561069fb1b114b774ebc624bc..c3b2d7bcbbb77e1ca3dd255aeb5fdd1a6586847f 100644 (file)
@@ -1030,6 +1030,7 @@ class BulkUD(object):
 
     def __init__(self, query):
         self.query = query.enable_eagerloads(False)
+        self.mapper = self.query._bind_mapper()
 
     @property
     def session(self):
@@ -1124,6 +1125,7 @@ class BulkFetch(BulkUD):
             self.primary_table.primary_key)
         self.matched_rows = session.execute(
             select_stmt,
+            mapper=self.mapper,
             params=query._params).fetchall()
 
 
@@ -1134,7 +1136,6 @@ class BulkUpdate(BulkUD):
         super(BulkUpdate, self).__init__(query)
         self.query._no_select_modifiers("update")
         self.values = values
-        self.mapper = self.query._mapper_zero_or_none()
 
     @classmethod
     def factory(cls, query, synchronize_session, values):
@@ -1180,7 +1181,8 @@ class BulkUpdate(BulkUD):
                                  self.context.whereclause, values)
 
         self.result = self.query.session.execute(
-            update_stmt, params=self.query._params)
+            update_stmt, params=self.query._params,
+            mapper=self.mapper)
         self.rowcount = self.result.rowcount
 
     def _do_post(self):
@@ -1207,8 +1209,10 @@ class BulkDelete(BulkUD):
         delete_stmt = sql.delete(self.primary_table,
                                  self.context.whereclause)
 
-        self.result = self.query.session.execute(delete_stmt,
-                                                 params=self.query._params)
+        self.result = self.query.session.execute(
+            delete_stmt,
+            params=self.query._params,
+            mapper=self.mapper)
         self.rowcount = self.result.rowcount
 
     def _do_post(self):
index 7302574e62a0beafc514d5cf828c144a790ee2a6..60a6379521c04cc18eb67dbf2d66916cbd1a0300 100644 (file)
@@ -160,7 +160,6 @@ class Query(object):
 
         for from_obj in obj:
             info = inspect(from_obj)
-
             if hasattr(info, 'mapper') and \
                     (info.is_mapper or info.is_aliased_class):
                 self._select_from_entity = from_obj
@@ -286,8 +285,9 @@ class Query(object):
         return self._entities[0]
 
     def _mapper_zero(self):
-        return self._select_from_entity or \
-            self._entity_zero().entity_zero
+        return self._select_from_entity \
+            if self._select_from_entity is not None \
+            else self._entity_zero().entity_zero
 
     @property
     def _mapper_entities(self):
@@ -301,11 +301,14 @@ class Query(object):
             self._mapper_zero()
         )
 
-    def _mapper_zero_or_none(self):
-        if self._primary_entity:
-            return self._primary_entity.mapper
-        else:
-            return None
+    def _bind_mapper(self):
+        ezero = self._mapper_zero()
+        if ezero is not None:
+            insp = inspect(ezero)
+            if hasattr(insp, 'mapper'):
+                return insp.mapper
+
+        return None
 
     def _only_mapper_zero(self, rationale=None):
         if len(self._entities) > 1:
@@ -988,6 +991,7 @@ class Query(object):
             statement.correlate(None)
         q = self._from_selectable(fromclause)
         q._enable_single_crit = False
+        q._select_from_entity = self._mapper_zero()
         if entities:
             q._set_entities(entities)
         return q
@@ -2526,7 +2530,7 @@ class Query(object):
 
     def _execute_and_instances(self, querycontext):
         conn = self._connection_from_session(
-            mapper=self._mapper_zero_or_none(),
+            mapper=self._bind_mapper(),
             clause=querycontext.statement,
             close_with_result=True)
 
@@ -3592,15 +3596,26 @@ class _ColumnEntity(_QueryEntity):
         # leaking out their entities into the main select construct
         self.actual_froms = actual_froms = set(column._from_objects)
 
-        self.entities = util.OrderedSet(
+        all_elements = [
+            elem for elem in visitors.iterate(column, {})
+            if 'parententity' in elem._annotations
+        ]
+
+        self.entities = util.unique_list(
+            elem._annotations['parententity']
+            for elem in all_elements
+            if 'parententity' in elem._annotations
+        )
+
+        self._from_entities = set(
             elem._annotations['parententity']
-            for elem in visitors.iterate(column, {})
+            for elem in all_elements
             if 'parententity' in elem._annotations
             and actual_froms.intersection(elem._from_objects)
         )
 
         if self.entities:
-            self.entity_zero = list(self.entities)[0]
+            self.entity_zero = self.entities[0]
         elif self.namespace is not None:
             self.entity_zero = self.namespace
         else:
@@ -3626,7 +3641,9 @@ class _ColumnEntity(_QueryEntity):
     def setup_entity(self, ext_info, aliased_adapter):
         if 'selectable' not in self.__dict__:
             self.selectable = ext_info.selectable
-        self.froms.add(ext_info.selectable)
+
+        if self.actual_froms.intersection(ext_info.selectable._from_objects):
+            self.froms.add(ext_info.selectable)
 
     def corresponds_to(self, entity):
         # TODO: just returning False here,
index af6d960f59aef0b84b5380697222da2bd2f83734..a2a1ee09614d5e3911e0aa9fadb0bf768453c9cd 100644 (file)
@@ -3224,8 +3224,9 @@ class SessionBindTest(QueryTest):
         get_bind = mock.Mock(side_effect=session.get_bind)
         with mock.patch.object(session, "get_bind", get_bind):
             yield
-        is_(get_bind.mock_calls[0][1][0], inspect(self.classes.User))
-        is_not_(get_bind.mock_calls[0][2]['clause'], None)
+        for call_ in get_bind.mock_calls:
+            is_(call_[1][0], inspect(self.classes.User))
+            is_not_(call_[2]['clause'], None)
 
     def test_single_entity_q(self):
         User = self.classes.User
@@ -3251,20 +3252,34 @@ class SessionBindTest(QueryTest):
         with self._assert_bind_args(session):
             session.query(func.max(User.name)).all()
 
-    def test_bulk_update(self):
+    def test_bulk_update_no_sync(self):
         User = self.classes.User
         session = Session()
         with self._assert_bind_args(session):
             session.query(User).filter(User.id == 15).update(
                 {"name": "foob"}, synchronize_session=False)
 
-    def test_bulk_delete(self):
+    def test_bulk_delete_no_sync(self):
         User = self.classes.User
         session = Session()
         with self._assert_bind_args(session):
             session.query(User).filter(User.id == 15).delete(
                 synchronize_session=False)
 
+    def test_bulk_update_fetch_sync(self):
+        User = self.classes.User
+        session = Session()
+        with self._assert_bind_args(session):
+            session.query(User).filter(User.id == 15).update(
+                {"name": "foob"}, synchronize_session='fetch')
+
+    def test_bulk_delete_fetch_sync(self):
+        User = self.classes.User
+        session = Session()
+        with self._assert_bind_args(session):
+            session.query(User).filter(User.id == 15).delete(
+                synchronize_session='fetch')
+
     def test_column_property(self):
         User = self.classes.User
 
@@ -3275,3 +3290,21 @@ class SessionBindTest(QueryTest):
         session = Session()
         with self._assert_bind_args(session):
             session.query(func.max(User.score)).scalar()
+
+    def test_column_property_select(self):
+        User = self.classes.User
+        Address = self.classes.Address
+
+        mapper = inspect(User)
+        mapper.add_property(
+            "score",
+            column_property(
+                select([func.sum(Address.id)]).
+                where(Address.user_id == User.id).as_scalar()
+            )
+        )
+        session = Session()
+
+        with self._assert_bind_args(session):
+            session.query(func.max(User.score)).scalar()
+