]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- refine oracle returning some more to use purely positional approach
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Aug 2012 20:08:52 +0000 (20:08 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Aug 2012 20:08:52 +0000 (20:08 +0000)
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
test/dialect/test_oracle.py

index ff1b8043de1d21f97f07996012d43352e78e8fb4..399f9003547bd6bd6068a3d055de1e35b570e372 100644 (file)
@@ -501,20 +501,23 @@ class OracleCompiler(compiler.SQLCompiler):
 
     def returning_clause(self, stmt, returning_cols):
 
-        def create_out_param(col, i):
-            bindparam = sql.outparam("ret_%d" % i, type_=col.type)
-            self.binds[bindparam.key] = bindparam
-            return self.bindparam_string(self._truncate_bindparam(bindparam))
-
-        columnlist = list(expression._select_iterables(returning_cols))
-
-        columns = [
-                self._label_select_column(None, c, True, False, {},
-                                            within_columns_clause=False)
-                for c in columnlist
-            ]
-
-        binds = [create_out_param(c, i) for i, c in enumerate(columnlist)]
+        columns = []
+        binds = []
+        for i, column in enumerate(expression._select_iterables(returning_cols)):
+            if column.type._has_column_expression:
+                col_expr = column.type.column_expression(column)
+            else:
+                col_expr = column
+            outparam = sql.outparam("ret_%d" % i, type_=column.type)
+            self.binds[outparam.key] = outparam
+            binds.append(self.bindparam_string(self._truncate_bindparam(outparam)))
+            columns.append(self.process(col_expr, within_columns_clause=False))
+            self.result_map[outparam.key] = (
+                outparam.key,
+                (column, getattr(column, 'name', None),
+                                        getattr(column, 'key', None)),
+                column.type
+            )
 
         return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
 
index ea1913c3844cea24f2aa386229c69eff87b62670..b6feb426ab8a7910a57d7f5b97277bc605b885bb 100644 (file)
@@ -354,7 +354,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
                                     "Cannot create out parameter for parameter "
                                     "%r - it's type %r is not supported by"
                                     " cx_oracle" %
-                                    (bindparam.name, bindparam.type)
+                                    (bindparam.key, bindparam.type)
                                     )
                     name = self.compiled.bind_names[bindparam]
                     self.out_parameters[name] = self.cursor.var(dbtype)
@@ -443,13 +443,10 @@ class ReturningResultProxy(_result.FullyBufferedResultProxy):
     def _cursor_description(self):
         returning = self.context.compiled.returning
 
-        ret = []
-        for c in returning:
-            if hasattr(c, 'name'):
-                ret.append((c.name, c.type))
-            else:
-                ret.append((c.anon_label, c.type))
-        return ret
+        return [
+            ("ret_%d" % i, None)
+            for i, col in enumerate(returning)
+        ]
 
     def _buffer_rows(self):
         return collections.deque([tuple(self._returning_params["ret_%d" % i]
index 0348cd1372c0e225fb17e38916ee1b0b0aee08f0..943cffb48e3f50aad52c9ee7620625f4fce76f46 100644 (file)
@@ -460,8 +460,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         compiled = stmt.compile(dialect=oracle.dialect())
         eq_(
             compiled.result_map,
-            {'c3': ('c3', (t1.c.c3, 'c3', 'c3'), t1.c.c3.type),
-            'lower': ('lower', (), fn.type)}
+            {'ret_1': ('ret_1', (t1.c.c3, 'c3', 'c3'), t1.c.c3.type),
+            'ret_0': ('ret_0', (fn, 'lower', None), fn.type)}
 
         )
         self.assert_compile(