]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the "threadlocal" engine has been rewritten and simplified
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 24 Jan 2010 18:13:21 +0000 (18:13 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 24 Jan 2010 18:13:21 +0000 (18:13 +0000)
and now supports SAVEPOINT operations.

CHANGES
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/test/engines.py
test/engine/test_transaction.py

diff --git a/CHANGES b/CHANGES
index 326d64b1b7738a72b67c86720de0cb6a35914b44..e322df15de9e0f43835e8de404e239713dde62bf 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -378,6 +378,9 @@ CHANGES
   - All pyodbc-dialects now support extra pyodbc-specific 
     kw arguments 'ansi', 'unicode_results', 'autocommit'.
     [ticket:1621]
+
+  - the "threadlocal" engine has been rewritten and simplified 
+    and now supports SAVEPOINT operations.
     
   - deprecated or removed
       * result.last_inserted_ids() is deprecated.  Use 
index 2f26add6b910de806094c350279ac793a9706ede..26d1b69d92259d0fdf8568f36d415c8ac0c121b4 100644 (file)
@@ -765,7 +765,7 @@ class Connection(Connectable):
         return self.engine.Connection(
                     self.engine, self.__connection,
                      _branch=self.__branch, _options=opt)
-        
+    
     @property
     def dialect(self):
         "Dialect used by this Connection."
@@ -1026,7 +1026,8 @@ class Connection(Connectable):
             conn.close()
         self.__invalid = False
         del self.__connection
-
+        self.__transaction = None
+        
     def scalar(self, object, *multiparams, **params):
         """Executes and returns the first column of the first row.
 
index 27d857623e37f200c220b31c34794f8aa4a27c81..a9892ae7e09cbfda7cc8f97fbacd19e436db5323 100644 (file)
@@ -7,211 +7,95 @@ invoked automatically when the threadlocal engine strategy is used.
 
 from sqlalchemy import util
 from sqlalchemy.engine import base
-
-
-class TLSession(object):
-    def __init__(self, engine):
-        self.engine = engine
-        self.__tcount = 0
-
-    def get_connection(self, close_with_result=False):
-        try:
-            return self.__transaction._increment_connect()
-        except AttributeError:
-            return self.engine.TLConnection(self, self.engine.pool.connect(),
-                                            close_with_result=close_with_result)
-
-    def reset(self):
-        try:
-            self.__transaction._force_close()
-            del self.__transaction
-            del self.__trans
-        except AttributeError:
-            pass
-        self.__tcount = 0
-
-    def _conn_closed(self):
-        if self.__tcount == 1:
-            self.__trans._trans.rollback()
-            self.reset()
-
-    def in_transaction(self):
-        return self.__tcount > 0
-
-    def prepare(self):
-        if self.__tcount == 1:
-            self.__trans._trans.prepare()
-
-    def begin_twophase(self, xid=None):
-        if self.__tcount == 0:
-            self.__transaction = self.get_connection()
-            self.__trans = self.__transaction._begin_twophase(xid=xid)
-        self.__tcount += 1
-        return self.__trans
-
-    def begin(self, **kwargs):
-        if self.__tcount == 0:
-            self.__transaction = self.get_connection()
-            self.__trans = self.__transaction._begin(**kwargs)
-        self.__tcount += 1
-        return self.__trans
-
-    def rollback(self):
-        if self.__tcount > 0:
-            try:
-                self.__trans._trans.rollback()
-            finally:
-                self.reset()
-
-    def commit(self):
-        if self.__tcount == 1:
-            try:
-                self.__trans._trans.commit()
-            finally:
-                self.reset()
-        elif self.__tcount > 1:
-            self.__tcount -= 1
-            
-    def close(self):
-        if self.__tcount == 1:
-            self.rollback()
-        elif self.__tcount > 1:
-            self.__tcount -= 1
-        
-    def is_begun(self):
-        return self.__tcount > 0
-
+import weakref
 
 class TLConnection(base.Connection):
-    def __init__(self, session, connection, **kwargs):
-        base.Connection.__init__(self, session.engine, connection, **kwargs)
-        self.__session = session
-        self.__opencount = 1
-
-    def _branch(self):
-        return self.engine.Connection(self.engine, self.connection, _branch=True)
-
-    def session(self):
-        return self.__session
-    session = property(session)
-
+    def __init__(self, *arg, **kw):
+        super(TLConnection, self).__init__(*arg, **kw)
+        self.__opencount = 0
+    
     def _increment_connect(self):
         self.__opencount += 1
         return self
-
-    def _begin(self, **kwargs):
-        return TLTransaction(
-            super(TLConnection, self).begin(**kwargs), self.__session)
-
-    def _begin_twophase(self, xid=None):
-        return TLTransaction(
-            super(TLConnection, self).begin_twophase(xid=xid), self.__session)
-
-    def in_transaction(self):
-        return self.session.in_transaction()
-
-    def begin(self, **kwargs):
-        return self.session.begin(**kwargs)
-
-    def begin_twophase(self, xid=None):
-        return self.session.begin_twophase(xid=xid)
     
-    def begin_nested(self):
-        raise NotImplementedError("SAVEPOINT transactions with the 'threadlocal' strategy")
-        
     def close(self):
         if self.__opencount == 1:
             base.Connection.close(self)
-            self.__session._conn_closed()
         self.__opencount -= 1
 
     def _force_close(self):
         self.__opencount = 0
         base.Connection.close(self)
 
-
-class TLTransaction(base.Transaction):
-    def __init__(self, trans, session):
-        self._trans = trans
-        self._session = session
-
-    def connection(self):
-        return self._trans.connection
-    connection = property(connection)
-    
-    def is_active(self):
-        return self._trans.is_active
-    is_active = property(is_active)
-
-    def rollback(self):
-        self._session.rollback()
-
-    def prepare(self):
-        self._session.prepare()
-
-    def commit(self):
-        self._session.commit()
-
-    def close(self):
-        self._session.close()
-
-    def __enter__(self):
-        return self
-
-    def __exit__(self, type, value, traceback):
-        self._trans.__exit__(type, value, traceback)
-
-
+        
 class TLEngine(base.Engine):
-    """An Engine that includes support for thread-local managed transactions.
+    """An Engine that includes support for thread-local managed transactions."""
 
-    The TLEngine relies upon its Pool having "threadlocal" behavior,
-    so that once a connection is checked out for the current thread,
-    you get that same connection repeatedly.
-    """
 
     def __init__(self, *args, **kwargs):
-        """Construct a new TLEngine."""
-
         super(TLEngine, self).__init__(*args, **kwargs)
-        self.context = util.threading.local()
-
+        self._connections = util.threading.local()
         proxy = kwargs.get('proxy')
         if proxy:
             self.TLConnection = base._proxy_connection_cls(TLConnection, proxy)
         else:
             self.TLConnection = TLConnection
 
-    def session(self):
-        "Returns the current thread's TLSession"
-        if not hasattr(self.context, 'session'):
-            self.context.session = TLSession(self)
-        return self.context.session
-
-    session = property(session)
-
-    def contextual_connect(self, **kwargs):
-        """Return a TLConnection which is thread-locally scoped."""
-
-        return self.session.get_connection(**kwargs)
-
-    def begin_twophase(self, **kwargs):
-        return self.session.begin_twophase(**kwargs)
+    def contextual_connect(self, **kw):
+        if not hasattr(self._connections, 'conn'):
+            connection = None
+        else:
+            connection = self._connections.conn()
+        
+        if connection is None or connection.closed:
+            # guards against pool-level reapers, if desired.
+            # or not connection.connection.is_valid:
+            connection = self.TLConnection(self, self.pool.connect(), **kw)
+            self._connections.conn = conn = weakref.ref(connection)
+        
+        return connection._increment_connect()
+    
+    def begin_twophase(self, xid=None):
+        if not hasattr(self._connections, 'trans'):
+            self._connections.trans = []
+        self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid))
 
     def begin_nested(self):
-        raise NotImplementedError("SAVEPOINT transactions with the 'threadlocal' strategy")
+        if not hasattr(self._connections, 'trans'):
+            self._connections.trans = []
+        self._connections.trans.append(self.contextual_connect().begin_nested())
+        
+    def begin(self):
+        if not hasattr(self._connections, 'trans'):
+            self._connections.trans = []
+        self._connections.trans.append(self.contextual_connect().begin())
         
-    def begin(self, **kwargs):
-        return self.session.begin(**kwargs)
-
     def prepare(self):
-        self.session.prepare()
+        self._connections.trans[-1].prepare()
         
     def commit(self):
-        self.session.commit()
-
+        trans = self._connections.trans.pop(-1)
+        trans.commit()
+        
     def rollback(self):
-        self.session.rollback()
-
+        trans = self._connections.trans.pop(-1)
+        trans.rollback()
+    
+    def dispose(self):
+        self._connections = util.threading.local()
+        super(TLEngine, self).dispose()
+        
+    @property
+    def closed(self):
+        return not hasattr(self._connections, 'conn') or \
+                self._connections.conn() is None or \
+                self._connections.conn().closed
+        
+    def close(self):
+        if not self.closed:
+            self.contextual_connect().close()
+            del self._connections.conn
+            self._connections.trans = []
+        
     def __repr__(self):
         return 'TLEngine(%s)' % str(self.url)
index 31d1658af9569e1460bba4b42804c61999ee8b3d..2f3d11bda656fd3240340e09f84049de85d94f3b 100644 (file)
@@ -3,6 +3,7 @@ from collections import deque
 import config
 from sqlalchemy.util import function_named, callable
 import re
+import warnings
 
 class ConnectionKiller(object):
     def __init__(self):
@@ -24,8 +25,7 @@ class ConnectionKiller(object):
                 except (SystemExit, KeyboardInterrupt):
                     raise
                 except Exception, e:
-                    # fixme
-                    sys.stderr.write("\n" + str(e) + "\n")
+                    warnings.warn("testing_reaper couldn't close connection: %s" % e)
 
     def rollback_all(self):
         self._apply_all(('rollback',))
index c51623f2cd921ebbf78d06d71fb924c04703d65d..2499571736505426b15187113e5dcd28b12c73bd 100644 (file)
@@ -468,23 +468,20 @@ class TLTransactionTest(TestBase):
     def teardown_class(cls):
         users.drop(tlengine)
         tlengine.dispose()
-
-    def test_nested_unsupported(self):
-        assert_raises(NotImplementedError, tlengine.contextual_connect().begin_nested)
-        assert_raises(NotImplementedError, tlengine.begin_nested)
+    
+    def setup(self):
+        # ensure tests start with engine closed
+        tlengine.close()
         
     def test_connection_close(self):
-        """test that when connections are closed for real, transactions are rolled back and disposed."""
+        """test that when connections are closed for real, 
+        transactions are rolled back and disposed."""
 
         c = tlengine.contextual_connect()
         c.begin()
-        assert tlengine.session.in_transaction()
-        assert hasattr(tlengine.session, '_TLSession__transaction')
-        assert hasattr(tlengine.session, '_TLSession__trans')
+        assert c.in_transaction()
         c.close()
-        assert not tlengine.session.in_transaction()
-        assert not hasattr(tlengine.session, '_TLSession__transaction')
-        assert not hasattr(tlengine.session, '_TLSession__trans')
+        assert not c.in_transaction()
 
     def test_transaction_close(self):
         c = tlengine.contextual_connect()
@@ -615,8 +612,9 @@ class TLTransactionTest(TestBase):
             conn.close()
             external_connection.close()
 
-    def test_nesting(self):
-        """tests nesting of transactions"""
+    def test_nesting_rollback(self):
+        """tests nesting of transactions, rollback at the end"""
+        
         external_connection = tlengine.connect()
         self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
         tlengine.begin()
@@ -633,6 +631,25 @@ class TLTransactionTest(TestBase):
         finally:
             external_connection.close()
 
+    def test_nesting_commit(self):
+        """tests nesting of transactions, commit at the end."""
+        
+        external_connection = tlengine.connect()
+        self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
+        tlengine.begin()
+        tlengine.execute(users.insert(), user_id=1, user_name='user1')
+        tlengine.execute(users.insert(), user_id=2, user_name='user2')
+        tlengine.execute(users.insert(), user_id=3, user_name='user3')
+        tlengine.begin()
+        tlengine.execute(users.insert(), user_id=4, user_name='user4')
+        tlengine.execute(users.insert(), user_id=5, user_name='user5')
+        tlengine.commit()
+        tlengine.commit()
+        try:
+            self.assert_(external_connection.scalar("select count(1) from query_users") == 5)
+        finally:
+            external_connection.close()
+
     def test_mixed_nesting(self):
         """tests nesting of transactions off the TLEngine directly inside of
         tranasctions off the connection from the TLEngine"""
@@ -684,14 +701,101 @@ class TLTransactionTest(TestBase):
         finally:
             external_connection.close()
 
+    @testing.requires.savepoints
+    def test_nested_subtransaction_rollback(self):
+        
+        tlengine.begin()
+        tlengine.execute(users.insert(), user_id=1, user_name='user1')
+        tlengine.begin_nested()
+        tlengine.execute(users.insert(), user_id=2, user_name='user2')
+        tlengine.rollback()
+        tlengine.execute(users.insert(), user_id=3, user_name='user3')
+        tlengine.commit()
+        tlengine.close()
+
+        eq_(
+            tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+            [(1,),(3,)]
+        )
+        tlengine.close()
+
+    @testing.requires.savepoints
+    @testing.crashes('oracle+zxjdbc', 'Errors out and causes subsequent tests to deadlock')
+    def test_nested_subtransaction_commit(self):
+        tlengine.begin()
+        tlengine.execute(users.insert(), user_id=1, user_name='user1')
+        tlengine.begin_nested()
+        tlengine.execute(users.insert(), user_id=2, user_name='user2')
+        tlengine.commit()
+        tlengine.execute(users.insert(), user_id=3, user_name='user3')
+        tlengine.commit()
+
+        tlengine.close()
+        eq_(
+            tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+            [(1,),(2,),(3,)]
+        )
+        tlengine.close()
+
+    @testing.requires.savepoints
+    def test_rollback_to_subtransaction(self):
+        tlengine.begin()
+        tlengine.execute(users.insert(), user_id=1, user_name='user1')
+        tlengine.begin_nested()
+        tlengine.execute(users.insert(), user_id=2, user_name='user2')
+        tlengine.begin()
+        tlengine.execute(users.insert(), user_id=3, user_name='user3')
+        tlengine.rollback()
+        tlengine.execute(users.insert(), user_id=4, user_name='user4')
+        tlengine.commit()
+        tlengine.close()
+
+        eq_(
+            tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+            [(1,),(4,)]
+        )
+        tlengine.close()
+
     def test_connections(self):
         """tests that contextual_connect is threadlocal"""
         c1 = tlengine.contextual_connect()
         c2 = tlengine.contextual_connect()
         assert c1.connection is c2.connection
         c2.close()
-        assert c1.connection.connection is not None
-
+        assert not c1.closed
+        assert not tlengine.closed
+        
+    def test_result_closing(self):
+        """tests that contextual_connect is threadlocal"""
+        
+        r1 = tlengine.execute("select 1")
+        r2 = tlengine.execute("select 1")
+        row1 = r1.fetchone()
+        row2 = r2.fetchone()
+        r1.close()
+        assert r2.connection is r1.connection
+        assert not r2.connection.closed
+        assert not tlengine.closed
+        
+        # close again, nothing happens
+        # since resultproxy calls close() only
+        # once
+        r1.close()
+        assert r2.connection is r1.connection
+        assert not r2.connection.closed
+        assert not tlengine.closed
+
+        r2.close()
+        assert r2.connection.closed
+        assert tlengine.closed
+    
+    def test_dispose(self):
+        eng = create_engine(testing.db.url, strategy='threadlocal')
+        result = eng.execute("select 1")
+        eng.dispose()
+        eng.execute("select 1")
+        
+        
     @testing.requires.two_phase_transactions
     def test_two_phase_transaction(self):
         tlengine.begin_twophase()