]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- returning() support is native to insert(), update(), delete(). Implementations
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 19 Jul 2009 02:20:18 +0000 (02:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 19 Jul 2009 02:20:18 +0000 (02:20 +0000)
of varying levels of functionality exist for Postgresql, Firebird, MSSQL and
Oracle.
- MSSQL still has a few glitches that need to be resolved via label/column targeting logic.
- its looking like time to take another look at positional column targeting overall.

14 files changed:
06CHANGES
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/dialect/test_firebird.py
test/dialect/test_mssql.py
test/dialect/test_postgresql.py
test/sql/test_returning.py [new file with mode: 0644]

index 8cadbcb02c2dba63fc5fdafce4e561e922091718..3045721cba4addea6aa348a2d412852e79f03bef 100644 (file)
--- a/06CHANGES
+++ b/06CHANGES
@@ -8,6 +8,11 @@
       on the structure of criteria, so success/failure is deterministic based on 
       code structure.
 
+- sql
+    - returning() support is native to insert(), update(), delete().  Implementations
+      of varying levels of functionality exist for Postgresql, Firebird, MSSQL and
+      Oracle.
+      
 - engines
     - transaction isolation level may be specified with
       create_engine(... isolation_level="..."); available on
index f30749ed72a1628c93e44cbaa8aed3f28c1c79e8..949289eb360b75f4c7acb0d3c64f31e3fbf00f43 100644 (file)
@@ -221,7 +221,7 @@ class FBCompiler(sql.compiler.SQLCompiler):
 
     visit_char_length_func = visit_length_func
 
-    def function_argspec(self, func):
+    def function_argspec(self, func, **kw):
         if func.clauses:
             return self.process(func.clause_expr)
         else:
@@ -253,40 +253,22 @@ class FBCompiler(sql.compiler.SQLCompiler):
 
         return ""
 
-    def _append_returning(self, text, stmt):
-        returning_cols = stmt.kwargs["firebird_returning"]
+    def returning_clause(self, stmt):
+        returning_cols = stmt._returning
+
         def flatten_columnlist(collist):
             for c in collist:
-                if isinstance(c, sql.expression.Selectable):
+                if isinstance(c, expression.Selectable):
                     for co in c.columns:
                         yield co
                 else:
                     yield c
-        columns = [self.process(c, within_columns_clause=True)
-                   for c in flatten_columnlist(returning_cols)]
-        text += ' RETURNING ' + ', '.join(columns)
-        return text
-
-    def visit_update(self, update_stmt):
-        text = super(FBCompiler, self).visit_update(update_stmt)
-        if "firebird_returning" in update_stmt.kwargs:
-            return self._append_returning(text, update_stmt)
-        else:
-            return text
 
-    def visit_insert(self, insert_stmt):
-        text = super(FBCompiler, self).visit_insert(insert_stmt)
-        if "firebird_returning" in insert_stmt.kwargs:
-            return self._append_returning(text, insert_stmt)
-        else:
-            return text
-
-    def visit_delete(self, delete_stmt):
-        text = super(FBCompiler, self).visit_delete(delete_stmt)
-        if "firebird_returning" in delete_stmt.kwargs:
-            return self._append_returning(text, delete_stmt)
-        else:
-            return text
+        columns = [
+                self.process(c, within_columns_clause=True, result_map=self.result_map) 
+                for c in flatten_columnlist(returning_cols)
+            ]
+        return 'RETURNING ' + ', '.join(columns)
 
 
 class FBDDLCompiler(sql.compiler.DDLCompiler):
index 849b72b9793888598c44a95bca44e692dc628fd8..9831b5134bbe2286f410cec5dec9b363cc03d4cd 100644 (file)
@@ -224,7 +224,9 @@ Known Issues
 import datetime, decimal, inspect, operator, sys, re
 
 from sqlalchemy import sql, schema as sa_schema, exc, util
-from sqlalchemy.sql import select, compiler, expression, operators as sql_operators, functions as sql_functions
+from sqlalchemy.sql import select, compiler, expression, \
+                            operators as sql_operators, \
+                            functions as sql_functions, util as sql_util
 from sqlalchemy.engine import default, base, reflection
 from sqlalchemy import types as sqltypes
 from decimal import Decimal as _python_Decimal
@@ -844,6 +846,7 @@ def _table_sequence_column(tbl):
 class MSExecutionContext(default.DefaultExecutionContext):
     _enable_identity_insert = False
     _select_lastrowid = False
+    _result_proxy = None
     
     def pre_exec(self):
         """Activate IDENTITY_INSERT if needed."""
@@ -859,6 +862,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
                 self._enable_identity_insert = False
             
             self._select_lastrowid = insert_has_sequence and \
+                                        not self.compiled.statement.returning and \
                                         not self._enable_identity_insert and \
                                         not self.executemany
             
@@ -880,6 +884,10 @@ class MSExecutionContext(default.DefaultExecutionContext):
         if self._enable_identity_insert:
             self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
 
+        if (self.isinsert or self.isupdate or self.isdelete) and \
+                self.compiled.statement._returning:
+            self._result_proxy = base.FullyBufferedResultProxy(self)
+        
     def handle_dbapi_exception(self, e):
         if self._enable_identity_insert:
             try:
@@ -887,6 +895,8 @@ class MSExecutionContext(default.DefaultExecutionContext):
             except:
                 pass
 
+    def get_result_proxy(self):
+        return self._result_proxy or base.ResultProxy(self)
 
 class MSSQLCompiler(compiler.SQLCompiler):
 
@@ -1023,7 +1033,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
                 return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
             return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
 
-    def visit_insert(self, insert_stmt):
+    def dont_visit_insert(self, insert_stmt):
         insert_select = False
         if insert_stmt.parameters:
             insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)]
@@ -1050,6 +1060,30 @@ class MSSQLCompiler(compiler.SQLCompiler):
         else:
             return super(MSSQLCompiler, self).visit_insert(insert_stmt)
 
+    def returning_clause(self, stmt):
+        returning_cols = stmt._returning
+
+        def flatten_columnlist(collist):
+            for c in collist:
+                if isinstance(c, expression.Selectable):
+                    for co in c.columns:
+                        yield co
+                else:
+                    yield c
+                    
+        if self.isinsert or self.isupdate:
+            target = stmt.table.alias("inserted")
+        else:
+            target = stmt.table.alias("deleted")
+        
+        adapter = sql_util.ClauseAdapter(target)
+        columns = [
+            self.process(adapter.traverse(c), within_columns_clause=True, result_map=self.result_map) 
+            for c in flatten_columnlist(returning_cols)
+        ]
+
+        return 'OUTPUT ' + ', '.join(columns)
+
     def label_select_column(self, select, column, asfrom):
         if isinstance(column, expression.Function):
             return column.label(None)
index 35c85c2c95a422df50e7508352445bf3a060fa57..7c956f6bede3b8bce38b0854745021faba6f4792 100644 (file)
@@ -145,6 +145,9 @@ class LONG(sqltypes.Text):
     __visit_name__ = 'LONG'
     
 class _OracleBoolean(sqltypes.Boolean):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.NUMBER
+    
     def result_processor(self, dialect):
         def process(value):
             if value is None:
@@ -315,6 +318,29 @@ class OracleCompiler(compiler.SQLCompiler):
         else:
             return self.process(alias.original, **kwargs)
 
+    def returning_clause(self, stmt):
+        returning_cols = stmt._returning
+            
+        def flatten_columnlist(collist):
+            for c in collist:
+                if isinstance(c, expression.Selectable):
+                    for co in c.columns:
+                        yield co
+                else:
+                    yield c
+
+        def create_out_param(col, i):
+            bindparam = sql.outparam("ret_%d" % i, type_=col.type)
+            self.binds[bindparam.key] = bindparam
+            return self.bindparam_string(self._truncate_bindparam(bindparam))
+        
+        # within_columns_clause =False so that labels (foo AS bar) don't render
+        columns = [self.process(c, within_columns_clause=False) for c in flatten_columnlist(returning_cols)]
+        
+        binds = [create_out_param(c, i) for i, c in enumerate(flatten_columnlist(returning_cols))]
+        
+        return 'RETURNING ' + ', '.join(columns) +  " INTO " + ", ".join(binds)
+
     def _TODO_visit_compound_select(self, select):
         """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
         pass
@@ -424,7 +450,9 @@ class OracleDDLCompiler(compiler.DDLCompiler):
 
 class OracleDefaultRunner(base.DefaultRunner):
     def visit_sequence(self, seq):
-        return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {})
+        return self.execute_string("SELECT " + 
+                    self.dialect.identifier_preparer.format_sequence(seq) + 
+                    ".nextval FROM DUAL", {})
 
 class OracleIdentifierPreparer(compiler.IdentifierPreparer):
     
index fe74dce7af7cad325c73c6e3397946e89b3a7f62..54e4d119e15b8dd2f9b9197940251f25361d178d 100644 (file)
@@ -213,9 +213,32 @@ class Oracle_cx_oracleExecutionContext(DefaultExecutionContext):
                 type_code = column[1]
                 if type_code in self.dialect.ORACLE_BINARY_TYPES:
                     return base.BufferedColumnResultProxy(self)
+        
+        if hasattr(self, 'out_parameters') and \
+            (self.isinsert or self.isupdate or self.isdelete) and \
+                self.compiled.statement._returning:
+                
+            return ReturningResultProxy(self)
+        else:
+            return base.ResultProxy(self)
 
-        return base.ResultProxy(self)
-
+class ReturningResultProxy(base.FullyBufferedResultProxy):
+    """Result proxy which stuffs the _returning clause + outparams into the fetch."""
+    
+    def _cursor_description(self):
+        returning = self.context.compiled.statement._returning
+        
+        ret = []
+        for c in returning:
+            if hasattr(c, 'key'):
+                ret.append((c.key, c.type))
+            else:
+                ret.append((c.anon_label, c.type))
+        return ret
+    
+    def _buffer_rows(self):
+        returning = self.context.compiled.statement._returning
+        return [tuple(self.context.out_parameters["ret_%d" % i] for i, c in enumerate(returning))]
 
 class Oracle_cx_oracle(OracleDialect):
     execution_ctx_cls = Oracle_cx_oracleExecutionContext
index 1aa96e85242469e35d6d08a2f704645d3c4c2081..849ec500668a9d1d199308239db483033e2a8fdb 100644 (file)
@@ -263,13 +263,9 @@ class PGCompiler(compiler.SQLCompiler):
         else:
             return super(PGCompiler, self).for_update_clause(select)
 
-    def _append_returning(self, text, stmt):
-        try:
-            returning_cols = stmt.kwargs['postgresql_returning']
-        except KeyError:
-            returning_cols = stmt.kwargs['postgres_returning']
-            util.warn_deprecated("The 'postgres_returning' argument has been renamed 'postgresql_returning'")
-            
+    def returning_clause(self, stmt):
+        returning_cols = stmt._returning
+        
         def flatten_columnlist(collist):
             for c in collist:
                 if isinstance(c, expression.Selectable):
@@ -277,23 +273,13 @@ class PGCompiler(compiler.SQLCompiler):
                         yield co
                 else:
                     yield c
-        columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)]
-        text += ' RETURNING ' + ', '.join(columns)
-        return text
-
-    def visit_update(self, update_stmt):
-        text = super(PGCompiler, self).visit_update(update_stmt)
-        if 'postgresql_returning' in update_stmt.kwargs or 'postgres_returning' in update_stmt.kwargs:
-            return self._append_returning(text, update_stmt)
-        else:
-            return text
-
-    def visit_insert(self, insert_stmt):
-        text = super(PGCompiler, self).visit_insert(insert_stmt)
-        if 'postgresql_returning' in insert_stmt.kwargs or 'postgres_returning' in insert_stmt.kwargs:
-            return self._append_returning(text, insert_stmt)
-        else:
-            return text
+    
+        columns = [
+                self.process(c, within_columns_clause=True, result_map=self.result_map) 
+                for c in flatten_columnlist(returning_cols)
+            ]
+            
+        return 'RETURNING ' + ', '.join(columns)
 
     def visit_extract(self, extract, **kwargs):
         field = self.extract_map.get(extract.field, extract.field)
index 470ca811b158c56e5d45328b593651bd6647d5b8..feefb88d269b33d8c325fdf32e47cb1d47e0e3b2 100644 (file)
@@ -1064,6 +1064,7 @@ class Connection(Connectable):
             self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context)
         if context.compiled:
             context.post_exec()
+            
         if context.should_autocommit and not self.in_transaction():
             self._commit_impl()
         return context.get_result_proxy()
@@ -1586,7 +1587,7 @@ class ResultProxy(object):
     """
 
     _process_row = RowProxy
-
+    
     def __init__(self, context):
         self.context = context
         self.dialect = context.dialect
@@ -1607,14 +1608,22 @@ class ResultProxy(object):
     @property
     def out_parameters(self):
         return self.context.out_parameters
-
-    def _init_metadata(self):
+    
+    def _cursor_description(self):
         metadata = self.cursor.description
         if metadata is None:
-            # no results, get rowcount (which requires open cursor on some DB's such as firebird),
-            # then close
+            return
+        else:
+            return [(r[0], r[1]) for r in metadata]
+            
+    def _init_metadata(self):
+        
+        metadata = self._cursor_description()
+        if metadata is None:
+            # no results, get rowcount 
+            # (which requires open cursor on some DB's such as firebird),
             self.rowcount
-            self.close()
+            self.close() # autoclose
             return
 
         self._props = util.populate_column_dict(None)
@@ -1623,8 +1632,7 @@ class ResultProxy(object):
 
         typemap = self.dialect.dbapi_type_map
 
-        for i, item in enumerate(metadata):
-            colname = item[0]
+        for i, (colname, coltype) in enumerate(metadata):
 
             if self.dialect.description_encoding:
                 colname = colname.decode(self.dialect.description_encoding)
@@ -1640,9 +1648,9 @@ class ResultProxy(object):
                 try:
                     (name, obj, type_) = self.context.result_map[colname.lower()]
                 except KeyError:
-                    (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
+                    (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE))
             else:
-                (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
+                (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE))
 
             rec = (type_, type_.dialect_impl(self.dialect).result_processor(self.dialect), i)
 
@@ -1949,8 +1957,44 @@ class BufferedRowResultProxy(ResultProxy):
         return result
 
     def _fetchall_impl(self):
-        return self.__rowbuffer + list(self.cursor.fetchall())
+        ret = self.__rowbuffer + list(self.cursor.fetchall())
+        self.__rowbuffer[:] = []
+        return ret
+
+class FullyBufferedResultProxy(ResultProxy):
+    """A result proxy that buffers rows fully upon creation.
+    
+    Used for operations where a result is to be delivered
+    after the database conversation can not be continued,
+    such as MSSQL INSERT...OUTPUT after an autocommit.
+    
+    """
+    def _init_metadata(self):
+        self.__rowbuffer = self._buffer_rows()
+        super(FullyBufferedResultProxy, self)._init_metadata()
+        
+    def _buffer_rows(self):
+        return self.cursor.fetchall()
+        
+    def _fetchone_impl(self):
+        if self.__rowbuffer:
+            return self.__rowbuffer.pop(0)
+        else:
+            return None
 
+    def _fetchmany_impl(self, size=None):
+        result = []
+        for x in range(0, size):
+            row = self._fetchone_impl()
+            if row is None:
+                break
+            result.append(row)
+        return result
+
+    def _fetchall_impl(self):
+        ret = self.__rowbuffer
+        self.__rowbuffer = []
+        return ret
 
 class BufferedColumnResultProxy(ResultProxy):
     """A ResultProxy with column buffering behavior.
index 5a86d7c94b0882a905e809d59f625c914c086cc1..3f90baa5c2d5f15eefc98597bd236c75660f5e72 100644 (file)
@@ -187,7 +187,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                 self.statement = unicode(compiled).encode(self.dialect.encoding)
             else:
                 self.statement = unicode(compiled)
-            self.isinsert = self.isupdate = self.executemany = False
+            self.isinsert = self.isupdate = self.isdelete = self.executemany = False
             self.should_autocommit = True
             self.result_map = None
             self.cursor = self.create_cursor()
@@ -221,6 +221,7 @@ class DefaultExecutionContext(base.ExecutionContext):
 
             self.isinsert = compiled.isinsert
             self.isupdate = compiled.isupdate
+            self.isdelete = compiled.isdelete
             self.should_autocommit = compiled.statement._autocommit
             if isinstance(compiled.statement, expression._TextClause):
                 self.should_autocommit = self.should_autocommit or self.should_autocommit_text(self.statement)
@@ -246,13 +247,13 @@ class DefaultExecutionContext(base.ExecutionContext):
                 self.statement = statement.encode(self.dialect.encoding)
             else:
                 self.statement = statement
-            self.isinsert = self.isupdate = False
+            self.isinsert = self.isupdate = self.isdelete = False
             self.cursor = self.create_cursor()
             self.should_autocommit = self.should_autocommit_text(statement)
         else:
             # no statement. used for standalone ColumnDefault execution.
             self.statement = self.compiled = None
-            self.isinsert = self.isupdate = self.executemany = self.should_autocommit = False
+            self.isinsert = self.isupdate = self.isdelete = self.executemany = self.should_autocommit = False
             self.cursor = self.create_cursor()
 
     @property
index 30bcc45e5971ad38df7278b12aa4d2708b86c579..b862c8c8114f3970bd1d44467f84d83d5dd3c444 100644 (file)
@@ -672,43 +672,73 @@ class SQLCompiler(engine.Compiled):
     def visit_insert(self, insert_stmt):
         self.isinsert = True
         colparams = self._get_colparams(insert_stmt)
-        preparer = self.preparer
-
-        insert = ' '.join(["INSERT"] +
-                          [self.process(x) for x in insert_stmt._prefixes])
 
         if not colparams and \
                 not self.dialect.supports_default_values and \
                 not self.dialect.supports_empty_insert:
             raise exc.CompileError(
                 "The version of %s you are using does not support empty inserts." % self.dialect.name)
-        elif not colparams and self.dialect.supports_default_values:
-            return (insert + " INTO %s DEFAULT VALUES" % (
-                (preparer.format_table(insert_stmt.table),)))
-        else: 
-            return (insert + " INTO %s (%s) VALUES (%s)" %
-                (preparer.format_table(insert_stmt.table),
-                 ', '.join([preparer.format_column(c[0])
-                           for c in colparams]),
-                 ', '.join([c[1] for c in colparams])))
 
+        preparer = self.preparer
+        supports_default_values = self.dialect.supports_default_values
+        
+        text = "INSERT"
+        
+        prefixes = [self.process(x) for x in insert_stmt._prefixes]
+        if prefixes:
+            text += " " + " ".join(prefixes)
+        
+        text += " INTO " + preparer.format_table(insert_stmt.table)
+         
+        if not colparams and supports_default_values:
+            text += " DEFAULT VALUES"
+        else: 
+            text += " (%s)" % ', '.join([preparer.format_column(c[0])
+                       for c in colparams])
+
+        if insert_stmt._returning:
+            returning_clause = self.returning_clause(insert_stmt)
+
+            # cheating
+            if returning_clause.startswith("OUTPUT"):
+                text += " " + returning_clause
+                returning_clause = None
+                
+        if colparams or not supports_default_values:
+            text += " VALUES (%s)" % \
+                     ', '.join([c[1] for c in colparams])
+        
+        if insert_stmt._returning and returning_clause:
+            text += " " + returning_clause
+        
+        return text
+        
     def visit_update(self, update_stmt):
         self.stack.append({'from': set([update_stmt.table])})
 
         self.isupdate = True
         colparams = self._get_colparams(update_stmt)
 
-        text = ' '.join((
-            "UPDATE",
-            self.preparer.format_table(update_stmt.table),
-            'SET',
-            ', '.join(self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
-                      for c in colparams)
-            ))
+        text = "UPDATE " + self.preparer.format_table(update_stmt.table)
+        
+        text += ' SET ' + \
+                ', '.join(
+                        self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
+                      for c in colparams
+                )
 
+        if update_stmt._returning:
+            returning_clause = self.returning_clause(update_stmt)
+            if returning_clause.startswith("OUTPUT"):
+                text += " " + returning_clause
+                returning_clause = None
+                
         if update_stmt._whereclause:
             text += " WHERE " + self.process(update_stmt._whereclause)
 
+        if update_stmt._returning and returning_clause:
+            text += " " + returning_clause
+            
         self.stack.pop(-1)
 
         return text
@@ -804,9 +834,18 @@ class SQLCompiler(engine.Compiled):
 
         text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
 
+        if delete_stmt._returning:
+            returning_clause = self.returning_clause(delete_stmt)
+            if returning_clause.startswith("OUTPUT"):
+                text += " " + returning_clause
+                returning_clause = None
+                
         if delete_stmt._whereclause:
             text += " WHERE " + self.process(delete_stmt._whereclause)
 
+        if delete_stmt._returning and returning_clause:
+            text += " " + returning_clause
+            
         self.stack.pop(-1)
 
         return text
index fd144a2101b0e7e42194eb5e45f125569c7ab143..142cdcbe5a251638d68f84157c1a082a4afc7451 100644 (file)
@@ -3743,7 +3743,7 @@ class _UpdateBase(ClauseElement):
 
     supports_execution = True
     _autocommit = True
-
+    
     def _generate(self):
         s = self.__class__.__new__(self.__class__)
         s.__dict__ = self.__dict__.copy()
@@ -3771,6 +3771,51 @@ class _UpdateBase(ClauseElement):
         self._bind = bind
     bind = property(bind, _set_bind)
 
+    _returning_re = re.compile(r'(?:firebird|postgres(?:ql)?)_returning')
+    def _process_deprecated_kw(self, kwargs):
+        for k in list(kwargs):
+            m = self._returning_re.match(k)
+            if m:
+                self._returning = kwargs.pop(k)
+                util.warn_deprecated(
+                    "The %r argument is deprecated.  Please use statement.returning(col1, col2, ...)" % k
+                )
+        return kwargs
+    
+    @_generative
+    def returning(self, *cols):
+        """Add a RETURNING or equivalent clause to this statement.
+        
+        The given list of columns represent columns within the table
+        that is the target of the INSERT, UPDATE, or DELETE.  Each 
+        element can be any column expression.  ``Table`` objects
+        will be expanded into their individual columns.
+        
+        Upon compilation, a RETURNING clause, or database equivalent, 
+        will be rendered within the statement.   For INSERT and UPDATE, 
+        the values are the newly inserted/updated values.  For DELETE, 
+        the values are those of the rows which were deleted.
+        
+        Upon execution, the values of the columns to be returned
+        are made available via the result set and can be iterated
+        using ``fetchone()`` and similar.   For DBAPIs which do not
+        natively support returning values (i.e. cx_oracle), 
+        SQLAlchemy will approximate this behavior at the result level
+        so that a reasonable amount of behavioral neutrality is 
+        provided.
+        
+        Note that not all databases/DBAPIs
+        support RETURNING.   For those backends with no support,
+        an exception is raised upon compilation and/or execution.
+        For those who do support it, the functionality across backends
+        varies greatly, including restrictions on executemany()
+        and other statements which return multiple rows. Please 
+        read the documentation notes for the database in use in 
+        order to determine the availability of RETURNING.
+        
+        """
+        self._returning = cols
+        
 class _ValuesBase(_UpdateBase):
 
     __visit_name__ = 'values_base'
@@ -3819,16 +3864,19 @@ class Insert(_ValuesBase):
                 inline=False, 
                 bind=None, 
                 prefixes=None, 
+                returning=None,
                 **kwargs):
         _ValuesBase.__init__(self, table, values)
         self._bind = bind
         self.select = None
         self.inline = inline
+        self._returning = returning
         if prefixes:
             self._prefixes = [_literal_as_text(p) for p in prefixes]
         else:
             self._prefixes = []
-        self.kwargs = kwargs
+            
+        self.kwargs = self._process_deprecated_kw(kwargs)
 
     def get_children(self, **kwargs):
         if self.select is not None:
@@ -3865,15 +3913,18 @@ class Update(_ValuesBase):
                 values=None, 
                 inline=False, 
                 bind=None, 
+                returning=None,
                 **kwargs):
         _ValuesBase.__init__(self, table, values)
         self._bind = bind
+        self._returning = returning
         if whereclause:
             self._whereclause = _literal_as_text(whereclause)
         else:
             self._whereclause = None
         self.inline = inline
-        self.kwargs = kwargs
+
+        self.kwargs = self._process_deprecated_kw(kwargs)
 
     def get_children(self, **kwargs):
         if self._whereclause is not None:
@@ -3907,15 +3958,22 @@ class Delete(_UpdateBase):
 
     __visit_name__ = 'delete'
 
-    def __init__(self, table, whereclause, bind=None, **kwargs):
+    def __init__(self, 
+            table, 
+            whereclause, 
+            bind=None, 
+            returning =None,
+            **kwargs):
         self._bind = bind
         self.table = table
+        self._returning = returning
+        
         if whereclause:
             self._whereclause = _literal_as_text(whereclause)
         else:
             self._whereclause = None
 
-        self.kwargs = kwargs
+        self.kwargs = self._process_deprecated_kw(kwargs)
 
     def get_children(self, **kwargs):
         if self._whereclause is not None:
index 017306691c251822c2949e36aac7a192d34954f5..0c19a4c7e737186353152684e3a817e076acea2c 100644 (file)
@@ -105,14 +105,14 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             column('description', String(128)),
         )
 
-        u = update(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name])
+        u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
         self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING mytable.myid, mytable.name")
 
-        u = update(table1, values=dict(name='foo'), firebird_returning=[table1])
+        u = update(table1, values=dict(name='foo')).returning(table1)
         self.assert_compile(u, "UPDATE mytable SET name=:name "\
             "RETURNING mytable.myid, mytable.name, mytable.description")
 
-        u = update(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)])
+        u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
         self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name)")
 
     def test_insert_returning(self):
@@ -122,87 +122,17 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             column('description', String(128)),
         )
 
-        i = insert(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name])
+        i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING mytable.myid, mytable.name")
 
-        i = insert(table1, values=dict(name='foo'), firebird_returning=[table1])
+        i = insert(table1, values=dict(name='foo')).returning(table1)
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) "\
             "RETURNING mytable.myid, mytable.name, mytable.description")
 
-        i = insert(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)])
+        i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name)")
 
 
-class ReturningTest(TestBase, AssertsExecutionResults):
-    __only_on__ = 'firebird'
-
-    @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
-    def test_update_returning(self):
-        meta = MetaData(testing.db)
-        table = Table('tables', meta,
-            Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
-            Column('persons', Integer),
-            Column('full', Boolean)
-        )
-        table.create()
-        try:
-            table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
-
-            result = table.update(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute()
-            eq_(result.fetchall(), [(1,)])
-
-            result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
-            eq_(result2.fetchall(), [(1,True),(2,False)])
-        finally:
-            table.drop()
-
-    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
-    def test_insert_returning(self):
-        meta = MetaData(testing.db)
-        table = Table('tables', meta,
-            Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
-            Column('persons', Integer),
-            Column('full', Boolean)
-        )
-        table.create()
-        try:
-            result = table.insert(firebird_returning=[table.c.id]).execute({'persons': 1, 'full': False})
-
-            eq_(result.fetchall(), [(1,)])
-
-            # Multiple inserts only return the last row
-            result2 = table.insert(firebird_returning=[table]).execute(
-                 [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
-
-            eq_(result2.fetchall(), [(3,3,True)])
-
-            result3 = table.insert(firebird_returning=[table.c.id]).execute({'persons': 4, 'full': False})
-            eq_([dict(row) for row in result3], [{'id': 4}])
-
-            result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, 1) returning persons')
-            eq_([dict(row) for row in result4], [{'persons': 10}])
-        finally:
-            table.drop()
-
-    @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
-    def test_delete_returning(self):
-        meta = MetaData(testing.db)
-        table = Table('tables', meta,
-            Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
-            Column('persons', Integer),
-            Column('full', Boolean)
-        )
-        table.create()
-        try:
-            table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
-
-            result = table.delete(table.c.persons > 4, firebird_returning=[table.c.id]).execute()
-            eq_(result.fetchall(), [(1,)])
-
-            result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
-            eq_(result2.fetchall(), [(2,False),])
-        finally:
-            table.drop()
 
 
 class MiscTest(TestBase):
index 5bb42d805b5ba36a77567a6b6df10ddbd1163bef..f76e1c9fb8a16170b81809508ac3d3cdab64d4d9 100644 (file)
@@ -158,6 +158,45 @@ class CompileTest(TestBase, AssertsCompiledSQL):
                 select([extract(field, t.c.col1)]),
                 'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % field)
 
+    def test_update_returning(self):
+        table1 = table('mytable',
+            column('myid', Integer),
+            column('name', String(128)),
+            column('description', String(128)),
+        )
+
+        u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
+        self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, inserted.name")
+
+        u = update(table1, values=dict(name='foo')).returning(table1)
+        self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, "
+                            "inserted.name, inserted.description")
+
+        u = update(table1, values=dict(name='foo')).returning(table1).where(table1.c.name=='bar')
+        self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, "
+                            "inserted.name, inserted.description WHERE mytable.name = :name_1")
+        
+        u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+        self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name)")
+
+    def test_insert_returning(self):
+        table1 = table('mytable',
+            column('myid', Integer),
+            column('name', String(128)),
+            column('description', String(128)),
+        )
+
+        i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
+        self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, inserted.name VALUES (:name)")
+
+        i = insert(table1, values=dict(name='foo')).returning(table1)
+        self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, "
+                                "inserted.name, inserted.description VALUES (:name)")
+
+        i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+        self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) VALUES (:name)")
+
+
 
 class IdentityInsertTest(TestBase, AssertsCompiledSQL):
     __only_on__ = 'mssql'
index 19364942ec70c6da676d9660eb73693b0534a309..2b9a687ebf087cfb12ae58184abdcc09ce52fef6 100644 (file)
@@ -32,14 +32,14 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             column('description', String(128)),
         )
 
-        u = update(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name])
+        u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
         self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
 
-        u = update(table1, values=dict(name='foo'), postgresql_returning=[table1])
+        u = update(table1, values=dict(name='foo')).returning(table1)
         self.assert_compile(u, "UPDATE mytable SET name=%(name)s "\
             "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
 
-        u = update(table1, values=dict(name='foo'), postgresql_returning=[func.length(table1.c.name)])
+        u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
         self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect)
 
         
@@ -51,17 +51,17 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             column('description', String(128)),
         )
 
-        i = insert(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name])
+        i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect)
 
-        i = insert(table1, values=dict(name='foo'), postgresql_returning=[table1])
+        i = insert(table1, values=dict(name='foo')).returning(table1)
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) "\
             "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
 
-        i = insert(table1, values=dict(name='foo'), postgresql_returning=[func.length(table1.c.name)])
+        i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect)
     
-    @testing.uses_deprecated(r".*'postgres_returning' argument has been renamed.*")
+    @testing.uses_deprecated(r".*argument is deprecated.  Please use statement.returning.*")
     def test_old_returning_names(self):
         dialect = postgresql.dialect()
         table1 = table('mytable',
@@ -73,6 +73,9 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
         self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
 
+        u = update(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name])
+        self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
+
         i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect)
         
@@ -100,60 +103,6 @@ class CompileTest(TestBase, AssertsCompiledSQL):
                 "SELECT EXTRACT(%s FROM t.col1::timestamp) AS anon_1 "
                 "FROM t" % field)
 
-class ReturningTest(TestBase, AssertsExecutionResults):
-    __only_on__ = 'postgresql'
-
-    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
-    def test_update_returning(self):
-        meta = MetaData(testing.db)
-        table = Table('tables', meta,
-            Column('id', Integer, primary_key=True),
-            Column('persons', Integer),
-            Column('full', Boolean)
-        )
-        table.create()
-        try:
-            table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
-
-            result = table.update(table.c.persons > 4, dict(full=True), postgresql_returning=[table.c.id]).execute()
-            eq_(result.fetchall(), [(1,)])
-
-            result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
-            eq_(result2.fetchall(), [(1,True),(2,False)])
-        finally:
-            table.drop()
-
-    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
-    def test_insert_returning(self):
-        meta = MetaData(testing.db)
-        table = Table('tables', meta,
-            Column('id', Integer, primary_key=True),
-            Column('persons', Integer),
-            Column('full', Boolean)
-        )
-        table.create()
-        try:
-            result = table.insert(postgresql_returning=[table.c.id]).execute({'persons': 1, 'full': False})
-
-            eq_(result.fetchall(), [(1,)])
-
-            @testing.fails_on('postgresql', 'Known limitation of psycopg2')
-            def test_executemany():
-                # return value is documented as failing with psycopg2/executemany
-                result2 = table.insert(postgresql_returning=[table]).execute(
-                     [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
-                eq_(result2.fetchall(), [(2, 2, False), (3,3,True)])
-            
-            test_executemany()
-            
-            result3 = table.insert(postgresql_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False})
-            eq_([dict(row) for row in result3], [{'double_id':8}])
-
-            result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons')
-            eq_([dict(row) for row in result4], [{'persons': 10}])
-        finally:
-            table.drop()
-
 
 class InsertTest(TestBase, AssertsExecutionResults):
     __only_on__ = 'postgresql'
diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py
new file mode 100644 (file)
index 0000000..ead61cd
--- /dev/null
@@ -0,0 +1,136 @@
+from sqlalchemy.test.testing import eq_
+from sqlalchemy import *
+from sqlalchemy.test import *
+from sqlalchemy.test.schema import Table, Column
+from sqlalchemy.types import TypeDecorator
+
+class ReturningTest(TestBase, AssertsExecutionResults):
+    __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access')
+
+    def setup(self):
+        meta = MetaData(testing.db)
+        global table, GoofyType
+        
+        class GoofyType(TypeDecorator):
+            impl = String
+            
+            def process_bind_param(self, value, dialect):
+                if value is None:
+                    return None
+                return "FOO" + value
+
+            def process_result_value(self, value, dialect):
+                if value is None:
+                    return None
+                return value + "BAR"
+            
+        table = Table('tables', meta,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('persons', Integer),
+            Column('full', Boolean),
+            Column('goofy', GoofyType(50))
+        )
+        table.create()
+    
+    def teardown(self):
+        table.drop()
+
+    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_column_targeting(self):
+        result = table.insert().returning(table.c.id, table.c.full).execute({'persons': 1, 'full': False})
+        
+        row = result.first()
+        assert row[table.c.id] == row['id'] == 1
+        assert row[table.c.full] == row['full'] == False
+        
+        result = table.insert().values(persons=5, full=True, goofy="somegoofy").\
+                            returning(table.c.persons, table.c.full, table.c.goofy).execute()
+        row = result.first()
+        assert row[table.c.persons] == row['persons'] == 5
+        assert row[table.c.full] == row['full'] == True
+        assert row[table.c.goofy] == row['goofy'] == "FOOsomegoofyBAR"
+    
+    @testing.fails_on('firebird', "fb can't handle returning x AS y")
+    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_labeling(self):
+        result = table.insert().values(persons=6).\
+                            returning(table.c.persons.label('lala')).execute()
+        row = result.first()
+        assert row['lala'] == 6
+
+    @testing.fails_on('firebird', "fb/kintersbasdb can't handle the bind params")
+    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_anon_expressions(self):
+        result = table.insert().values(goofy="someOTHERgoofy").\
+                            returning(func.lower(table.c.goofy, type_=GoofyType)).execute()
+        row = result.first()
+        assert row[0] == "foosomeothergoofyBAR"
+
+        result = table.insert().values(persons=12).\
+                            returning(table.c.persons + 18).execute()
+        row = result.first()
+        assert row[0] == 30
+        
+    @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_update_returning(self):
+        table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
+
+        result = table.update(table.c.persons > 4, dict(full=True)).returning(table.c.id).execute()
+        eq_(result.fetchall(), [(1,)])
+
+        result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+        eq_(result2.fetchall(), [(1,True),(2,False)])
+
+    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_insert_returning(self):
+        result = table.insert().returning(table.c.id).execute({'persons': 1, 'full': False})
+
+        eq_(result.fetchall(), [(1,)])
+
+        @testing.fails_on('postgresql', '')
+        @testing.fails_on('oracle', '')
+        def test_executemany():
+            # return value is documented as failing with psycopg2/executemany
+            result2 = table.insert().returning(table).execute(
+                 [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
+            
+            if testing.against('firebird', 'mssql'):
+                # Multiple inserts only return the last row
+                eq_(result2.fetchall(), [(3,3,True, None)])
+            else:
+                # nobody does this as far as we know (pg8000?)
+                eq_(result2.fetchall(), [(2, 2, False, None), (3,3,True, None)])
+
+        test_executemany()
+
+        result3 = table.insert().returning(table.c.id).execute({'persons': 4, 'full': False})
+        eq_([dict(row) for row in result3], [{'id': 4}])
+
+    @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    @testing.fails_on_everything_except('postgresql', 'firebird')
+    def test_literal_returning(self):
+        if testing.against("postgresql"):
+            literal_true = "true"
+        else:
+            literal_true = "1"
+
+        result4 = testing.db.execute('insert into tables (id, persons, "full") '
+                                        'values (5, 10, %s) returning persons' % literal_true)
+        eq_([dict(row) for row in result4], [{'persons': 10}])
+
+    @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_delete_returning(self):
+        table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
+
+        result = table.delete(table.c.persons > 4).returning(table.c.id).execute()
+        eq_(result.fetchall(), [(1,)])
+
+        result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+        eq_(result2.fetchall(), [(2,False),])