]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the "connection" argument from engine.transaction() and
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 29 Jan 2010 22:20:55 +0000 (22:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 29 Jan 2010 22:20:55 +0000 (22:20 +0000)
engine.run_callable() is removed - Connection itself
now has those methods.   All four methods accept
*args and **kwargs which are passed to the given callable,
as well as the operating connection.

CHANGES
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/schema.py
test/engine/test_transaction.py

diff --git a/CHANGES b/CHANGES
index aef444c38a23a2f299f69a7201784a5eaafe4f65..1ee77d6f769ee2cd8b3ddc389d85f2f41accb184 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -425,7 +425,12 @@ CHANGES
         result.inserted_primary_key
       * dialect.get_default_schema_name(connection) is now
         public via dialect.default_schema_name.
-            
+      * the "connection" argument from engine.transaction() and
+        engine.run_callable() is removed - Connection itself
+        now has those methods.   All four methods accept
+        *args and **kwargs which are passed to the given callable, 
+        as well as the operating connection.
+        
 - schema
     - the `__contains__()` method of `MetaData` now accepts
       strings or `Table` objects as arguments.  If given
index 6e4a34219573173510db4bb78a13cbf98690764b..6a4fa5d08ed724d1c4e77e9aca505708a70285f9 100644 (file)
@@ -1220,8 +1220,28 @@ class Connection(Connectable):
     def default_schema_name(self):
         return self.engine.dialect.get_default_schema_name(self)
 
-    def run_callable(self, callable_):
-        return callable_(self)
+    def transaction(self, callable_, *args, **kwargs):
+        """Execute the given function within a transaction boundary.
+
+        This is a shortcut for explicitly calling `begin()` and `commit()`
+        and optionally `rollback()` when exceptions are raised.  The
+        given `*args` and `**kwargs` will be passed to the function.
+        
+        See also transaction() on engine.
+        
+        """
+
+        trans = self.begin()
+        try:
+            ret = self.run_callable(callable_, *args, **kwargs)
+            trans.commit()
+            return ret
+        except:
+            trans.rollback()
+            raise
+
+    def run_callable(self, callable_, *args, **kwargs):
+        return callable_(self, *args, **kwargs)
 
 
 class Transaction(object):
@@ -1406,42 +1426,31 @@ class Engine(Connectable):
             if connection is None:
                 conn.close()
 
-    def transaction(self, callable_, connection=None, *args, **kwargs):
+    def transaction(self, callable_, *args, **kwargs):
         """Execute the given function within a transaction boundary.
 
         This is a shortcut for explicitly calling `begin()` and `commit()`
         and optionally `rollback()` when exceptions are raised.  The
-        given `*args` and `**kwargs` will be passed to the function, as
-        well as the Connection used in the transaction.
+        given `*args` and `**kwargs` will be passed to the function.
+        
+        The connection used is that of contextual_connect().
+        
+        See also the similar method on Connection itself.
+        
         """
-
-        if connection is None:
-            conn = self.contextual_connect()
-        else:
-            conn = connection
+        
+        conn = self.contextual_connect()
         try:
-            trans = conn.begin()
-            try:
-                ret = callable_(conn, *args, **kwargs)
-                trans.commit()
-                return ret
-            except:
-                trans.rollback()
-                raise
+            return conn.transaction(callable_, *args, **kwargs)
         finally:
-            if connection is None:
-                conn.close()
+            conn.close()
 
-    def run_callable(self, callable_, connection=None, *args, **kwargs):
-        if connection is None:
-            conn = self.contextual_connect()
-        else:
-            conn = connection
+    def run_callable(self, callable_, *args, **kwargs):
+        conn = self.contextual_connect()
         try:
-            return callable_(conn, *args, **kwargs)
+            return conn.run_callable(callable_, *args, **kwargs)
         finally:
-            if connection is None:
-                conn.close()
+            conn.close()
 
     def execute(self, statement, *multiparams, **params):
         connection = self.contextual_connect(close_with_result=True)
@@ -1506,7 +1515,7 @@ class Engine(Connectable):
                 conn.close()
 
     def has_table(self, table_name, schema=None):
-        return self.run_callable(lambda c: self.dialect.has_table(c, table_name, schema=schema))
+        return self.run_callable(self.dialect.has_table, table_name, schema)
 
     def raw_connection(self):
         """Return a DB-API connection."""
index cd9bd48921a80ce3be98d177a548610519bfbec8..c345b0b21ede5e6116c1fe0172682f180a1f8bb3 100644 (file)
@@ -397,9 +397,7 @@ class Table(SchemaItem, expression.TableClause):
         if bind is None:
             bind = _bind_or_error(self)
 
-        def do(conn):
-            return conn.dialect.has_table(conn, self.name, schema=self.schema)
-        return bind.run_callable(do)
+        return bind.run_callable(bind.dialect.has_table, self.name, schema=self.schema)
 
     def create(self, bind=None, checkfirst=False):
         """Issue a ``CREATE`` statement for this table.
index fa2cc3e59cdf9c9c175ed21adcdf38acfc359364..5be1d937401fe4cd164403a37c4571287d75ccca 100644 (file)
@@ -1,6 +1,7 @@
 from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import sys, time, threading
-from sqlalchemy import create_engine, MetaData, INT, VARCHAR, Sequence, select, Integer, String, func, text
+from sqlalchemy import create_engine, MetaData, INT, VARCHAR, Sequence, \
+                            select, Integer, String, func, text, exc
 from sqlalchemy.test.schema import Table
 from sqlalchemy.test.schema import Column
 from sqlalchemy.test import TestBase, testing
@@ -73,7 +74,24 @@ class TransactionTest(TestBase):
         result = connection.execute("select * from query_users")
         assert len(result.fetchall()) == 0
         connection.close()
-
+    
+    def test_transaction_container(self):
+        
+        def go(conn, table, data):
+            for d in data:
+                conn.execute(table.insert(), d)
+            
+        testing.db.transaction(go, users, [dict(user_id=1, user_name='user1')])
+        eq_(testing.db.execute(users.select()).fetchall(), [(1, 'user1')])
+        
+        assert_raises(exc.DBAPIError, 
+            testing.db.transaction, go, users, [
+                {'user_id':2, 'user_name':'user2'},
+                {'user_id':1, 'user_name':'user3'},
+            ]
+        )
+        eq_(testing.db.execute(users.select()).fetchall(), [(1, 'user1')])
+        
     def test_nested_rollback(self):
         connection = testing.db.connect()