From 9806d81675ef62363753a028ada43bc460728cf5 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 24 Jan 2010 18:13:21 +0000 Subject: [PATCH] - the "threadlocal" engine has been rewritten and simplified and now supports SAVEPOINT operations. --- CHANGES | 3 + lib/sqlalchemy/engine/base.py | 5 +- lib/sqlalchemy/engine/threadlocal.py | 230 +++++++-------------------- lib/sqlalchemy/test/engines.py | 4 +- test/engine/test_transaction.py | 134 ++++++++++++++-- 5 files changed, 184 insertions(+), 192 deletions(-) diff --git a/CHANGES b/CHANGES index 326d64b1b7..e322df15de 100644 --- a/CHANGES +++ b/CHANGES @@ -378,6 +378,9 @@ CHANGES - All pyodbc-dialects now support extra pyodbc-specific kw arguments 'ansi', 'unicode_results', 'autocommit'. [ticket:1621] + + - the "threadlocal" engine has been rewritten and simplified + and now supports SAVEPOINT operations. - deprecated or removed * result.last_inserted_ids() is deprecated. Use diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 2f26add6b9..26d1b69d92 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -765,7 +765,7 @@ class Connection(Connectable): return self.engine.Connection( self.engine, self.__connection, _branch=self.__branch, _options=opt) - + @property def dialect(self): "Dialect used by this Connection." @@ -1026,7 +1026,8 @@ class Connection(Connectable): conn.close() self.__invalid = False del self.__connection - + self.__transaction = None + def scalar(self, object, *multiparams, **params): """Executes and returns the first column of the first row. diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 27d857623e..a9892ae7e0 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -7,211 +7,95 @@ invoked automatically when the threadlocal engine strategy is used. from sqlalchemy import util from sqlalchemy.engine import base - - -class TLSession(object): - def __init__(self, engine): - self.engine = engine - self.__tcount = 0 - - def get_connection(self, close_with_result=False): - try: - return self.__transaction._increment_connect() - except AttributeError: - return self.engine.TLConnection(self, self.engine.pool.connect(), - close_with_result=close_with_result) - - def reset(self): - try: - self.__transaction._force_close() - del self.__transaction - del self.__trans - except AttributeError: - pass - self.__tcount = 0 - - def _conn_closed(self): - if self.__tcount == 1: - self.__trans._trans.rollback() - self.reset() - - def in_transaction(self): - return self.__tcount > 0 - - def prepare(self): - if self.__tcount == 1: - self.__trans._trans.prepare() - - def begin_twophase(self, xid=None): - if self.__tcount == 0: - self.__transaction = self.get_connection() - self.__trans = self.__transaction._begin_twophase(xid=xid) - self.__tcount += 1 - return self.__trans - - def begin(self, **kwargs): - if self.__tcount == 0: - self.__transaction = self.get_connection() - self.__trans = self.__transaction._begin(**kwargs) - self.__tcount += 1 - return self.__trans - - def rollback(self): - if self.__tcount > 0: - try: - self.__trans._trans.rollback() - finally: - self.reset() - - def commit(self): - if self.__tcount == 1: - try: - self.__trans._trans.commit() - finally: - self.reset() - elif self.__tcount > 1: - self.__tcount -= 1 - - def close(self): - if self.__tcount == 1: - self.rollback() - elif self.__tcount > 1: - self.__tcount -= 1 - - def is_begun(self): - return self.__tcount > 0 - +import weakref class TLConnection(base.Connection): - def __init__(self, session, connection, **kwargs): - base.Connection.__init__(self, session.engine, connection, **kwargs) - self.__session = session - self.__opencount = 1 - - def _branch(self): - return self.engine.Connection(self.engine, self.connection, _branch=True) - - def session(self): - return self.__session - session = property(session) - + def __init__(self, *arg, **kw): + super(TLConnection, self).__init__(*arg, **kw) + self.__opencount = 0 + def _increment_connect(self): self.__opencount += 1 return self - - def _begin(self, **kwargs): - return TLTransaction( - super(TLConnection, self).begin(**kwargs), self.__session) - - def _begin_twophase(self, xid=None): - return TLTransaction( - super(TLConnection, self).begin_twophase(xid=xid), self.__session) - - def in_transaction(self): - return self.session.in_transaction() - - def begin(self, **kwargs): - return self.session.begin(**kwargs) - - def begin_twophase(self, xid=None): - return self.session.begin_twophase(xid=xid) - def begin_nested(self): - raise NotImplementedError("SAVEPOINT transactions with the 'threadlocal' strategy") - def close(self): if self.__opencount == 1: base.Connection.close(self) - self.__session._conn_closed() self.__opencount -= 1 def _force_close(self): self.__opencount = 0 base.Connection.close(self) - -class TLTransaction(base.Transaction): - def __init__(self, trans, session): - self._trans = trans - self._session = session - - def connection(self): - return self._trans.connection - connection = property(connection) - - def is_active(self): - return self._trans.is_active - is_active = property(is_active) - - def rollback(self): - self._session.rollback() - - def prepare(self): - self._session.prepare() - - def commit(self): - self._session.commit() - - def close(self): - self._session.close() - - def __enter__(self): - return self - - def __exit__(self, type, value, traceback): - self._trans.__exit__(type, value, traceback) - - + class TLEngine(base.Engine): - """An Engine that includes support for thread-local managed transactions. + """An Engine that includes support for thread-local managed transactions.""" - The TLEngine relies upon its Pool having "threadlocal" behavior, - so that once a connection is checked out for the current thread, - you get that same connection repeatedly. - """ def __init__(self, *args, **kwargs): - """Construct a new TLEngine.""" - super(TLEngine, self).__init__(*args, **kwargs) - self.context = util.threading.local() - + self._connections = util.threading.local() proxy = kwargs.get('proxy') if proxy: self.TLConnection = base._proxy_connection_cls(TLConnection, proxy) else: self.TLConnection = TLConnection - def session(self): - "Returns the current thread's TLSession" - if not hasattr(self.context, 'session'): - self.context.session = TLSession(self) - return self.context.session - - session = property(session) - - def contextual_connect(self, **kwargs): - """Return a TLConnection which is thread-locally scoped.""" - - return self.session.get_connection(**kwargs) - - def begin_twophase(self, **kwargs): - return self.session.begin_twophase(**kwargs) + def contextual_connect(self, **kw): + if not hasattr(self._connections, 'conn'): + connection = None + else: + connection = self._connections.conn() + + if connection is None or connection.closed: + # guards against pool-level reapers, if desired. + # or not connection.connection.is_valid: + connection = self.TLConnection(self, self.pool.connect(), **kw) + self._connections.conn = conn = weakref.ref(connection) + + return connection._increment_connect() + + def begin_twophase(self, xid=None): + if not hasattr(self._connections, 'trans'): + self._connections.trans = [] + self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid)) def begin_nested(self): - raise NotImplementedError("SAVEPOINT transactions with the 'threadlocal' strategy") + if not hasattr(self._connections, 'trans'): + self._connections.trans = [] + self._connections.trans.append(self.contextual_connect().begin_nested()) + + def begin(self): + if not hasattr(self._connections, 'trans'): + self._connections.trans = [] + self._connections.trans.append(self.contextual_connect().begin()) - def begin(self, **kwargs): - return self.session.begin(**kwargs) - def prepare(self): - self.session.prepare() + self._connections.trans[-1].prepare() def commit(self): - self.session.commit() - + trans = self._connections.trans.pop(-1) + trans.commit() + def rollback(self): - self.session.rollback() - + trans = self._connections.trans.pop(-1) + trans.rollback() + + def dispose(self): + self._connections = util.threading.local() + super(TLEngine, self).dispose() + + @property + def closed(self): + return not hasattr(self._connections, 'conn') or \ + self._connections.conn() is None or \ + self._connections.conn().closed + + def close(self): + if not self.closed: + self.contextual_connect().close() + del self._connections.conn + self._connections.trans = [] + def __repr__(self): return 'TLEngine(%s)' % str(self.url) diff --git a/lib/sqlalchemy/test/engines.py b/lib/sqlalchemy/test/engines.py index 31d1658af9..2f3d11bda6 100644 --- a/lib/sqlalchemy/test/engines.py +++ b/lib/sqlalchemy/test/engines.py @@ -3,6 +3,7 @@ from collections import deque import config from sqlalchemy.util import function_named, callable import re +import warnings class ConnectionKiller(object): def __init__(self): @@ -24,8 +25,7 @@ class ConnectionKiller(object): except (SystemExit, KeyboardInterrupt): raise except Exception, e: - # fixme - sys.stderr.write("\n" + str(e) + "\n") + warnings.warn("testing_reaper couldn't close connection: %s" % e) def rollback_all(self): self._apply_all(('rollback',)) diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index c51623f2cd..2499571736 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -468,23 +468,20 @@ class TLTransactionTest(TestBase): def teardown_class(cls): users.drop(tlengine) tlengine.dispose() - - def test_nested_unsupported(self): - assert_raises(NotImplementedError, tlengine.contextual_connect().begin_nested) - assert_raises(NotImplementedError, tlengine.begin_nested) + + def setup(self): + # ensure tests start with engine closed + tlengine.close() def test_connection_close(self): - """test that when connections are closed for real, transactions are rolled back and disposed.""" + """test that when connections are closed for real, + transactions are rolled back and disposed.""" c = tlengine.contextual_connect() c.begin() - assert tlengine.session.in_transaction() - assert hasattr(tlengine.session, '_TLSession__transaction') - assert hasattr(tlengine.session, '_TLSession__trans') + assert c.in_transaction() c.close() - assert not tlengine.session.in_transaction() - assert not hasattr(tlengine.session, '_TLSession__transaction') - assert not hasattr(tlengine.session, '_TLSession__trans') + assert not c.in_transaction() def test_transaction_close(self): c = tlengine.contextual_connect() @@ -615,8 +612,9 @@ class TLTransactionTest(TestBase): conn.close() external_connection.close() - def test_nesting(self): - """tests nesting of transactions""" + def test_nesting_rollback(self): + """tests nesting of transactions, rollback at the end""" + external_connection = tlengine.connect() self.assert_(external_connection.connection is not tlengine.contextual_connect().connection) tlengine.begin() @@ -633,6 +631,25 @@ class TLTransactionTest(TestBase): finally: external_connection.close() + def test_nesting_commit(self): + """tests nesting of transactions, commit at the end.""" + + external_connection = tlengine.connect() + self.assert_(external_connection.connection is not tlengine.contextual_connect().connection) + tlengine.begin() + tlengine.execute(users.insert(), user_id=1, user_name='user1') + tlengine.execute(users.insert(), user_id=2, user_name='user2') + tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.begin() + tlengine.execute(users.insert(), user_id=4, user_name='user4') + tlengine.execute(users.insert(), user_id=5, user_name='user5') + tlengine.commit() + tlengine.commit() + try: + self.assert_(external_connection.scalar("select count(1) from query_users") == 5) + finally: + external_connection.close() + def test_mixed_nesting(self): """tests nesting of transactions off the TLEngine directly inside of tranasctions off the connection from the TLEngine""" @@ -684,14 +701,101 @@ class TLTransactionTest(TestBase): finally: external_connection.close() + @testing.requires.savepoints + def test_nested_subtransaction_rollback(self): + + tlengine.begin() + tlengine.execute(users.insert(), user_id=1, user_name='user1') + tlengine.begin_nested() + tlengine.execute(users.insert(), user_id=2, user_name='user2') + tlengine.rollback() + tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.commit() + tlengine.close() + + eq_( + tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), + [(1,),(3,)] + ) + tlengine.close() + + @testing.requires.savepoints + @testing.crashes('oracle+zxjdbc', 'Errors out and causes subsequent tests to deadlock') + def test_nested_subtransaction_commit(self): + tlengine.begin() + tlengine.execute(users.insert(), user_id=1, user_name='user1') + tlengine.begin_nested() + tlengine.execute(users.insert(), user_id=2, user_name='user2') + tlengine.commit() + tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.commit() + + tlengine.close() + eq_( + tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), + [(1,),(2,),(3,)] + ) + tlengine.close() + + @testing.requires.savepoints + def test_rollback_to_subtransaction(self): + tlengine.begin() + tlengine.execute(users.insert(), user_id=1, user_name='user1') + tlengine.begin_nested() + tlengine.execute(users.insert(), user_id=2, user_name='user2') + tlengine.begin() + tlengine.execute(users.insert(), user_id=3, user_name='user3') + tlengine.rollback() + tlengine.execute(users.insert(), user_id=4, user_name='user4') + tlengine.commit() + tlengine.close() + + eq_( + tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), + [(1,),(4,)] + ) + tlengine.close() + def test_connections(self): """tests that contextual_connect is threadlocal""" c1 = tlengine.contextual_connect() c2 = tlengine.contextual_connect() assert c1.connection is c2.connection c2.close() - assert c1.connection.connection is not None - + assert not c1.closed + assert not tlengine.closed + + def test_result_closing(self): + """tests that contextual_connect is threadlocal""" + + r1 = tlengine.execute("select 1") + r2 = tlengine.execute("select 1") + row1 = r1.fetchone() + row2 = r2.fetchone() + r1.close() + assert r2.connection is r1.connection + assert not r2.connection.closed + assert not tlengine.closed + + # close again, nothing happens + # since resultproxy calls close() only + # once + r1.close() + assert r2.connection is r1.connection + assert not r2.connection.closed + assert not tlengine.closed + + r2.close() + assert r2.connection.closed + assert tlengine.closed + + def test_dispose(self): + eng = create_engine(testing.db.url, strategy='threadlocal') + result = eng.execute("select 1") + eng.dispose() + eng.execute("select 1") + + @testing.requires.two_phase_transactions def test_two_phase_transaction(self): tlengine.begin_twophase() -- 2.47.3