From 4372eafaabd71e0e3935bf4b3931f1a75ac0993c Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 30 Dec 2005 00:23:01 +0000 Subject: [PATCH] catching up oracle to current, some tweaks to unittests to work better with oracle, allow different ordering of expected statements. unittests still dont work completely with oracle due to sequence columns in INSERT statements --- lib/sqlalchemy/databases/oracle.py | 19 +++++---- test/objectstore.py | 13 +++--- test/testbase.py | 64 ++++++++++++++++++++---------- 3 files changed, 61 insertions(+), 35 deletions(-) diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index ae1445759d..04d59a5828 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -110,7 +110,7 @@ class OracleSQLEngine(ansisql.ANSISQLEngine): def schemadropper(self, proxy, **params): return OracleSchemaDropper(proxy, **params) def defaultrunner(self, proxy): - return OracleDefaultRunner(proxy) + return OracleDefaultRunner(self, proxy) def reflecttable(self, table): raise "not implemented" @@ -143,10 +143,10 @@ class OracleCompiler(ansisql.ANSICompiler): """oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False.""" - def __init__(self, engine, statement, bindparams, use_ansi = True, **kwargs): + def __init__(self, engine, statement, parameters, use_ansi = True, **kwargs): self._outertable = None self._use_ansi = use_ansi - ansisql.ANSICompiler.__init__(self, engine, statement, bindparams, **kwargs) + ansisql.ANSICompiler.__init__(self, engine, statement, parameters, **kwargs) def visit_join(self, join): if self._use_ansi: @@ -165,7 +165,12 @@ class OracleCompiler(ansisql.ANSICompiler): join.onclause.accept_visitor(self) self._outertable = outertable - + + def visit_alias(self, alias): + """oracle doesnt like 'FROM table AS alias'. is the AS standard SQL??""" + self.froms[alias] = self.get_from_text(alias.selectable) + " " + alias.name + self.strings[alias] = self.get_str(alias.selectable) + def visit_column(self, column): if self._use_ansi: return ansisql.ANSICompiler.visit_column(self, column) @@ -181,8 +186,8 @@ class OracleCompiler(ansisql.ANSICompiler): with autoincrement fields that require they not be present. so, put them all in for all primary key columns.""" for c in insert.table.primary_key: - if not self.bindparams.has_key(c.key): - self.bindparams[c.key] = None + if not self.parameters.has_key(c.key): + self.parameters[c.key] = None return ansisql.ANSICompiler.visit_insert(self, insert) def visit_select(self, select): @@ -235,4 +240,4 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner): return self.proxy(str(c), c.get_params()).fetchone()[0] def visit_sequence(self, seq): - return self.exec_default_sql(seq.name + ".nextval") + return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0] diff --git a/test/objectstore.py b/test/objectstore.py index d14d8bfe08..a22539cc6f 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -577,20 +577,19 @@ class SaveTest(AssertMixin): l = m.select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name, keywords.c.name]) self.assert_result(l, *data) - print "\n\n\n" + print "\n\n\nTESTTESTTEST" objects[4].item_name = 'item4updated' k = Keyword() k.name = 'yellow' objects[5].keywords.append(k) self.assert_sql(db, lambda:objectstore.commit(), [ - ( - "UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id", + { + "UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id": [{'item_name': 'item4updated', 'items_item_id': objects[4].item_id}] - ), - ( - "INSERT INTO keywords (name) VALUES (:name)", + , + "INSERT INTO keywords (name) VALUES (:name)": {'name': 'yellow'} - ), + }, ("INSERT INTO itemkeywords (item_id, keyword_id) VALUES (:item_id, :keyword_id)", lambda: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}] ) diff --git a/test/testbase.py b/test/testbase.py index d08b585277..a5785f2992 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -33,7 +33,7 @@ def parse_argv(): elif DBTYPE == 'mysql': db = engine.create_engine('mysql://db=test&host=127.0.0.1&user=scott&passwd=tiger', echo=echo) elif DBTYPE == 'oracle': - db = engine.create_engine('oracle://db=test&host=127.0.0.1&user=scott&passwd=tiger', echo=echo) + db = engine.create_engine('oracle://user=scott&password=tiger', echo=echo) db = EngineAssert(db) class PersistTest(unittest.TestCase): @@ -108,35 +108,57 @@ class EngineAssert(object): statement = re.sub(r'\n', '', statement) if self.assert_list is not None: - item = self.assert_list.pop() + item = self.assert_list[-1] + if not isinstance(item, dict): + item = self.assert_list.pop() + else: + # asserting a dictionary of statements->parameters + # this is to specify query assertions where the queries can be in + # multiple orderings + if not item.has_key('_converted'): + for key in item.keys(): + ckey = self.convert_statement(key) + item[ckey] = item[key] + if ckey != key: + del item[key] + item['_converted'] = True + try: + entry = item.pop(statement) + if len(item) == 1: + self.assert_list.pop() + item = (statement, entry) + print "OK ON", statement + except KeyError: + self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)) + (query, params) = item if callable(params): params = params() - # deal with paramstyles of different engines - paramstyle = self.engine.paramstyle - if paramstyle == 'named': - pass - elif paramstyle =='pyformat': - query = re.sub(r':([\w_]+)', r"%(\1)s", query) - else: - # positional params - names = [] - repl = None - if paramstyle=='qmark': - repl = "?" - elif paramstyle=='format': - repl = r"%s" - elif paramstyle=='numeric': - repl = None - counter = 0 - query = re.sub(r':([\w_]+)', repl, query) + query = self.convert_statement(query) self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) self.sql_count += 1 return self.realexec(proxy, compiled, parameters, **kwargs) - + def convert_statement(self, query): + paramstyle = self.engine.paramstyle + if paramstyle == 'named': + pass + elif paramstyle =='pyformat': + query = re.sub(r':([\w_]+)', r"%(\1)s", query) + else: + # positional params + repl = None + if paramstyle=='qmark': + repl = "?" + elif paramstyle=='format': + repl = r"%s" + elif paramstyle=='numeric': + repl = None + query = re.sub(r':([\w_]+)', repl, query) + return query + class TTestSuite(unittest.TestSuite): """override unittest.TestSuite to provide per-TestCase class setUpAll() and tearDownAll() functionality""" def __init__(self, tests=()): -- 2.47.2