]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- factor consistent set_isolation_level(), get_isolation_level()
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 16 Jan 2011 18:05:18 +0000 (13:05 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 16 Jan 2011 18:05:18 +0000 (13:05 -0500)
per-connection methods for sqlite, postgresql, psycopg2 dialects
- move isolation test suite to test engines/test_transaction
- preparing for [ticket:2001]

lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/dialects/sqlite/base.py
test/dialect/test_postgresql.py
test/dialect/test_sqlite.py
test/engine/test_transaction.py
test/lib/requires.py

index a8fb4e51a7bf2b83d90fb9dc6b82c491c909a37b..9097c3a6eabbb8315b68676299e1f29ea62e0a27 100644 (file)
@@ -769,16 +769,36 @@ class PGDialect(default.DefaultDialect):
     def on_connect(self):
         if self.isolation_level is not None:
             def connect(conn):
-                cursor = conn.cursor()
-                cursor.execute(
-                    "SET SESSION CHARACTERISTICS AS TRANSACTION "
-                    "ISOLATION LEVEL %s" % self.isolation_level)
-                cursor.execute("COMMIT")
-                cursor.close()
+                self.set_isolation_level(conn, self.isolation_level)
             return connect
         else:
             return None
 
+    _isolation_lookup = set(['SERIALIZABLE', 
+                'READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ'])
+
+    def set_isolation_level(self, connection, level):
+        level = level.replace('_', ' ')
+        if level not in self._isolation_lookup:
+            raise exc.ArgumentError(
+                "Invalid value '%s' for isolation_level. "
+                "Valid isolation levels for %s are %s" % 
+                (self.name, level, ", ".join(self._isolation_lookup))
+                ) 
+        cursor = connection.cursor()
+        cursor.execute(
+            "SET SESSION CHARACTERISTICS AS TRANSACTION "
+            "ISOLATION LEVEL %s" % level)
+        cursor.execute("COMMIT")
+        cursor.close()
+
+    def get_isolation_level(self, connection):
+        cursor = connection.cursor()
+        cursor.execute('show transaction isolation level')
+        val = cursor.fetchone()[0]
+        cursor.close()
+        return val.upper()
+
     def do_begin_twophase(self, connection, xid):
         self.do_begin(connection.connection)
 
index 411bd42bd3a18fe927ae0d8d5345ef0146fe4542..806ba41f8b62de45932532715bdbbf3b3e74e45b 100644 (file)
@@ -243,23 +243,32 @@ class PGDialect_psycopg2(PGDialect):
         psycopg = __import__('psycopg2')
         return psycopg
 
-    def on_connect(self):
-        if self.isolation_level is not None:
-            extensions = __import__('psycopg2.extensions').extensions
-            isol = {
+    @util.memoized_property
+    def _isolation_lookup(self):
+        extensions = __import__('psycopg2.extensions').extensions
+        return {
             'READ_COMMITTED':extensions.ISOLATION_LEVEL_READ_COMMITTED, 
             'READ_UNCOMMITTED':extensions.ISOLATION_LEVEL_READ_UNCOMMITTED, 
             'REPEATABLE_READ':extensions.ISOLATION_LEVEL_REPEATABLE_READ,
             'SERIALIZABLE':extensions.ISOLATION_LEVEL_SERIALIZABLE
+        }
 
-            }
+    def set_isolation_level(self, connection, level):
+        try:
+            level = self._isolation_lookup[level.replace(' ', '_')]
+        except KeyError:
+            raise exc.ArgumentError(
+                "Invalid value '%s' for isolation_level. "
+                "Valid isolation levels for %s are %s" % 
+                (self.name, level, ", ".join(self._isolation_lookup))
+                ) 
+
+        connection.set_isolation_level(level)
+
+    def on_connect(self):
+        if self.isolation_level is not None:
             def base_on_connect(conn):
-                try:
-                    conn.set_isolation_level(isol[self.isolation_level])
-                except:
-                    raise exc.InvalidRequestError(
-                                "Invalid isolation level: '%s'" % 
-                                self.isolation_level)
+                self.set_isolation_level(conn, self.isolation_level)
         else:
             base_on_connect = None
 
index f732f1f44856fe3b2f7092dddb46c43f45a15714..9f1f6432582700d74a6c4622db66ddabfad66abd 100644 (file)
@@ -372,11 +372,6 @@ class SQLiteDialect(default.DefaultDialect):
 
     def __init__(self, isolation_level=None, native_datetime=False, **kwargs):
         default.DefaultDialect.__init__(self, **kwargs)
-        if isolation_level and isolation_level not in ('SERIALIZABLE',
-                'READ UNCOMMITTED'):
-            raise exc.ArgumentError("Invalid value for isolation_level. "
-                "Valid isolation levels for sqlite are 'SERIALIZABLE' and "
-                "'READ UNCOMMITTED'.")
         self.isolation_level = isolation_level
 
         # this flag used by pysqlite dialect, and perhaps others in the
@@ -391,18 +386,39 @@ class SQLiteDialect(default.DefaultDialect):
             self.supports_cast = \
                                 self.dbapi.sqlite_version_info >= (3, 2, 3)
 
+    _isolation_lookup = {
+        'READ UNCOMMITTED':1,
+        'SERIALIZABLE':0
+    }
+    def set_isolation_level(self, connection, level):
+        try:
+            isolation_level = self._isolation_lookup[level.replace('_', ' ')]
+        except KeyError:
+            raise exc.ArgumentError(
+                "Invalid value '%s' for isolation_level. "
+                "Valid isolation levels for %s are %s" % 
+                (self.name, level, ", ".join(self._isolation_lookup))
+                ) 
+        cursor = connection.cursor()
+        cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level)
+        cursor.close()
+
+    def get_isolation_level(self, connection):
+        cursor = connection.cursor()
+        cursor.execute('PRAGMA read_uncommitted')
+        value = cursor.fetchone()[0]
+        cursor.close()
+        if value == 0:
+            return "SERIALIZABLE"
+        elif value == 1:
+            return "READ UNCOMMITTED"
+        else:
+            assert False, "Unknown isolation level %s" % value
 
     def on_connect(self):
         if self.isolation_level is not None:
-            if self.isolation_level == 'READ UNCOMMITTED':
-                isolation_level = 1
-            else:
-                isolation_level = 0
-
             def connect(conn):
-                cursor = conn.cursor()
-                cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level)
-                cursor.close()
+                self.set_isolation_level(conn, self.isolation_level)
             return connect
         else:
             return None
index 083d32b15c1317369dab66359db6375a35b52145..f8ca65e603ac3e4db9c8c40173e3d7fa5543d901 100644 (file)
@@ -1382,40 +1382,6 @@ class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
             ])
 
 
-    @testing.fails_on('postgresql+pypostgresql',
-                      'pypostgresql bombs on multiple calls')
-    def test_set_isolation_level(self):
-        """Test setting the isolation level with create_engine"""
-
-        eng = create_engine(testing.db.url)
-        eq_(eng.execute('show transaction isolation level').scalar(),
-            'read committed')
-        eng = create_engine(testing.db.url,
-                            isolation_level='SERIALIZABLE')
-        eq_(eng.execute('show transaction isolation level').scalar(),
-            'serializable')
-
-        # check that it stays
-        conn = eng.connect()
-        eq_(conn.execute('show transaction isolation level').scalar(),
-            'serializable')
-        conn.close()
-
-        conn = eng.connect()
-        eq_(conn.execute('show transaction isolation level').scalar(),
-            'serializable')
-        conn.close()
-
-        eng = create_engine(testing.db.url, isolation_level='FOO')
-        if testing.db.driver == 'zxjdbc':
-            exception_cls = eng.dialect.dbapi.Error
-        elif testing.db.driver == 'psycopg2':
-            exception_cls = exc.InvalidRequestError
-        else:
-            exception_cls = eng.dialect.dbapi.ProgrammingError
-        assert_raises(exception_cls, eng.execute,
-                      'show transaction isolation level')
-
     @testing.fails_on('+zxjdbc', 'psycopg2/pg8000 specific assertion')
     @testing.fails_on('pypostgresql',
                       'psycopg2/pg8000 specific assertion')
index 2413c12e833fe2df638001593bba0d24da5827e6..b99f58bd25bed25a3dc9609a73f2f7e8e3e392a9 100644 (file)
@@ -353,20 +353,6 @@ class DialectTest(TestBase, AssertsExecutionResults):
         finally:
             meta.drop_all()
 
-    def test_set_isolation_level(self):
-        """Test setting the read uncommitted/serializable levels"""
-
-        eng = create_engine(testing.db.url)
-        eq_(eng.execute('PRAGMA read_uncommitted').scalar(), 0)
-        eng = create_engine(testing.db.url,
-                            isolation_level='READ UNCOMMITTED')
-        eq_(eng.execute('PRAGMA read_uncommitted').scalar(), 1)
-        eng = create_engine(testing.db.url,
-                            isolation_level='SERIALIZABLE')
-        eq_(eng.execute('PRAGMA read_uncommitted').scalar(), 0)
-        assert_raises(exc.ArgumentError, create_engine, testing.db.url,
-                      isolation_level='FOO')
-
     def test_create_index_with_schema(self):
         """Test creation of index with explicit schema"""
 
index bec8b0037fb22a19897188550b492a14120b08ba..1fb0267bb5405f1b12820f8b61b302fa92f7b802 100644 (file)
@@ -1,5 +1,5 @@
 from test.lib.testing import eq_, assert_raises, \
-    assert_raises_message
+    assert_raises_message, ne_
 import sys
 import time
 import threading
@@ -1109,3 +1109,66 @@ class ForUpdateTest(TestBase):
                 update_style='nowait')
         self.assert_(len(errors) != 0)
 
+class IsolationLevelTest(TestBase):
+    def _default_isolation_level(self):
+        if testing.against('sqlite'):
+            return 'SERIALIZABLE'
+        elif testing.against('postgresql'):
+            return 'READ COMMITTED'
+        else:
+            assert False, "default isolation level not known"
+
+    def _non_default_isolation_level(self):
+        if testing.against('sqlite'):
+            return 'READ UNCOMMITTED'
+        elif testing.against('postgresql'):
+            return 'SERIALIZABLE'
+        else:
+            assert False, "non default isolation level not known"
+
+    @testing.requires.isolation_level
+    def test_engine_param_stays(self):
+
+        eng = create_engine(testing.db.url)
+        isolation_level = eng.dialect.get_isolation_level(eng.connect().connection)
+        level = self._non_default_isolation_level()
+
+        ne_(isolation_level, level)
+
+        eng = create_engine(testing.db.url,
+                            isolation_level=level)
+        eq_(
+            eng.dialect.get_isolation_level(eng.connect().connection),
+            level
+        )
+
+        # check that it stays
+        conn = eng.connect()
+        eq_(
+            eng.dialect.get_isolation_level(conn.connection),
+            level
+        )
+        conn.close()
+
+        conn = eng.connect()
+        eq_(
+            eng.dialect.get_isolation_level(conn.connection),
+            level
+        )
+        conn.close()
+
+    @testing.requires.isolation_level
+    def test_default_level(self):
+        eng = create_engine(testing.db.url)
+        isolation_level = eng.dialect.get_isolation_level(eng.connect().connection)
+        eq_(isolation_level, self._default_isolation_level())
+
+    @testing.requires.isolation_level
+    def test_invalid_level(self):
+        eng = create_engine(testing.db.url, isolation_level='FOO')
+        assert_raises_message(
+            exc.ArgumentError, 
+                "Invalid value '%s' for isolation_level. "
+                "Valid isolation levels for %s are %s" % 
+                (eng.dialect.name, "FOO", ", ".join(eng.dialect._isolation_lookup)),
+            eng.connect)
index b689250d27540f45392b0786a376b5aa4e52d7b7..03bda74d29876123743a48f13d23ebc81d402e3e 100644 (file)
@@ -11,6 +11,7 @@ from testing import \
      exclude, \
      emits_warning_on,\
      skip_if,\
+     only_on,\
      fails_on,\
      fails_on_everything_except
 
@@ -92,6 +93,14 @@ def independent_connections(fn):
                 'SQL Server 2005+ is required for independent connections'),
         )
 
+def isolation_level(fn):
+    return _chain_decorators_on(
+        fn,
+        only_on(('postgresql', 'sqlite'), "DBAPI has no isolation level support"),
+        fails_on('postgresql+pypostgresql',
+                      'pypostgresql bombs on multiple isolation level calls')
+    )
+
 def row_triggers(fn):
     """Target must support standard statement-running EACH ROW triggers."""
     return _chain_decorators_on(