From ecb5b87c19eeec41a0dc0d96af88454de17c6511 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 17 Mar 2006 02:15:09 +0000 Subject: [PATCH] refactor to engine to have a separate SQLSession object. allows nested transactions. util.ThreadLocal __hasattr__ method/raise_error param meaningless, removed renamed old engines test to reflection --- lib/sqlalchemy/databases/sqlite.py | 3 + lib/sqlalchemy/engine.py | 93 +++++++++++++++++++----------- lib/sqlalchemy/util.py | 10 +--- test/alltests.py | 5 +- test/{engines.py => reflection.py} | 2 +- test/tables.py | 2 +- 6 files changed, 69 insertions(+), 46 deletions(-) rename test/{engines.py => reflection.py} (99%) diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 18a97e3f78..ffa7c1bd45 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -151,6 +151,9 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): def dbapi(self): return sqlite + + def push_session(self): + raise InvalidRequestError("SQLite doesn't support nested sessions") def schemagenerator(self, **params): return SQLiteSchemaGenerator(self, **params) diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 670a742b18..aa0449628d 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -171,8 +171,38 @@ class DefaultRunner(schema.SchemaVisitor): return default.arg class SQLSession(object): - pass + """represents a particular connection retrieved from a SQLEngine, and associated transactional state.""" + def __init__(self, engine, parent=None): + self.engine = engine + self.parent = parent + self.__tcount = 0 + def _connection(self): + try: + return self.__connection + except AttributeError: + self.__connection = self.engine._pool.connect() + return self.__connection + connection = property(_connection, doc="the connection represented by this SQLSession. The connection is late-connecting, meaning the call to the connection pool only occurs when it is first called (and the pool will typically only connect the first time it is called as well)") + def begin(self): + """begins" a transaction on this SQLSession's connection. repeated calls to begin() will increment a counter that must be decreased by corresponding commit() statements before an actual commit occurs. this is to provide "nested" behavior of transactions so that different functions in a particular call stack can call begin()/commit() independently of each other without knowledge of an existing transaction.""" + if self.__tcount == 0: + self.engine.do_begin(self.connection) + self.__tcount += 1 + def rollback(self): + """rolls back the transaction on this SQLSession's connection. this can be called regardless of the "begin" counter value, i.e. can be called from anywhere inside a callstack. the "begin" counter is cleared.""" + if self.__tcount > 0: + self.engine.do_rollback(self.connection) + self.__tcount = 0 + def commit(self): + """commits the transaction started by begin(). If begin() was called multiple times, a counter will be decreased for each call to commit(), with the actual commit operation occuring when the counter reaches zero. this is to provide "nested" behavior of transactions so that different functions in a particular call stack can call begin()/commit() independently of each other without knowledge of an existing transaction.""" + if self.__tcount == 1: + self.engine.do_commit(self.connection) + elif self.__tcount > 1: + self.__tcount -= 1 + def is_begun(self): + return self.__tcount > 0 + class SQLEngine(schema.SchemaEngine): """ The central "database" object used by an application. Subclasses of this object is used @@ -192,6 +222,7 @@ class SQLEngine(schema.SchemaEngine): (cargs, cparams) = self.connect_args() if pool is None: params['echo'] = echo_pool + params['use_threadlocal'] = False self._pool = sqlalchemy.pool.manage(self.dbapi(), **params).get_pool(*cargs, **cparams) else: self._pool = pool @@ -200,7 +231,7 @@ class SQLEngine(schema.SchemaEngine): self.echo_uow = echo_uow self.convert_unicode = convert_unicode self.encoding = encoding - self.context = util.ThreadLocal(raiseerror=False) + self.context = util.ThreadLocal() self._ischema = None self._figure_paramstyle() self.logger = logger or util.Logger(origin='engine') @@ -390,9 +421,28 @@ class SQLEngine(schema.SchemaEngine): """implementations might want to put logic here for turning autocommit on/off, etc.""" connection.commit() + def _session(self): + if not hasattr(self.context, 'session'): + self.context.session = SQLSession(self) + return self.context.session + session = property(_session, doc="returns the current thread's SQLSession") + + def push_session(self): + """pushes a new SQLSession onto this engine, temporarily replacing the previous one for the current thread. The previous session can be restored by calling pop_session(). this allows the usage of a new connection and possibly transaction within a particular block, superceding the existing one, including any transactions that are in progress. Returns the new SQLSession object.""" + sess = SQLSession(self, self.context.session) + self.context.session = sess + return sess + def pop_session(self): + """restores the current thread's SQLSession to that before the last push_session. Returns the restored SQLSession object. Raises an exception if there is no SQLSession pushed onto the stack.""" + sess = self.context.session.parent + if sess is None: + raise InvalidRequestError("No SQLSession is pushed onto the stack.") + self.context.session = sess + return sess + def connection(self): """returns a managed DBAPI connection from this SQLEngine's connection pool.""" - return self._pool.connect() + return self.session.connection def unique_connection(self): """returns a DBAPI connection from this SQLEngine's connection pool that is distinct from the current thread's connection.""" @@ -434,40 +484,15 @@ class SQLEngine(schema.SchemaEngine): self.commit() def begin(self): - """"begins" a transaction on a pooled connection, and stores the connection in a thread-local - context. repeated calls to begin() within the same thread will increment a counter that must be - decreased by corresponding commit() statements before an actual commit occurs. this is to provide - "nested" behavior of transactions so that different functions can all call begin()/commit() and still - call each other.""" - if getattr(self.context, 'transaction', None) is None: - conn = self.connection() - self.do_begin(conn) - self.context.transaction = conn - self.context.tcount = 1 - else: - self.context.tcount += 1 + """"begins a transaction on the current thread's SQLSession.""" + self.session.begin() def rollback(self): - """rolls back the current thread-local transaction started by begin(). the "begin" counter - is cleared and the transaction ended.""" - if self.context.transaction is not None: - self.do_rollback(self.context.transaction) - self.context.transaction = None - self.context.tcount = None + """rolls back the transaction on the current thread's SQLSession.""" + self.session.rollback() def commit(self): - """commits the current thread-local transaction started by begin(). If begin() was called multiple - times, a counter will be decreased for each call to commit(), with the actual commit operation occuring - when the counter reaches zero. this is to provide - "nested" behavior of transactions so that different functions can all call begin()/commit() and still - call each other.""" - if self.context.transaction is not None: - count = self.context.tcount - 1 - self.context.tcount = count - if count == 0: - self.do_commit(self.context.transaction) - self.context.transaction = None - self.context.tcount = None + self.session.commit() def _process_defaults(self, proxy, compiled, parameters, **kwargs): """INSERT and UPDATE statements, when compiled, may have additional columns added to their @@ -642,7 +667,7 @@ class SQLEngine(schema.SchemaEngine): self._executemany(cursor, statement, parameters) else: self._execute(cursor, statement, parameters) - if self.context.transaction is None: + if not self.session.is_begun(): self.do_commit(connection) except: self.do_rollback(connection) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 303cf56830..98d3a70f14 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -214,11 +214,8 @@ class OrderedDict(dict): class ThreadLocal(object): """an object in which attribute access occurs only within the context of the current thread""" - def __init__(self, raiseerror = True): + def __init__(self): self.__dict__['_tdict'] = {} - self.__dict__['_raiseerror'] = raiseerror - def __hasattr__(self, key): - return self._tdict.has_key("%d_%s" % (thread.get_ident(), key)) def __delattr__(self, key): try: del self._tdict["%d_%s" % (thread.get_ident(), key)] @@ -228,10 +225,7 @@ class ThreadLocal(object): try: return self._tdict["%d_%s" % (thread.get_ident(), key)] except KeyError: - if self._raiseerror: - raise AttributeError(key) - else: - return None + raise AttributeError(key) def __setattr__(self, key, value): self._tdict["%d_%s" % (thread.get_ident(), key)] = value diff --git a/test/alltests.py b/test/alltests.py index 3199b89f91..d30f97287c 100644 --- a/test/alltests.py +++ b/test/alltests.py @@ -12,11 +12,12 @@ def suite(): 'attributes', 'dependency', - # connectivity + # connectivity, execution 'pool', + 'engine', # schema/tables - 'engines', + 'reflection', 'testtypes', 'indexes', diff --git a/test/engines.py b/test/reflection.py similarity index 99% rename from test/engines.py rename to test/reflection.py index 3bb20c7251..718957addb 100644 --- a/test/engines.py +++ b/test/reflection.py @@ -10,7 +10,7 @@ from testbase import PersistTest import testbase import unittest, re -class EngineTest(PersistTest): +class ReflectionTest(PersistTest): def testbasic(self): # really trip it up with a circular reference diff --git a/test/tables.py b/test/tables.py index 927903e486..f1e1a845bc 100644 --- a/test/tables.py +++ b/test/tables.py @@ -14,7 +14,7 @@ db = testbase.db users = Table('users', db, Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), Column('user_name', String(40)), - + mysql_engine='innodb' ) addresses = Table('email_addresses', db, -- 2.47.2