]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Aug 2005 22:03:57 +0000 (22:03 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Aug 2005 22:03:57 +0000 (22:03 +0000)
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapper.py
test/mapper.py

index 9e662f455a3bf86e3d947b35a9af2ecd6795b9c7..db1191a9bc89b8fa6f73fd0f79605c05723a1969 100644 (file)
@@ -55,7 +55,7 @@ class SQLEngine(schema.SchemaEngine):
         self._pool = sqlalchemy.pool.manage(self.dbapi()).get_pool(*cargs, **cparams)
         self._echo = echo
         self.context = util.ThreadLocal()
-        
+
     def schemagenerator(self, proxy, **params):
         raise NotImplementedError()
 
@@ -64,7 +64,7 @@ class SQLEngine(schema.SchemaEngine):
 
     def reflecttable(self, table):
         raise NotImplementedError()
-        
+
     def columnimpl(self, column):
         return sql.ColumnSelectable(column)
 
@@ -72,22 +72,52 @@ class SQLEngine(schema.SchemaEngine):
         """returns a thread-local map of the generated primary keys corresponding to the most recent
         insert statement.  keys are the names of columns."""
         raise NotImplementedError()
-        
+
     def connect_args(self):
         raise NotImplementedError()
-        
+
     def dbapi(self):
         raise NotImplementedError()
 
     def compile(self, statement, bindparams):
         raise NotImplementedError()
 
+    def do_begin(self, connection):
+        """implementations might want to put logic here for turning autocommit on/off, etc."""
+        pass
+    def do_rollback(self, connection):
+        """implementations might want to put logic here for turning autocommit on/off, etc."""
+        connection.rollback()
+    def do_commit(self, connection):
+        """implementations might want to put logic here for turning autocommit on/off, etc."""
+        connection.commit()
+
     def proxy(self):
         return lambda s, p = None: self.execute(s, p)
-        
+
     def connection(self):
         return self._pool.connect()
 
+    def multi_transaction(self, tables, func):
+        """provides a transaction boundary across tables which may be in multiple databases.
+        
+        clearly, this approach only goes so far, such as if database A commits, then database B commits
+        and fails, A is already committed.  Any failure conditions have to be raised before anyone
+        commits for this to be useful."""
+        engines = util.Set()
+        for table in tables:
+            engines.append(table.engine)
+        for engine in engines:
+            engine.begin()
+        try:
+            func()
+        except:
+            for engine in engines:
+                engine.rollback()
+            raise
+        for engine in engines:
+            engine.commit()
+
     def transaction(self, func):
         self.begin()
         try:
@@ -96,10 +126,12 @@ class SQLEngine(schema.SchemaEngine):
             self.rollback()
             raise
         self.commit()
-            
+
+        
     def begin(self):
         if getattr(self.context, 'transaction', None) is None:
             conn = self.connection()
+            self.do_begin(conn)
             self.context.transaction = conn
             self.context.tcount = 1
         else:
@@ -107,7 +139,7 @@ class SQLEngine(schema.SchemaEngine):
             
     def rollback(self):
         if self.context.transaction is not None:
-            self.context.transaction.rollback()
+            self.do_rollback(self.context.transaction)
             self.context.transaction = None
             self.context.tcount = None
             
@@ -116,7 +148,7 @@ class SQLEngine(schema.SchemaEngine):
             count = self.context.tcount - 1
             self.context.tcount = count
             if count == 0:
-                self.context.transaction.commit()
+                self.do_commit(self.context.transaction)
                 self.context.transaction = None
                 self.context.tcount = None
 
index 283bfb97bcd3c44396af1a6dec687922aeb912a2..bef7474f24f9bc1bc59e32476edec3c61150ba5d 100644 (file)
@@ -63,10 +63,10 @@ def identitymap():
 
 def globalidentity():
     return _global_identitymap
-    
+
 def eagerload(name):
     return EagerLazySwitcher(name, toeager = True)
-    
+
 def lazyload(name):
     return EagerLazySwitcher(name, toeager = False)
 
@@ -104,7 +104,7 @@ class Mapper(object):
 
         if isroot:
             self.init(self)
-    
+
     def hash_key(self):
         return mapper_hash_key(
             self.class_,
@@ -181,7 +181,7 @@ class Mapper(object):
             for option in options:
                 option.process(mapper)
             return _mappers.setdefault(hashkey, mapper)
-        
+
     def select(self, arg = None, **params):
         """selects instances of the object from the database.  
         
@@ -196,28 +196,37 @@ class Mapper(object):
             return self._select_statement(arg, **params)
         else:
             return self._select_whereclause(arg, **params)
-        
+
     def save(self, obj, traverse = True, refetch = False):
-        """saves the object.  based on the existence of its primary key, either inserts or updates.
+        """saves the object across all its primary tables.  
+        based on the existence of the primary key for each table, either inserts or updates.
         primary key is determined by the underlying database engine's sequence methodology.
-        traverse indicates attached objects should be saved as well.
+        the traverse flag indicates attached objects should be saved as well.
         
         if smart attributes are being used for the object, the "dirty" flag, or the absense 
         of the attribute, determines if the item is saved.  if smart attributes are not being 
         used, the item is saved unconditionally.
         """
-        # TODO: support multi-table saves
+
         if getattr(obj, 'dirty', True):
-            for table in self.tables:
-                for col in table.columns:
-                    if getattr(obj, col.key, None) is None:
-                        self.insert(obj, table)
-                        break
-                else:
-                    self.update(obj, table)
+            f = def():
+                for table in self.tables:
+                    for col in table.columns:
+                        if getattr(obj, col.key, None) is None:
+                            self.insert(obj, table)
+                            break
+                    else:
+                        self.update(obj, table)
+
+                for prop in self.props.values():
+                    prop.save(obj, traverse, refetch)
+            self.transaction(f)
+        else:
+            for prop in self.props.values():
+                prop.save(obj, traverse, refetch)
 
-        for prop in self.props.values():
-            prop.save(obj, traverse, refetch)
+    def transaction(self, f):
+        return self.table.engine.multi_transaction(self.tables, f)
 
     def remove(self, obj, traverse = True):
         """removes the object.  traverse indicates attached objects should be removed as well."""
@@ -433,7 +442,7 @@ class PropertyLoader(MapperProperty):
             if self.primaryjoin is None:
                 self.primaryjoin = match_primaries(parent.selectable, self.target)
 
-    def save(self, object, traverse, refetch):
+    def save(self, obj, traverse, refetch):
         # if a mapping table does not exist, save a row for all objects
         # in our list normally, setting their primary keys
         # else, determine the foreign key column in our table, set it to the parent
@@ -441,15 +450,20 @@ class PropertyLoader(MapperProperty):
         # if a mapping table exists, determine the two foreign key columns 
         # in the mapping table, set the two values, and insert that row, for
         # each row in the list
-        if self.secondary is None:
-            self.mapper.save(object)
-        else:
-            # TODO: crap, we dont have a simple version of what object props/cols match to which
-            pass
+        for child in getattr(obj, self.key):
+            setter = ForeignKeySetter(obj, child)
+            self.primaryjoin.accept_visitor(setter)
+            self.mapper.save(child)
 
     def delete(self):
         self.mapper.delete()
 
+class ForeignKeySetter(ClauseVisitor):
+    def visit_binary(self, binary):
+        if binary.operator == '==':
+            if binary.left.table == self.primarytable and binary.right.table == self.secondarytable:
+                setattr(self.child, binary.left.colname, getattr(obj, binary.right.colname))
+
 class LazyLoader(PropertyLoader):
 
     def init(self, key, parent, root):
index 7429070a369c17c85ae1c50272d14f4ec0f2aa5a..d1f29e43893e48c053856358de25060176591563 100644 (file)
@@ -250,18 +250,18 @@ class SaveTest(PersistTest):
         m = mapper(User, users, echo=True)
         m.save(u)
         m.save(u2)
-        
+
         # assert the first one retreives the same from the identity map
         nu = m.get(u.user_id)
         self.assert_(u is nu)
-        
+
         # clear out the identity map, so next get forces a SELECT
         m.identitymap.clear()
 
         # check it again, identity should be different but ids the same
         nu = m.get(u.user_id)
         self.assert_(u is not nu and u.user_id == nu.user_id and nu.user_name == 'savetester')
-        
+
         # change first users name and save
         u.user_name = 'modifiedname'
         m.save(u)
@@ -292,6 +292,6 @@ class SaveTest(PersistTest):
         print repr(usertable)
         addresstable = engine.ResultProxy(addresses.select().execute()).fetchall()
         print repr(addresstable)
-        
+
 if __name__ == "__main__":
-    unittest.main()        
+    unittest.main()