]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
refactor to engine to have a separate SQLSession object. allows nested transactions.
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 17 Mar 2006 02:15:09 +0000 (02:15 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 17 Mar 2006 02:15:09 +0000 (02:15 +0000)
util.ThreadLocal __hasattr__ method/raise_error param meaningless, removed
renamed old engines test to reflection

lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/util.py
test/alltests.py
test/reflection.py [moved from test/engines.py with 99% similarity]
test/tables.py

index 18a97e3f788dddab2cee3bc69f7134d4203fd20e..ffa7c1bd450a1416c9a46a97cba1ff65f338944e 100644 (file)
@@ -151,6 +151,9 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine):
 
     def dbapi(self):
         return sqlite
+        
+    def push_session(self):
+        raise InvalidRequestError("SQLite doesn't support nested sessions")
 
     def schemagenerator(self, **params):
         return SQLiteSchemaGenerator(self, **params)
index 670a742b1881901742f67b23e997168306acb830..aa0449628d258e4ef235e95d526e20b6a6dfaade 100644 (file)
@@ -171,8 +171,38 @@ class DefaultRunner(schema.SchemaVisitor):
             return default.arg
             
 class SQLSession(object):
-    pass
+    """represents a particular connection retrieved from a SQLEngine, and associated transactional state."""
+    def __init__(self, engine, parent=None):
+        self.engine = engine
+        self.parent = parent
+        self.__tcount = 0
+    def _connection(self):
+        try:
+            return self.__connection
+        except AttributeError:
+            self.__connection = self.engine._pool.connect()
+            return self.__connection
+    connection = property(_connection, doc="the connection represented by this SQLSession.  The connection is late-connecting, meaning the call to the connection pool only occurs when it is first called (and the pool will typically only connect the first time it is called as well)")
     
+    def begin(self):
+        """begins" a transaction on this SQLSession's connection.  repeated calls to begin() will increment a counter that must be decreased by corresponding commit() statements before an actual commit occurs.  this is to provide "nested" behavior of transactions so that different functions in a particular call stack can call begin()/commit() independently of each other without knowledge of an existing transaction."""
+        if self.__tcount == 0:
+            self.engine.do_begin(self.connection)
+        self.__tcount += 1
+    def rollback(self):
+        """rolls back the transaction on this SQLSession's connection.  this can be called regardless of the "begin" counter value, i.e. can be called from anywhere inside a callstack.  the "begin" counter is cleared."""
+        if self.__tcount > 0:
+            self.engine.do_rollback(self.connection)
+            self.__tcount = 0
+    def commit(self):
+        """commits the transaction started by begin().  If begin() was called multiple times, a counter will be decreased for each call to commit(), with the actual commit operation occuring when the counter reaches zero.  this is to provide "nested" behavior of transactions so that different functions in a particular call stack can call begin()/commit() independently of each other without knowledge of an existing transaction."""
+        if self.__tcount == 1:
+            self.engine.do_commit(self.connection)
+        elif self.__tcount > 1:
+            self.__tcount -= 1
+    def is_begun(self):
+        return self.__tcount > 0
+                
 class SQLEngine(schema.SchemaEngine):
     """
     The central "database" object used by an application.  Subclasses of this object is used
@@ -192,6 +222,7 @@ class SQLEngine(schema.SchemaEngine):
         (cargs, cparams) = self.connect_args()
         if pool is None:
             params['echo'] = echo_pool
+            params['use_threadlocal'] = False
             self._pool = sqlalchemy.pool.manage(self.dbapi(), **params).get_pool(*cargs, **cparams)
         else:
             self._pool = pool
@@ -200,7 +231,7 @@ class SQLEngine(schema.SchemaEngine):
         self.echo_uow = echo_uow
         self.convert_unicode = convert_unicode
         self.encoding = encoding
-        self.context = util.ThreadLocal(raiseerror=False)
+        self.context = util.ThreadLocal()
         self._ischema = None
         self._figure_paramstyle()
         self.logger = logger or util.Logger(origin='engine')
@@ -390,9 +421,28 @@ class SQLEngine(schema.SchemaEngine):
         """implementations might want to put logic here for turning autocommit on/off, etc."""
         connection.commit()
 
+    def _session(self):
+        if not hasattr(self.context, 'session'):
+            self.context.session = SQLSession(self)
+        return self.context.session
+    session = property(_session, doc="returns the current thread's SQLSession")
+    
+    def push_session(self):
+        """pushes a new SQLSession onto this engine, temporarily replacing the previous one for the current thread.  The previous session can be restored by calling pop_session().  this allows the usage of a new connection and possibly transaction within a particular block, superceding the existing one, including any transactions that are in progress.  Returns the new SQLSession object."""
+        sess = SQLSession(self, self.context.session)
+        self.context.session = sess
+        return sess
+    def pop_session(self):
+        """restores the current thread's SQLSession to that before the last push_session.  Returns the restored SQLSession object.  Raises an exception if there is no SQLSession pushed onto the stack."""
+        sess = self.context.session.parent
+        if sess is None:
+            raise InvalidRequestError("No SQLSession is pushed onto the stack.")
+        self.context.session = sess
+        return sess
+        
     def connection(self):
         """returns a managed DBAPI connection from this SQLEngine's connection pool."""
-        return self._pool.connect()
+        return self.session.connection
 
     def unique_connection(self):
         """returns a DBAPI connection from this SQLEngine's connection pool that is distinct from the current thread's connection."""
@@ -434,40 +484,15 @@ class SQLEngine(schema.SchemaEngine):
         self.commit()
         
     def begin(self):
-        """"begins" a transaction on a pooled connection, and stores the connection in a thread-local
-        context.  repeated calls to begin() within the same thread will increment a counter that must be
-        decreased by corresponding commit() statements before an actual commit occurs.  this is to provide
-        "nested" behavior of transactions so that different functions can all call begin()/commit() and still
-        call each other."""
-        if getattr(self.context, 'transaction', None) is None:
-            conn = self.connection()
-            self.do_begin(conn)
-            self.context.transaction = conn
-            self.context.tcount = 1
-        else:
-            self.context.tcount += 1
+        """"begins a transaction on the current thread's SQLSession."""
+        self.session.begin()
             
     def rollback(self):
-        """rolls back the current thread-local transaction started by begin().  the "begin" counter 
-        is cleared and the transaction ended."""
-        if self.context.transaction is not None:
-            self.do_rollback(self.context.transaction)
-            self.context.transaction = None
-            self.context.tcount = None
+        """rolls back the transaction on the current thread's SQLSession."""
+        self.session.rollback()
             
     def commit(self):
-        """commits the current thread-local transaction started by begin().  If begin() was called multiple
-        times, a counter will be decreased for each call to commit(), with the actual commit operation occuring
-        when the counter reaches zero.  this is to provide
-        "nested" behavior of transactions so that different functions can all call begin()/commit() and still
-        call each other."""
-        if self.context.transaction is not None:
-            count = self.context.tcount - 1
-            self.context.tcount = count
-            if count == 0:
-                self.do_commit(self.context.transaction)
-                self.context.transaction = None
-                self.context.tcount = None
+        self.session.commit()
 
     def _process_defaults(self, proxy, compiled, parameters, **kwargs):
         """INSERT and UPDATE statements, when compiled, may have additional columns added to their
@@ -642,7 +667,7 @@ class SQLEngine(schema.SchemaEngine):
                 self._executemany(cursor, statement, parameters)
             else:
                 self._execute(cursor, statement, parameters)
-            if self.context.transaction is None:
+            if not self.session.is_begun():
                 self.do_commit(connection)
         except:
             self.do_rollback(connection)
index 303cf56830f9201c60f2cab55e024766d821e2a8..98d3a70f146c5db333d5530a5906dd919bd8e950 100644 (file)
@@ -214,11 +214,8 @@ class OrderedDict(dict):
 
 class ThreadLocal(object):
     """an object in which attribute access occurs only within the context of the current thread"""
-    def __init__(self, raiseerror = True):
+    def __init__(self):
         self.__dict__['_tdict'] = {}
-        self.__dict__['_raiseerror'] = raiseerror
-    def __hasattr__(self, key):
-        return self._tdict.has_key("%d_%s" % (thread.get_ident(), key))
     def __delattr__(self, key):
         try:
             del self._tdict["%d_%s" % (thread.get_ident(), key)]
@@ -228,10 +225,7 @@ class ThreadLocal(object):
         try:
             return self._tdict["%d_%s" % (thread.get_ident(), key)]
         except KeyError:
-            if self._raiseerror:
-                raise AttributeError(key)
-            else:
-                return None
+            raise AttributeError(key)
     def __setattr__(self, key, value):
         self._tdict["%d_%s" % (thread.get_ident(), key)] = value
 
index 3199b89f9193ff9a4cfe2b1495af09beacdade0b..d30f97287cd6697fa675e586a1749e555b560e33 100644 (file)
@@ -12,11 +12,12 @@ def suite():
         'attributes', 
         'dependency',
         
-        # connectivity
+        # connectivity, execution
         'pool', 
+        'engine',
         
         # schema/tables
-        'engines', 
+        'reflection', 
         'testtypes',
         'indexes',
 
similarity index 99%
rename from test/engines.py
rename to test/reflection.py
index 3bb20c7251cd8b5c9f46db798003919762c8e6e2..718957addb6a23161fc8b4274a46658915321ab6 100644 (file)
@@ -10,7 +10,7 @@ from testbase import PersistTest
 import testbase
 import unittest, re
 
-class EngineTest(PersistTest):
+class ReflectionTest(PersistTest):
     def testbasic(self):
         # really trip it up with a circular reference
         
index 927903e48681aa8f6b1b002e30d719223ad02722..f1e1a845bcf1f7eece92a7e90957de63393c4ea1 100644 (file)
@@ -14,7 +14,7 @@ db = testbase.db
 users = Table('users', db,
     Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
     Column('user_name', String(40)),
-    
+    mysql_engine='innodb'
 )
 
 addresses = Table('email_addresses', db,