state. An actual commit() may or may not have occured, depending on whether or not
the flush started its own transaction or participated in a larger transaction.
"""
+
+ def after_begin(self, session, transaction, connection):
+ """Execute after a transaction is begun on a connection
+
+ `transaction` is the SessionTransaction. This method is called after an
+ engine level transaction is begun on a connection.
+ """
class SessionTransaction(object):
"""Represents a Session-level Transaction.
transaction = conn.begin()
self._connections[conn] = self._connections[conn.engine] = (conn, transaction, conn is not bind)
+ if self.session.extension is not None:
+ self.session.extension.after_begin(self.session, self, conn)
return conn
def prepare(self):
for subtransaction in self.session.transaction._iterate_parents(upto=self):
subtransaction.close()
- if self.is_active:
+ if self.is_active or self._prepared:
for transaction in self._iterate_parents():
if transaction._parent is None or transaction.nested:
transaction._rollback_impl()
break
else:
transaction._deactivate()
+
self.close()
return self._parent
log.append('after_flush')
def after_flush_postexec(self, session, flush_context):
log.append('after_flush_postexec')
+ def after_begin(self, session, transaction, connection):
+ log.append('after_begin')
sess = create_session(extension = MyExt())
u = User()
sess.save(u)
sess.flush()
-
- assert log == ['before_flush', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec']
+ assert log == ['before_flush', 'after_begin', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec']
log = []
sess = create_session(transactional=True, extension=MyExt())
u = User()
sess.save(u)
sess.flush()
- assert log == ['before_flush', 'after_flush', 'after_flush_postexec']
+ assert log == ['before_flush', 'after_begin', 'after_flush', 'after_flush_postexec']
log = []
u.user_name = 'ed'
log = []
sess.commit()
assert log == ['before_commit', 'after_commit']
+
+ log = []
+ sess = create_session(transactional=True, extension=MyExt(), bind=testing.db)
+ conn = sess.connection()
+ assert log == ['after_begin']
def test_pickled_update(self):
mapper(User, users)