From: Mike Bayer Date: Sun, 19 Jul 2009 02:20:18 +0000 (+0000) Subject: - returning() support is native to insert(), update(), delete(). Implementations X-Git-Tag: rel_0_6_6~110 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5fc747943259e6b69441f79b944641438383a1bc;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - returning() support is native to insert(), update(), delete(). Implementations of varying levels of functionality exist for Postgresql, Firebird, MSSQL and Oracle. - MSSQL still has a few glitches that need to be resolved via label/column targeting logic. - its looking like time to take another look at positional column targeting overall. --- diff --git a/06CHANGES b/06CHANGES index 8cadbcb02c..3045721cba 100644 --- a/06CHANGES +++ b/06CHANGES @@ -8,6 +8,11 @@ on the structure of criteria, so success/failure is deterministic based on code structure. +- sql + - returning() support is native to insert(), update(), delete(). Implementations + of varying levels of functionality exist for Postgresql, Firebird, MSSQL and + Oracle. + - engines - transaction isolation level may be specified with create_engine(... isolation_level="..."); available on diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index f30749ed72..949289eb36 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -221,7 +221,7 @@ class FBCompiler(sql.compiler.SQLCompiler): visit_char_length_func = visit_length_func - def function_argspec(self, func): + def function_argspec(self, func, **kw): if func.clauses: return self.process(func.clause_expr) else: @@ -253,40 +253,22 @@ class FBCompiler(sql.compiler.SQLCompiler): return "" - def _append_returning(self, text, stmt): - returning_cols = stmt.kwargs["firebird_returning"] + def returning_clause(self, stmt): + returning_cols = stmt._returning + def flatten_columnlist(collist): for c in collist: - if isinstance(c, sql.expression.Selectable): + if isinstance(c, expression.Selectable): for co in c.columns: yield co else: yield c - columns = [self.process(c, within_columns_clause=True) - for c in flatten_columnlist(returning_cols)] - text += ' RETURNING ' + ', '.join(columns) - return text - - def visit_update(self, update_stmt): - text = super(FBCompiler, self).visit_update(update_stmt) - if "firebird_returning" in update_stmt.kwargs: - return self._append_returning(text, update_stmt) - else: - return text - def visit_insert(self, insert_stmt): - text = super(FBCompiler, self).visit_insert(insert_stmt) - if "firebird_returning" in insert_stmt.kwargs: - return self._append_returning(text, insert_stmt) - else: - return text - - def visit_delete(self, delete_stmt): - text = super(FBCompiler, self).visit_delete(delete_stmt) - if "firebird_returning" in delete_stmt.kwargs: - return self._append_returning(text, delete_stmt) - else: - return text + columns = [ + self.process(c, within_columns_clause=True, result_map=self.result_map) + for c in flatten_columnlist(returning_cols) + ] + return 'RETURNING ' + ', '.join(columns) class FBDDLCompiler(sql.compiler.DDLCompiler): diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 849b72b979..9831b5134b 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -224,7 +224,9 @@ Known Issues import datetime, decimal, inspect, operator, sys, re from sqlalchemy import sql, schema as sa_schema, exc, util -from sqlalchemy.sql import select, compiler, expression, operators as sql_operators, functions as sql_functions +from sqlalchemy.sql import select, compiler, expression, \ + operators as sql_operators, \ + functions as sql_functions, util as sql_util from sqlalchemy.engine import default, base, reflection from sqlalchemy import types as sqltypes from decimal import Decimal as _python_Decimal @@ -844,6 +846,7 @@ def _table_sequence_column(tbl): class MSExecutionContext(default.DefaultExecutionContext): _enable_identity_insert = False _select_lastrowid = False + _result_proxy = None def pre_exec(self): """Activate IDENTITY_INSERT if needed.""" @@ -859,6 +862,7 @@ class MSExecutionContext(default.DefaultExecutionContext): self._enable_identity_insert = False self._select_lastrowid = insert_has_sequence and \ + not self.compiled.statement.returning and \ not self._enable_identity_insert and \ not self.executemany @@ -880,6 +884,10 @@ class MSExecutionContext(default.DefaultExecutionContext): if self._enable_identity_insert: self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) + if (self.isinsert or self.isupdate or self.isdelete) and \ + self.compiled.statement._returning: + self._result_proxy = base.FullyBufferedResultProxy(self) + def handle_dbapi_exception(self, e): if self._enable_identity_insert: try: @@ -887,6 +895,8 @@ class MSExecutionContext(default.DefaultExecutionContext): except: pass + def get_result_proxy(self): + return self._result_proxy or base.ResultProxy(self) class MSSQLCompiler(compiler.SQLCompiler): @@ -1023,7 +1033,7 @@ class MSSQLCompiler(compiler.SQLCompiler): return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs) return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) - def visit_insert(self, insert_stmt): + def dont_visit_insert(self, insert_stmt): insert_select = False if insert_stmt.parameters: insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)] @@ -1050,6 +1060,30 @@ class MSSQLCompiler(compiler.SQLCompiler): else: return super(MSSQLCompiler, self).visit_insert(insert_stmt) + def returning_clause(self, stmt): + returning_cols = stmt._returning + + def flatten_columnlist(collist): + for c in collist: + if isinstance(c, expression.Selectable): + for co in c.columns: + yield co + else: + yield c + + if self.isinsert or self.isupdate: + target = stmt.table.alias("inserted") + else: + target = stmt.table.alias("deleted") + + adapter = sql_util.ClauseAdapter(target) + columns = [ + self.process(adapter.traverse(c), within_columns_clause=True, result_map=self.result_map) + for c in flatten_columnlist(returning_cols) + ] + + return 'OUTPUT ' + ', '.join(columns) + def label_select_column(self, select, column, asfrom): if isinstance(column, expression.Function): return column.label(None) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 35c85c2c95..7c956f6bed 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -145,6 +145,9 @@ class LONG(sqltypes.Text): __visit_name__ = 'LONG' class _OracleBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + def result_processor(self, dialect): def process(value): if value is None: @@ -315,6 +318,29 @@ class OracleCompiler(compiler.SQLCompiler): else: return self.process(alias.original, **kwargs) + def returning_clause(self, stmt): + returning_cols = stmt._returning + + def flatten_columnlist(collist): + for c in collist: + if isinstance(c, expression.Selectable): + for co in c.columns: + yield co + else: + yield c + + def create_out_param(col, i): + bindparam = sql.outparam("ret_%d" % i, type_=col.type) + self.binds[bindparam.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + + # within_columns_clause =False so that labels (foo AS bar) don't render + columns = [self.process(c, within_columns_clause=False) for c in flatten_columnlist(returning_cols)] + + binds = [create_out_param(c, i) for i, c in enumerate(flatten_columnlist(returning_cols))] + + return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + def _TODO_visit_compound_select(self, select): """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" pass @@ -424,7 +450,9 @@ class OracleDDLCompiler(compiler.DDLCompiler): class OracleDefaultRunner(base.DefaultRunner): def visit_sequence(self, seq): - return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {}) + return self.execute_string("SELECT " + + self.dialect.identifier_preparer.format_sequence(seq) + + ".nextval FROM DUAL", {}) class OracleIdentifierPreparer(compiler.IdentifierPreparer): diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index fe74dce7af..54e4d119e1 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -213,9 +213,32 @@ class Oracle_cx_oracleExecutionContext(DefaultExecutionContext): type_code = column[1] if type_code in self.dialect.ORACLE_BINARY_TYPES: return base.BufferedColumnResultProxy(self) + + if hasattr(self, 'out_parameters') and \ + (self.isinsert or self.isupdate or self.isdelete) and \ + self.compiled.statement._returning: + + return ReturningResultProxy(self) + else: + return base.ResultProxy(self) - return base.ResultProxy(self) - +class ReturningResultProxy(base.FullyBufferedResultProxy): + """Result proxy which stuffs the _returning clause + outparams into the fetch.""" + + def _cursor_description(self): + returning = self.context.compiled.statement._returning + + ret = [] + for c in returning: + if hasattr(c, 'key'): + ret.append((c.key, c.type)) + else: + ret.append((c.anon_label, c.type)) + return ret + + def _buffer_rows(self): + returning = self.context.compiled.statement._returning + return [tuple(self.context.out_parameters["ret_%d" % i] for i, c in enumerate(returning))] class Oracle_cx_oracle(OracleDialect): execution_ctx_cls = Oracle_cx_oracleExecutionContext diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 1aa96e8524..849ec50066 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -263,13 +263,9 @@ class PGCompiler(compiler.SQLCompiler): else: return super(PGCompiler, self).for_update_clause(select) - def _append_returning(self, text, stmt): - try: - returning_cols = stmt.kwargs['postgresql_returning'] - except KeyError: - returning_cols = stmt.kwargs['postgres_returning'] - util.warn_deprecated("The 'postgres_returning' argument has been renamed 'postgresql_returning'") - + def returning_clause(self, stmt): + returning_cols = stmt._returning + def flatten_columnlist(collist): for c in collist: if isinstance(c, expression.Selectable): @@ -277,23 +273,13 @@ class PGCompiler(compiler.SQLCompiler): yield co else: yield c - columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)] - text += ' RETURNING ' + ', '.join(columns) - return text - - def visit_update(self, update_stmt): - text = super(PGCompiler, self).visit_update(update_stmt) - if 'postgresql_returning' in update_stmt.kwargs or 'postgres_returning' in update_stmt.kwargs: - return self._append_returning(text, update_stmt) - else: - return text - - def visit_insert(self, insert_stmt): - text = super(PGCompiler, self).visit_insert(insert_stmt) - if 'postgresql_returning' in insert_stmt.kwargs or 'postgres_returning' in insert_stmt.kwargs: - return self._append_returning(text, insert_stmt) - else: - return text + + columns = [ + self.process(c, within_columns_clause=True, result_map=self.result_map) + for c in flatten_columnlist(returning_cols) + ] + + return 'RETURNING ' + ', '.join(columns) def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 470ca811b1..feefb88d26 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1064,6 +1064,7 @@ class Connection(Connectable): self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context) if context.compiled: context.post_exec() + if context.should_autocommit and not self.in_transaction(): self._commit_impl() return context.get_result_proxy() @@ -1586,7 +1587,7 @@ class ResultProxy(object): """ _process_row = RowProxy - + def __init__(self, context): self.context = context self.dialect = context.dialect @@ -1607,14 +1608,22 @@ class ResultProxy(object): @property def out_parameters(self): return self.context.out_parameters - - def _init_metadata(self): + + def _cursor_description(self): metadata = self.cursor.description if metadata is None: - # no results, get rowcount (which requires open cursor on some DB's such as firebird), - # then close + return + else: + return [(r[0], r[1]) for r in metadata] + + def _init_metadata(self): + + metadata = self._cursor_description() + if metadata is None: + # no results, get rowcount + # (which requires open cursor on some DB's such as firebird), self.rowcount - self.close() + self.close() # autoclose return self._props = util.populate_column_dict(None) @@ -1623,8 +1632,7 @@ class ResultProxy(object): typemap = self.dialect.dbapi_type_map - for i, item in enumerate(metadata): - colname = item[0] + for i, (colname, coltype) in enumerate(metadata): if self.dialect.description_encoding: colname = colname.decode(self.dialect.description_encoding) @@ -1640,9 +1648,9 @@ class ResultProxy(object): try: (name, obj, type_) = self.context.result_map[colname.lower()] except KeyError: - (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE)) + (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE)) else: - (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE)) + (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE)) rec = (type_, type_.dialect_impl(self.dialect).result_processor(self.dialect), i) @@ -1949,8 +1957,44 @@ class BufferedRowResultProxy(ResultProxy): return result def _fetchall_impl(self): - return self.__rowbuffer + list(self.cursor.fetchall()) + ret = self.__rowbuffer + list(self.cursor.fetchall()) + self.__rowbuffer[:] = [] + return ret + +class FullyBufferedResultProxy(ResultProxy): + """A result proxy that buffers rows fully upon creation. + + Used for operations where a result is to be delivered + after the database conversation can not be continued, + such as MSSQL INSERT...OUTPUT after an autocommit. + + """ + def _init_metadata(self): + self.__rowbuffer = self._buffer_rows() + super(FullyBufferedResultProxy, self)._init_metadata() + + def _buffer_rows(self): + return self.cursor.fetchall() + + def _fetchone_impl(self): + if self.__rowbuffer: + return self.__rowbuffer.pop(0) + else: + return None + def _fetchmany_impl(self, size=None): + result = [] + for x in range(0, size): + row = self._fetchone_impl() + if row is None: + break + result.append(row) + return result + + def _fetchall_impl(self): + ret = self.__rowbuffer + self.__rowbuffer = [] + return ret class BufferedColumnResultProxy(ResultProxy): """A ResultProxy with column buffering behavior. diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 5a86d7c94b..3f90baa5c2 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -187,7 +187,7 @@ class DefaultExecutionContext(base.ExecutionContext): self.statement = unicode(compiled).encode(self.dialect.encoding) else: self.statement = unicode(compiled) - self.isinsert = self.isupdate = self.executemany = False + self.isinsert = self.isupdate = self.isdelete = self.executemany = False self.should_autocommit = True self.result_map = None self.cursor = self.create_cursor() @@ -221,6 +221,7 @@ class DefaultExecutionContext(base.ExecutionContext): self.isinsert = compiled.isinsert self.isupdate = compiled.isupdate + self.isdelete = compiled.isdelete self.should_autocommit = compiled.statement._autocommit if isinstance(compiled.statement, expression._TextClause): self.should_autocommit = self.should_autocommit or self.should_autocommit_text(self.statement) @@ -246,13 +247,13 @@ class DefaultExecutionContext(base.ExecutionContext): self.statement = statement.encode(self.dialect.encoding) else: self.statement = statement - self.isinsert = self.isupdate = False + self.isinsert = self.isupdate = self.isdelete = False self.cursor = self.create_cursor() self.should_autocommit = self.should_autocommit_text(statement) else: # no statement. used for standalone ColumnDefault execution. self.statement = self.compiled = None - self.isinsert = self.isupdate = self.executemany = self.should_autocommit = False + self.isinsert = self.isupdate = self.isdelete = self.executemany = self.should_autocommit = False self.cursor = self.create_cursor() @property diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 30bcc45e59..b862c8c811 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -672,43 +672,73 @@ class SQLCompiler(engine.Compiled): def visit_insert(self, insert_stmt): self.isinsert = True colparams = self._get_colparams(insert_stmt) - preparer = self.preparer - - insert = ' '.join(["INSERT"] + - [self.process(x) for x in insert_stmt._prefixes]) if not colparams and \ not self.dialect.supports_default_values and \ not self.dialect.supports_empty_insert: raise exc.CompileError( "The version of %s you are using does not support empty inserts." % self.dialect.name) - elif not colparams and self.dialect.supports_default_values: - return (insert + " INTO %s DEFAULT VALUES" % ( - (preparer.format_table(insert_stmt.table),))) - else: - return (insert + " INTO %s (%s) VALUES (%s)" % - (preparer.format_table(insert_stmt.table), - ', '.join([preparer.format_column(c[0]) - for c in colparams]), - ', '.join([c[1] for c in colparams]))) + preparer = self.preparer + supports_default_values = self.dialect.supports_default_values + + text = "INSERT" + + prefixes = [self.process(x) for x in insert_stmt._prefixes] + if prefixes: + text += " " + " ".join(prefixes) + + text += " INTO " + preparer.format_table(insert_stmt.table) + + if not colparams and supports_default_values: + text += " DEFAULT VALUES" + else: + text += " (%s)" % ', '.join([preparer.format_column(c[0]) + for c in colparams]) + + if insert_stmt._returning: + returning_clause = self.returning_clause(insert_stmt) + + # cheating + if returning_clause.startswith("OUTPUT"): + text += " " + returning_clause + returning_clause = None + + if colparams or not supports_default_values: + text += " VALUES (%s)" % \ + ', '.join([c[1] for c in colparams]) + + if insert_stmt._returning and returning_clause: + text += " " + returning_clause + + return text + def visit_update(self, update_stmt): self.stack.append({'from': set([update_stmt.table])}) self.isupdate = True colparams = self._get_colparams(update_stmt) - text = ' '.join(( - "UPDATE", - self.preparer.format_table(update_stmt.table), - 'SET', - ', '.join(self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] - for c in colparams) - )) + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + + text += ' SET ' + \ + ', '.join( + self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] + for c in colparams + ) + if update_stmt._returning: + returning_clause = self.returning_clause(update_stmt) + if returning_clause.startswith("OUTPUT"): + text += " " + returning_clause + returning_clause = None + if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) + if update_stmt._returning and returning_clause: + text += " " + returning_clause + self.stack.pop(-1) return text @@ -804,9 +834,18 @@ class SQLCompiler(engine.Compiled): text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) + if delete_stmt._returning: + returning_clause = self.returning_clause(delete_stmt) + if returning_clause.startswith("OUTPUT"): + text += " " + returning_clause + returning_clause = None + if delete_stmt._whereclause: text += " WHERE " + self.process(delete_stmt._whereclause) + if delete_stmt._returning and returning_clause: + text += " " + returning_clause + self.stack.pop(-1) return text diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index fd144a2101..142cdcbe5a 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -3743,7 +3743,7 @@ class _UpdateBase(ClauseElement): supports_execution = True _autocommit = True - + def _generate(self): s = self.__class__.__new__(self.__class__) s.__dict__ = self.__dict__.copy() @@ -3771,6 +3771,51 @@ class _UpdateBase(ClauseElement): self._bind = bind bind = property(bind, _set_bind) + _returning_re = re.compile(r'(?:firebird|postgres(?:ql)?)_returning') + def _process_deprecated_kw(self, kwargs): + for k in list(kwargs): + m = self._returning_re.match(k) + if m: + self._returning = kwargs.pop(k) + util.warn_deprecated( + "The %r argument is deprecated. Please use statement.returning(col1, col2, ...)" % k + ) + return kwargs + + @_generative + def returning(self, *cols): + """Add a RETURNING or equivalent clause to this statement. + + The given list of columns represent columns within the table + that is the target of the INSERT, UPDATE, or DELETE. Each + element can be any column expression. ``Table`` objects + will be expanded into their individual columns. + + Upon compilation, a RETURNING clause, or database equivalent, + will be rendered within the statement. For INSERT and UPDATE, + the values are the newly inserted/updated values. For DELETE, + the values are those of the rows which were deleted. + + Upon execution, the values of the columns to be returned + are made available via the result set and can be iterated + using ``fetchone()`` and similar. For DBAPIs which do not + natively support returning values (i.e. cx_oracle), + SQLAlchemy will approximate this behavior at the result level + so that a reasonable amount of behavioral neutrality is + provided. + + Note that not all databases/DBAPIs + support RETURNING. For those backends with no support, + an exception is raised upon compilation and/or execution. + For those who do support it, the functionality across backends + varies greatly, including restrictions on executemany() + and other statements which return multiple rows. Please + read the documentation notes for the database in use in + order to determine the availability of RETURNING. + + """ + self._returning = cols + class _ValuesBase(_UpdateBase): __visit_name__ = 'values_base' @@ -3819,16 +3864,19 @@ class Insert(_ValuesBase): inline=False, bind=None, prefixes=None, + returning=None, **kwargs): _ValuesBase.__init__(self, table, values) self._bind = bind self.select = None self.inline = inline + self._returning = returning if prefixes: self._prefixes = [_literal_as_text(p) for p in prefixes] else: self._prefixes = [] - self.kwargs = kwargs + + self.kwargs = self._process_deprecated_kw(kwargs) def get_children(self, **kwargs): if self.select is not None: @@ -3865,15 +3913,18 @@ class Update(_ValuesBase): values=None, inline=False, bind=None, + returning=None, **kwargs): _ValuesBase.__init__(self, table, values) self._bind = bind + self._returning = returning if whereclause: self._whereclause = _literal_as_text(whereclause) else: self._whereclause = None self.inline = inline - self.kwargs = kwargs + + self.kwargs = self._process_deprecated_kw(kwargs) def get_children(self, **kwargs): if self._whereclause is not None: @@ -3907,15 +3958,22 @@ class Delete(_UpdateBase): __visit_name__ = 'delete' - def __init__(self, table, whereclause, bind=None, **kwargs): + def __init__(self, + table, + whereclause, + bind=None, + returning =None, + **kwargs): self._bind = bind self.table = table + self._returning = returning + if whereclause: self._whereclause = _literal_as_text(whereclause) else: self._whereclause = None - self.kwargs = kwargs + self.kwargs = self._process_deprecated_kw(kwargs) def get_children(self, **kwargs): if self._whereclause is not None: diff --git a/test/dialect/test_firebird.py b/test/dialect/test_firebird.py index 017306691c..0c19a4c7e7 100644 --- a/test/dialect/test_firebird.py +++ b/test/dialect/test_firebird.py @@ -105,14 +105,14 @@ class CompileTest(TestBase, AssertsCompiledSQL): column('description', String(128)), ) - u = update(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name]) + u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING mytable.myid, mytable.name") - u = update(table1, values=dict(name='foo'), firebird_returning=[table1]) + u = update(table1, values=dict(name='foo')).returning(table1) self.assert_compile(u, "UPDATE mytable SET name=:name "\ "RETURNING mytable.myid, mytable.name, mytable.description") - u = update(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)]) + u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name)") def test_insert_returning(self): @@ -122,87 +122,17 @@ class CompileTest(TestBase, AssertsCompiledSQL): column('description', String(128)), ) - i = insert(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name]) + i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING mytable.myid, mytable.name") - i = insert(table1, values=dict(name='foo'), firebird_returning=[table1]) + i = insert(table1, values=dict(name='foo')).returning(table1) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) "\ "RETURNING mytable.myid, mytable.name, mytable.description") - i = insert(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)]) + i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name)") -class ReturningTest(TestBase, AssertsExecutionResults): - __only_on__ = 'firebird' - - @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') - def test_update_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, Sequence('gen_tables_id'), primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) - - result = table.update(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute() - eq_(result.fetchall(), [(1,)]) - - result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() - eq_(result2.fetchall(), [(1,True),(2,False)]) - finally: - table.drop() - - @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') - def test_insert_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, Sequence('gen_tables_id'), primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - result = table.insert(firebird_returning=[table.c.id]).execute({'persons': 1, 'full': False}) - - eq_(result.fetchall(), [(1,)]) - - # Multiple inserts only return the last row - result2 = table.insert(firebird_returning=[table]).execute( - [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) - - eq_(result2.fetchall(), [(3,3,True)]) - - result3 = table.insert(firebird_returning=[table.c.id]).execute({'persons': 4, 'full': False}) - eq_([dict(row) for row in result3], [{'id': 4}]) - - result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, 1) returning persons') - eq_([dict(row) for row in result4], [{'persons': 10}]) - finally: - table.drop() - - @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') - def test_delete_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, Sequence('gen_tables_id'), primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) - - result = table.delete(table.c.persons > 4, firebird_returning=[table.c.id]).execute() - eq_(result.fetchall(), [(1,)]) - - result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() - eq_(result2.fetchall(), [(2,False),]) - finally: - table.drop() class MiscTest(TestBase): diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index 5bb42d805b..f76e1c9fb8 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -158,6 +158,45 @@ class CompileTest(TestBase, AssertsCompiledSQL): select([extract(field, t.c.col1)]), 'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % field) + def test_update_returning(self): + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, inserted.name") + + u = update(table1, values=dict(name='foo')).returning(table1) + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, " + "inserted.name, inserted.description") + + u = update(table1, values=dict(name='foo')).returning(table1).where(table1.c.name=='bar') + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, " + "inserted.name, inserted.description WHERE mytable.name = :name_1") + + u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name)") + + def test_insert_returning(self): + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) + self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, inserted.name VALUES (:name)") + + i = insert(table1, values=dict(name='foo')).returning(table1) + self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, " + "inserted.name, inserted.description VALUES (:name)") + + i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) VALUES (:name)") + + class IdentityInsertTest(TestBase, AssertsCompiledSQL): __only_on__ = 'mssql' diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index 19364942ec..2b9a687ebf 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -32,14 +32,14 @@ class CompileTest(TestBase, AssertsCompiledSQL): column('description', String(128)), ) - u = update(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name]) + u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect) - u = update(table1, values=dict(name='foo'), postgresql_returning=[table1]) + u = update(table1, values=dict(name='foo')).returning(table1) self.assert_compile(u, "UPDATE mytable SET name=%(name)s "\ "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect) - u = update(table1, values=dict(name='foo'), postgresql_returning=[func.length(table1.c.name)]) + u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect) @@ -51,17 +51,17 @@ class CompileTest(TestBase, AssertsCompiledSQL): column('description', String(128)), ) - i = insert(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name]) + i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect) - i = insert(table1, values=dict(name='foo'), postgresql_returning=[table1]) + i = insert(table1, values=dict(name='foo')).returning(table1) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) "\ "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect) - i = insert(table1, values=dict(name='foo'), postgresql_returning=[func.length(table1.c.name)]) + i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect) - @testing.uses_deprecated(r".*'postgres_returning' argument has been renamed.*") + @testing.uses_deprecated(r".*argument is deprecated. Please use statement.returning.*") def test_old_returning_names(self): dialect = postgresql.dialect() table1 = table('mytable', @@ -73,6 +73,9 @@ class CompileTest(TestBase, AssertsCompiledSQL): u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect) + u = update(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect) + i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect) @@ -100,60 +103,6 @@ class CompileTest(TestBase, AssertsCompiledSQL): "SELECT EXTRACT(%s FROM t.col1::timestamp) AS anon_1 " "FROM t" % field) -class ReturningTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgresql' - - @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') - def test_update_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) - - result = table.update(table.c.persons > 4, dict(full=True), postgresql_returning=[table.c.id]).execute() - eq_(result.fetchall(), [(1,)]) - - result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() - eq_(result2.fetchall(), [(1,True),(2,False)]) - finally: - table.drop() - - @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') - def test_insert_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - result = table.insert(postgresql_returning=[table.c.id]).execute({'persons': 1, 'full': False}) - - eq_(result.fetchall(), [(1,)]) - - @testing.fails_on('postgresql', 'Known limitation of psycopg2') - def test_executemany(): - # return value is documented as failing with psycopg2/executemany - result2 = table.insert(postgresql_returning=[table]).execute( - [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) - eq_(result2.fetchall(), [(2, 2, False), (3,3,True)]) - - test_executemany() - - result3 = table.insert(postgresql_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False}) - eq_([dict(row) for row in result3], [{'double_id':8}]) - - result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons') - eq_([dict(row) for row in result4], [{'persons': 10}]) - finally: - table.drop() - class InsertTest(TestBase, AssertsExecutionResults): __only_on__ = 'postgresql' diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py new file mode 100644 index 0000000000..ead61cd418 --- /dev/null +++ b/test/sql/test_returning.py @@ -0,0 +1,136 @@ +from sqlalchemy.test.testing import eq_ +from sqlalchemy import * +from sqlalchemy.test import * +from sqlalchemy.test.schema import Table, Column +from sqlalchemy.types import TypeDecorator + +class ReturningTest(TestBase, AssertsExecutionResults): + __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access') + + def setup(self): + meta = MetaData(testing.db) + global table, GoofyType + + class GoofyType(TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + if value is None: + return None + return "FOO" + value + + def process_result_value(self, value, dialect): + if value is None: + return None + return value + "BAR" + + table = Table('tables', meta, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('persons', Integer), + Column('full', Boolean), + Column('goofy', GoofyType(50)) + ) + table.create() + + def teardown(self): + table.drop() + + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_column_targeting(self): + result = table.insert().returning(table.c.id, table.c.full).execute({'persons': 1, 'full': False}) + + row = result.first() + assert row[table.c.id] == row['id'] == 1 + assert row[table.c.full] == row['full'] == False + + result = table.insert().values(persons=5, full=True, goofy="somegoofy").\ + returning(table.c.persons, table.c.full, table.c.goofy).execute() + row = result.first() + assert row[table.c.persons] == row['persons'] == 5 + assert row[table.c.full] == row['full'] == True + assert row[table.c.goofy] == row['goofy'] == "FOOsomegoofyBAR" + + @testing.fails_on('firebird', "fb can't handle returning x AS y") + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_labeling(self): + result = table.insert().values(persons=6).\ + returning(table.c.persons.label('lala')).execute() + row = result.first() + assert row['lala'] == 6 + + @testing.fails_on('firebird', "fb/kintersbasdb can't handle the bind params") + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_anon_expressions(self): + result = table.insert().values(goofy="someOTHERgoofy").\ + returning(func.lower(table.c.goofy, type_=GoofyType)).execute() + row = result.first() + assert row[0] == "foosomeothergoofyBAR" + + result = table.insert().values(persons=12).\ + returning(table.c.persons + 18).execute() + row = result.first() + assert row[0] == 30 + + @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_update_returning(self): + table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) + + result = table.update(table.c.persons > 4, dict(full=True)).returning(table.c.id).execute() + eq_(result.fetchall(), [(1,)]) + + result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() + eq_(result2.fetchall(), [(1,True),(2,False)]) + + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_insert_returning(self): + result = table.insert().returning(table.c.id).execute({'persons': 1, 'full': False}) + + eq_(result.fetchall(), [(1,)]) + + @testing.fails_on('postgresql', '') + @testing.fails_on('oracle', '') + def test_executemany(): + # return value is documented as failing with psycopg2/executemany + result2 = table.insert().returning(table).execute( + [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) + + if testing.against('firebird', 'mssql'): + # Multiple inserts only return the last row + eq_(result2.fetchall(), [(3,3,True, None)]) + else: + # nobody does this as far as we know (pg8000?) + eq_(result2.fetchall(), [(2, 2, False, None), (3,3,True, None)]) + + test_executemany() + + result3 = table.insert().returning(table.c.id).execute({'persons': 4, 'full': False}) + eq_([dict(row) for row in result3], [{'id': 4}]) + + @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + @testing.fails_on_everything_except('postgresql', 'firebird') + def test_literal_returning(self): + if testing.against("postgresql"): + literal_true = "true" + else: + literal_true = "1" + + result4 = testing.db.execute('insert into tables (id, persons, "full") ' + 'values (5, 10, %s) returning persons' % literal_true) + eq_([dict(row) for row in result4], [{'persons': 10}]) + + @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_delete_returning(self): + table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) + + result = table.delete(table.c.persons > 4).returning(table.c.id).execute() + eq_(result.fetchall(), [(1,)]) + + result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() + eq_(result2.fetchall(), [(2,False),])