From: Mike Bayer Date: Mon, 27 Jul 2009 02:09:54 +0000 (+0000) Subject: - implicit returning support. insert() will use RETURNING to get at primary key... X-Git-Tag: rel_0_6_6~63 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2ef368aa54dfabb871d17f0d4ac64d698182e95b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - implicit returning support. insert() will use RETURNING to get at primary key values generated via sequence or default if the dialect detects the feature being availble. works for fb+ pg, needs work for oracle, mssql --- diff --git a/06CHANGES b/06CHANGES index ea75565519..141f834a26 100644 --- a/06CHANGES +++ b/06CHANGES @@ -13,7 +13,14 @@ - sql - returning() support is native to insert(), update(), delete(). Implementations of varying levels of functionality exist for Postgresql, Firebird, MSSQL and - Oracle. + Oracle. returning() can be called explicitly with column expressions which + are then returned in the resultset, usually via fetchone() or first(). + + insert() constructs will also use RETURNING implicitly to get newly + generated primary key values, if the database version in use supports it + (a version number check is performed). This occurs if no end-user + returning() was specified. + - engines - transaction isolation level may be specified with diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 58fa19f50f..1a441eaa6e 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -253,8 +253,7 @@ class FBCompiler(sql.compiler.SQLCompiler): return "" - def returning_clause(self, stmt): - returning_cols = stmt._returning + def returning_clause(self, stmt, returning_cols): columns = [ self.process( @@ -312,13 +311,15 @@ class FBDialect(default.DefaultDialect): name = 'firebird' max_identifier_length = 31 + supports_sequences = True sequences_optional = False supports_default_values = True - supports_empty_insert = False preexecute_pk_sequences = True - supports_pk_autoincrement = False + postfetch_lastrowid = False + requires_name_normalize = True + supports_empty_insert = False statement_compiler = FBCompiler ddl_compiler = FBDDLCompiler @@ -344,7 +345,9 @@ class FBDialect(default.DefaultDialect): self.colspecs = { sqltypes.DateTime: sqltypes.DATE } - + else: + self.implicit_returning = True + def normalize_name(self, name): # Remove trailing spaces: FB uses a CHAR() type, # that is padded with spaces diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 7db0d25899..f1a7cd9aa4 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -863,7 +863,7 @@ class MSExecutionContext(default.DefaultExecutionContext): self._enable_identity_insert = False self._select_lastrowid = insert_has_sequence and \ - not self.compiled.statement.returning and \ + not self.compiled.rendered_returning and \ not self._enable_identity_insert and \ not self.executemany @@ -880,14 +880,17 @@ class MSExecutionContext(default.DefaultExecutionContext): else: self.cursor.execute("SELECT @@identity AS lastrowid") row = self.cursor.fetchall()[0] # fetchall() ensures the cursor is consumed without closing it - self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:] + self._lastrowid = int(row[0]) if self._enable_identity_insert: self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) if (self.isinsert or self.isupdate or self.isdelete) and \ - self.compiled.statement._returning: + self.compiled.rendered_returning: self._result_proxy = base.FullyBufferedResultProxy(self) + + def get_lastrowid(self): + return self._lastrowid def handle_dbapi_exception(self, e): if self._enable_identity_insert: @@ -1034,8 +1037,7 @@ class MSSQLCompiler(compiler.SQLCompiler): return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs) return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) - def returning_clause(self, stmt): - returning_cols = stmt._returning + def returning_clause(self, stmt, returning_cols): if self.isinsert or self.isupdate: target = stmt.table.alias("inserted") diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 550f26e676..5c5d2171a9 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -43,7 +43,7 @@ class MSExecutionContext_pyodbc(MSExecutionContext): # so we need to just keep flipping self.cursor.nextset() - self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:] + self._lastrowid = int(row[0]) else: super(MSExecutionContext_pyodbc, self).post_exec() diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index b325f5ef58..1c5c251e54 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1182,20 +1182,15 @@ ischema_names = { class MySQLExecutionContext(default.DefaultExecutionContext): def post_exec(self): - if self.isinsert and not self.executemany: - if (not len(self._last_inserted_ids) or - self._last_inserted_ids[0] is None): - self._last_inserted_ids = ([self._lastrowid(self.cursor)] + - self._last_inserted_ids[1:]) - elif (not self.isupdate and not self.should_autocommit and + # TODO: i think this 'charset' in the info thing + # is out + + if (not self.isupdate and not self.should_autocommit and self.statement and SET_RE.match(self.statement)): # This misses if a user forces autocommit on text('SET NAMES'), # which is probably a programming error anyhow. self.connection.info.pop(('mysql', 'charset'), None) - def _lastrowid(self, cursor): - raise NotImplementedError() - def should_autocommit_text(self, statement): return AUTOCOMMIT_RE.match(statement) diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index b5f7779843..6ecfc4b845 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -30,9 +30,7 @@ from sqlalchemy.sql import operators as sql_operators from sqlalchemy import exc, log, schema, sql, types as sqltypes, util class MySQL_mysqldbExecutionContext(MySQLExecutionContext): - def _lastrowid(self, cursor): - return cursor.lastrowid - + @property def rowcount(self): if hasattr(self, '_rowcount'): diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index 4896173e45..1ea7ec8646 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -5,7 +5,8 @@ from sqlalchemy import util import re class MySQL_pyodbcExecutionContext(MySQLExecutionContext): - def _lastrowid(self, cursor): + + def get_lastrowid(self): cursor = self.create_cursor() cursor.execute("SELECT LAST_INSERT_ID()") lastrowid = cursor.fetchone()[0] diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py index b32b6fe2a1..81ad8379cd 100644 --- a/lib/sqlalchemy/dialects/mysql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/mysql/zxjdbc.py @@ -15,10 +15,8 @@ from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector from sqlalchemy import types as sqltypes, util class MySQL_jdbcExecutionContext(MySQLExecutionContext): - def _real_lastrowid(self, cursor): - return cursor.lastrowid - - def _lastrowid(self, cursor): + + def get_lastrowid(self): cursor = self.create_cursor() cursor.execute("SELECT LAST_INSERT_ID()") lastrowid = cursor.fetchone()[0] diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 882fc05c71..09d18cb19b 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -318,8 +318,7 @@ class OracleCompiler(compiler.SQLCompiler): else: return self.process(alias.original, **kwargs) - def returning_clause(self, stmt): - returning_cols = stmt._returning + def returning_clause(self, stmt, returning_cols): def create_out_param(col, i): bindparam = sql.outparam("ret_%d" % i, type_=col.type) @@ -473,10 +472,12 @@ class OracleDialect(default.DefaultDialect): max_identifier_length = 30 supports_sane_rowcount = True supports_sane_multi_rowcount = False + supports_sequences = True sequences_optional = False preexecute_pk_sequences = True - supports_pk_autoincrement = False + postfetch_lastrowid = False + default_paramstyle = 'named' colspecs = colspecs ischema_names = ischema_names @@ -502,6 +503,12 @@ class OracleDialect(default.DefaultDialect): self.use_ansi = use_ansi self.optimize_limits = optimize_limits +# TODO: implement server_version_info for oracle +# def initialize(self, connection): +# super(OracleDialect, self).initialize(connection) +# self.implicit_returning = self.server_version_info > (10, ) and \ +# self.__dict__.get('implicit_returning', True) + def has_table(self, connection, table_name, schema=None): if not schema: schema = self.get_default_schema_name(connection) diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 54e4d119e1..c007998ec4 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -216,7 +216,7 @@ class Oracle_cx_oracleExecutionContext(DefaultExecutionContext): if hasattr(self, 'out_parameters') and \ (self.isinsert or self.isupdate or self.isdelete) and \ - self.compiled.statement._returning: + self.compiled.rendered_returning: return ReturningResultProxy(self) else: @@ -226,7 +226,7 @@ class ReturningResultProxy(base.FullyBufferedResultProxy): """Result proxy which stuffs the _returning clause + outparams into the fetch.""" def _cursor_description(self): - returning = self.context.compiled.statement._returning + returning = self.context.compiled.returning or self.context.compiled.statement._returning ret = [] for c in returning: @@ -237,7 +237,7 @@ class ReturningResultProxy(base.FullyBufferedResultProxy): return ret def _buffer_rows(self): - returning = self.context.compiled.statement._returning + returning = self.context.compiled.returning or self.context.compiled.statement._returning return [tuple(self.context.out_parameters["ret_%d" % i] for i, c in enumerate(returning))] class Oracle_cx_oracle(OracleDialect): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 2b0ebf5f40..a865f069bd 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -263,8 +263,7 @@ class PGCompiler(compiler.SQLCompiler): else: return super(PGCompiler, self).for_update_clause(select) - def returning_clause(self, stmt): - returning_cols = stmt._returning + def returning_clause(self, stmt, returning_cols): columns = [ self.process( @@ -449,10 +448,13 @@ class PGDialect(default.DefaultDialect): supports_alter = True max_identifier_length = 63 supports_sane_rowcount = True + supports_sequences = True sequences_optional = True preexecute_pk_sequences = True - supports_pk_autoincrement = False + preexecute_autoincrement_sequences = True + postfetch_lastrowid = False + supports_default_values = True supports_empty_insert = False default_paramstyle = 'pyformat' @@ -471,6 +473,11 @@ class PGDialect(default.DefaultDialect): default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level + def initialize(self, connection): + super(PGDialect, self).initialize(connection) + self.implicit_returning = self.server_version_info > (8, 3) and \ + self.__dict__.get('implicit_returning', True) + def visit_pool(self, pool): if self.isolation_level is not None: class SetIsolationLevel(object): diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 0eb4e9ede5..a1873f33a8 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -110,17 +110,9 @@ from sqlalchemy.engine import default from sqlalchemy import types as sqltypes from sqlalchemy import util -class SQLite_pysqliteExecutionContext(default.DefaultExecutionContext): - def post_exec(self): - if self.isinsert and not self.executemany: - if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: - self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - - class SQLite_pysqlite(SQLiteDialect): default_paramstyle = 'qmark' poolclass = pool.SingletonThreadPool - execution_ctx_cls = SQLite_pysqliteExecutionContext # Py3K #description_encoding = None diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 7fff18d023..ba3880ca2b 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -105,14 +105,25 @@ class Dialect(object): executemany. preexecute_pk_sequences - Indicate if the dialect should pre-execute sequences on primary - key columns during an INSERT, if it's desired that the new row's - primary key be available after execution. - - supports_pk_autoincrement - Indicates if the dialect should allow the database to passively assign - a primary key column value. - + Indicate if the dialect should pre-execute sequences or default + generation functions on primary key columns during an INSERT, if + it's desired that the new row's primary key be available after execution. + Pre-execution is disabled if the database supports "returning" + and "implicit_returning" is True. + + preexecute_autoincrement_sequences + True if 'implicit' primary key functions must be executed separately + in order to get their value. This is currently oriented towards + Postgresql. + + implicit_returning + use RETURNING or equivalent during INSERT execution in order to load + newly generated primary keys and other column defaults in one execution, + which are then available via last_inserted_ids(). + If an insert statement has returning() specified explicitly, + the "implicit" functionality is not used and last_inserted_ids() + will not be available. + dbapi_type_map A mapping of DB-API type objects present in this Dialect's DB-API implementation mapped to TypeEngine implementations used @@ -1069,11 +1080,14 @@ class Connection(Connectable): self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context) if context.compiled: context.post_exec() + if context.isinsert and not context.executemany: + context.post_insert() if context.should_autocommit and not self.in_transaction(): self._commit_impl() + return context.get_result_proxy() - + def _handle_dbapi_exception(self, e, statement, parameters, cursor, context): if getattr(self, '_reentrant_error', False): # Py3K @@ -1608,6 +1622,18 @@ class ResultProxy(object): @property def lastrowid(self): + """return the 'lastrowid' accessor on the DBAPI cursor. + + This is a DBAPI specific method and is only functional + for those backends which support it, for statements + where it is appropriate. + + Usage of this method is normally unnecessary; the + last_inserted_ids() method provides a + tuple of primary key values for a newly inserted row, + regardless of database backend. + + """ return self.cursor.lastrowid @property @@ -1751,8 +1777,7 @@ class ResultProxy(object): See ExecutionContext for details. """ - - return self.context.last_inserted_ids() + return self.context.last_inserted_ids(self) def last_updated_params(self): """Return ``last_updated_params()`` from the underlying ExecutionContext. diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 3f90baa5c2..14e73e5886 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -30,9 +30,14 @@ class DefaultDialect(base.Dialect): preparer = compiler.IdentifierPreparer defaultrunner = base.DefaultRunner supports_alter = True + supports_sequences = False sequences_optional = False - + preexecute_pk_sequences = False + preexecute_autoincrement_sequences = False + postfetch_lastrowid = True + implicit_returning = False + # Py3K #supports_unicode_statements = True #supports_unicode_binds = True @@ -45,8 +50,6 @@ class DefaultDialect(base.Dialect): max_identifier_length = 9999 supports_sane_rowcount = True supports_sane_multi_rowcount = True - preexecute_pk_sequences = False - supports_pk_autoincrement = True dbapi_type_map = {} default_paramstyle = 'named' supports_default_values = False @@ -63,6 +66,7 @@ class DefaultDialect(base.Dialect): def __init__(self, convert_unicode=False, assert_unicode=False, encoding='utf-8', paramstyle=None, dbapi=None, + implicit_returning=None, label_length=None, **kwargs): self.convert_unicode = convert_unicode self.assert_unicode = assert_unicode @@ -76,6 +80,8 @@ class DefaultDialect(base.Dialect): self.paramstyle = self.dbapi.paramstyle else: self.paramstyle = self.default_paramstyle + if implicit_returning is not None: + self.implicit_returning = implicit_returning self.positional = self.paramstyle in ('qmark', 'format', 'numeric') self.identifier_preparer = self.preparer(self) self.type_compiler = self.type_compiler(self) @@ -176,6 +182,8 @@ class DefaultDialect(base.Dialect): class DefaultExecutionContext(base.ExecutionContext): + _lastrowid = None + def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None): self.dialect = dialect self._connection = self.root_connection = connection @@ -329,6 +337,35 @@ class DefaultExecutionContext(base.ExecutionContext): def post_exec(self): pass + + def get_lastrowid(self): + """return self.cursor.lastrowid, or equivalent, after an INSERT. + + This may involve calling special cursor functions, + issuing a new SELECT on the cursor (or a new one), + or returning a stored value that was + calculated within post_exec(). + + This function will only be called for dialects + which support "implicit" primary key generation, + keep preexecute_autoincrement_sequences set to False, + and when no explicit id value was bound to the + statement. + + The function is called once, directly after + post_exec() and before the transaction is committed + or ResultProxy is generated. If the post_exec() + method assigns a value to `self._lastrowid`, the + value is used in place of calling get_lastrowid(). + + Note that this method is *not* equivalent to the + ``lastrowid`` method on ``ResultProxy``, which is a + direct proxy to the DBAPI ``lastrowid`` accessor + in all cases. + + """ + + return self.cursor.lastrowid def handle_dbapi_exception(self, e): pass @@ -345,9 +382,34 @@ class DefaultExecutionContext(base.ExecutionContext): def supports_sane_multi_rowcount(self): return self.dialect.supports_sane_multi_rowcount - - def last_inserted_ids(self): - return self._last_inserted_ids + + def post_insert(self): + if self.dialect.postfetch_lastrowid and \ + self._lastrowid is None and \ + (not len(self._last_inserted_ids) or \ + self._last_inserted_ids[0] is None): + + self._lastrowid = self.get_lastrowid() + + def last_inserted_ids(self, resultproxy): + if not self.isinsert: + raise exc.InvalidRequestError("Statement is not an insert() expression construct.") + + if self.dialect.implicit_returning and \ + not self.compiled.statement._returning and \ + not resultproxy.closed: + + row = resultproxy.first() + + self._last_inserted_ids = [v is not None and v or row[c] + for c, v in zip(self.compiled.statement.table.primary_key, self._last_inserted_ids) + ] + return self._last_inserted_ids + + elif self._lastrowid is not None: + return [self._lastrowid] + self._last_inserted_ids[1:] + else: + return self._last_inserted_ids def last_inserted_params(self): return self._last_inserted_params diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 66dc84f194..79ed44e1ba 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -161,7 +161,8 @@ class SQLCompiler(engine.Compiled): # level to define if this Compiled instance represents # INSERT/UPDATE/DELETE isdelete = isinsert = isupdate = False - + rendered_returning = False + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -696,9 +697,10 @@ class SQLCompiler(engine.Compiled): text += " (%s)" % ', '.join([preparer.format_column(c[0]) for c in colparams]) - if insert_stmt._returning: - returning_clause = self.returning_clause(insert_stmt) - + if self.returning or insert_stmt._returning: + returning_clause = self.returning_clause(insert_stmt, self.returning or insert_stmt._returning) + self.rendered_returning = True + # cheating if returning_clause.startswith("OUTPUT"): text += " " + returning_clause @@ -708,7 +710,7 @@ class SQLCompiler(engine.Compiled): text += " VALUES (%s)" % \ ', '.join([c[1] for c in colparams]) - if insert_stmt._returning and returning_clause: + if (self.returning or insert_stmt._returning) and returning_clause: text += " " + returning_clause return text @@ -728,7 +730,8 @@ class SQLCompiler(engine.Compiled): ) if update_stmt._returning: - returning_clause = self.returning_clause(update_stmt) + returning_clause = self.returning_clause(update_stmt, update_stmt._returning) + self.rendered_returning = True if returning_clause.startswith("OUTPUT"): text += " " + returning_clause returning_clause = None @@ -756,7 +759,8 @@ class SQLCompiler(engine.Compiled): self.postfetch = [] self.prefetch = [] - + self.returning = [] + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: @@ -785,19 +789,43 @@ class SQLCompiler(engine.Compiled): self.postfetch.append(c) value = self.process(value.self_group()) values.append((c, value)) + elif isinstance(c, schema.Column): if self.isinsert: - if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline): - if (((isinstance(c.default, schema.Sequence) and - not c.default.optional) or - not self.dialect.supports_pk_autoincrement) or - (c.default is not None and - not isinstance(c.default, schema.Sequence))): - values.append((c, create_bind_param(c, None))) - self.prefetch.append(c) + if c.primary_key and \ + ( + self.dialect.preexecute_pk_sequences or + self.dialect.implicit_returning + ) and \ + not self.inline and \ + not self.statement._returning: + + if self.dialect.implicit_returning: + if isinstance(c.default, schema.Sequence): + proc = self.process(c.default) + if proc is not None: + values.append((c, proc)) + self.returning.append(c) + elif isinstance(c.default, schema.ColumnDefault) and \ + isinstance(c.default.arg, sql.ClauseElement): + values.append((c, self.process(c.default.arg.self_group()))) + self.returning.append(c) + elif c.default is not None: + values.append((c, create_bind_param(c, None))) + self.prefetch.append(c) + else: + self.returning.append(c) + else: + if c.default is not None or \ + self.dialect.preexecute_autoincrement_sequences: + + values.append((c, create_bind_param(c, None))) + self.prefetch.append(c) + elif isinstance(c.default, schema.ColumnDefault): if isinstance(c.default.arg, sql.ClauseElement): values.append((c, self.process(c.default.arg.self_group()))) + if not c.primary_key: # dont add primary key column to postfetch self.postfetch.append(c) @@ -835,7 +863,9 @@ class SQLCompiler(engine.Compiled): text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) if delete_stmt._returning: - returning_clause = self.returning_clause(delete_stmt) + returning_clause = self.returning_clause(delete_stmt, delete_stmt._returning) + self.rendered_returning = True + self.returning = delete_stmt._returning if returning_clause.startswith("OUTPUT"): text += " " + returning_clause returning_clause = None diff --git a/test/aaa_profiling/test_zoomark.py b/test/aaa_profiling/test_zoomark.py index 83fe2b78d3..a1f6277df8 100644 --- a/test/aaa_profiling/test_zoomark.py +++ b/test/aaa_profiling/test_zoomark.py @@ -327,7 +327,7 @@ class ZooMarkTest(TestBase): def test_profile_1a_populate(self): self.test_baseline_1a_populate() - @profiling.function_call_count(322, {'2.4': 202}) + @profiling.function_call_count(305, {'2.4': 202}) def test_profile_2_insert(self): self.test_baseline_2_insert() diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index 33d3eda20c..3e3884ebe6 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -1,4 +1,5 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy.test import engines import datetime from sqlalchemy import * from sqlalchemy.orm import * @@ -110,11 +111,14 @@ class InsertTest(TestBase, AssertsExecutionResults): @classmethod def setup_class(cls): global metadata + cls.engine= testing.db metadata = MetaData(testing.db) def teardown(self): metadata.drop_all() metadata.tables.clear() + if self.engine is not testing.db: + self.engine.dispose() def test_compiled_insert(self): table = Table('testtable', metadata, @@ -134,6 +138,13 @@ class InsertTest(TestBase, AssertsExecutionResults): metadata.create_all() self._assert_data_with_sequence(table, "my_seq") + def test_sequence_returning_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, Sequence('my_seq'), primary_key=True), + Column('data', String(30))) + metadata.create_all() + self._assert_data_with_sequence_returning(table, "my_seq") + def test_opt_sequence_insert(self): table = Table('testtable', metadata, Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True), @@ -141,6 +152,13 @@ class InsertTest(TestBase, AssertsExecutionResults): metadata.create_all() self._assert_data_autoincrement(table) + def test_opt_sequence_returning_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True), + Column('data', String(30))) + metadata.create_all() + self._assert_data_autoincrement_returning(table) + def test_autoincrement_insert(self): table = Table('testtable', metadata, Column('id', Integer, primary_key=True), @@ -148,6 +166,13 @@ class InsertTest(TestBase, AssertsExecutionResults): metadata.create_all() self._assert_data_autoincrement(table) + def test_autoincrement_returning_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30))) + metadata.create_all() + self._assert_data_autoincrement_returning(table) + def test_noautoincrement_insert(self): table = Table('testtable', metadata, Column('id', Integer, primary_key=True, autoincrement=False), @@ -156,6 +181,9 @@ class InsertTest(TestBase, AssertsExecutionResults): self._assert_data_noautoincrement(table) def _assert_data_autoincrement(self, table): + self.engine = engines.testing_engine(options={'implicit_returning':False}) + metadata.bind = self.engine + def go(): # execute with explicit id r = table.insert().execute({'id':30, 'data':'d1'}) @@ -180,7 +208,7 @@ class InsertTest(TestBase, AssertsExecutionResults): # note that the test framework doesnt capture the "preexecute" of a seqeuence # or default. we just see it in the bind params. - self.assert_sql(testing.db, go, [], with_sequences=[ + self.assert_sql(self.engine, go, [], with_sequences=[ ( "INSERT INTO testtable (id, data) VALUES (:id, :data)", {'id':30, 'data':'d1'} @@ -221,7 +249,7 @@ class InsertTest(TestBase, AssertsExecutionResults): # test the same series of events using a reflected # version of the table - m2 = MetaData(testing.db) + m2 = MetaData(self.engine) table = Table(table.name, m2, autoload=True) def go(): @@ -233,7 +261,7 @@ class InsertTest(TestBase, AssertsExecutionResults): table.insert(inline=True).execute({'id':33, 'data':'d7'}) table.insert(inline=True).execute({'data':'d8'}) - self.assert_sql(testing.db, go, [], with_sequences=[ + self.assert_sql(self.engine, go, [], with_sequences=[ ( "INSERT INTO testtable (id, data) VALUES (:id, :data)", {'id':30, 'data':'d1'} @@ -272,7 +300,127 @@ class InsertTest(TestBase, AssertsExecutionResults): ] table.delete().execute() + def _assert_data_autoincrement_returning(self, table): + self.engine = engines.testing_engine(options={'implicit_returning':True}) + metadata.bind = self.engine + + def go(): + # execute with explicit id + r = table.insert().execute({'id':30, 'data':'d1'}) + assert r.last_inserted_ids() == [30] + + # execute with prefetch id + r = table.insert().execute({'data':'d2'}) + assert r.last_inserted_ids() == [1] + + # executemany with explicit ids + table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) + + # executemany, uses SERIAL + table.insert().execute({'data':'d5'}, {'data':'d6'}) + + # single execute, explicit id, inline + table.insert(inline=True).execute({'id':33, 'data':'d7'}) + + # single execute, inline, uses SERIAL + table.insert(inline=True).execute({'data':'d8'}) + + self.assert_sql(self.engine, go, [], with_sequences=[ + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':30, 'data':'d1'} + ), + ( + "INSERT INTO testtable (data) VALUES (:data) RETURNING testtable.id", + {'data': 'd2'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d5'}, {'data':'d6'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':33, 'data':'d7'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d8'}] + ), + ]) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (1, 'd2'), + (31, 'd3'), + (32, 'd4'), + (2, 'd5'), + (3, 'd6'), + (33, 'd7'), + (4, 'd8'), + ] + table.delete().execute() + + # test the same series of events using a reflected + # version of the table + m2 = MetaData(self.engine) + table = Table(table.name, m2, autoload=True) + + def go(): + table.insert().execute({'id':30, 'data':'d1'}) + r = table.insert().execute({'data':'d2'}) + assert r.last_inserted_ids() == [5] + table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) + table.insert().execute({'data':'d5'}, {'data':'d6'}) + table.insert(inline=True).execute({'id':33, 'data':'d7'}) + table.insert(inline=True).execute({'data':'d8'}) + + self.assert_sql(self.engine, go, [], with_sequences=[ + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':30, 'data':'d1'} + ), + ( + "INSERT INTO testtable (data) VALUES (:data) RETURNING testtable.id", + {'data':'d2'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d5'}, {'data':'d6'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':33, 'data':'d7'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d8'}] + ), + ]) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (5, 'd2'), + (31, 'd3'), + (32, 'd4'), + (6, 'd5'), + (7, 'd6'), + (33, 'd7'), + (8, 'd8'), + ] + table.delete().execute() + def _assert_data_with_sequence(self, table, seqname): + self.engine = engines.testing_engine(options={'implicit_returning':False}) + metadata.bind = self.engine + def go(): table.insert().execute({'id':30, 'data':'d1'}) table.insert().execute({'data':'d2'}) @@ -281,7 +429,7 @@ class InsertTest(TestBase, AssertsExecutionResults): table.insert(inline=True).execute({'id':33, 'data':'d7'}) table.insert(inline=True).execute({'data':'d8'}) - self.assert_sql(testing.db, go, [], with_sequences=[ + self.assert_sql(self.engine, go, [], with_sequences=[ ( "INSERT INTO testtable (id, data) VALUES (:id, :data)", {'id':30, 'data':'d1'} @@ -322,10 +470,66 @@ class InsertTest(TestBase, AssertsExecutionResults): # cant test reflection here since the Sequence must be # explicitly specified + def _assert_data_with_sequence_returning(self, table, seqname): + self.engine = engines.testing_engine(options={'implicit_returning':True}) + metadata.bind = self.engine + + def go(): + table.insert().execute({'id':30, 'data':'d1'}) + table.insert().execute({'data':'d2'}) + table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) + table.insert().execute({'data':'d5'}, {'data':'d6'}) + table.insert(inline=True).execute({'id':33, 'data':'d7'}) + table.insert(inline=True).execute({'data':'d8'}) + + self.assert_sql(self.engine, go, [], with_sequences=[ + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':30, 'data':'d1'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (nextval('my_seq'), :data) RETURNING testtable.id", + {'data':'d2'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname, + [{'data':'d5'}, {'data':'d6'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':33, 'data':'d7'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname, + [{'data':'d8'}] + ), + ]) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (1, 'd2'), + (31, 'd3'), + (32, 'd4'), + (2, 'd5'), + (3, 'd6'), + (33, 'd7'), + (4, 'd8'), + ] + + # cant test reflection here since the Sequence must be + # explicitly specified + def _assert_data_noautoincrement(self, table): + self.engine = engines.testing_engine(options={'implicit_returning':False}) + metadata.bind = self.engine + table.insert().execute({'id':30, 'data':'d1'}) - if testing.db.driver == 'pg8000': + if self.engine.driver == 'pg8000': exception_cls = exc.ProgrammingError else: exception_cls = exc.IntegrityError @@ -350,7 +554,7 @@ class InsertTest(TestBase, AssertsExecutionResults): # test the same series of events using a reflected # version of the table - m2 = MetaData(testing.db) + m2 = MetaData(self.engine) table = Table(table.name, m2, autoload=True) table.insert().execute({'id':30, 'data':'d1'}) diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 3f1c1c10d5..f2bc5a53b4 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -3,7 +3,7 @@ import datetime from sqlalchemy import Sequence, Column, func from sqlalchemy.sql import select, text import sqlalchemy as sa -from sqlalchemy.test import testing +from sqlalchemy.test import testing, engines from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean from sqlalchemy.test.schema import Table from sqlalchemy.test.testing import eq_ @@ -540,16 +540,17 @@ class SequenceTest(testing.TestBase): def testseqnonpk(self): """test sequences fire off as defaults on non-pk columns""" - result = sometable.insert().execute(name="somename") + engine = engines.testing_engine(options={'implicit_returning':False}) + result = engine.execute(sometable.insert(), name="somename") assert 'id' in result.postfetch_cols() - result = sometable.insert().execute(name="someother") + result = engine.execute(sometable.insert(), name="someother") assert 'id' in result.postfetch_cols() sometable.insert().execute( {'name':'name3'}, {'name':'name4'}) - eq_(sometable.select().execute().fetchall(), + eq_(sometable.select().order_by(sometable.c.id).execute().fetchall(), [(1, "somename", 1), (2, "someother", 2), (3, "name3", 3), diff --git a/test/sql/test_query.py b/test/sql/test_query.py index bbc399aa6b..37030c94f4 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -5,6 +5,7 @@ from sqlalchemy import exc, sql from sqlalchemy.engine import default from sqlalchemy.test import * from sqlalchemy.test.testing import eq_, assert_raises_message +from sqlalchemy.test.schema import Table, Column class QueryTest(TestBase): @@ -13,11 +14,11 @@ class QueryTest(TestBase): global users, users2, addresses, metadata metadata = MetaData(testing.db) users = Table('query_users', metadata, - Column('user_id', INT, Sequence('user_id_seq', optional=True), primary_key = True), + Column('user_id', INT, primary_key=True, test_needs_autoincrement=True), Column('user_name', VARCHAR(20)), ) addresses = Table('query_addresses', metadata, - Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key=True), + Column('address_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer, ForeignKey('query_users.user_id')), Column('address', String(30))) @@ -59,14 +60,14 @@ class QueryTest(TestBase): def test_lastrow_accessor(self): """Tests the last_inserted_ids() and lastrow_has_id() functions.""" - def insert_values(table, values): + def insert_values(engine, table, values): """ Inserts a row into a table, returns the full list of values INSERTed including defaults that fired off on the DB side and detects rows that had defaults and post-fetches. """ - result = table.insert().execute(**values) + result = engine.execute(table.insert(), **values) ret = values.copy() for col, id in zip(table.primary_key, result.last_inserted_ids()): @@ -74,68 +75,78 @@ class QueryTest(TestBase): if result.lastrow_has_defaults(): criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())]) - row = table.select(criterion).execute().first() + row = engine.execute(table.select(criterion)).first() for c in table.c: ret[c.key] = row[c] return ret - for supported, table, values, assertvalues in [ - ( - {'unsupported':['sqlite']}, - Table("t1", metadata, - Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True), - Column('foo', String(30), primary_key=True)), - {'foo':'hi'}, - {'id':1, 'foo':'hi'} - ), - ( - {'unsupported':['sqlite']}, - Table("t2", metadata, - Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True), - Column('foo', String(30), primary_key=True), - Column('bar', String(30), server_default='hi') + if testing.against('firebird', 'postgres', 'oracle', 'mssql'): + test_engines = [ + engines.testing_engine(options={'implicit_returning':False}), + engines.testing_engine(options={'implicit_returning':True}), + ] + else: + test_engines = [testing.db] + + for engine in test_engines: + metadata = MetaData() + for supported, table, values, assertvalues in [ + ( + {'unsupported':['sqlite']}, + Table("t1", metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('foo', String(30), primary_key=True)), + {'foo':'hi'}, + {'id':1, 'foo':'hi'} ), - {'foo':'hi'}, - {'id':1, 'foo':'hi', 'bar':'hi'} - ), - ( - {'unsupported':[]}, - Table("t3", metadata, - Column("id", String(40), primary_key=True), - Column('foo', String(30), primary_key=True), - Column("bar", String(30)) + ( + {'unsupported':['sqlite']}, + Table("t2", metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('foo', String(30), primary_key=True), + Column('bar', String(30), server_default='hi') ), - {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}, - {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"} - ), - ( - {'unsupported':[]}, - Table("t4", metadata, - Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True), - Column('foo', String(30), primary_key=True), - Column('bar', String(30), server_default='hi') + {'foo':'hi'}, + {'id':1, 'foo':'hi', 'bar':'hi'} ), - {'foo':'hi', 'id':1}, - {'id':1, 'foo':'hi', 'bar':'hi'} - ), - ( - {'unsupported':[]}, - Table("t5", metadata, - Column('id', String(10), primary_key=True), - Column('bar', String(30), server_default='hi') + ( + {'unsupported':[]}, + Table("t3", metadata, + Column("id", String(40), primary_key=True), + Column('foo', String(30), primary_key=True), + Column("bar", String(30)) + ), + {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}, + {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"} ), - {'id':'id1'}, - {'id':'id1', 'bar':'hi'}, - ), - ]: - if testing.db.name in supported['unsupported']: - continue - try: - table.create() - i = insert_values(table, values) - assert i == assertvalues, repr(i) + " " + repr(assertvalues) - finally: - table.drop() + ( + {'unsupported':[]}, + Table("t4", metadata, + Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True), + Column('foo', String(30), primary_key=True), + Column('bar', String(30), server_default='hi') + ), + {'foo':'hi', 'id':1}, + {'id':1, 'foo':'hi', 'bar':'hi'} + ), + ( + {'unsupported':[]}, + Table("t5", metadata, + Column('id', String(10), primary_key=True), + Column('bar', String(30), server_default='hi') + ), + {'id':'id1'}, + {'id':'id1', 'bar':'hi'}, + ), + ]: + if testing.db.name in supported['unsupported']: + continue + try: + table.create(bind=engine, checkfirst=True) + i = insert_values(engine, table, values) + assert i == assertvalues, "tablename: %s %r %r" % (table.name, repr(i), repr(assertvalues)) + finally: + table.drop(bind=engine) def test_row_iteration(self): users.insert().execute( diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index ead61cd418..04cfa4be8f 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -4,6 +4,7 @@ from sqlalchemy.test import * from sqlalchemy.test.schema import Table, Column from sqlalchemy.types import TypeDecorator + class ReturningTest(TestBase, AssertsExecutionResults): __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access') @@ -30,7 +31,7 @@ class ReturningTest(TestBase, AssertsExecutionResults): Column('full', Boolean), Column('goofy', GoofyType(50)) ) - table.create() + table.create(checkfirst=True) def teardown(self): table.drop() @@ -134,3 +135,24 @@ class ReturningTest(TestBase, AssertsExecutionResults): result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() eq_(result2.fetchall(), [(2,False),]) + +class SequenceReturningTest(TestBase): + __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'mssql') + + def setup(self): + meta = MetaData(testing.db) + global table, seq + seq = Sequence('tid_seq') + table = Table('tables', meta, + Column('id', Integer, seq, primary_key=True), + Column('data', String(50)) + ) + table.create(checkfirst=True) + + def teardown(self): + table.drop() + + def test_insert(self): + r = table.insert().values(data='hi').returning(table.c.id).execute() + assert r.first() == (1, ) + assert seq.execute() == 2 diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index a172eb4523..670ae1fd0c 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -422,7 +422,7 @@ class ReduceTest(TestBase, AssertsExecutionResults): 'm': page_table.join(magazine_page_table), 'c': page_table.join(magazine_page_table).join(classified_page_table), }, None, 'page_join') - + eq_( util.column_set(sql_util.reduce_columns([pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])), util.column_set([pjoin.c.id])