from sqlalchemy import util
from sqlalchemy.engine import base
-
-
-class TLSession(object):
- def __init__(self, engine):
- self.engine = engine
- self.__tcount = 0
-
- def get_connection(self, close_with_result=False):
- try:
- return self.__transaction._increment_connect()
- except AttributeError:
- return self.engine.TLConnection(self, self.engine.pool.connect(),
- close_with_result=close_with_result)
-
- def reset(self):
- try:
- self.__transaction._force_close()
- del self.__transaction
- del self.__trans
- except AttributeError:
- pass
- self.__tcount = 0
-
- def _conn_closed(self):
- if self.__tcount == 1:
- self.__trans._trans.rollback()
- self.reset()
-
- def in_transaction(self):
- return self.__tcount > 0
-
- def prepare(self):
- if self.__tcount == 1:
- self.__trans._trans.prepare()
-
- def begin_twophase(self, xid=None):
- if self.__tcount == 0:
- self.__transaction = self.get_connection()
- self.__trans = self.__transaction._begin_twophase(xid=xid)
- self.__tcount += 1
- return self.__trans
-
- def begin(self, **kwargs):
- if self.__tcount == 0:
- self.__transaction = self.get_connection()
- self.__trans = self.__transaction._begin(**kwargs)
- self.__tcount += 1
- return self.__trans
-
- def rollback(self):
- if self.__tcount > 0:
- try:
- self.__trans._trans.rollback()
- finally:
- self.reset()
-
- def commit(self):
- if self.__tcount == 1:
- try:
- self.__trans._trans.commit()
- finally:
- self.reset()
- elif self.__tcount > 1:
- self.__tcount -= 1
-
- def close(self):
- if self.__tcount == 1:
- self.rollback()
- elif self.__tcount > 1:
- self.__tcount -= 1
-
- def is_begun(self):
- return self.__tcount > 0
-
+import weakref
class TLConnection(base.Connection):
- def __init__(self, session, connection, **kwargs):
- base.Connection.__init__(self, session.engine, connection, **kwargs)
- self.__session = session
- self.__opencount = 1
-
- def _branch(self):
- return self.engine.Connection(self.engine, self.connection, _branch=True)
-
- def session(self):
- return self.__session
- session = property(session)
-
+ def __init__(self, *arg, **kw):
+ super(TLConnection, self).__init__(*arg, **kw)
+ self.__opencount = 0
+
def _increment_connect(self):
self.__opencount += 1
return self
-
- def _begin(self, **kwargs):
- return TLTransaction(
- super(TLConnection, self).begin(**kwargs), self.__session)
-
- def _begin_twophase(self, xid=None):
- return TLTransaction(
- super(TLConnection, self).begin_twophase(xid=xid), self.__session)
-
- def in_transaction(self):
- return self.session.in_transaction()
-
- def begin(self, **kwargs):
- return self.session.begin(**kwargs)
-
- def begin_twophase(self, xid=None):
- return self.session.begin_twophase(xid=xid)
- def begin_nested(self):
- raise NotImplementedError("SAVEPOINT transactions with the 'threadlocal' strategy")
-
def close(self):
if self.__opencount == 1:
base.Connection.close(self)
- self.__session._conn_closed()
self.__opencount -= 1
def _force_close(self):
self.__opencount = 0
base.Connection.close(self)
-
-class TLTransaction(base.Transaction):
- def __init__(self, trans, session):
- self._trans = trans
- self._session = session
-
- def connection(self):
- return self._trans.connection
- connection = property(connection)
-
- def is_active(self):
- return self._trans.is_active
- is_active = property(is_active)
-
- def rollback(self):
- self._session.rollback()
-
- def prepare(self):
- self._session.prepare()
-
- def commit(self):
- self._session.commit()
-
- def close(self):
- self._session.close()
-
- def __enter__(self):
- return self
-
- def __exit__(self, type, value, traceback):
- self._trans.__exit__(type, value, traceback)
-
-
+
class TLEngine(base.Engine):
- """An Engine that includes support for thread-local managed transactions.
+ """An Engine that includes support for thread-local managed transactions."""
- The TLEngine relies upon its Pool having "threadlocal" behavior,
- so that once a connection is checked out for the current thread,
- you get that same connection repeatedly.
- """
def __init__(self, *args, **kwargs):
- """Construct a new TLEngine."""
-
super(TLEngine, self).__init__(*args, **kwargs)
- self.context = util.threading.local()
-
+ self._connections = util.threading.local()
proxy = kwargs.get('proxy')
if proxy:
self.TLConnection = base._proxy_connection_cls(TLConnection, proxy)
else:
self.TLConnection = TLConnection
- def session(self):
- "Returns the current thread's TLSession"
- if not hasattr(self.context, 'session'):
- self.context.session = TLSession(self)
- return self.context.session
-
- session = property(session)
-
- def contextual_connect(self, **kwargs):
- """Return a TLConnection which is thread-locally scoped."""
-
- return self.session.get_connection(**kwargs)
-
- def begin_twophase(self, **kwargs):
- return self.session.begin_twophase(**kwargs)
+ def contextual_connect(self, **kw):
+ if not hasattr(self._connections, 'conn'):
+ connection = None
+ else:
+ connection = self._connections.conn()
+
+ if connection is None or connection.closed:
+ # guards against pool-level reapers, if desired.
+ # or not connection.connection.is_valid:
+ connection = self.TLConnection(self, self.pool.connect(), **kw)
+ self._connections.conn = conn = weakref.ref(connection)
+
+ return connection._increment_connect()
+
+ def begin_twophase(self, xid=None):
+ if not hasattr(self._connections, 'trans'):
+ self._connections.trans = []
+ self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid))
def begin_nested(self):
- raise NotImplementedError("SAVEPOINT transactions with the 'threadlocal' strategy")
+ if not hasattr(self._connections, 'trans'):
+ self._connections.trans = []
+ self._connections.trans.append(self.contextual_connect().begin_nested())
+
+ def begin(self):
+ if not hasattr(self._connections, 'trans'):
+ self._connections.trans = []
+ self._connections.trans.append(self.contextual_connect().begin())
- def begin(self, **kwargs):
- return self.session.begin(**kwargs)
-
def prepare(self):
- self.session.prepare()
+ self._connections.trans[-1].prepare()
def commit(self):
- self.session.commit()
-
+ trans = self._connections.trans.pop(-1)
+ trans.commit()
+
def rollback(self):
- self.session.rollback()
-
+ trans = self._connections.trans.pop(-1)
+ trans.rollback()
+
+ def dispose(self):
+ self._connections = util.threading.local()
+ super(TLEngine, self).dispose()
+
+ @property
+ def closed(self):
+ return not hasattr(self._connections, 'conn') or \
+ self._connections.conn() is None or \
+ self._connections.conn().closed
+
+ def close(self):
+ if not self.closed:
+ self.contextual_connect().close()
+ del self._connections.conn
+ self._connections.trans = []
+
def __repr__(self):
return 'TLEngine(%s)' % str(self.url)
def teardown_class(cls):
users.drop(tlengine)
tlengine.dispose()
-
- def test_nested_unsupported(self):
- assert_raises(NotImplementedError, tlengine.contextual_connect().begin_nested)
- assert_raises(NotImplementedError, tlengine.begin_nested)
+
+ def setup(self):
+ # ensure tests start with engine closed
+ tlengine.close()
def test_connection_close(self):
- """test that when connections are closed for real, transactions are rolled back and disposed."""
+ """test that when connections are closed for real,
+ transactions are rolled back and disposed."""
c = tlengine.contextual_connect()
c.begin()
- assert tlengine.session.in_transaction()
- assert hasattr(tlengine.session, '_TLSession__transaction')
- assert hasattr(tlengine.session, '_TLSession__trans')
+ assert c.in_transaction()
c.close()
- assert not tlengine.session.in_transaction()
- assert not hasattr(tlengine.session, '_TLSession__transaction')
- assert not hasattr(tlengine.session, '_TLSession__trans')
+ assert not c.in_transaction()
def test_transaction_close(self):
c = tlengine.contextual_connect()
conn.close()
external_connection.close()
- def test_nesting(self):
- """tests nesting of transactions"""
+ def test_nesting_rollback(self):
+ """tests nesting of transactions, rollback at the end"""
+
external_connection = tlengine.connect()
self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
tlengine.begin()
finally:
external_connection.close()
+ def test_nesting_commit(self):
+ """tests nesting of transactions, commit at the end."""
+
+ external_connection = tlengine.connect()
+ self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
+ tlengine.begin()
+ tlengine.execute(users.insert(), user_id=1, user_name='user1')
+ tlengine.execute(users.insert(), user_id=2, user_name='user2')
+ tlengine.execute(users.insert(), user_id=3, user_name='user3')
+ tlengine.begin()
+ tlengine.execute(users.insert(), user_id=4, user_name='user4')
+ tlengine.execute(users.insert(), user_id=5, user_name='user5')
+ tlengine.commit()
+ tlengine.commit()
+ try:
+ self.assert_(external_connection.scalar("select count(1) from query_users") == 5)
+ finally:
+ external_connection.close()
+
def test_mixed_nesting(self):
"""tests nesting of transactions off the TLEngine directly inside of
tranasctions off the connection from the TLEngine"""
finally:
external_connection.close()
+ @testing.requires.savepoints
+ def test_nested_subtransaction_rollback(self):
+
+ tlengine.begin()
+ tlengine.execute(users.insert(), user_id=1, user_name='user1')
+ tlengine.begin_nested()
+ tlengine.execute(users.insert(), user_id=2, user_name='user2')
+ tlengine.rollback()
+ tlengine.execute(users.insert(), user_id=3, user_name='user3')
+ tlengine.commit()
+ tlengine.close()
+
+ eq_(
+ tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+ [(1,),(3,)]
+ )
+ tlengine.close()
+
+ @testing.requires.savepoints
+ @testing.crashes('oracle+zxjdbc', 'Errors out and causes subsequent tests to deadlock')
+ def test_nested_subtransaction_commit(self):
+ tlengine.begin()
+ tlengine.execute(users.insert(), user_id=1, user_name='user1')
+ tlengine.begin_nested()
+ tlengine.execute(users.insert(), user_id=2, user_name='user2')
+ tlengine.commit()
+ tlengine.execute(users.insert(), user_id=3, user_name='user3')
+ tlengine.commit()
+
+ tlengine.close()
+ eq_(
+ tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+ [(1,),(2,),(3,)]
+ )
+ tlengine.close()
+
+ @testing.requires.savepoints
+ def test_rollback_to_subtransaction(self):
+ tlengine.begin()
+ tlengine.execute(users.insert(), user_id=1, user_name='user1')
+ tlengine.begin_nested()
+ tlengine.execute(users.insert(), user_id=2, user_name='user2')
+ tlengine.begin()
+ tlengine.execute(users.insert(), user_id=3, user_name='user3')
+ tlengine.rollback()
+ tlengine.execute(users.insert(), user_id=4, user_name='user4')
+ tlengine.commit()
+ tlengine.close()
+
+ eq_(
+ tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
+ [(1,),(4,)]
+ )
+ tlengine.close()
+
def test_connections(self):
"""tests that contextual_connect is threadlocal"""
c1 = tlengine.contextual_connect()
c2 = tlengine.contextual_connect()
assert c1.connection is c2.connection
c2.close()
- assert c1.connection.connection is not None
-
+ assert not c1.closed
+ assert not tlengine.closed
+
+ def test_result_closing(self):
+ """tests that contextual_connect is threadlocal"""
+
+ r1 = tlengine.execute("select 1")
+ r2 = tlengine.execute("select 1")
+ row1 = r1.fetchone()
+ row2 = r2.fetchone()
+ r1.close()
+ assert r2.connection is r1.connection
+ assert not r2.connection.closed
+ assert not tlengine.closed
+
+ # close again, nothing happens
+ # since resultproxy calls close() only
+ # once
+ r1.close()
+ assert r2.connection is r1.connection
+ assert not r2.connection.closed
+ assert not tlengine.closed
+
+ r2.close()
+ assert r2.connection.closed
+ assert tlengine.closed
+
+ def test_dispose(self):
+ eng = create_engine(testing.db.url, strategy='threadlocal')
+ result = eng.execute("select 1")
+ eng.dispose()
+ eng.execute("select 1")
+
+
@testing.requires.two_phase_transactions
def test_two_phase_transaction(self):
tlengine.begin_twophase()