]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- query.get() now returns None if queried for an identifier
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 13 Mar 2010 17:28:50 +0000 (12:28 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 13 Mar 2010 17:28:50 +0000 (12:28 -0500)
that is present in the identity map with a different class
than the one requested, i.e. when using polymorphic loading.
[ticket:1727]

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/inheritance/test_basic.py

diff --git a/CHANGES b/CHANGES
index b8ca5a0a7edaff3a1c55308a6962203138383a4f..61aa45b1f96cbffa0c493aa026b8c756c128f4e4 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -47,6 +47,11 @@ CHANGES
     from_statement() to start with since it no longer modifies
     the query.  [ticket:1688]
 
+  - query.get() now returns None if queried for an identifier
+    that is present in the identity map with a different class 
+    than the one requested, i.e. when using polymorphic loading.  
+    [ticket:1727]
+    
   - A major fix in query.join(), when the "on" clause is an
     attribute of an aliased() construct, but there is already
     an existing join made out to a compatible target, query properly
index 1b5bcb7fe42928dc997f2fda56af45e91db9f296..fde93ff1b28a0c6a2efc92aa9f2c1a5da40295a4 100644 (file)
@@ -1537,17 +1537,23 @@ class Query(object):
                                         only_load_props=None, passive=None):
         lockmode = lockmode or self._lockmode
         
+        mapper = self._mapper_zero()
         if not self._populate_existing and \
                 not refresh_state and \
-                not self._mapper_zero().always_refresh and \
+                not mapper.always_refresh and \
                 lockmode is None:
             instance = self.session.identity_map.get(key)
             if instance:
+                # item present in identity map with a different class
+                if not issubclass(instance.__class__, mapper.class_):
+                    return None
+                    
                 state = attributes.instance_state(instance)
+                
+                # expired - ensure it still exists
                 if state.expired:
                     if passive is attributes.PASSIVE_NO_FETCH:
                         return attributes.PASSIVE_NO_RESULT
-                    
                     try:
                         state()
                     except orm_exc.ObjectDeletedError:
@@ -1570,8 +1576,6 @@ class Query(object):
             q = self._clone()
 
         if ident is not None:
-            mapper = q._mapper_zero()
-            params = {}
             (_get_clause, _get_params) = mapper._get_clause
             
             # None present in ident - turn those comparisons
@@ -1587,14 +1591,16 @@ class Query(object):
             _get_clause = q._adapt_clause(_get_clause, True, False)
             q._criterion = _get_clause
 
-            for i, primary_key in enumerate(mapper.primary_key):
-                try:
-                    params[_get_params[primary_key].key] = ident[i]
-                except IndexError:
-                    raise sa_exc.InvalidRequestError(
-                        "Could not find enough values to formulate primary "
-                        "key for query.get(); primary key columns are %s" %
-                        ','.join("'%s'" % c for c in mapper.primary_key))
+            params = dict([
+                (_get_params[primary_key].key, id_val)
+                for id_val, primary_key in zip(ident, mapper.primary_key)
+            ])
+
+            if len(params) != len(mapper.primary_key):
+                raise sa_exc.InvalidRequestError(
+                    "Incorrect number of values in identifier to formulate primary "
+                    "key for query.get(); primary key columns are %s" %
+                    ','.join("'%s'" % c for c in mapper.primary_key))
                         
             q._params = params
 
index aed7cf5efaac9aebe3df6cd693b74e2dec6f7431..ce773a7bc8244ba34f97144d3d17df5e50442843 100644 (file)
@@ -28,7 +28,7 @@ class O2MTest(_base.MappedTest):
             Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False),
             Column('data', String(20)))
 
-    def testbasic(self):
+    def test_basic(self):
         class Foo(object):
             def __init__(self, data=None):
                 self.data = data
@@ -279,78 +279,88 @@ class GetTest(_base.MappedTest):
             Column('foo_id', Integer, ForeignKey('foo.id')),
             Column('bar_id', Integer, ForeignKey('bar.id')),
             Column('data', String(20)))
+    
+    @classmethod
+    def setup_classes(cls):
+        class Foo(_base.BasicEntity):
+            pass
 
-    def _create_test(polymorphic, name):
-        def test_get(self):
-            class Foo(object):
-                pass
-
-            class Bar(Foo):
-                pass
-
-            class Blub(Bar):
-                pass
-
-            if polymorphic:
-                mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo')
-                mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar')
-                mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub')
-            else:
-                mapper(Foo, foo)
-                mapper(Bar, bar, inherits=Foo)
-                mapper(Blub, blub, inherits=Bar)
-
-            sess = create_session()
-            f = Foo()
-            b = Bar()
-            bl = Blub()
-            sess.add(f)
-            sess.add(b)
-            sess.add(bl)
-            sess.flush()
+        class Bar(Foo):
+            pass
 
-            if polymorphic:
-                def go():
-                    assert sess.query(Foo).get(f.id) == f
-                    assert sess.query(Foo).get(b.id) == b
-                    assert sess.query(Foo).get(bl.id) == bl
-                    assert sess.query(Bar).get(b.id) == b
-                    assert sess.query(Bar).get(bl.id) == bl
-                    assert sess.query(Blub).get(bl.id) == bl
+        class Blub(Bar):
+            pass
 
-                self.assert_sql_count(testing.db, go, 0)
-            else:
-                # this is testing the 'wrong' behavior of using get()
-                # polymorphically with mappers that are not configured to be
-                # polymorphic.  the important part being that get() always
-                # returns an instance of the query's type.
-                def go():
-                    assert sess.query(Foo).get(f.id) == f
+    def test_get_polymorphic(self):
+        self._do_get_test(True)
+    
+    def test_get_nonpolymorphic(self):
+        self._do_get_test(False)
 
-                    bb = sess.query(Foo).get(b.id)
-                    assert isinstance(b, Foo) and bb.id==b.id
+    @testing.resolve_artifact_names
+    def _do_get_test(self, polymorphic):
+        if polymorphic:
+            mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo')
+            mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar')
+            mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub')
+        else:
+            mapper(Foo, foo)
+            mapper(Bar, bar, inherits=Foo)
+            mapper(Blub, blub, inherits=Bar)
 
-                    bll = sess.query(Foo).get(bl.id)
-                    assert isinstance(bll, Foo) and bll.id==bl.id
+        sess = create_session()
+        f = Foo()
+        b = Bar()
+        bl = Blub()
+        sess.add(f)
+        sess.add(b)
+        sess.add(bl)
+        sess.flush()
+
+        if polymorphic:
+            def go():
+                assert sess.query(Foo).get(f.id) is f
+                assert sess.query(Foo).get(b.id) is b
+                assert sess.query(Foo).get(bl.id) is bl
+                assert sess.query(Bar).get(b.id) is b
+                assert sess.query(Bar).get(bl.id) is bl
+                assert sess.query(Blub).get(bl.id) is bl
+
+                # test class mismatches - item is present
+                # in the identity map but we requested a subclass
+                assert sess.query(Blub).get(f.id) is None
+                assert sess.query(Blub).get(b.id) is None
+                assert sess.query(Bar).get(f.id) is None
+                
+            self.assert_sql_count(testing.db, go, 0)
+        else:
+            # this is testing the 'wrong' behavior of using get()
+            # polymorphically with mappers that are not configured to be
+            # polymorphic.  the important part being that get() always
+            # returns an instance of the query's type.
+            def go():
+                assert sess.query(Foo).get(f.id) is f
 
-                    assert sess.query(Bar).get(b.id) == b
+                bb = sess.query(Foo).get(b.id)
+                assert isinstance(b, Foo) and bb.id==b.id
 
-                    bll = sess.query(Bar).get(bl.id)
-                    assert isinstance(bll, Bar) and bll.id == bl.id
+                bll = sess.query(Foo).get(bl.id)
+                assert isinstance(bll, Foo) and bll.id==bl.id
 
-                    assert sess.query(Blub).get(bl.id) == bl
+                assert sess.query(Bar).get(b.id) is b
 
-                self.assert_sql_count(testing.db, go, 3)
+                bll = sess.query(Bar).get(bl.id)
+                assert isinstance(bll, Bar) and bll.id == bl.id
 
-        test_get = function_named(test_get, name)
-        return test_get
+                assert sess.query(Blub).get(bl.id) is bl
+
+            self.assert_sql_count(testing.db, go, 3)
 
-    test_get_polymorphic = _create_test(True, 'test_get_polymorphic')
-    test_get_nonpolymorphic = _create_test(False, 'test_get_nonpolymorphic')
 
 class EagerLazyTest(_base.MappedTest):
     """tests eager load/lazy load of child items off inheritance mappers, tests that
     LazyLoader constructs the right query condition."""
+    
     @classmethod
     def define_tables(cls, metadata):
         global foo, bar, bar_foo
@@ -367,7 +377,7 @@ class EagerLazyTest(_base.MappedTest):
         )
 
     @testing.fails_on('maxdb', 'FIXME: unknown')
-    def testbasic(self):
+    def test_basic(self):
         class Foo(object): pass
         class Bar(Foo): pass
 
@@ -394,7 +404,8 @@ class EagerLazyTest(_base.MappedTest):
         self.assert_(len(q.first().eager) == 1)
 
 class EagerTargetingTest(_base.MappedTest):
-    """test a scenario where joined table inheritance might be confused as an eagerly loaded joined table."""
+    """test a scenario where joined table inheritance might be 
+    confused as an eagerly loaded joined table."""
     
     @classmethod
     def define_tables(cls, metadata):
@@ -450,31 +461,32 @@ class EagerTargetingTest(_base.MappedTest):
         
 class FlushTest(_base.MappedTest):
     """test dependency sorting among inheriting mappers"""
+    
     @classmethod
     def define_tables(cls, metadata):
-        global users, roles, user_roles, admins
-        users = Table('users', metadata,
+        Table('users', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('email', String(128)),
             Column('password', String(16)),
         )
 
-        roles = Table('role', metadata,
+        Table('roles', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('description', String(32))
         )
 
-        user_roles = Table('user_role', metadata,
+        Table('user_roles', metadata,
             Column('user_id', Integer, ForeignKey('users.id'), primary_key=True),
-            Column('role_id', Integer, ForeignKey('role.id'), primary_key=True)
+            Column('role_id', Integer, ForeignKey('roles.id'), primary_key=True)
         )
 
-        admins = Table('admin', metadata,
+        Table('admins', metadata,
             Column('admin_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('user_id', Integer, ForeignKey('users.id'))
         )
 
-    def testone(self):
+    @testing.resolve_artifact_names
+    def test_one(self):
         class User(object):pass
         class Role(object):pass
         class Admin(User):pass
@@ -501,7 +513,8 @@ class FlushTest(_base.MappedTest):
 
         assert user_roles.count().scalar() == 1
 
-    def testtwo(self):
+    @testing.resolve_artifact_names
+    def test_two(self):
         class User(object):
             def __init__(self, email=None, password=None):
                 self.email = email
@@ -541,34 +554,24 @@ class FlushTest(_base.MappedTest):
 class VersioningTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
-        global base, subtable, stuff
-        base = Table('base', metadata,
+        Table('base', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('version_id', Integer, nullable=False),
             Column('value', String(40)),
             Column('discriminator', Integer, nullable=False)
         )
-        subtable = Table('subtable', metadata,
+        Table('subtable', metadata,
             Column('id', None, ForeignKey('base.id'), primary_key=True),
             Column('subdata', String(50))
             )
-        stuff = Table('stuff', metadata,
+        Table('stuff', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('parent', Integer, ForeignKey('base.id'))
             )
 
-    def setup(self):
-        super(VersioningTest, self).setup()
-        if not testing.db.dialect.supports_sane_rowcount:
-            self._warnings_filters = warnings.filters[:]
-            warnings.filterwarnings('ignore', category=sa_exc.SAWarning)
-
-    def teardown(self):
-        super(VersioningTest, self).teardown()
-        if not testing.db.dialect.supports_sane_rowcount:
-            warnings.filters[:] = self._warnings_filters
-
+    @testing.emits_warning(r".*updated rowcount")
     @engines.close_open_connections
+    @testing.resolve_artifact_names
     def test_save_update(self):
         class Base(_fixtures.Base):
             pass
@@ -577,7 +580,10 @@ class VersioningTest(_base.MappedTest):
         class Stuff(Base):
             pass
         mapper(Stuff, stuff)
-        mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1, properties={
+        mapper(Base, base, 
+                    polymorphic_on=base.c.discriminator, 
+                    version_id_col=base.c.version_id, 
+                    polymorphic_identity=1, properties={
             'stuff':relation(Stuff)
         })
         mapper(Sub, subtable, inherits=Base, polymorphic_identity=2)
@@ -599,17 +605,14 @@ class VersioningTest(_base.MappedTest):
 
         sess.flush()
 
-        try:
-            sess2.query(Base).with_lockmode('read').get(s1.id)
-            assert False
-        except orm_exc.ConcurrentModificationError, e:
-            assert True
+        assert_raises(orm_exc.ConcurrentModificationError,
+                        sess2.query(Base).with_lockmode('read').get, 
+                        s1.id)
 
-        try:
+        if not testing.db.dialect.supports_sane_rowcount:
             sess2.flush()
-            assert not testing.db.dialect.supports_sane_rowcount
-        except orm_exc.ConcurrentModificationError, e:
-            assert True
+        else:
+            assert_raises(orm_exc.ConcurrentModificationError, sess2.flush)
 
         sess2.refresh(s2)
         if testing.db.dialect.supports_sane_rowcount:
@@ -617,13 +620,17 @@ class VersioningTest(_base.MappedTest):
         s2.subdata = 'sess2 subdata'
         sess2.flush()
 
+    @testing.emits_warning(r".*updated rowcount")
+    @testing.resolve_artifact_names
     def test_delete(self):
         class Base(_fixtures.Base):
             pass
         class Sub(Base):
             pass
 
-        mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1)
+        mapper(Base, base, 
+                    polymorphic_on=base.c.discriminator, 
+                    version_id_col=base.c.version_id, polymorphic_identity=1)
         mapper(Sub, subtable, inherits=Base, polymorphic_identity=2)
 
         sess = create_session()
@@ -697,17 +704,24 @@ class DistinctPKTest(_base.MappedTest):
 
     def test_explicit_props(self):
         person_mapper = mapper(Person, person_table)
-        mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id})
+        mapper(Employee, employee_table, inherits=person_mapper,
+                        properties={'pid':person_table.c.id, 
+                                    'eid':employee_table.c.id})
         self._do_test(True)
 
     def test_explicit_composite_pk(self):
         person_mapper = mapper(Person, person_table)
-        try:
-            mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id])
-            self._do_test(True)
-            assert False
-        except sa_exc.SAWarning, e:
-            assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'.  Use explicit properties to give each column its own mapped attribute name.", str(e)
+        mapper(Employee, employee_table, 
+                    inherits=person_mapper, 
+                    primary_key=[person_table.c.id, employee_table.c.id])
+        assert_raises_message(sa_exc.SAWarning, 
+                                    r"On mapper Mapper\|Employee\|employees, "
+                                    "primary key column 'employees.id' is being "
+                                    "combined with distinct primary key column 'persons.id' "
+                                    "in attribute 'id'.  Use explicit properties to give "
+                                    "each column its own mapped attribute name.",
+            self._do_test, True
+        )
 
     def test_explicit_pk(self):
         person_mapper = mapper(Person, person_table)
@@ -1242,6 +1256,7 @@ class DeleteOrphanTest(_base.MappedTest):
         s1 = SubClass(data='s1')
         sess.add(s1)
         assert_raises_message(orm_exc.FlushError, 
-            "is not attached to any parent 'Parent' instance via that classes' 'related' attribute", sess.flush)
+            r"is not attached to any parent 'Parent' instance via "
+            "that classes' 'related' attribute", sess.flush)