]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
repaired oracle savepoint implementation
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Aug 2007 00:03:26 +0000 (00:03 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Aug 2007 00:03:26 +0000 (00:03 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/oracle.py
test/engine/transaction.py

index 430027ed8d2ac60b26e397db7dd6b00c2a402b16..0efaf86575d39338ef8e3dfde25fe75e100c403d 100644 (file)
@@ -745,13 +745,13 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor):
         return text
         
     def visit_savepoint(self, savepoint_stmt):
-        return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+        return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
 
     def visit_rollback_to_savepoint(self, savepoint_stmt):
-        return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+        return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
     
     def visit_release_savepoint(self, savepoint_stmt):
-        return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
+        return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
     
     def __str__(self):
         return self.string
@@ -1052,8 +1052,8 @@ class ANSIIdentifierPreparer(object):
     def format_alias(self, alias, name=None):
         return self.__generic_obj_format(alias, name or alias.name)
 
-    def format_savepoint(self, savepoint):
-        return self.__generic_obj_format(savepoint, savepoint)
+    def format_savepoint(self, savepoint, name=None):
+        return self.__generic_obj_format(savepoint, name or savepoint.ident)
 
     def format_constraint(self, constraint):
         return self.__generic_obj_format(constraint, constraint.name)
index 9b8bb2f9e585232ef80a7112c7dfa9c1524f6263..7bbc63fba7ab2f02d97d9d88d88dc7db6f3bfd7e 100644 (file)
@@ -280,12 +280,19 @@ class OracleDialect(ansisql.ANSIDialect):
         else:
             return "rowid"
 
+    def do_release_savepoint(self, connection, name):
+        # Oracle does not support RELEASE SAVEPOINT
+        pass
+
     def create_execution_context(self, *args, **kwargs):
         return OracleExecutionContext(self, *args, **kwargs)
 
     def compiler(self, statement, bindparams, **kwargs):
         return OracleCompiler(self, statement, bindparams, **kwargs)
 
+    def preparer(self):
+        return OracleIdentifierPreparer(self)
+
     def schemagenerator(self, *args, **kwargs):
         return OracleSchemaGenerator(self, *args, **kwargs)
 
@@ -662,4 +669,10 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
     def visit_sequence(self, seq):
         return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar()
 
+class OracleIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+    def format_savepoint(self, savepoint):
+        name = re.sub(r'^_+', '', savepoint.ident)
+        return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
+
+    
 dialect = OracleDialect
index 3c84684daeb87e3b43e8532a30226c7b247c87d4..8c516b19525addac2712f1d14f27ae267985c125 100644 (file)
@@ -173,7 +173,7 @@ class TransactionTest(PersistTest):
         )
         connection.close()
     
-    @testing.supported('postgres', 'mysql')
+    @testing.supported('postgres', 'mysql', 'oracle')
     @testing.exclude('mysql', '<', (5, 0, 3))
     def testtwophasetransaction(self):
         connection = testbase.db.connect()
@@ -301,7 +301,7 @@ class TLTransactionTest(PersistTest):
         tlengine = create_engine(testbase.db.url, strategy='threadlocal')
         metadata = MetaData()
         users = Table('query_users', metadata,
-            Column('user_id', INT, primary_key = True),
+            Column('user_id', INT, Sequence('query_users_id_seq', optional=True), primary_key=True),
             Column('user_name', VARCHAR(20)),
             test_needs_acid=True,
         )