]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- implicit returning support. insert() will use RETURNING to get at primary key...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Jul 2009 02:09:54 +0000 (02:09 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Jul 2009 02:09:54 +0000 (02:09 +0000)
21 files changed:
06CHANGES
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mssql/pyodbc.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/pyodbc.py
lib/sqlalchemy/dialects/mysql/zxjdbc.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/pysqlite.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/aaa_profiling/test_zoomark.py
test/dialect/test_postgresql.py
test/sql/test_defaults.py
test/sql/test_query.py
test/sql/test_returning.py
test/sql/test_selectable.py

index ea755655198cd574ab22393e982993d95e9980d1..141f834a26e48960c334f9e4afb3325cd2dec150 100644 (file)
--- a/06CHANGES
+++ b/06CHANGES
 - sql
     - returning() support is native to insert(), update(), delete().  Implementations
       of varying levels of functionality exist for Postgresql, Firebird, MSSQL and
-      Oracle.
+      Oracle.   returning() can be called explicitly with column expressions which
+      are then returned in the resultset, usually via fetchone() or first().
+      
+      insert() constructs will also use RETURNING implicitly to get newly
+      generated primary key values, if the database version in use supports it
+      (a version number check is performed).   This occurs if no end-user
+      returning() was specified.
+      
       
 - engines
     - transaction isolation level may be specified with
index 58fa19f50fd63b85007ef59de3f9c5117e1d788b..1a441eaa6e22ed540d5c933baf29d31a822a01b2 100644 (file)
@@ -253,8 +253,7 @@ class FBCompiler(sql.compiler.SQLCompiler):
 
         return ""
 
-    def returning_clause(self, stmt):
-        returning_cols = stmt._returning
+    def returning_clause(self, stmt, returning_cols):
 
         columns = [
                 self.process(
@@ -312,13 +311,15 @@ class FBDialect(default.DefaultDialect):
     name = 'firebird'
 
     max_identifier_length = 31
+    
     supports_sequences = True
     sequences_optional = False
     supports_default_values = True
-    supports_empty_insert = False
     preexecute_pk_sequences = True
-    supports_pk_autoincrement = False
+    postfetch_lastrowid = False
+    
     requires_name_normalize = True
+    supports_empty_insert = False
 
     statement_compiler = FBCompiler
     ddl_compiler = FBDDLCompiler
@@ -344,7 +345,9 @@ class FBDialect(default.DefaultDialect):
             self.colspecs = {
                 sqltypes.DateTime: sqltypes.DATE
             }
-
+        else:
+            self.implicit_returning = True
+            
     def normalize_name(self, name):
         # Remove trailing spaces: FB uses a CHAR() type,
         # that is padded with spaces
index 7db0d25899c0cfbe64ba7414cbd4ef712cb61b14..f1a7cd9aa40d9d90508f6d4cc4939460145a0581 100644 (file)
@@ -863,7 +863,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
                 self._enable_identity_insert = False
             
             self._select_lastrowid = insert_has_sequence and \
-                                        not self.compiled.statement.returning and \
+                                        not self.compiled.rendered_returning and \
                                         not self._enable_identity_insert and \
                                         not self.executemany
             
@@ -880,14 +880,17 @@ class MSExecutionContext(default.DefaultExecutionContext):
             else:
                 self.cursor.execute("SELECT @@identity AS lastrowid")
             row = self.cursor.fetchall()[0]   # fetchall() ensures the cursor is consumed without closing it
-            self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
+            self._lastrowid = int(row[0])
             
         if self._enable_identity_insert:
             self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
 
         if (self.isinsert or self.isupdate or self.isdelete) and \
-                self.compiled.statement._returning:
+                self.compiled.rendered_returning:
             self._result_proxy = base.FullyBufferedResultProxy(self)
+    
+    def get_lastrowid(self):
+        return self._lastrowid
         
     def handle_dbapi_exception(self, e):
         if self._enable_identity_insert:
@@ -1034,8 +1037,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
                 return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
             return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
 
-    def returning_clause(self, stmt):
-        returning_cols = stmt._returning
+    def returning_clause(self, stmt, returning_cols):
 
         if self.isinsert or self.isupdate:
             target = stmt.table.alias("inserted")
index 550f26e6762e032657f0c855fa895a6a974aa4b7..5c5d2171a996e23f6e2cdc8aab257e4aeeb56e30 100644 (file)
@@ -43,7 +43,7 @@ class MSExecutionContext_pyodbc(MSExecutionContext):
                     # so we need to just keep flipping
                     self.cursor.nextset()
                     
-            self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
+            self._lastrowid = int(row[0])
         else:
             super(MSExecutionContext_pyodbc, self).post_exec()
 
index b325f5ef58377d3e2a67ba858436e1254b4637a6..1c5c251e5441a7c7e1a8f47ab707dbe246f7fcba 100644 (file)
@@ -1182,20 +1182,15 @@ ischema_names = {
 
 class MySQLExecutionContext(default.DefaultExecutionContext):
     def post_exec(self):
-        if self.isinsert and not self.executemany:
-            if (not len(self._last_inserted_ids) or
-                self._last_inserted_ids[0] is None):
-                self._last_inserted_ids = ([self._lastrowid(self.cursor)] +
-                                           self._last_inserted_ids[1:])
-        elif (not self.isupdate and not self.should_autocommit and
+        # TODO: i think this 'charset' in the info thing 
+        # is out
+        
+        if (not self.isupdate and not self.should_autocommit and
               self.statement and SET_RE.match(self.statement)):
             # This misses if a user forces autocommit on text('SET NAMES'),
             # which is probably a programming error anyhow.
             self.connection.info.pop(('mysql', 'charset'), None)
 
-    def _lastrowid(self, cursor):
-        raise NotImplementedError()
-    
     def should_autocommit_text(self, statement):
         return AUTOCOMMIT_RE.match(statement)
 
index b5f7779843816581e37e22e245ad28d5b11123a3..6ecfc4b845c332afb62c52585255b8893031d04a 100644 (file)
@@ -30,9 +30,7 @@ from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
 
 class MySQL_mysqldbExecutionContext(MySQLExecutionContext):
-    def _lastrowid(self, cursor):
-        return cursor.lastrowid
-
+    
     @property
     def rowcount(self):
         if hasattr(self, '_rowcount'):
index 4896173e45108b65ccc35e218fd5dfd4b44184a4..1ea7ec8646d96d168ca9590b39bf5fb3e26c5266 100644 (file)
@@ -5,7 +5,8 @@ from sqlalchemy import util
 import re
 
 class MySQL_pyodbcExecutionContext(MySQLExecutionContext):
-    def _lastrowid(self, cursor):
+
+    def get_lastrowid(self):
         cursor = self.create_cursor()
         cursor.execute("SELECT LAST_INSERT_ID()")
         lastrowid = cursor.fetchone()[0]
index b32b6fe2a1c39777f8b2193204a1d4a412bc66d6..81ad8379cd3b5864b373dc06ad54b5961d184a1b 100644 (file)
@@ -15,10 +15,8 @@ from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
 from sqlalchemy import types as sqltypes, util
 
 class MySQL_jdbcExecutionContext(MySQLExecutionContext):
-    def _real_lastrowid(self, cursor):
-        return cursor.lastrowid
-
-    def _lastrowid(self, cursor):
+    
+    def get_lastrowid(self):
         cursor = self.create_cursor()
         cursor.execute("SELECT LAST_INSERT_ID()")
         lastrowid = cursor.fetchone()[0]
index 882fc05c71988975ac9d495415546afef8e8dc7b..09d18cb19b5905ba51cd782230c56b336866ec4f 100644 (file)
@@ -318,8 +318,7 @@ class OracleCompiler(compiler.SQLCompiler):
         else:
             return self.process(alias.original, **kwargs)
 
-    def returning_clause(self, stmt):
-        returning_cols = stmt._returning
+    def returning_clause(self, stmt, returning_cols):
             
         def create_out_param(col, i):
             bindparam = sql.outparam("ret_%d" % i, type_=col.type)
@@ -473,10 +472,12 @@ class OracleDialect(default.DefaultDialect):
     max_identifier_length = 30
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = False
+
     supports_sequences = True
     sequences_optional = False
     preexecute_pk_sequences = True
-    supports_pk_autoincrement = False
+    postfetch_lastrowid = False
+    
     default_paramstyle = 'named'
     colspecs = colspecs
     ischema_names = ischema_names
@@ -502,6 +503,12 @@ class OracleDialect(default.DefaultDialect):
         self.use_ansi = use_ansi
         self.optimize_limits = optimize_limits
 
+# TODO: implement server_version_info for oracle
+#    def initialize(self, connection):
+#        super(OracleDialect, self).initialize(connection)
+#        self.implicit_returning = self.server_version_info > (10, ) and \
+#                                        self.__dict__.get('implicit_returning', True)
+
     def has_table(self, connection, table_name, schema=None):
         if not schema:
             schema = self.get_default_schema_name(connection)
index 54e4d119e15b8dd2f9b9197940251f25361d178d..c007998ec41f15d463c4b9c5a6f463380509b08b 100644 (file)
@@ -216,7 +216,7 @@ class Oracle_cx_oracleExecutionContext(DefaultExecutionContext):
         
         if hasattr(self, 'out_parameters') and \
             (self.isinsert or self.isupdate or self.isdelete) and \
-                self.compiled.statement._returning:
+                self.compiled.rendered_returning:
                 
             return ReturningResultProxy(self)
         else:
@@ -226,7 +226,7 @@ class ReturningResultProxy(base.FullyBufferedResultProxy):
     """Result proxy which stuffs the _returning clause + outparams into the fetch."""
     
     def _cursor_description(self):
-        returning = self.context.compiled.statement._returning
+        returning = self.context.compiled.returning or self.context.compiled.statement._returning
         
         ret = []
         for c in returning:
@@ -237,7 +237,7 @@ class ReturningResultProxy(base.FullyBufferedResultProxy):
         return ret
     
     def _buffer_rows(self):
-        returning = self.context.compiled.statement._returning
+        returning = self.context.compiled.returning or self.context.compiled.statement._returning
         return [tuple(self.context.out_parameters["ret_%d" % i] for i, c in enumerate(returning))]
 
 class Oracle_cx_oracle(OracleDialect):
index 2b0ebf5f40e02a6dfbe03cab619afce046118513..a865f069bd9e68cef9ba7c63e62bd5abd8e3745e 100644 (file)
@@ -263,8 +263,7 @@ class PGCompiler(compiler.SQLCompiler):
         else:
             return super(PGCompiler, self).for_update_clause(select)
 
-    def returning_clause(self, stmt):
-        returning_cols = stmt._returning
+    def returning_clause(self, stmt, returning_cols):
         
         columns = [
                 self.process(
@@ -449,10 +448,13 @@ class PGDialect(default.DefaultDialect):
     supports_alter = True
     max_identifier_length = 63
     supports_sane_rowcount = True
+    
     supports_sequences = True
     sequences_optional = True
     preexecute_pk_sequences = True
-    supports_pk_autoincrement = False
+    preexecute_autoincrement_sequences = True
+    postfetch_lastrowid = False
+    
     supports_default_values = True
     supports_empty_insert = False
     default_paramstyle = 'pyformat'
@@ -471,6 +473,11 @@ class PGDialect(default.DefaultDialect):
         default.DefaultDialect.__init__(self, **kwargs)
         self.isolation_level = isolation_level
 
+    def initialize(self, connection):
+        super(PGDialect, self).initialize(connection)
+        self.implicit_returning = self.server_version_info > (8, 3) and \
+                                        self.__dict__.get('implicit_returning', True)
+        
     def visit_pool(self, pool):
         if self.isolation_level is not None:
             class SetIsolationLevel(object):
index 0eb4e9ede57bdf2e4609ba7e068038230983236b..a1873f33a85bb1cfbfee08f6c6e227ab4a28630c 100644 (file)
@@ -110,17 +110,9 @@ from sqlalchemy.engine import default
 from sqlalchemy import types as sqltypes
 from sqlalchemy import util
 
-class SQLite_pysqliteExecutionContext(default.DefaultExecutionContext):
-    def post_exec(self):
-        if self.isinsert and not self.executemany:
-            if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
-                self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
-
-
 class SQLite_pysqlite(SQLiteDialect):
     default_paramstyle = 'qmark'
     poolclass = pool.SingletonThreadPool
-    execution_ctx_cls = SQLite_pysqliteExecutionContext
     
     # Py3K
     #description_encoding = None
index 7fff18d0239fbea58dcd886d32cd7805875555be..ba3880ca2b988fa285810782b8324a0b35bbcd3a 100644 (file)
@@ -105,14 +105,25 @@ class Dialect(object):
       executemany.
 
     preexecute_pk_sequences
-      Indicate if the dialect should pre-execute sequences on primary
-      key columns during an INSERT, if it's desired that the new row's
-      primary key be available after execution.
-
-    supports_pk_autoincrement
-      Indicates if the dialect should allow the database to passively assign
-      a primary key column value.
-
+      Indicate if the dialect should pre-execute sequences or default
+      generation functions on primary key columns during an INSERT, if 
+      it's desired that the new row's primary key be available after execution.
+      Pre-execution is disabled if the database supports "returning"
+      and "implicit_returning" is True.
+
+    preexecute_autoincrement_sequences
+      True if 'implicit' primary key functions must be executed separately
+      in order to get their value.   This is currently oriented towards
+      Postgresql.
+      
+    implicit_returning
+      use RETURNING or equivalent during INSERT execution in order to load 
+      newly generated primary keys and other column defaults in one execution,
+      which are then available via last_inserted_ids().
+      If an insert statement has returning() specified explicitly, 
+      the "implicit" functionality is not used and last_inserted_ids()
+      will not be available.
+      
     dbapi_type_map
       A mapping of DB-API type objects present in this Dialect's
       DB-API implementation mapped to TypeEngine implementations used
@@ -1069,11 +1080,14 @@ class Connection(Connectable):
             self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context)
         if context.compiled:
             context.post_exec()
+            if context.isinsert and not context.executemany:
+                context.post_insert()
             
         if context.should_autocommit and not self.in_transaction():
             self._commit_impl()
+            
         return context.get_result_proxy()
-
+        
     def _handle_dbapi_exception(self, e, statement, parameters, cursor, context):
         if getattr(self, '_reentrant_error', False):
             # Py3K
@@ -1608,6 +1622,18 @@ class ResultProxy(object):
 
     @property
     def lastrowid(self):
+        """return the 'lastrowid' accessor on the DBAPI cursor.
+        
+        This is a DBAPI specific method and is only functional
+        for those backends which support it, for statements
+        where it is appropriate.
+        
+        Usage of this method is normally unnecessary; the
+        last_inserted_ids() method provides a
+        tuple of primary key values for a newly inserted row,
+        regardless of database backend.
+        
+        """
         return self.cursor.lastrowid
 
     @property
@@ -1751,8 +1777,7 @@ class ResultProxy(object):
 
         See ExecutionContext for details.
         """
-
-        return self.context.last_inserted_ids()
+        return self.context.last_inserted_ids(self)
 
     def last_updated_params(self):
         """Return ``last_updated_params()`` from the underlying ExecutionContext.
index 3f90baa5c2d5f15eefc98597bd236c75660f5e72..14e73e58862eea702b28ae55354a178595c1304c 100644 (file)
@@ -30,9 +30,14 @@ class DefaultDialect(base.Dialect):
     preparer = compiler.IdentifierPreparer
     defaultrunner = base.DefaultRunner
     supports_alter = True
+
     supports_sequences = False
     sequences_optional = False
-
+    preexecute_pk_sequences = False
+    preexecute_autoincrement_sequences = False
+    postfetch_lastrowid = True
+    implicit_returning = False
+    
     # Py3K
     #supports_unicode_statements = True
     #supports_unicode_binds = True
@@ -45,8 +50,6 @@ class DefaultDialect(base.Dialect):
     max_identifier_length = 9999
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = True
-    preexecute_pk_sequences = False
-    supports_pk_autoincrement = True
     dbapi_type_map = {}
     default_paramstyle = 'named'
     supports_default_values = False
@@ -63,6 +66,7 @@ class DefaultDialect(base.Dialect):
 
     def __init__(self, convert_unicode=False, assert_unicode=False,
                  encoding='utf-8', paramstyle=None, dbapi=None,
+                 implicit_returning=None,
                  label_length=None, **kwargs):
         self.convert_unicode = convert_unicode
         self.assert_unicode = assert_unicode
@@ -76,6 +80,8 @@ class DefaultDialect(base.Dialect):
             self.paramstyle = self.dbapi.paramstyle
         else:
             self.paramstyle = self.default_paramstyle
+        if implicit_returning is not None:
+            self.implicit_returning = implicit_returning
         self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
         self.identifier_preparer = self.preparer(self)
         self.type_compiler = self.type_compiler(self)
@@ -176,6 +182,8 @@ class DefaultDialect(base.Dialect):
 
 
 class DefaultExecutionContext(base.ExecutionContext):
+    _lastrowid = None
+    
     def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None):
         self.dialect = dialect
         self._connection = self.root_connection = connection
@@ -329,6 +337,35 @@ class DefaultExecutionContext(base.ExecutionContext):
 
     def post_exec(self):
         pass
+    
+    def get_lastrowid(self):
+        """return self.cursor.lastrowid, or equivalent, after an INSERT.
+        
+        This may involve calling special cursor functions,
+        issuing a new SELECT on the cursor (or a new one),
+        or returning a stored value that was
+        calculated within post_exec().
+        
+        This function will only be called for dialects
+        which support "implicit" primary key generation,
+        keep preexecute_autoincrement_sequences set to False,
+        and when no explicit id value was bound to the
+        statement.
+        
+        The function is called once, directly after 
+        post_exec() and before the transaction is committed
+        or ResultProxy is generated.   If the post_exec()
+        method assigns a value to `self._lastrowid`, the
+        value is used in place of calling get_lastrowid().
+        
+        Note that this method is *not* equivalent to the
+        ``lastrowid`` method on ``ResultProxy``, which is a
+        direct proxy to the DBAPI ``lastrowid`` accessor
+        in all cases.
+        
+        """
+        
+        return self.cursor.lastrowid
 
     def handle_dbapi_exception(self, e):
         pass
@@ -345,9 +382,34 @@ class DefaultExecutionContext(base.ExecutionContext):
 
     def supports_sane_multi_rowcount(self):
         return self.dialect.supports_sane_multi_rowcount
-
-    def last_inserted_ids(self):
-        return self._last_inserted_ids
+    
+    def post_insert(self):
+        if self.dialect.postfetch_lastrowid and \
+            self._lastrowid is None and \
+            (not len(self._last_inserted_ids) or \
+                        self._last_inserted_ids[0] is None):
+            
+            self._lastrowid = self.get_lastrowid()
+        
+    def last_inserted_ids(self, resultproxy):
+        if not self.isinsert:
+            raise exc.InvalidRequestError("Statement is not an insert() expression construct.")
+            
+        if self.dialect.implicit_returning and \
+                not self.compiled.statement._returning and \
+                not resultproxy.closed:
+
+            row = resultproxy.first()
+
+            self._last_inserted_ids = [v is not None and v or row[c] 
+                for c, v in zip(self.compiled.statement.table.primary_key, self._last_inserted_ids)
+            ]
+            return self._last_inserted_ids
+            
+        elif self._lastrowid is not None:
+            return [self._lastrowid] + self._last_inserted_ids[1:]
+        else:
+            return self._last_inserted_ids
 
     def last_inserted_params(self):
         return self._last_inserted_params
index 66dc84f194c5a9a4aee583f820590aa23ddc413d..79ed44e1ba1f5feed62a6e79e42cffbedf47437b 100644 (file)
@@ -161,7 +161,8 @@ class SQLCompiler(engine.Compiled):
     # level to define if this Compiled instance represents
     # INSERT/UPDATE/DELETE
     isdelete = isinsert = isupdate = False
-
+    rendered_returning = False
+    
     def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
         """Construct a new ``DefaultCompiler`` object.
 
@@ -696,9 +697,10 @@ class SQLCompiler(engine.Compiled):
             text += " (%s)" % ', '.join([preparer.format_column(c[0])
                        for c in colparams])
 
-        if insert_stmt._returning:
-            returning_clause = self.returning_clause(insert_stmt)
-
+        if self.returning or insert_stmt._returning:
+            returning_clause = self.returning_clause(insert_stmt, self.returning or insert_stmt._returning)
+            self.rendered_returning = True
+            
             # cheating
             if returning_clause.startswith("OUTPUT"):
                 text += " " + returning_clause
@@ -708,7 +710,7 @@ class SQLCompiler(engine.Compiled):
             text += " VALUES (%s)" % \
                      ', '.join([c[1] for c in colparams])
         
-        if insert_stmt._returning and returning_clause:
+        if (self.returning or insert_stmt._returning) and returning_clause:
             text += " " + returning_clause
         
         return text
@@ -728,7 +730,8 @@ class SQLCompiler(engine.Compiled):
                 )
 
         if update_stmt._returning:
-            returning_clause = self.returning_clause(update_stmt)
+            returning_clause = self.returning_clause(update_stmt, update_stmt._returning)
+            self.rendered_returning = True
             if returning_clause.startswith("OUTPUT"):
                 text += " " + returning_clause
                 returning_clause = None
@@ -756,7 +759,8 @@ class SQLCompiler(engine.Compiled):
 
         self.postfetch = []
         self.prefetch = []
-
+        self.returning = []
+        
         # no parameters in the statement, no parameters in the
         # compiled params - return binds for all columns
         if self.column_keys is None and stmt.parameters is None:
@@ -785,19 +789,43 @@ class SQLCompiler(engine.Compiled):
                     self.postfetch.append(c)
                     value = self.process(value.self_group())
                 values.append((c, value))
+
             elif isinstance(c, schema.Column):
                 if self.isinsert:
-                    if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline):
-                        if (((isinstance(c.default, schema.Sequence) and
-                              not c.default.optional) or
-                             not self.dialect.supports_pk_autoincrement) or
-                            (c.default is not None and
-                             not isinstance(c.default, schema.Sequence))):
-                            values.append((c, create_bind_param(c, None)))
-                            self.prefetch.append(c)
+                    if c.primary_key and \
+                        (
+                            self.dialect.preexecute_pk_sequences or 
+                            self.dialect.implicit_returning
+                        ) and \
+                        not self.inline and \
+                        not self.statement._returning:
+
+                        if self.dialect.implicit_returning:
+                            if isinstance(c.default, schema.Sequence):
+                                proc = self.process(c.default)
+                                if proc is not None:
+                                    values.append((c, proc))
+                                self.returning.append(c)
+                            elif isinstance(c.default, schema.ColumnDefault) and \
+                                        isinstance(c.default.arg, sql.ClauseElement):
+                                values.append((c, self.process(c.default.arg.self_group())))
+                                self.returning.append(c)
+                            elif c.default is not None:
+                                values.append((c, create_bind_param(c, None)))
+                                self.prefetch.append(c)
+                            else:
+                                self.returning.append(c)
+                        else:
+                            if c.default is not None or \
+                                self.dialect.preexecute_autoincrement_sequences:
+
+                                values.append((c, create_bind_param(c, None)))
+                                self.prefetch.append(c)
+                                
                     elif isinstance(c.default, schema.ColumnDefault):
                         if isinstance(c.default.arg, sql.ClauseElement):
                             values.append((c, self.process(c.default.arg.self_group())))
+                            
                             if not c.primary_key:
                                 # dont add primary key column to postfetch
                                 self.postfetch.append(c)
@@ -835,7 +863,9 @@ class SQLCompiler(engine.Compiled):
         text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
 
         if delete_stmt._returning:
-            returning_clause = self.returning_clause(delete_stmt)
+            returning_clause = self.returning_clause(delete_stmt, delete_stmt._returning)
+            self.rendered_returning = True
+            self.returning = delete_stmt._returning
             if returning_clause.startswith("OUTPUT"):
                 text += " " + returning_clause
                 returning_clause = None
index 83fe2b78d3afaa1e3dc70730f94200e8d364a8ae..a1f6277df87545b9c950f89a6c5419171a600ce5 100644 (file)
@@ -327,7 +327,7 @@ class ZooMarkTest(TestBase):
     def test_profile_1a_populate(self):
         self.test_baseline_1a_populate()
 
-    @profiling.function_call_count(322, {'2.4': 202})
+    @profiling.function_call_count(305, {'2.4': 202})
     def test_profile_2_insert(self):
         self.test_baseline_2_insert()
 
index 33d3eda20c583cfc45c27c672603f8df7a09ee41..3e3884ebe6c72c3cd2b48f6220c7a8287c9d672e 100644 (file)
@@ -1,4 +1,5 @@
 from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
+from sqlalchemy.test import  engines
 import datetime
 from sqlalchemy import *
 from sqlalchemy.orm import *
@@ -110,11 +111,14 @@ class InsertTest(TestBase, AssertsExecutionResults):
     @classmethod
     def setup_class(cls):
         global metadata
+        cls.engine= testing.db
         metadata = MetaData(testing.db)
 
     def teardown(self):
         metadata.drop_all()
         metadata.tables.clear()
+        if self.engine is not testing.db:
+            self.engine.dispose()
 
     def test_compiled_insert(self):
         table = Table('testtable', metadata,
@@ -134,6 +138,13 @@ class InsertTest(TestBase, AssertsExecutionResults):
         metadata.create_all()
         self._assert_data_with_sequence(table, "my_seq")
 
+    def test_sequence_returning_insert(self):
+        table = Table('testtable', metadata,
+            Column('id', Integer, Sequence('my_seq'), primary_key=True),
+            Column('data', String(30)))
+        metadata.create_all()
+        self._assert_data_with_sequence_returning(table, "my_seq")
+
     def test_opt_sequence_insert(self):
         table = Table('testtable', metadata,
             Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True),
@@ -141,6 +152,13 @@ class InsertTest(TestBase, AssertsExecutionResults):
         metadata.create_all()
         self._assert_data_autoincrement(table)
 
+    def test_opt_sequence_returning_insert(self):
+        table = Table('testtable', metadata,
+            Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True),
+            Column('data', String(30)))
+        metadata.create_all()
+        self._assert_data_autoincrement_returning(table)
+
     def test_autoincrement_insert(self):
         table = Table('testtable', metadata,
             Column('id', Integer, primary_key=True),
@@ -148,6 +166,13 @@ class InsertTest(TestBase, AssertsExecutionResults):
         metadata.create_all()
         self._assert_data_autoincrement(table)
 
+    def test_autoincrement_returning_insert(self):
+        table = Table('testtable', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30)))
+        metadata.create_all()
+        self._assert_data_autoincrement_returning(table)
+
     def test_noautoincrement_insert(self):
         table = Table('testtable', metadata,
             Column('id', Integer, primary_key=True, autoincrement=False),
@@ -156,6 +181,9 @@ class InsertTest(TestBase, AssertsExecutionResults):
         self._assert_data_noautoincrement(table)
 
     def _assert_data_autoincrement(self, table):
+        self.engine = engines.testing_engine(options={'implicit_returning':False})
+        metadata.bind = self.engine
+
         def go():
             # execute with explicit id
             r = table.insert().execute({'id':30, 'data':'d1'})
@@ -180,7 +208,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
         # note that the test framework doesnt capture the "preexecute" of a seqeuence
         # or default.  we just see it in the bind params.
 
-        self.assert_sql(testing.db, go, [], with_sequences=[
+        self.assert_sql(self.engine, go, [], with_sequences=[
             (
                 "INSERT INTO testtable (id, data) VALUES (:id, :data)",
                 {'id':30, 'data':'d1'}
@@ -221,7 +249,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
 
         # test the same series of events using a reflected
         # version of the table
-        m2 = MetaData(testing.db)
+        m2 = MetaData(self.engine)
         table = Table(table.name, m2, autoload=True)
 
         def go():
@@ -233,7 +261,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
             table.insert(inline=True).execute({'id':33, 'data':'d7'})
             table.insert(inline=True).execute({'data':'d8'})
 
-        self.assert_sql(testing.db, go, [], with_sequences=[
+        self.assert_sql(self.engine, go, [], with_sequences=[
             (
                 "INSERT INTO testtable (id, data) VALUES (:id, :data)",
                 {'id':30, 'data':'d1'}
@@ -272,7 +300,127 @@ class InsertTest(TestBase, AssertsExecutionResults):
         ]
         table.delete().execute()
 
+    def _assert_data_autoincrement_returning(self, table):
+        self.engine = engines.testing_engine(options={'implicit_returning':True})
+        metadata.bind = self.engine
+
+        def go():
+            # execute with explicit id
+            r = table.insert().execute({'id':30, 'data':'d1'})
+            assert r.last_inserted_ids() == [30]
+
+            # execute with prefetch id
+            r = table.insert().execute({'data':'d2'})
+            assert r.last_inserted_ids() == [1]
+
+            # executemany with explicit ids
+            table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
+
+            # executemany, uses SERIAL
+            table.insert().execute({'data':'d5'}, {'data':'d6'})
+
+            # single execute, explicit id, inline
+            table.insert(inline=True).execute({'id':33, 'data':'d7'})
+
+            # single execute, inline, uses SERIAL
+            table.insert(inline=True).execute({'data':'d8'})
+        
+        self.assert_sql(self.engine, go, [], with_sequences=[
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                {'id':30, 'data':'d1'}
+            ),
+            (
+                "INSERT INTO testtable (data) VALUES (:data) RETURNING testtable.id",
+                {'data': 'd2'}
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}]
+            ),
+            (
+                "INSERT INTO testtable (data) VALUES (:data)",
+                [{'data':'d5'}, {'data':'d6'}]
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                [{'id':33, 'data':'d7'}]
+            ),
+            (
+                "INSERT INTO testtable (data) VALUES (:data)",
+                [{'data':'d8'}]
+            ),
+        ])
+
+        assert table.select().execute().fetchall() == [
+            (30, 'd1'),
+            (1, 'd2'),
+            (31, 'd3'),
+            (32, 'd4'),
+            (2, 'd5'),
+            (3, 'd6'),
+            (33, 'd7'),
+            (4, 'd8'),
+        ]
+        table.delete().execute()
+
+        # test the same series of events using a reflected
+        # version of the table
+        m2 = MetaData(self.engine)
+        table = Table(table.name, m2, autoload=True)
+
+        def go():
+            table.insert().execute({'id':30, 'data':'d1'})
+            r = table.insert().execute({'data':'d2'})
+            assert r.last_inserted_ids() == [5]
+            table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
+            table.insert().execute({'data':'d5'}, {'data':'d6'})
+            table.insert(inline=True).execute({'id':33, 'data':'d7'})
+            table.insert(inline=True).execute({'data':'d8'})
+
+        self.assert_sql(self.engine, go, [], with_sequences=[
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                {'id':30, 'data':'d1'}
+            ),
+            (
+                "INSERT INTO testtable (data) VALUES (:data) RETURNING testtable.id",
+                {'data':'d2'}
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}]
+            ),
+            (
+                "INSERT INTO testtable (data) VALUES (:data)",
+                [{'data':'d5'}, {'data':'d6'}]
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                [{'id':33, 'data':'d7'}]
+            ),
+            (
+                "INSERT INTO testtable (data) VALUES (:data)",
+                [{'data':'d8'}]
+            ),
+        ])
+
+        assert table.select().execute().fetchall() == [
+            (30, 'd1'),
+            (5, 'd2'),
+            (31, 'd3'),
+            (32, 'd4'),
+            (6, 'd5'),
+            (7, 'd6'),
+            (33, 'd7'),
+            (8, 'd8'),
+        ]
+        table.delete().execute()
+
     def _assert_data_with_sequence(self, table, seqname):
+        self.engine = engines.testing_engine(options={'implicit_returning':False})
+        metadata.bind = self.engine
+
         def go():
             table.insert().execute({'id':30, 'data':'d1'})
             table.insert().execute({'data':'d2'})
@@ -281,7 +429,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
             table.insert(inline=True).execute({'id':33, 'data':'d7'})
             table.insert(inline=True).execute({'data':'d8'})
 
-        self.assert_sql(testing.db, go, [], with_sequences=[
+        self.assert_sql(self.engine, go, [], with_sequences=[
             (
                 "INSERT INTO testtable (id, data) VALUES (:id, :data)",
                 {'id':30, 'data':'d1'}
@@ -322,10 +470,66 @@ class InsertTest(TestBase, AssertsExecutionResults):
         # cant test reflection here since the Sequence must be
         # explicitly specified
 
+    def _assert_data_with_sequence_returning(self, table, seqname):
+        self.engine = engines.testing_engine(options={'implicit_returning':True})
+        metadata.bind = self.engine
+
+        def go():
+            table.insert().execute({'id':30, 'data':'d1'})
+            table.insert().execute({'data':'d2'})
+            table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
+            table.insert().execute({'data':'d5'}, {'data':'d6'})
+            table.insert(inline=True).execute({'id':33, 'data':'d7'})
+            table.insert(inline=True).execute({'data':'d8'})
+
+        self.assert_sql(self.engine, go, [], with_sequences=[
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                {'id':30, 'data':'d1'}
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (nextval('my_seq'), :data) RETURNING testtable.id",
+                {'data':'d2'}
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}]
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname,
+                [{'data':'d5'}, {'data':'d6'}]
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+                [{'id':33, 'data':'d7'}]
+            ),
+            (
+                "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname,
+                [{'data':'d8'}]
+            ),
+        ])
+
+        assert table.select().execute().fetchall() == [
+            (30, 'd1'),
+            (1, 'd2'),
+            (31, 'd3'),
+            (32, 'd4'),
+            (2, 'd5'),
+            (3, 'd6'),
+            (33, 'd7'),
+            (4, 'd8'),
+        ]
+
+        # cant test reflection here since the Sequence must be
+        # explicitly specified
+
     def _assert_data_noautoincrement(self, table):
+        self.engine = engines.testing_engine(options={'implicit_returning':False})
+        metadata.bind = self.engine
+
         table.insert().execute({'id':30, 'data':'d1'})
         
-        if testing.db.driver == 'pg8000':
+        if self.engine.driver == 'pg8000':
             exception_cls = exc.ProgrammingError
         else:
             exception_cls = exc.IntegrityError
@@ -350,7 +554,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
 
         # test the same series of events using a reflected
         # version of the table
-        m2 = MetaData(testing.db)
+        m2 = MetaData(self.engine)
         table = Table(table.name, m2, autoload=True)
         table.insert().execute({'id':30, 'data':'d1'})
 
index 3f1c1c10d56651bb10c159a713c01b65317f02d8..f2bc5a53b4b4dc1b8222426422750214ad42c658 100644 (file)
@@ -3,7 +3,7 @@ import datetime
 from sqlalchemy import Sequence, Column, func
 from sqlalchemy.sql import select, text
 import sqlalchemy as sa
-from sqlalchemy.test import testing
+from sqlalchemy.test import testing, engines
 from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean
 from sqlalchemy.test.schema import Table
 from sqlalchemy.test.testing import eq_
@@ -540,16 +540,17 @@ class SequenceTest(testing.TestBase):
     def testseqnonpk(self):
         """test sequences fire off as defaults on non-pk columns"""
 
-        result = sometable.insert().execute(name="somename")
+        engine = engines.testing_engine(options={'implicit_returning':False})
+        result = engine.execute(sometable.insert(), name="somename")
         assert 'id' in result.postfetch_cols()
 
-        result = sometable.insert().execute(name="someother")
+        result = engine.execute(sometable.insert(), name="someother")
         assert 'id' in result.postfetch_cols()
 
         sometable.insert().execute(
             {'name':'name3'},
             {'name':'name4'})
-        eq_(sometable.select().execute().fetchall(),
+        eq_(sometable.select().order_by(sometable.c.id).execute().fetchall(),
             [(1, "somename", 1),
              (2, "someother", 2),
              (3, "name3", 3),
index bbc399aa6ba38846130ea35ce7d4a0fb90234594..37030c94f4577ba099fec907429da10e6d937b67 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy import exc, sql
 from sqlalchemy.engine import default
 from sqlalchemy.test import *
 from sqlalchemy.test.testing import eq_, assert_raises_message
+from sqlalchemy.test.schema import Table, Column
 
 class QueryTest(TestBase):
 
@@ -13,11 +14,11 @@ class QueryTest(TestBase):
         global users, users2, addresses, metadata
         metadata = MetaData(testing.db)
         users = Table('query_users', metadata,
-            Column('user_id', INT, Sequence('user_id_seq', optional=True), primary_key = True),
+            Column('user_id', INT, primary_key=True, test_needs_autoincrement=True),
             Column('user_name', VARCHAR(20)),
         )
         addresses = Table('query_addresses', metadata,
-            Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key=True),
+            Column('address_id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('user_id', Integer, ForeignKey('query_users.user_id')),
             Column('address', String(30)))
             
@@ -59,14 +60,14 @@ class QueryTest(TestBase):
     def test_lastrow_accessor(self):
         """Tests the last_inserted_ids() and lastrow_has_id() functions."""
 
-        def insert_values(table, values):
+        def insert_values(engine, table, values):
             """
             Inserts a row into a table, returns the full list of values
             INSERTed including defaults that fired off on the DB side and
             detects rows that had defaults and post-fetches.
             """
 
-            result = table.insert().execute(**values)
+            result = engine.execute(table.insert(), **values)
             ret = values.copy()
             
             for col, id in zip(table.primary_key, result.last_inserted_ids()):
@@ -74,68 +75,78 @@ class QueryTest(TestBase):
 
             if result.lastrow_has_defaults():
                 criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())])
-                row = table.select(criterion).execute().first()
+                row = engine.execute(table.select(criterion)).first()
                 for c in table.c:
                     ret[c.key] = row[c]
             return ret
 
-        for supported, table, values, assertvalues in [
-            (
-                {'unsupported':['sqlite']},
-                Table("t1", metadata,
-                    Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True),
-                    Column('foo', String(30), primary_key=True)),
-                {'foo':'hi'},
-                {'id':1, 'foo':'hi'}
-            ),
-            (
-                {'unsupported':['sqlite']},
-                Table("t2", metadata,
-                    Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True),
-                    Column('foo', String(30), primary_key=True),
-                    Column('bar', String(30), server_default='hi')
+        if testing.against('firebird', 'postgres', 'oracle', 'mssql'):
+            test_engines = [
+                engines.testing_engine(options={'implicit_returning':False}),
+                engines.testing_engine(options={'implicit_returning':True}),
+            ]
+        else:
+            test_engines = [testing.db]
+            
+        for engine in test_engines:
+            metadata = MetaData()
+            for supported, table, values, assertvalues in [
+                (
+                    {'unsupported':['sqlite']},
+                    Table("t1", metadata,
+                        Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+                        Column('foo', String(30), primary_key=True)),
+                    {'foo':'hi'},
+                    {'id':1, 'foo':'hi'}
                 ),
-                {'foo':'hi'},
-                {'id':1, 'foo':'hi', 'bar':'hi'}
-            ),
-            (
-                {'unsupported':[]},
-                Table("t3", metadata,
-                    Column("id", String(40), primary_key=True),
-                    Column('foo', String(30), primary_key=True),
-                    Column("bar", String(30))
+                (
+                    {'unsupported':['sqlite']},
+                    Table("t2", metadata,
+                        Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+                        Column('foo', String(30), primary_key=True),
+                        Column('bar', String(30), server_default='hi')
                     ),
-                    {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"},
-                    {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}
-            ),
-            (
-                {'unsupported':[]},
-                Table("t4", metadata,
-                    Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True),
-                    Column('foo', String(30), primary_key=True),
-                    Column('bar', String(30), server_default='hi')
+                    {'foo':'hi'},
+                    {'id':1, 'foo':'hi', 'bar':'hi'}
                 ),
-                {'foo':'hi', 'id':1},
-                {'id':1, 'foo':'hi', 'bar':'hi'}
-            ),
-            (
-                {'unsupported':[]},
-                Table("t5", metadata,
-                    Column('id', String(10), primary_key=True),
-                    Column('bar', String(30), server_default='hi')
+                (
+                    {'unsupported':[]},
+                    Table("t3", metadata,
+                        Column("id", String(40), primary_key=True),
+                        Column('foo', String(30), primary_key=True),
+                        Column("bar", String(30))
+                        ),
+                        {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"},
+                        {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}
                 ),
-                {'id':'id1'},
-                {'id':'id1', 'bar':'hi'},
-            ),
-        ]:
-            if testing.db.name in supported['unsupported']:
-                continue
-            try:
-                table.create()
-                i = insert_values(table, values)
-                assert i == assertvalues, repr(i) + " " + repr(assertvalues)
-            finally:
-                table.drop()
+                (
+                    {'unsupported':[]},
+                    Table("t4", metadata,
+                        Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True),
+                        Column('foo', String(30), primary_key=True),
+                        Column('bar', String(30), server_default='hi')
+                    ),
+                    {'foo':'hi', 'id':1},
+                    {'id':1, 'foo':'hi', 'bar':'hi'}
+                ),
+                (
+                    {'unsupported':[]},
+                    Table("t5", metadata,
+                        Column('id', String(10), primary_key=True),
+                        Column('bar', String(30), server_default='hi')
+                    ),
+                    {'id':'id1'},
+                    {'id':'id1', 'bar':'hi'},
+                ),
+            ]:
+                if testing.db.name in supported['unsupported']:
+                    continue
+                try:
+                    table.create(bind=engine, checkfirst=True)
+                    i = insert_values(engine, table, values)
+                    assert i == assertvalues, "tablename: %s %r %r" % (table.name, repr(i), repr(assertvalues))
+                finally:
+                    table.drop(bind=engine)
 
     def test_row_iteration(self):
         users.insert().execute(
index ead61cd418b1e7576040c34cda2f2bdf2b87303c..04cfa4be8f820eb01e2d8d5b359c7eec5e34a61c 100644 (file)
@@ -4,6 +4,7 @@ from sqlalchemy.test import *
 from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.types import TypeDecorator
 
+        
 class ReturningTest(TestBase, AssertsExecutionResults):
     __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access')
 
@@ -30,7 +31,7 @@ class ReturningTest(TestBase, AssertsExecutionResults):
             Column('full', Boolean),
             Column('goofy', GoofyType(50))
         )
-        table.create()
+        table.create(checkfirst=True)
     
     def teardown(self):
         table.drop()
@@ -134,3 +135,24 @@ class ReturningTest(TestBase, AssertsExecutionResults):
 
         result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
         eq_(result2.fetchall(), [(2,False),])
+
+class SequenceReturningTest(TestBase):
+    __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'mssql')
+
+    def setup(self):
+        meta = MetaData(testing.db)
+        global table, seq
+        seq = Sequence('tid_seq')
+        table = Table('tables', meta,
+                    Column('id', Integer, seq, primary_key=True),
+                    Column('data', String(50))
+                )
+        table.create(checkfirst=True)
+
+    def teardown(self):
+        table.drop()
+
+    def test_insert(self):
+        r = table.insert().values(data='hi').returning(table.c.id).execute()
+        assert r.first() == (1, )
+        assert seq.execute() == 2
index a172eb45230596d3f7670529ba89133c8b8cc8b2..670ae1fd0c410f5628ac1eff97b0ac7e60a46dc4 100644 (file)
@@ -422,7 +422,7 @@ class ReduceTest(TestBase, AssertsExecutionResults):
                 'm': page_table.join(magazine_page_table),
                 'c': page_table.join(magazine_page_table).join(classified_page_table),
             }, None, 'page_join')
-            
+        
         eq_(
             util.column_set(sql_util.reduce_columns([pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])),
             util.column_set([pjoin.c.id])