]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Sep 2005 18:48:26 +0000 (18:48 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Sep 2005 18:48:26 +0000 (18:48 +0000)
lib/sqlalchemy/mapper.py
test/mapper.py
test/tables.py

index 46dfde3e84085b313c00a1e3d5484e19e669ead8..2b981e647e1bf063319ff663829aeda0f49efb0d 100644 (file)
@@ -216,7 +216,7 @@ class Mapper(object):
     def _setattrbycolumn(self, obj, column, value):
         self.columntoproperty[column][0].setattr(obj, value)
         
-    def save(self, obj, traverse = True, refetch = False):
+    def save(self, obj, traverse = True):
         """saves the object across all its primary tables.  
         based on the existence of the primary key for each table, either inserts or updates.
         primary key is determined by the underlying database engine's sequence methodology.
@@ -229,6 +229,8 @@ class Mapper(object):
 
         if getattr(obj, 'dirty', True):
             def foo():
+                insert_statement = None
+                update_statement = None
                 for table in self.tables:
                     params = {}
                     # TODO: prepare the insert() and update() - (1) within the code or
@@ -264,11 +266,11 @@ class Mapper(object):
                 obj.dirty = False
                 for prop in self.props.values():
                     if not isinstance(prop, ColumnProperty):
-                        prop.save(obj, traverse, refetch)
+                        prop.save(obj, traverse)
             self.transaction(foo)
         else:
             for prop in self.props.values():
-                prop.save(obj, traverse, refetch)
+                prop.save(obj, traverse)
 
     def transaction(self, f):
         return self.table.engine.multi_transaction(self.tables, f)
@@ -277,15 +279,6 @@ class Mapper(object):
         """removes the object.  traverse indicates attached objects should be removed as well."""
         pass
 
-    def delete(self, obj):
-        """deletes the object's row from its table unconditionally. this is a lower-level
-        operation than remove."""
-        # delete dependencies ?
-        # delete row
-        # remove primary keys
-        # unset dirty flag
-        pass
-
     def _compile(self, whereclause = None, **options):
         statement = sql.select([self.selectable], whereclause)
         for key, value in self.props.iteritems():
@@ -365,7 +358,7 @@ class MapperProperty:
         """called when the MapperProperty is first attached to a new parent Mapper."""
         pass
 
-    def save(self, object, traverse, refetch):
+    def save(self, object, traverse):
         """called when the instance is being saved"""
         pass
 
@@ -432,48 +425,56 @@ class PropertyLoader(MapperProperty):
             if self.primaryjoin is None:
                 self.primaryjoin = match_primaries(parent.selectable, self.target)
 
-    def save(self, obj, traverse, refetch):
+    def save(self, obj, traverse):
         # saves child objects
         
-        # TODO: put association table inserts/deletes into one batch
-        #if self.secondary is not None:
-         #   secondary_delete = self.secondary.delete(sql.and_([c == bindparam(c.key) for c in setter.secondary.c]))
-            
+        if self.secondary is not None:
+            secondary_delete = []
+            secondary_insert = []
+             
         setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, self.secondary, obj)
         childlist = getattr(obj, self.key)
         if not isinstance(childlist, util.HistoryArraySet):
             childlist = util.HistoryArraySet(childlist)
             clean_setattr(obj, self.key, childlist)
+            
         for child in childlist.deleted_items():
             setter.child = child
+            setter.associationrow = {}
             setter.clearkeys = True
             self.primaryjoin.accept_visitor(setter)
             child.dirty = True
-            self.mapper.save(child)
+            self.mapper.save(child, traverse)
             if self.secondary is not None:
                 self.secondaryjoin.accept_visitor(setter)
-                # TODO: prepare this above
-                statement = self.secondary.delete(sql.and_(*[c == setter.associationrow[c.key] for c in self.secondary.c]))
-                statement.echo = self.mapper.echo
-                statement.execute()
+                secondary_delete.append(setter.associationrow)
+                
         for child in childlist.added_items():
             setter.child = child
+            setter.associationrow = {}
             self.primaryjoin.accept_visitor(setter)
             child.dirty = True
-            self.mapper.save(child)
+            self.mapper.save(child, traverse)
             if self.secondary is not None:
                 self.secondaryjoin.accept_visitor(setter)
-                # TODO: prepare this above
+                secondary_insert.append(setter.associationrow)
+
+        if self.secondary is not None:
+            if len(secondary_delete):
+                statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key) for c in self.secondary.c]))
+                statement.echo = self.mapper.echo
+                statement.execute(*secondary_delete)
+            if len(secondary_insert):
                 statement = self.secondary.insert()
                 statement.echo = self.mapper.echo
-                statement.execute(**setter.associationrow)
+                statement.execute(*secondary_insert)
+
         for child in childlist.unchanged_items():
-            self.mapper.save(child)
+            self.mapper.save(child, traverse)
         # TODO: if transaction fails state is invalid
         # use unit of work ?
         childlist.clear_history()
             
-            
     def delete(self):
         self.mapper.delete()
 
index 7f129cf945fde189913d37461b02249fbadbb3b5..ba37335ace91e93fb9c10887d396c9c3e77e208f 100644 (file)
@@ -2,6 +2,7 @@ from testbase import PersistTest
 import unittest, sys, os
 from sqlalchemy.mapper import *
 
+ECHO = False
 execfile("test/tables.py")
 
 class User(object):
@@ -58,7 +59,7 @@ class MapperTest(AssertMixin):
         #globalidentity().clear()
 
     def testget(self):
-        m = mapper(User, users, scope = "thread", echo = True)
+        m = mapper(User, users, scope = "thread")
         self.assert_(m.get(19) is None)
         u = m.get(7)
         u2 = m.get(7)
@@ -85,7 +86,7 @@ class MapperTest(AssertMixin):
         """tests that a lazy relation can be upgraded to an eager relation via the options method"""
         m = mapper(User, users, properties = dict(
             addresses = relation(Address, addresses, lazy = True)
-        ), echo = True)
+        ))
         l = m.options(eagerload('addresses')).select()
         self.assert_result(l, User,
             {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
@@ -97,7 +98,7 @@ class MapperTest(AssertMixin):
         """tests that an eager relation can be upgraded to a lazy relation via the options method"""
         m = mapper(User, users, properties = dict(
             addresses = relation(Address, addresses, lazy = False)
-        ), echo = True)
+        ))
         l = m.options(lazyload('addresses')).select()
         self.assert_result(l, User,
             {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
@@ -114,7 +115,7 @@ class LazyTest(AssertMixin):
         """tests a basic one-to-many lazy load"""
         m = mapper(User, users, properties = dict(
             addresses = relation(Address, addresses, lazy = True)
-        ), echo = True)
+        ))
         l = m.select(users.c.user_id == 7)
         self.assert_result(l, User,
             {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
@@ -126,7 +127,7 @@ class LazyTest(AssertMixin):
 
         m = mapper(Item, items, properties = dict(
                 keywords = relation(Keyword, keywords, itemkeywords, lazy = True),
-            ), echo = True)
+            ))
         l = m.select()
         self.assert_result(l, Item, 
             {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
@@ -156,7 +157,7 @@ class EagerTest(PersistTest):
         m = mapper(User, users, properties = dict(
             #addresses = relation(Address, addresses, lazy = False),
             addresses = relation(m, lazy = False),
-        ), echo = True)
+        ))
         l = m.select()
         print repr(l)
 
@@ -166,7 +167,7 @@ class EagerTest(PersistTest):
         criterion doesnt interfere with the eager load criterion."""
         m = mapper(User, users, properties = dict(
             addresses = relation(Address, addresses, primaryjoin = users.c.user_id==addresses.c.user_id, lazy = False)
-        ), echo = True)
+        ))
         l = m.select(and_(addresses.c.email_address == 'ed@lala.com', addresses.c.user_id==users.c.user_id))
         print repr(l)
 
@@ -199,7 +200,7 @@ class EagerTest(PersistTest):
         m = mapper(User, users, properties = dict(
             orders_open = relation(Order, openorders, primaryjoin = and_(openorders.c.isopen == 1, users.c.user_id==openorders.c.user_id), lazy = False),
             orders_closed = relation(Order, closedorders, primaryjoin = and_(closedorders.c.isopen == 0, users.c.user_id==closedorders.c.user_id), lazy = False)
-        ), echo = True)
+        ))
         l = m.select()
         print repr(l)
 
@@ -213,7 +214,7 @@ class EagerTest(PersistTest):
         m = mapper(User, users, properties = dict(
             addresses = relation(Address, addresses, lazy = False),
             orders = relation(ordermapper, primaryjoin = users.c.user_id==orders.c.user_id, lazy = False),
-        ), echo = True)
+        ))
         l = m.select()
         print repr(l)
     
@@ -222,7 +223,7 @@ class EagerTest(PersistTest):
         
         m = mapper(Item, items, properties = dict(
                 keywords = relation(Keyword, keywords, itemkeywords, lazy = False),
-            ), echo = True)
+            ))
         l = m.select()
         print repr(l)
         
@@ -235,17 +236,16 @@ class EagerTest(PersistTest):
         m = mapper(Item, items, 
         properties = dict(
                 keywords = relation(Keyword, keywords, itemkeywords, lazy = False),
-            ), 
-        echo = True)
+            ))
 
         m = mapper(Order, orders, properties = dict(
                 items = relation(m, lazy = False)
-            ), echo = True)
+            ))
         l = m.select("orders.order_id in (1,2,3)")
         #l = m.select()
         print repr(l)
 
-class SaveTest(PersistTest):
+class SaveTest(AssertMixin):
 
     def testbasic(self):
         # save two users
@@ -253,7 +253,7 @@ class SaveTest(PersistTest):
         u.user_name = 'savetester'
         u2 = User()
         u2.user_name = 'savetester2'
-        m = mapper(User, users, echo=True)
+        m = mapper(User, users)
         m.save(u)
         m.save(u2)
 
@@ -282,7 +282,7 @@ class SaveTest(PersistTest):
         """tests a save of an object where each instance spans two tables. also tests
         redefinition of the keynames for the column properties."""
         usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
-        m = mapper(User, usersaddresses, table = users, echo = True, 
+        m = mapper(User, usersaddresses, table = users,  
             properties = dict(
                 email = ColumnProperty(addresses.c.email_address), 
                 foo_id = ColumnProperty(users.c.user_id, addresses.c.user_id)
@@ -314,7 +314,7 @@ class SaveTest(PersistTest):
         """test basic save of one to many."""
         m = mapper(User, users, properties = dict(
             addresses = relation(Address, addresses, lazy = True)
-        ), echo = True)
+        ))
         u = User()
         u.user_name = 'one2manytester'
         u.addresses = []
@@ -344,7 +344,7 @@ class SaveTest(PersistTest):
         """tests that an alias of a table can be used in a mapper. 
         the mapper has to locate the original table and columns to keep it all straight."""
         ualias = Alias(users, 'ualias')
-        m = mapper(User, ualias, echo = True)
+        m = mapper(User, ualias)
         u = User()
         u.user_name = 'testalias'
         m.save(u)
@@ -355,7 +355,7 @@ class SaveTest(PersistTest):
     def testremove(self):
         m = mapper(User, users, properties = dict(
             addresses = relation(Address, addresses, lazy = True)
-        ), echo = True)
+        ))
         u = User()
         u.user_name = 'one2manytester'
         u.addresses = []
@@ -381,7 +381,7 @@ class SaveTest(PersistTest):
 
         m = mapper(Item, items, properties = dict(
                 keywords = relation(Keyword, keywords, itemkeywords, lazy = False),
-            ), echo = True)
+            ))
 
         keywordmapper = mapper(Keyword, keywords)
 
@@ -395,12 +395,25 @@ class SaveTest(PersistTest):
         for k in klist:
             item.keywords.append(k)
         m.save(item)
-        print repr(m.select(items.c.item_id == item.item_id))
+        l = m.select(items.c.item_id == item.item_id)
+
+        self.assert_result(l, Item,
+            {'item_id' : item.item_id, 'keywords' : (Keyword, [
+                {'name' : 'purple'},
+                {'name' : 'blue'},
+                {'name' : 'big'},
+                {'name' : 'round'}
+            ])})
 
         del item.keywords[2]
         del item.keywords[2]
         m.save(item)
-        print repr(m.select(items.c.item_id == item.item_id))
+        l = m.select(items.c.item_id == item.item_id)
+        self.assert_result(l, Item,
+            {'item_id' : item.item_id, 'keywords' : (Keyword, [
+                {'name' : 'purple'},
+                {'name' : 'blue'},
+            ])})
 
 if __name__ == "__main__":
     unittest.main()
index bddb7b499d8db79cf1dc6200c8b800436b5af05a..3822ab9280761b4538b4a44cfe030893f5caf2c5 100644 (file)
@@ -8,12 +8,12 @@ DBTYPE = 'sqlite_memory'
 
 if DBTYPE == 'sqlite_memory':
     import sqlalchemy.databases.sqlite as sqllite
-    db = sqllite.engine(':memory:', {}, echo = False)
+    db = sqllite.engine(':memory:', {}, echo = ECHO)
 elif DBTYPE == 'sqlite_file':
     import sqlalchemy.databases.sqlite as sqllite
     if os.access('querytest.db', os.F_OK):
         os.remove('querytest.db')
-    db = sqllite.engine('querytest.db', opts = {}, echo = True)
+    db = sqllite.engine('querytest.db', opts = {}, echo = ECHO)
 elif DBTYPE == 'postgres':
     pass