]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure .mapper is set on _ColumnEntity
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Oct 2016 13:34:32 +0000 (09:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Oct 2016 14:42:14 +0000 (10:42 -0400)
_ColumnEntity didn't seem to have .mapper present, which
due to the way _mapper_zero() worked didn't tend to come
across it.   With :ticket:`3608` _mapper_zero() has
been simplified so make sure this is now present.
Also ensure that _select_from_entity is an entity and
not a mapped class, though this does not seem to matter
at the moment.

Fixes: #3836
Change-Id: Id6dae8e700269b97de3b01562edee95ac1e01f80

doc/build/changelog/changelog_11.rst
lib/sqlalchemy/orm/query.py
test/orm/test_query.py

index 4361ee63c8c97c2578006ab229833cf0511c7168..70db53d2f12c74fbb0ff52473a93ef2dd832a357 100644 (file)
         is changed.  The autoincrement flag can only be True if the datatype
         is of integer affinity in the 1.1 series.
 
+    .. change::
+        :tags: bug, orm
+        :tickets: 3836
+
+        Fixed regression where some :class:`.Query` methods like
+        :meth:`.Query.update` and others would fail if the :class:`.Query`
+        were against a series of mapped columns, rather than the mapped
+        entity as a whole.
+
     .. change::
         :tags: bug, sql
         :tickets: 3833
index d6a81ffd66608b4e5f86fb1d0341c84b7c5400c8..23d33b0d14f96fb565fc2f818fb4f96cdce7a05d 100644 (file)
@@ -165,7 +165,7 @@ class Query(object):
             info = inspect(from_obj)
             if hasattr(info, 'mapper') and \
                     (info.is_mapper or info.is_aliased_class):
-                self._select_from_entity = from_obj
+                self._select_from_entity = info
                 if set_base_alias:
                     raise sa_exc.ArgumentError(
                         "A selectable (FromClause) instance is "
@@ -3940,8 +3940,10 @@ class _ColumnEntity(_QueryEntity):
             self.entity_zero = _entity
             if _entity:
                 self.entities = [_entity]
+                self.mapper = _entity.mapper
             else:
                 self.entities = []
+                self.mapper = None
             self._from_entities = set(self.entities)
         else:
             all_elements = [
@@ -3963,10 +3965,13 @@ class _ColumnEntity(_QueryEntity):
             ])
             if self.entities:
                 self.entity_zero = self.entities[0]
+                self.mapper = self.entity_zero.mapper
             elif self.namespace is not None:
                 self.entity_zero = self.namespace
+                self.mapper = None
             else:
                 self.entity_zero = None
+                self.mapper = None
 
     supports_single_entity = False
 
index 493f6a7c81289915f5cba407a782dfdc9e834ed0..57408e10ef9fc11dbe64359e48ff20b3b41ec055 100644 (file)
@@ -813,6 +813,36 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL):
         q = s.query(User, Address)
         assert_raises(sa_exc.InvalidRequestError, q.get, 5)
 
+    def test_entity_or_mapper_zero(self):
+        User, Address = self.classes.User, self.classes.Address
+        s = create_session()
+
+        q = s.query(User, Address)
+        is_(q._mapper_zero(), inspect(User))
+        is_(q._entity_zero(), inspect(User))
+
+        u1 = aliased(User)
+        q = s.query(u1, Address)
+        is_(q._mapper_zero(), inspect(User))
+        is_(q._entity_zero(), inspect(u1))
+
+        q = s.query(User).select_from(Address)
+        is_(q._mapper_zero(), inspect(User))
+        is_(q._entity_zero(), inspect(Address))
+
+        q = s.query(User.name, Address)
+        is_(q._mapper_zero(), inspect(User))
+        is_(q._entity_zero(), inspect(User))
+
+        q = s.query(u1.name, Address)
+        is_(q._mapper_zero(), inspect(User))
+        is_(q._entity_zero(), inspect(u1))
+
+        q1 = s.query(User).exists()
+        q = s.query(q1)
+        is_(q._mapper_zero(), None)
+        is_(q._entity_zero(), None)
+
     def test_from_statement(self):
         User = self.classes.User