]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The primary :class:`.Mapper` of a :class:`.Query` is now passed to the
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Jan 2015 22:55:23 +0000 (17:55 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Jan 2015 22:55:23 +0000 (17:55 -0500)
:meth:`.Session.get_bind` method when calling upon
:meth:`.Query.count`, :meth:`.Query.update`, :meth:`.Query.delete`,
as well as queries against mapped columns,
:obj:`.column_property` objects, and SQL functions and expressions
derived from mapped columns.   This allows sessions that rely upon
either customized :meth:`.Session.get_bind` schemes or "bound" metadata
to work in all relevant cases.
fixes #3227 fixes #3242 fixes #1326

doc/build/changelog/changelog_10.rst
doc/build/changelog/migration_10.rst
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
test/orm/test_query.py

index 089c9fafb356a3abd00133a4867316268eb289fa..79e43e6a37612470948869ea85ca8806a185fedf 100644 (file)
     series as well.  For changes that are specific to 1.0 with an emphasis
     on compatibility concerns, see :doc:`/changelog/migration_10`.
 
+    .. change::
+        :tags: bug, orm
+        :tickets: 3227, 3242, 1326
+
+        The primary :class:`.Mapper` of a :class:`.Query` is now passed to the
+        :meth:`.Session.get_bind` method when calling upon
+        :meth:`.Query.count`, :meth:`.Query.update`, :meth:`.Query.delete`,
+        as well as queries against mapped columns,
+        :obj:`.column_property` objects, and SQL functions and expressions
+        derived from mapped columns.   This allows sessions that rely upon
+        either customized :meth:`.Session.get_bind` schemes or "bound" metadata
+        to work in all relevant cases.
+
+        .. seealso::
+
+            :ref:`bug_3227`
+
     .. change::
         :tags: enhancement, sql
         :tickets: 3074
index bd878f4cb2bcb1b43c15f59c6388df363a62dc22..c0369d8b868b202dfb4d136bd9f7d09f7029ccfb 100644 (file)
@@ -381,6 +381,59 @@ of inheritance-oriented scenarios, including:
 
 :ticket:`3035`
 
+
+.. _bug_3227:
+
+Session.get_bind() will receive the Mapper in all relevant Query cases
+-----------------------------------------------------------------------
+
+A series of issues were repaired where the :meth:`.Session.get_bind`
+would not receive the primary :class:`.Mapper` of the :class:`.Query`,
+even though this mapper was readily available (the primary mapper is the
+single mapper, or alternatively the first mapper, that is associated with
+a :class:`.Query` object).
+
+The :class:`.Mapper` object, when passed to :meth:`.Session.get_bind`,
+is typically used by sessions that make use of the
+:paramref:`.Session.binds` parameter to associate mappers with a
+series of engines (although in this use case, things frequently
+"worked" in most cases anyway as the bind would be located via the
+mapped table object), or more specifically implement a user-defined
+:meth:`.Session.get_bind` method that provies some pattern of
+selecting engines based on mappers, such as horizontal sharding or a
+so-called "routing" session that routes queries to different backends.
+
+These scenarios include:
+
+* :meth:`.Query.count`::
+
+        session.query(User).count()
+
+* :meth:`.Query.update` and :meth:`.Query.delete`, both for the UPDATE/DELETE
+  statement as well as for the SELECT used by the "fetch" strategy::
+
+        session.query(User).filter(User.id == 15).update(
+                {"name": "foob"}, synchronize_session='fetch')
+
+        session.query(User).filter(User.id == 15).delete(
+                synchronize_session='fetch')
+
+* Queries against individual columns::
+
+        session.query(User.id, User.name).all()
+
+* SQL functions and other expressions against indirect mappings such as
+  :obj:`.column_property`::
+
+        class User(Base):
+            # ...
+
+            score = column_property(func.coalesce(self.tables.users.c.name, None)))
+
+        session.query(func.max(User.score)).scalar()
+
+:ticket:`3227` :ticket:`3242` :ticket:`1326`
+
 .. _feature_2963:
 
 .info dictionary improvements
index e553f399de6ee73561069fb1b114b774ebc624bc..c3b2d7bcbbb77e1ca3dd255aeb5fdd1a6586847f 100644 (file)
@@ -1030,6 +1030,7 @@ class BulkUD(object):
 
     def __init__(self, query):
         self.query = query.enable_eagerloads(False)
+        self.mapper = self.query._bind_mapper()
 
     @property
     def session(self):
@@ -1124,6 +1125,7 @@ class BulkFetch(BulkUD):
             self.primary_table.primary_key)
         self.matched_rows = session.execute(
             select_stmt,
+            mapper=self.mapper,
             params=query._params).fetchall()
 
 
@@ -1134,7 +1136,6 @@ class BulkUpdate(BulkUD):
         super(BulkUpdate, self).__init__(query)
         self.query._no_select_modifiers("update")
         self.values = values
-        self.mapper = self.query._mapper_zero_or_none()
 
     @classmethod
     def factory(cls, query, synchronize_session, values):
@@ -1180,7 +1181,8 @@ class BulkUpdate(BulkUD):
                                  self.context.whereclause, values)
 
         self.result = self.query.session.execute(
-            update_stmt, params=self.query._params)
+            update_stmt, params=self.query._params,
+            mapper=self.mapper)
         self.rowcount = self.result.rowcount
 
     def _do_post(self):
@@ -1207,8 +1209,10 @@ class BulkDelete(BulkUD):
         delete_stmt = sql.delete(self.primary_table,
                                  self.context.whereclause)
 
-        self.result = self.query.session.execute(delete_stmt,
-                                                 params=self.query._params)
+        self.result = self.query.session.execute(
+            delete_stmt,
+            params=self.query._params,
+            mapper=self.mapper)
         self.rowcount = self.result.rowcount
 
     def _do_post(self):
index 7302574e62a0beafc514d5cf828c144a790ee2a6..cd8b0efbef161417d2c6ea6219b6a24365638b43 100644 (file)
@@ -146,7 +146,7 @@ class Query(object):
                         ext_info,
                         aliased_adapter
                     )
-                ent.setup_entity(*d[entity])
+                ent.setup_entity(ent, *d[entity])
 
     def _mapper_loads_polymorphically_with(self, mapper, adapter):
         for m2 in mapper._with_polymorphic_mappers or [mapper]:
@@ -160,7 +160,6 @@ class Query(object):
 
         for from_obj in obj:
             info = inspect(from_obj)
-
             if hasattr(info, 'mapper') and \
                     (info.is_mapper or info.is_aliased_class):
                 self._select_from_entity = from_obj
@@ -286,8 +285,9 @@ class Query(object):
         return self._entities[0]
 
     def _mapper_zero(self):
-        return self._select_from_entity or \
-            self._entity_zero().entity_zero
+        return self._select_from_entity \
+            if self._select_from_entity is not None \
+            else self._entity_zero().entity_zero
 
     @property
     def _mapper_entities(self):
@@ -301,11 +301,14 @@ class Query(object):
             self._mapper_zero()
         )
 
-    def _mapper_zero_or_none(self):
-        if self._primary_entity:
-            return self._primary_entity.mapper
-        else:
-            return None
+    def _bind_mapper(self):
+        ezero = self._mapper_zero()
+        if ezero is not None:
+            insp = inspect(ezero)
+            if hasattr(insp, 'mapper'):
+                return insp.mapper
+
+        return None
 
     def _only_mapper_zero(self, rationale=None):
         if len(self._entities) > 1:
@@ -988,6 +991,7 @@ class Query(object):
             statement.correlate(None)
         q = self._from_selectable(fromclause)
         q._enable_single_crit = False
+        q._select_from_entity = self._mapper_zero()
         if entities:
             q._set_entities(entities)
         return q
@@ -2526,7 +2530,7 @@ class Query(object):
 
     def _execute_and_instances(self, querycontext):
         conn = self._connection_from_session(
-            mapper=self._mapper_zero_or_none(),
+            mapper=self._bind_mapper(),
             clause=querycontext.statement,
             close_with_result=True)
 
@@ -3160,7 +3164,7 @@ class _MapperEntity(_QueryEntity):
 
     supports_single_entity = True
 
-    def setup_entity(self, ext_info, aliased_adapter):
+    def setup_entity(self, original_entity, ext_info, aliased_adapter):
         self.mapper = ext_info.mapper
         self.aliased_adapter = aliased_adapter
         self.selectable = ext_info.selectable
@@ -3507,9 +3511,9 @@ class _BundleEntity(_QueryEntity):
         for ent in self._entities:
             ent.adapt_to_selectable(c, sel)
 
-    def setup_entity(self, ext_info, aliased_adapter):
+    def setup_entity(self, original_entity, ext_info, aliased_adapter):
         for ent in self._entities:
-            ent.setup_entity(ext_info, aliased_adapter)
+            ent.setup_entity(original_entity, ext_info, aliased_adapter)
 
     def setup_context(self, query, context):
         for ent in self._entities:
@@ -3592,15 +3596,23 @@ class _ColumnEntity(_QueryEntity):
         # leaking out their entities into the main select construct
         self.actual_froms = actual_froms = set(column._from_objects)
 
-        self.entities = util.OrderedSet(
-            elem._annotations['parententity']
-            for elem in visitors.iterate(column, {})
+        all_elements = [
+            elem for elem in visitors.iterate(column, {})
             if 'parententity' in elem._annotations
-            and actual_froms.intersection(elem._from_objects)
+        ]
+
+        self.entities = util.unique_list([
+            elem._annotations['parententity']
+            for elem in all_elements
+        ])
+        self._from_entities = set(
+            elem._annotations['parententity']
+            for elem in all_elements
+            if actual_froms.intersection(elem._from_objects)
         )
 
         if self.entities:
-            self.entity_zero = list(self.entities)[0]
+            self.entity_zero = self.entities[0]
         elif self.namespace is not None:
             self.entity_zero = self.namespace
         else:
@@ -3623,10 +3635,12 @@ class _ColumnEntity(_QueryEntity):
         c.entity_zero = self.entity_zero
         c.entities = self.entities
 
-    def setup_entity(self, ext_info, aliased_adapter):
+    def setup_entity(self, original_entity, ext_info, aliased_adapter):
         if 'selectable' not in self.__dict__:
             self.selectable = ext_info.selectable
-        self.froms.add(ext_info.selectable)
+
+        if original_entity in self._from_entities:
+            self.froms.add(ext_info.selectable)
 
     def corresponds_to(self, entity):
         # TODO: just returning False here,
index af6d960f59aef0b84b5380697222da2bd2f83734..8639dde740025cdbe9310d8889c418523babe5f8 100644 (file)
@@ -3224,8 +3224,9 @@ class SessionBindTest(QueryTest):
         get_bind = mock.Mock(side_effect=session.get_bind)
         with mock.patch.object(session, "get_bind", get_bind):
             yield
-        is_(get_bind.mock_calls[0][1][0], inspect(self.classes.User))
-        is_not_(get_bind.mock_calls[0][2]['clause'], None)
+        for call_ in get_bind.mock_calls:
+            is_(call_[1][0], inspect(self.classes.User))
+            is_not_(call_[2]['clause'], None)
 
     def test_single_entity_q(self):
         User = self.classes.User
@@ -3251,20 +3252,34 @@ class SessionBindTest(QueryTest):
         with self._assert_bind_args(session):
             session.query(func.max(User.name)).all()
 
-    def test_bulk_update(self):
+    def test_bulk_update_no_sync(self):
         User = self.classes.User
         session = Session()
         with self._assert_bind_args(session):
             session.query(User).filter(User.id == 15).update(
                 {"name": "foob"}, synchronize_session=False)
 
-    def test_bulk_delete(self):
+    def test_bulk_delete_no_sync(self):
         User = self.classes.User
         session = Session()
         with self._assert_bind_args(session):
             session.query(User).filter(User.id == 15).delete(
                 synchronize_session=False)
 
+    def test_bulk_update_fetch_sync(self):
+        User = self.classes.User
+        session = Session()
+        with self._assert_bind_args(session):
+            session.query(User).filter(User.id == 15).update(
+                {"name": "foob"}, synchronize_session='fetch')
+
+    def test_bulk_delete_fetch_sync(self):
+        User = self.classes.User
+        session = Session()
+        with self._assert_bind_args(session):
+            session.query(User).filter(User.id == 15).delete(
+                synchronize_session='fetch')
+
     def test_column_property(self):
         User = self.classes.User