]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- further changes to attributes with regards to "trackparent". the "commit" operation
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Sep 2006 04:00:44 +0000 (04:00 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Sep 2006 04:00:44 +0000 (04:00 +0000)
now sets a "hasparent" flag for all attributes to all objects.  that way lazy loads
via callables get included in trackparent, and eager loads do as well because the mapper
calls commit() on all objects at load time.  this is a less shaky method than the "optimistic"
thing in the previous commit, but uses more memory and involves more overhead.
- some tweaks/cleanup to unit tests

CHANGES
lib/sqlalchemy/attributes.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
test/base/attributes.py
test/orm/mapper.py
test/orm/objectstore.py
test/orm/session.py

diff --git a/CHANGES b/CHANGES
index 79d8e3d6d5f29cea44a0db2d87a85a7c3fcd9fa6..f99cd35460c99f81fd836193e47ea38fd15af42a 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -29,10 +29,7 @@ parent isnt available to cascade from.
 - mappers can tell if one of their objects is an "orphan" based
 on interactions with the attribute package. this check is based
 on a status flag maintained for each relationship 
-when objects are attached and detached from each other.  if the
-status flag is not present, its assumed to be "False" for a 
-transient instance and assumed to be "True" for a persisted/detached
- instance.
+when objects are attached and detached from each other.
 - it is now invalid to declare a self-referential relationship with
 "delete-orphan" (as the abovementioned check would make them impossible
 to save)
index 8629d85a50c253908c5678507b96a4e8a9d38490..2a98a5f4daa70f09d0621a0de6cd328da7b5c642 100644 (file)
@@ -32,19 +32,16 @@ class InstrumentedAttribute(object):
         return self.get(obj)
 
     def hasparent(self, item, optimistic=False):
-        """return True if the given item is attached to a parent object 
-        via the attribute represented by this InstrumentedAttribute.
-        
-        optimistic indicates what we should return if the given item has no "hasparent"
-        record at all for the given attribute."""
-        return item._state.get(('hasparent', id(self)), optimistic)
+        """return the boolean value of a "hasparent" flag attached to the given item.
+        """
+        return item._state.get(('hasparent', id(self)), False)
         
     def sethasparent(self, item, value):
         """sets a boolean flag on the given item corresponding to whether or not it is
         attached to a parent object via the attribute represented by this InstrumentedAttribute."""
         if item is not None:
             item._state[('hasparent', id(self))] = value
-
+            
     def get_history(self, obj, passive=False):
         """return a new AttributeHistory object for the given object/this attribute's key.
         
@@ -140,16 +137,12 @@ class InstrumentedAttribute(object):
                     values = callable_()
                     l = InstrumentedList(self, obj, self._adapt_list(values), init=False)
                     
-                    # mark loaded instances with "hasparent" status.  commented out
-                    # because loaded objects use "optimistic" parent-checking
-                    #if self.trackparent and values is not None:
-                    #    [self.sethasparent(v, True) for v in values if v is not None]
-                    
                     # if a callable was executed, then its part of the "committed state"
                     # if any, so commit the newly loaded data
                     orig = state.get('original', None)
                     if orig is not None:
                         orig.commit_attribute(self, obj, l)
+                    
                 else:
                     # note that we arent raising AttributeErrors, just creating a new
                     # blank list and setting it.
@@ -165,11 +158,6 @@ class InstrumentedAttribute(object):
                     value = callable_()
                     obj.__dict__[self.key] = value
 
-                    # mark loaded instances with "hasparent" status.  commented out
-                    # because loaded objects use "optimistic" parent-checking
-                    #if self.trackparent and value is not None:
-                    #    self.sethasparent(value, True)
-                    
                     # if a callable was executed, then its part of the "committed state"
                     # if any, so commit the newly loaded data
                     orig = state.get('original', None)
@@ -478,14 +466,21 @@ class CommittedState(object):
         if attr.uselist:
             if value is not False:
                 self.data[attr.key] = [x for x in value]
+                if attr.trackparent:
+                    [attr.sethasparent(x, True) for x in self.data[attr.key]]
             elif obj.__dict__.has_key(attr.key):
                 self.data[attr.key] = [x for x in obj.__dict__[attr.key]]
+                if attr.trackparent:
+                    [attr.sethasparent(x, True) for x in self.data[attr.key]]
         else:
             if value is not False:
                 self.data[attr.key] = value
+                if attr.trackparent:
+                    attr.sethasparent(self.data[attr.key], True)
             elif obj.__dict__.has_key(attr.key):
                 self.data[attr.key] = obj.__dict__[attr.key]
-                        
+                if attr.trackparent:
+                    attr.sethasparent(self.data[attr.key], True)
     def rollback(self, manager, obj):
         for attr in manager.managed_attributes(obj.__class__):
             if self.data.has_key(attr.key):
index 21d4c6e36713997b8a7af721b1d165870dfc0e20..c7531eb8e7c22e3727066b773c213f0161ad54cb 100644 (file)
@@ -141,9 +141,10 @@ class Mapper(object):
         #self.compile()
     
     def _is_orphan(self, obj):
-        optimistic = hasattr(obj, '_instance_key')
         for (key,klass) in self.delete_orphans:
-            if not getattr(klass, key).hasparent(obj, optimistic=optimistic):
+            if not getattr(klass, key).hasparent(obj):
+                if not has_identity(obj):
+                    raise exceptions.FlushError("instance %s is an unsaved, pending instance and is an orphan" % obj)
                 return True
         else:
             return False
@@ -710,7 +711,7 @@ class Mapper(object):
 
         if not postupdate:
             for obj in objects:
-                if not hasattr(obj, "_instance_key"):
+                if not has_identity(obj):
                     self.extension.before_insert(self, connection, obj)
                 else:
                     self.extension.before_update(self, connection, obj)
@@ -747,7 +748,7 @@ class Mapper(object):
                 # 'postupdate' means a PropertyLoader is telling us, "yes I know you 
                 # already inserted/updated this row but I need you to UPDATE one more 
                 # time"
-                isinsert = not postupdate and not hasattr(obj, "_instance_key")
+                isinsert = not postupdate and not has_identity(obj)
                 hasdata = False
                 for col in table.columns:
                     if col is self.version_id_col:
@@ -1392,6 +1393,9 @@ def hash_key(obj):
     else:
         return repr(obj)
 
+def has_identity(object):
+    return hasattr(object, '_instance_key')
+    
 def has_mapper(object):
     """returns True if the given object has a mapper association"""
     return hasattr(object, '_entity_name')
index 16760dc060516e046f4694d9d0c8897777fc8fc5..e8f3ad7dde837d1503d87e39d5a7a479679c87d2 100644 (file)
@@ -218,7 +218,7 @@ class PropertyLoader(mapper.MapperProperty):
 
         if self.cascade.delete_orphan:
             if self.parent.class_ is self.mapper.class_:
-                raise exceptions.ArgumentError("Cant establish 'delete-orphan' cascade rule on a self-referential relationship.  You probably want cascade='all', which includes delete cascading but not orphan detection.")
+                raise exceptions.ArgumentError("Cant establish 'delete-orphan' cascade rule on a self-referential relationship (attribute '%s' on class '%s').  You probably want cascade='all', which includes delete cascading but not orphan detection." %(self.key, self.parent.class_.__name__))
             self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_))
             
         if self.secondaryjoin is not None and self.secondary is None:
@@ -379,9 +379,17 @@ class LazyLoader(PropertyLoader):
                 return None
             else:
                 return mapper.object_mapper(instance).props[self.key].setup_loader(instance)
+        
         def lazyload():
             params = {}
             allparams = True
+            # if the instance wasnt loaded from the database, then it cannot lazy load
+            # child items.  one reason for this is that a bi-directional relationship
+            # will not update properly, since bi-directional uses lazy loading functions
+            # in both directions, and this instance will not be present in the lazily-loaded
+            # results of the other objects since its not in the database
+            if not mapper.has_identity(instance):
+                return None
             #print "setting up loader, lazywhere", str(self.lazywhere), "binds", self.lazybinds
             for col, bind in self.lazybinds.iteritems():
                 params[bind.key] = self.parent._getattrbycolumn(instance, col)
index 831848cd994596fc3c3ae5fba3d8778070f8cd37..abacd0ac960abd569edb527ca9dc383238953b22 100644 (file)
@@ -181,27 +181,21 @@ class AttributesTest(PersistTest):
         p1 = Post()
         Blog.posts.set_callable(b, lambda:[p1])
         Post.blog.set_callable(p1, lambda:b)
-
+        manager.commit(p1, b)
         # assert connections
         assert p1.blog is b
         assert p1 in b.posts
 
-        # no orphans (but we are using optimistic checks)
-        assert getattr(Blog, 'posts').hasparent(p1, optimistic=True)
-        assert getattr(Post, 'blog').hasparent(b, optimistic=True)
-        
-        # lazy loads currently not processed for "hasparent" status, so without
-        # optimistic, it returns false
-        assert not getattr(Blog, 'posts').hasparent(p1, optimistic=False)
-        assert not getattr(Post, 'blog').hasparent(b, optimistic=False)
+        # no orphans
+        assert getattr(Blog, 'posts').hasparent(p1)
+        assert getattr(Post, 'blog').hasparent(b)
         
-        # ok what about non-optimistic.  well, dont use lazy loaders,
-        # assign things manually, so the "hasparent" flags get set
+        # manual connections
         b2 = Blog()
         p2 = Post()
         b2.posts.append(p2)
-        assert getattr(Blog, 'posts').hasparent(p2, optimistic=False)
-        assert getattr(Post, 'blog').hasparent(b2, optimistic=False)
+        assert getattr(Blog, 'posts').hasparent(p2)
+        assert getattr(Post, 'blog').hasparent(b2)
         
     def testinheritance(self):
         """tests that attributes are polymorphic"""
index e12e15139ec44c35468c05e90a41c03472b0ac59..bf16cdfc672d5d576f133541fc0c18a09213c868 100644 (file)
@@ -59,14 +59,10 @@ item_keyword_result = [
 
 class MapperSuperTest(AssertMixin):
     def setUpAll(self):
-        db.echo = False
         tables.create()
         tables.data()
-        db.echo = testbase.echo
     def tearDownAll(self):
-        db.echo = False
         tables.drop()
-        db.echo = testbase.echo
     def tearDown(self):
         clear_mappers()
     def setUp(self):
@@ -85,7 +81,7 @@ class MapperTest(MapperSuperTest):
         self.assert_(u is not u2)
 
     def testunicodeget(self):
-        """tests that Query.get properly sets up the type for the bind parameter.  using unicode would normally fail 
+        """test that Query.get properly sets up the type for the bind parameter.  using unicode would normally fail 
         on postgres, mysql and oracle unless it is converted to an encoded string"""
         metadata = BoundMetaData(db)
         table = Table('foo', metadata, 
@@ -151,7 +147,7 @@ class MapperTest(MapperSuperTest):
         self.assert_(a not in u.addresses)
 
     def testbadconstructor(self):
-        """tests that if the construction of a mapped class fails, the instnace does not get placed in the session"""
+        """test that if the construction of a mapped class fails, the instnace does not get placed in the session"""
         class Foo(object):
             def __init__(self, one, two):
                 pass
@@ -169,7 +165,7 @@ class MapperTest(MapperSuperTest):
             pass
             
     def testrefresh_lazy(self):
-        """tests that when a lazy loader is set as a trigger on an object's attribute (at the attribute level, not the class level), a refresh() operation doesnt fire the lazy loader or create any problems"""
+        """test that when a lazy loader is set as a trigger on an object's attribute (at the attribute level, not the class level), a refresh() operation doesnt fire the lazy loader or create any problems"""
         s = create_session()
         mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
         q2 = s.query(User).options(lazyload('addresses'))
@@ -179,6 +175,7 @@ class MapperTest(MapperSuperTest):
         self.assert_sql_count(db, go, 1)
 
     def testexpire(self):
+        """test the expire function"""
         s = create_session()
         mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)})
         u = s.get(User, 7)
@@ -209,6 +206,7 @@ class MapperTest(MapperSuperTest):
         self.assert_(u.user_name =='jack')
         
     def testrefresh2(self):
+        """test a hang condition that was occuring on expire/refresh"""
         s = create_session()
         mapper(Address, addresses)
 
@@ -231,6 +229,7 @@ class MapperTest(MapperSuperTest):
         s.refresh(u) #hangs
         
     def testmagic(self):
+        """not sure what this is really testing."""
         mapper(User, users, properties = {
             'addresses' : relation(mapper(Address, addresses))
         })
@@ -270,15 +269,33 @@ class MapperTest(MapperSuperTest):
         })
         q = create_session().query(m)
         q.select_by(email_address='foo')
+
+    def testmappingtojoin(self):
+        """test mapping to a join"""
+        usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
+        m = mapper(User, usersaddresses, primary_key=[users.c.user_id])
+        q = create_session().query(m)
+        l = q.select()
+        self.assert_result(l, User, *user_result[0:2])
+        
+    def testmappingtoouterjoin(self):
+        """test mapping to an outer join, with a composite primary key that allows nulls"""
+        result = [
+        {'user_id' : 7, 'address_id' : 1},
+        {'user_id' : 8, 'address_id' : 2},
+        {'user_id' : 8, 'address_id' : 3},
+        {'user_id' : 8, 'address_id' : 4},
+        {'user_id' : 9, 'address_id':None}
+        ]
         
-    def testjoinbyfk(self):
-        class UserWithAddress(object):
-                       pass
         j = join(users, addresses, isouter=True)
-        m = mapper(UserWithAddress, j, allow_null_pks=True)
+        m = mapper(User, j, allow_null_pks=True, primary_key=[users.c.user_id, addresses.c.address_id])
         q = create_session().query(m)
+        l = q.select()
+        self.assert_result(l, User, *result)
         
     def testjoinvia(self):
+        """test the join_via and join_to functions"""
         m = mapper(User, users, properties={
             'orders':relation(mapper(Order, orders, properties={
                 'items':relation(mapper(Item, orderitems))
@@ -300,6 +317,7 @@ class MapperTest(MapperSuperTest):
         self.assert_result(l, User, user_result[0])
         
     def testorderby(self):
+        """test ordering at the mapper and query level"""
         # TODO: make a unit test out of these various combinations
 #        m = mapper(User, users, order_by=desc(users.c.user_name))
         mapper(User, users, order_by=None)
@@ -313,7 +331,7 @@ class MapperTest(MapperSuperTest):
         
     @testbase.unsupported('firebird') 
     def testfunction(self):
-        """tests mapping to a SELECT statement that has functions in it."""
+        """test mapping to a SELECT statement that has functions in it."""
         s = select([users, (users.c.user_id * 2).label('concat'), func.count(addresses.c.address_id).label('count')],
         users.c.user_id==addresses.c.user_id, group_by=[c for c in users.c]).alias('myselect')
         mapper(User, s)
@@ -326,18 +344,15 @@ class MapperTest(MapperSuperTest):
         
     @testbase.unsupported('firebird') 
     def testcount(self):
+        """test the count function on Query
+        
+        (why doesnt this work on firebird?)"""
         mapper(User, users)
         q = create_session().query(User)
         self.assert_(q.count()==3)
         self.assert_(q.count(users.c.user_id.in_(8,9))==2)
         self.assert_(q.count_by(user_name='fred')==1)
             
-    def testmultitable(self):
-        usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
-        m = mapper(User, usersaddresses, primary_key=[users.c.user_id])
-        q = create_session().query(m)
-        l = q.select()
-        self.assert_result(l, User, *user_result[0:2])
 
     def testoverride(self):
         # assert that overriding a column raises an error
index 191b6ff4a843bdb572ab176339e9af6aa4275ac9..4c35fee6505287fa52dc7b19a2b63281244fb119 100644 (file)
@@ -25,15 +25,11 @@ class SessionTest(AssertMixin):
 class HistoryTest(SessionTest):
     def setUpAll(self):
         SessionTest.setUpAll(self)
-        db.echo = False
         users.create()
         addresses.create()
-        db.echo = testbase.echo
     def tearDownAll(self):
-        db.echo = False
         addresses.drop()
         users.drop()
-        db.echo = testbase.echo
         SessionTest.tearDownAll(self)
         
     def testattr(self):
@@ -230,7 +226,6 @@ class PKTest(SessionTest):
     @testbase.unsupported('mssql')
     def setUpAll(self):
         SessionTest.setUpAll(self)
-        #db.echo = False
         global table
         global table2
         global table3
@@ -256,14 +251,11 @@ class PKTest(SessionTest):
         table.create()
         table2.create()
         table3.create()
-        db.echo = testbase.echo
     @testbase.unsupported('mssql')
     def tearDownAll(self):
-        db.echo = False
         table.drop()
         table2.drop()
         table3.drop()
-        db.echo = testbase.echo
         SessionTest.tearDownAll(self)
         
     # not support on sqlite since sqlite's auto-pk generation only works with
@@ -430,7 +422,6 @@ class DefaultTest(SessionTest):
     defaults back from the engine."""
     def setUpAll(self):
         SessionTest.setUpAll(self)
-        #db.echo = 'debug'
         use_string_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') or db.engine.__module__.endswith('sqlite')
 
         if use_string_defaults:
@@ -504,17 +495,12 @@ class SaveTest(SessionTest):
 
     def setUpAll(self):
         SessionTest.setUpAll(self)
-        db.echo = False
         tables.create()
-        db.echo = testbase.echo
     def tearDownAll(self):
-        db.echo = False
         tables.drop()
-        db.echo = testbase.echo
         SessionTest.tearDownAll(self)
         
     def setUp(self):
-        db.echo = False
         keywords.insert().execute(
             dict(name='blue'),
             dict(name='red'),
@@ -524,12 +510,9 @@ class SaveTest(SessionTest):
             dict(name='round'),
             dict(name='square')
         )
-        db.echo = testbase.echo
 
     def tearDown(self):
-        db.echo = False
         tables.delete()
-        db.echo = testbase.echo
 
         #self.assert_(len(ctx.current.new) == 0)
         #self.assert_(len(ctx.current.dirty) == 0)
@@ -1224,7 +1207,6 @@ class SaveTest(SessionTest):
 class SaveTest2(SessionTest):
 
     def setUp(self):
-        db.echo = False
         ctx.current.clear()
         clear_mappers()
         self.users = Table('users', db,
@@ -1244,13 +1226,10 @@ class SaveTest2(SessionTest):
 #        raise repr(self.addresses) + repr(self.addresses.foreign_keys)
         self.users.create()
         self.addresses.create()
-        db.echo = testbase.echo
 
     def tearDown(self):
-        db.echo = False
         self.addresses.drop()
         self.users.drop()
-        db.echo = testbase.echo
         SessionTest.tearDown(self)
     
     def testbackwardsnonmatch(self):
@@ -1339,7 +1318,7 @@ class SaveTest3(SessionTest):
         pass
 
     def testmanytomanyxtracolremove(self):
-        """tests that a many-to-many on a table that has an extra column can properly delete rows from the table
+        """test that a many-to-many on a table that has an extra column can properly delete rows from the table
         without referencing the extra column"""
         mapper(Keyword, t3)
 
index 2be9ef67f904e663495b6bf7f0a4cacac1bb2f8a..085557021c6cbcd5f3f76ab276bff4f77902b238 100644 (file)
@@ -32,11 +32,11 @@ class OrphanDeletionTest(AssertMixin):
         ))
         s = create_session()
         a = Address()
+        s.save(a)
         try:
-            s.save(a)
-        except exceptions.InvalidRequestError, e:
+            s.flush()
+        except exceptions.FlushError, e:
             pass
-        s.flush()
         assert a.address_id is None, "Error: address should not be persistent"
         
     def test_delete_new_object(self):