From b48e0147ab03e267f01aa7270172905abe0867df Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 25 Aug 2012 20:08:52 +0000 Subject: [PATCH] - refine oracle returning some more to use purely positional approach --- lib/sqlalchemy/dialects/oracle/base.py | 31 +++++++++++---------- lib/sqlalchemy/dialects/oracle/cx_oracle.py | 13 ++++----- test/dialect/test_oracle.py | 4 +-- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index ff1b8043de..399f900354 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -501,20 +501,23 @@ class OracleCompiler(compiler.SQLCompiler): def returning_clause(self, stmt, returning_cols): - def create_out_param(col, i): - bindparam = sql.outparam("ret_%d" % i, type_=col.type) - self.binds[bindparam.key] = bindparam - return self.bindparam_string(self._truncate_bindparam(bindparam)) - - columnlist = list(expression._select_iterables(returning_cols)) - - columns = [ - self._label_select_column(None, c, True, False, {}, - within_columns_clause=False) - for c in columnlist - ] - - binds = [create_out_param(c, i) for i, c in enumerate(columnlist)] + columns = [] + binds = [] + for i, column in enumerate(expression._select_iterables(returning_cols)): + if column.type._has_column_expression: + col_expr = column.type.column_expression(column) + else: + col_expr = column + outparam = sql.outparam("ret_%d" % i, type_=column.type) + self.binds[outparam.key] = outparam + binds.append(self.bindparam_string(self._truncate_bindparam(outparam))) + columns.append(self.process(col_expr, within_columns_clause=False)) + self.result_map[outparam.key] = ( + outparam.key, + (column, getattr(column, 'name', None), + getattr(column, 'key', None)), + column.type + ) return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index ea1913c384..b6feb426ab 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -354,7 +354,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): "Cannot create out parameter for parameter " "%r - it's type %r is not supported by" " cx_oracle" % - (bindparam.name, bindparam.type) + (bindparam.key, bindparam.type) ) name = self.compiled.bind_names[bindparam] self.out_parameters[name] = self.cursor.var(dbtype) @@ -443,13 +443,10 @@ class ReturningResultProxy(_result.FullyBufferedResultProxy): def _cursor_description(self): returning = self.context.compiled.returning - ret = [] - for c in returning: - if hasattr(c, 'name'): - ret.append((c.name, c.type)) - else: - ret.append((c.anon_label, c.type)) - return ret + return [ + ("ret_%d" % i, None) + for i, col in enumerate(returning) + ] def _buffer_rows(self): return collections.deque([tuple(self._returning_params["ret_%d" % i] diff --git a/test/dialect/test_oracle.py b/test/dialect/test_oracle.py index 0348cd1372..943cffb48e 100644 --- a/test/dialect/test_oracle.py +++ b/test/dialect/test_oracle.py @@ -460,8 +460,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): compiled = stmt.compile(dialect=oracle.dialect()) eq_( compiled.result_map, - {'c3': ('c3', (t1.c.c3, 'c3', 'c3'), t1.c.c3.type), - 'lower': ('lower', (), fn.type)} + {'ret_1': ('ret_1', (t1.c.c3, 'c3', 'c3'), t1.c.c3.type), + 'ret_0': ('ret_0', (fn, 'lower', None), fn.type)} ) self.assert_compile( -- 2.47.3