From: Catherine Devlin Date: Thu, 20 Mar 2008 16:48:46 +0000 (+0000) Subject: Undoing patch #994, for now; more testing needed. Sorry. Also modifying test for... X-Git-Tag: rel_0_4_5~66 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6d5cb2522b0fbb849032a7cdcbfa27baee10c587;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Undoing patch #994, for now; more testing needed. Sorry. Also modifying test for query equivalence to account for underscoring of bind variables. --- diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index df45f69bb4..fc35df2bb7 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -659,16 +659,6 @@ class OracleCompiler(compiler.DefaultCompiler): """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" pass - def create_insert_update_bind(self, col, value): - key = col.key - # TODO: make this check more specific to reserved words - if len(key) < 30: - key += '_' - bindparam = sql.bindparam(key, value, shortname=col.key, type_=col.type) - self.binds[col.key] = bindparam - return self.bindparam_string(self._truncate_bindparam(bindparam)) - - def visit_select(self, select, **kwargs): """Look for ``LIMIT`` and OFFSET in a select statement, and if so tries to wrap it in a subquery with ``row_number()`` criterion. @@ -745,86 +735,3 @@ dialect.schemagenerator = OracleSchemaGenerator dialect.schemadropper = OracleSchemaDropper dialect.preparer = OracleIdentifierPreparer dialect.defaultrunner = OracleDefaultRunner - - -RESERVED_WORDS = util.Set(''' -SHARE -RAW -DROP -BETWEEN -FROM -DESC -OPTION -PRIOR -LONG -THEN -DEFAULT -ALTER -IS -INTO -MINUS -INTEGER -NUMBER -GRANT -IDENTIFIED -ALL -TO -ORDER -ON -FLOAT -DATE -HAVING -CLUSTER -NOWAIT -RESOURCE -ANY -TABLE -INDEX -FOR -UPDATE -WHERE -CHECK -SMALLINT -WITH -DELETE -BY -ASC -REVOKE -LIKE -SIZE -RENAME -NOCOMPRESS -NULL -GROUP -VALUES -AS -IN -VIEW -EXCLUSIVE -COMPRESS -SYNONYM -SELECT -INSERT -EXISTS -NOT -TRIGGER -ELSE -CREATE -INTERSECT -PCTFREE -DISTINCT -CONNECT -SET -MODE -OF -UNIQUE -VARCHAR2 -VARCHAR -LOCK -OR -CHAR -DECIMAL -UNION -PUBLIC -AND -START'''.splitlines()) \ No newline at end of file diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index f716d06f58..dfeefa337c 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -371,7 +371,7 @@ class DefaultExecutionContext(base.ExecutionContext): else: val = drunner.get_column_onupdate(c) if val is not None: - param[self.compiled.binds[c.key].key] = val + param[c.key] = val self.compiled_parameters = params else: @@ -385,15 +385,12 @@ class DefaultExecutionContext(base.ExecutionContext): val = drunner.get_column_onupdate(c) if val is not None: - compiled_parameters[self.compiled.binds[c.key].key] = val + compiled_parameters[c.key] = val if self.isinsert: - self._last_inserted_ids = [ - k and compiled_parameters.get(k.key, None) or None for k in - [self.compiled.binds.get(c.key, None) for c in self.compiled.statement.table.primary_key] - ] - self._last_inserted_params = dict([(key, compiled_parameters[self.compiled.bind_names[b]]) for key, b in self.compiled.binds.iteritems()]) + self._last_inserted_ids = [compiled_parameters.get(c.key, None) for c in self.compiled.statement.table.primary_key] + self._last_inserted_params = compiled_parameters else: - self._last_updated_params = dict([(key, compiled_parameters[self.compiled.bind_names[b]]) for key, b in self.compiled.binds.iteritems()]) + self._last_updated_params = compiled_parameters self.postfetch_cols = self.compiled.postfetch diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 4a45d6c153..6a048a7809 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -649,11 +649,6 @@ class DefaultCompiler(engine.Compiled): self.stack.pop(-1) return text - - def create_insert_update_bind(self, col, value): - bindparam = sql.bindparam(col.key, value, type_=col.type) - self.binds[col.key] = bindparam - return self.bindparam_string(self._truncate_bindparam(bindparam)) def _get_colparams(self, stmt): """create a set of tuples representing column/string pairs for use @@ -661,13 +656,18 @@ class DefaultCompiler(engine.Compiled): """ + def create_bind_param(col, value): + bindparam = sql.bindparam(col.key, value, type_=col.type) + self.binds[col.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + self.postfetch = [] self.prefetch = [] # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: - return [(c, self.create_insert_update_bind(c, None)) for c in stmt.table.columns] + return [(c, create_bind_param(c, None)) for c in stmt.table.columns] # if we have statement parameters - set defaults in the # compiled params @@ -686,7 +686,7 @@ class DefaultCompiler(engine.Compiled): if c.key in parameters: value = parameters[c.key] if sql._is_literal(value): - value = self.create_insert_update_bind(c, value) + value = create_bind_param(c, value) else: self.postfetch.append(c) value = self.process(value.self_group()) @@ -699,7 +699,7 @@ class DefaultCompiler(engine.Compiled): not self.dialect.supports_pk_autoincrement) or (c.default is not None and not isinstance(c.default, schema.Sequence))): - values.append((c, self.create_insert_update_bind(c, None))) + values.append((c, create_bind_param(c, None))) self.prefetch.append(c) elif isinstance(c.default, schema.ColumnDefault): if isinstance(c.default.arg, sql.ClauseElement): @@ -708,7 +708,7 @@ class DefaultCompiler(engine.Compiled): # dont add primary key column to postfetch self.postfetch.append(c) else: - values.append((c, self.create_insert_update_bind(c, None))) + values.append((c, create_bind_param(c, None))) self.prefetch.append(c) elif isinstance(c.default, schema.PassiveDefault): if not c.primary_key: @@ -725,7 +725,7 @@ class DefaultCompiler(engine.Compiled): values.append((c, self.process(c.onupdate.arg.self_group()))) self.postfetch.append(c) else: - values.append((c, self.create_insert_update_bind(c, None))) + values.append((c, create_bind_param(c, None))) self.prefetch.append(c) elif isinstance(c.onupdate, schema.PassiveDefault): self.postfetch.append(c) diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 03a8e6539c..08c3c9192b 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -370,6 +370,7 @@ class ExecutionContextWrapper(object): def __setattr__(self, key, value): setattr(self.ctx, key, value) + trailing_underscore_pattern = re.compile(r'(\W:[\w_#]+)_\b',re.MULTILINE) def post_execution(self): ctx = self.ctx statement = unicode(ctx.compiled) @@ -412,7 +413,15 @@ class ExecutionContextWrapper(object): parameters = ctx.compiled_parameters query = self.convert_statement(query) - testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) + equivalent = ( (statement == query) + or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) ) + ) \ + and \ + ( (params is None) or (params == parameters) + or params == [dict((k.strip('_'), v) for (k, v) in p.items())for p in parameters] + ) + testdata.unittest.assert_(equivalent, + "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) testdata.sql_count += 1 self.ctx.post_execution()