]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
catching up oracle to current, some tweaks to unittests to work better with oracle,
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Dec 2005 00:23:01 +0000 (00:23 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Dec 2005 00:23:01 +0000 (00:23 +0000)
allow different ordering of expected statements.
unittests still dont work completely with oracle due to sequence columns in INSERT statements

lib/sqlalchemy/databases/oracle.py
test/objectstore.py
test/testbase.py

index ae1445759d8b6b07d5a5bc21c8877bcdb6b71022..04d59a58282c878fb10154422bdd34d3c6b69a28 100644 (file)
@@ -110,7 +110,7 @@ class OracleSQLEngine(ansisql.ANSISQLEngine):
     def schemadropper(self, proxy, **params):
         return OracleSchemaDropper(proxy, **params)
     def defaultrunner(self, proxy):
-        return OracleDefaultRunner(proxy)
+        return OracleDefaultRunner(self, proxy)
         
     def reflecttable(self, table):
         raise "not implemented"
@@ -143,10 +143,10 @@ class OracleCompiler(ansisql.ANSICompiler):
     """oracle compiler modifies the lexical structure of Select statements to work under 
     non-ANSI configured Oracle databases, if the use_ansi flag is False."""
     
-    def __init__(self, engine, statement, bindparams, use_ansi = True, **kwargs):
+    def __init__(self, engine, statement, parameters, use_ansi = True, **kwargs):
         self._outertable = None
         self._use_ansi = use_ansi
-        ansisql.ANSICompiler.__init__(self, engine, statement, bindparams, **kwargs)
+        ansisql.ANSICompiler.__init__(self, engine, statement, parameters, **kwargs)
         
     def visit_join(self, join):
         if self._use_ansi:
@@ -165,7 +165,12 @@ class OracleCompiler(ansisql.ANSICompiler):
             join.onclause.accept_visitor(self)
 
             self._outertable = outertable
-        
+       
+    def visit_alias(self, alias):
+       """oracle doesnt like 'FROM table AS alias'.  is the AS standard SQL??"""
+        self.froms[alias] = self.get_from_text(alias.selectable) + " " + alias.name
+        self.strings[alias] = self.get_str(alias.selectable)
     def visit_column(self, column):
         if self._use_ansi:
             return ansisql.ANSICompiler.visit_column(self, column)
@@ -181,8 +186,8 @@ class OracleCompiler(ansisql.ANSICompiler):
          with autoincrement fields that require they not be present.  so, 
          put them all in for all primary key columns."""
         for c in insert.table.primary_key:
-            if not self.bindparams.has_key(c.key):
-                self.bindparams[c.key] = None
+            if not self.parameters.has_key(c.key):
+                self.parameters[c.key] = None
         return ansisql.ANSICompiler.visit_insert(self, insert)
 
     def visit_select(self, select):
@@ -235,4 +240,4 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
         return self.proxy(str(c), c.get_params()).fetchone()[0]
     
     def visit_sequence(self, seq):
-        return self.exec_default_sql(seq.name + ".nextval")
+        return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0]
index d14d8bfe0801004862f073df2f0486f820ae25ee..a22539cc6fbb82bdab6f44bff879a2d9198c4450 100644 (file)
@@ -577,20 +577,19 @@ class SaveTest(AssertMixin):
         l = m.select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name, keywords.c.name])
         self.assert_result(l, *data)
 
-        print "\n\n\n"
+        print "\n\n\nTESTTESTTEST"
         objects[4].item_name = 'item4updated'
         k = Keyword()
         k.name = 'yellow'
         objects[5].keywords.append(k)
         self.assert_sql(db, lambda:objectstore.commit(), [
-            (
-                "UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id",
+            {
+                "UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id":
                 [{'item_name': 'item4updated', 'items_item_id': objects[4].item_id}]
-            ),
-            (
-                "INSERT INTO keywords (name) VALUES (:name)", 
+            ,
+                "INSERT INTO keywords (name) VALUES (:name)":
                 {'name': 'yellow'}
-            ),
+            },
             ("INSERT INTO itemkeywords (item_id, keyword_id) VALUES (:item_id, :keyword_id)",
             lambda: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}]
             )
index d08b585277f4adaa96093ecb22511685ed7803a7..a5785f29922f5392fde2be47d7a57176ce0b1208 100644 (file)
@@ -33,7 +33,7 @@ def parse_argv():
     elif DBTYPE == 'mysql':
         db = engine.create_engine('mysql://db=test&host=127.0.0.1&user=scott&passwd=tiger', echo=echo)
     elif DBTYPE == 'oracle':
-        db = engine.create_engine('oracle://db=test&host=127.0.0.1&user=scott&passwd=tiger', echo=echo)
+        db = engine.create_engine('oracle://user=scott&password=tiger', echo=echo)
     db = EngineAssert(db)
 
 class PersistTest(unittest.TestCase):
@@ -108,35 +108,57 @@ class EngineAssert(object):
         statement = re.sub(r'\n', '', statement)
 
         if self.assert_list is not None:
-            item = self.assert_list.pop()
+            item = self.assert_list[-1]
+            if not isinstance(item, dict):
+                item = self.assert_list.pop()
+            else:
+                # asserting a dictionary of statements->parameters
+                # this is to specify query assertions where the queries can be in 
+                # multiple orderings
+                if not item.has_key('_converted'):
+                    for key in item.keys():
+                        ckey = self.convert_statement(key)
+                        item[ckey] = item[key]
+                       if ckey != key:
+                            del item[key]
+                    item['_converted'] = True
+                try:
+                    entry = item.pop(statement)
+                    if len(item) == 1:
+                        self.assert_list.pop()
+                    item = (statement, entry)
+                    print "OK ON", statement
+                except KeyError:
+                    self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement))
+
             (query, params) = item
             if callable(params):
                 params = params()
 
-            # deal with paramstyles of different engines
-            paramstyle = self.engine.paramstyle
-            if paramstyle == 'named':
-                pass
-            elif paramstyle =='pyformat':
-                query = re.sub(r':([\w_]+)', r"%(\1)s", query)
-            else:
-                # positional params
-                names = []
-                repl = None
-                if paramstyle=='qmark':
-                    repl = "?"
-                elif paramstyle=='format':
-                    repl = r"%s"
-                elif paramstyle=='numeric':
-                    repl = None
-                counter = 0
-                query = re.sub(r':([\w_]+)', repl, query)
+            query = self.convert_statement(query)
 
             self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
         self.sql_count += 1
         return self.realexec(proxy, compiled, parameters, **kwargs)
 
-
+    def convert_statement(self, query):
+        paramstyle = self.engine.paramstyle
+        if paramstyle == 'named':
+            pass
+        elif paramstyle =='pyformat':
+            query = re.sub(r':([\w_]+)', r"%(\1)s", query)
+        else:
+            # positional params
+            repl = None
+            if paramstyle=='qmark':
+                repl = "?"
+            elif paramstyle=='format':
+                repl = r"%s"
+            elif paramstyle=='numeric':
+                repl = None
+            query = re.sub(r':([\w_]+)', repl, query)
+        return query
+        
 class TTestSuite(unittest.TestSuite):
     """override unittest.TestSuite to provide per-TestCase class setUpAll() and tearDownAll() functionality"""
     def __init__(self, tests=()):