]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fixed up paramstyle translation
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Dec 2005 18:45:58 +0000 (18:45 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Dec 2005 18:45:58 +0000 (18:45 +0000)
test/testbase.py

index a42cf31bf7330d239cffae09faf287a23cbf143f..875e0d0b2bcd3ff5bdd3611a1f55ebb37e327dc4 100644 (file)
@@ -2,13 +2,39 @@ import unittest
 import StringIO
 import sqlalchemy.engine as engine
 import re, sys
+import sqlalchemy.databases.sqlite as sqlite
 import sqlalchemy.databases.postgres as postgres
-import sqlalchemy.databases.mysql as mysql
+#import sqlalchemy.databases.mysql as mysql
 
 echo = True
 db = None
 
 
+def parse_argv():
+    # we are using the unittest main runner, so we are just popping out the 
+    # arguments we need instead of using our own getopt type of thing
+    if len(sys.argv) >= 3:
+        if sys.argv[1] == '--db':
+            (param, DBTYPE) = (sys.argv.pop(1), sys.argv.pop(1))
+    else:
+        DBTYPE = 'sqlite'
+
+    global db
+    if DBTYPE == 'sqlite':
+        try:
+            db = engine.create_engine('sqlite://filename=:memory:', echo = echo)
+        except:
+            raise "Could not create sqlite engine.  specify --db <sqlite|sqlite_file|postgres|mysql|oracle> to test runner."
+    elif DBTYPE == 'sqlite_file':
+        db = engine.create_engine('sqlite://filename=querytest.db', echo = echo)
+    elif DBTYPE == 'postgres':
+        db = engine.create_engine('postgres://database=test&host=127.0.0.1&user=scott&password=tiger', echo=echo)
+    elif DBTYPE == 'mysql':
+        db = engine.create_engine('mysql://db=test&host=127.0.0.1&user=scott&passwd=tiger', echo=echo)
+    elif DBTYPE == 'oracle':
+        db = engine.create_engine('oracle://db=test&host=127.0.0.1&user=scott&passwd=tiger', echo=echo)
+    db = EngineAssert(db)
+
 class PersistTest(unittest.TestCase):
     """persist base class, provides default setUpAll, tearDownAll and echo functionality"""
     def __init__(self, *args, **params):
@@ -75,18 +101,41 @@ class EngineAssert(object):
             (query, params) = item
             if callable(params):
                 params = params()
-            
-            # TODO: standardize this to param styles instaed of checking dbengine types
-            if isinstance(self.engine, postgres.PGSQLEngine):
+
+            # deal with paramstyles of different engines
+            if isinstance(self.engine, sqlite.SQLiteSQLEngine):
+                paramstyle = 'named'
+            else:
+                db = self.engine.dbapi()
+                if db is not None:
+                    paramstyle = db.paramstyle
+                else:
+                    paramstyle = 'named'
+            if paramstyle == 'named':
+                pass
+            elif paramstyle =='pyformat':
                 query = re.sub(r':([\w_]+)', r"%(\1)s", query)
-            elif isinstance(self.engine, mysql.MySQLEngine):
+            else:
+                # positional params
                 names = []
+                repl = None
+                if paramstyle=='qmark':
+                    repl = "?"
+                elif paramstyle=='format':
+                    repl = r"%s"
+                elif paramstyle=='numeric':
+                    repl = None
+                counter = 0
                 def append_arg(match):
                     names.append(match.group(1))
-                    return r"%s"
-                    
+                    if repl is None:
+                        counter += 1
+                        return counter
+                    else:
+                        return repl
+                # substitute bind string in query, translate bind param
+                # dict to a list (or a list of dicts to a list of lists)
                 query = re.sub(r':([\w_]+)', append_arg, query)
-                
                 if isinstance(params, list):
                     args = []
                     for p in params:
@@ -99,7 +148,7 @@ class EngineAssert(object):
                     for n in names:
                         args.append(params[n])
                 params = args
-                
+            
             self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
         return self.realexec(statement, parameters, **kwargs)
 
@@ -155,24 +204,7 @@ class TTestSuite(unittest.TestSuite):
 
 unittest.TestLoader.suiteClass = TTestSuite
 
-if len(sys.argv) >= 3:
-    (param, DBTYPE) = (sys.argv.pop(1), sys.argv.pop(1))
-else:
-    (param, DBTYPE) = None, None
-if (param != '--db'):
-    raise "--db <sqlite|postgres|oracle|sqlite_file> param required"
-        
-if DBTYPE == 'sqlite':
-    db = engine.create_engine('sqlite://filename=:memory:', echo = echo)
-elif DBTYPE == 'sqlite_file':
-    db = engine.create_engine('sqlite://filename=querytest.db', echo = echo)
-elif DBTYPE == 'postgres':
-    db = engine.create_engine('postgres://database=test&host=127.0.0.1&user=scott&password=tiger', echo=echo)
-elif DBTYPE == 'mysql':
-    db = engine.create_engine('mysql://db=test&host=127.0.0.1&user=scott&passwd=tiger', echo=echo)
-elif DBTYPE == 'oracle':
-    db = engine.create_engine('oracle://db=test&host=127.0.0.1&user=scott&passwd=tiger', echo=echo)
-db = EngineAssert(db)
+parse_argv()
 
                     
 def runTests(suite):