on the structure of criteria, so success/failure is deterministic based on
code structure.
+- sql
+ - returning() support is native to insert(), update(), delete(). Implementations
+ of varying levels of functionality exist for Postgresql, Firebird, MSSQL and
+ Oracle.
+
- engines
- transaction isolation level may be specified with
create_engine(... isolation_level="..."); available on
visit_char_length_func = visit_length_func
- def function_argspec(self, func):
+ def function_argspec(self, func, **kw):
if func.clauses:
return self.process(func.clause_expr)
else:
return ""
- def _append_returning(self, text, stmt):
- returning_cols = stmt.kwargs["firebird_returning"]
+ def returning_clause(self, stmt):
+ returning_cols = stmt._returning
+
def flatten_columnlist(collist):
for c in collist:
- if isinstance(c, sql.expression.Selectable):
+ if isinstance(c, expression.Selectable):
for co in c.columns:
yield co
else:
yield c
- columns = [self.process(c, within_columns_clause=True)
- for c in flatten_columnlist(returning_cols)]
- text += ' RETURNING ' + ', '.join(columns)
- return text
-
- def visit_update(self, update_stmt):
- text = super(FBCompiler, self).visit_update(update_stmt)
- if "firebird_returning" in update_stmt.kwargs:
- return self._append_returning(text, update_stmt)
- else:
- return text
- def visit_insert(self, insert_stmt):
- text = super(FBCompiler, self).visit_insert(insert_stmt)
- if "firebird_returning" in insert_stmt.kwargs:
- return self._append_returning(text, insert_stmt)
- else:
- return text
-
- def visit_delete(self, delete_stmt):
- text = super(FBCompiler, self).visit_delete(delete_stmt)
- if "firebird_returning" in delete_stmt.kwargs:
- return self._append_returning(text, delete_stmt)
- else:
- return text
+ columns = [
+ self.process(c, within_columns_clause=True, result_map=self.result_map)
+ for c in flatten_columnlist(returning_cols)
+ ]
+ return 'RETURNING ' + ', '.join(columns)
class FBDDLCompiler(sql.compiler.DDLCompiler):
import datetime, decimal, inspect, operator, sys, re
from sqlalchemy import sql, schema as sa_schema, exc, util
-from sqlalchemy.sql import select, compiler, expression, operators as sql_operators, functions as sql_functions
+from sqlalchemy.sql import select, compiler, expression, \
+ operators as sql_operators, \
+ functions as sql_functions, util as sql_util
from sqlalchemy.engine import default, base, reflection
from sqlalchemy import types as sqltypes
from decimal import Decimal as _python_Decimal
class MSExecutionContext(default.DefaultExecutionContext):
_enable_identity_insert = False
_select_lastrowid = False
+ _result_proxy = None
def pre_exec(self):
"""Activate IDENTITY_INSERT if needed."""
self._enable_identity_insert = False
self._select_lastrowid = insert_has_sequence and \
+ not self.compiled.statement.returning and \
not self._enable_identity_insert and \
not self.executemany
if self._enable_identity_insert:
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
+ if (self.isinsert or self.isupdate or self.isdelete) and \
+ self.compiled.statement._returning:
+ self._result_proxy = base.FullyBufferedResultProxy(self)
+
def handle_dbapi_exception(self, e):
if self._enable_identity_insert:
try:
except:
pass
+ def get_result_proxy(self):
+ return self._result_proxy or base.ResultProxy(self)
class MSSQLCompiler(compiler.SQLCompiler):
return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
- def visit_insert(self, insert_stmt):
+ def dont_visit_insert(self, insert_stmt):
insert_select = False
if insert_stmt.parameters:
insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)]
else:
return super(MSSQLCompiler, self).visit_insert(insert_stmt)
+ def returning_clause(self, stmt):
+ returning_cols = stmt._returning
+
+ def flatten_columnlist(collist):
+ for c in collist:
+ if isinstance(c, expression.Selectable):
+ for co in c.columns:
+ yield co
+ else:
+ yield c
+
+ if self.isinsert or self.isupdate:
+ target = stmt.table.alias("inserted")
+ else:
+ target = stmt.table.alias("deleted")
+
+ adapter = sql_util.ClauseAdapter(target)
+ columns = [
+ self.process(adapter.traverse(c), within_columns_clause=True, result_map=self.result_map)
+ for c in flatten_columnlist(returning_cols)
+ ]
+
+ return 'OUTPUT ' + ', '.join(columns)
+
def label_select_column(self, select, column, asfrom):
if isinstance(column, expression.Function):
return column.label(None)
__visit_name__ = 'LONG'
class _OracleBoolean(sqltypes.Boolean):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
def result_processor(self, dialect):
def process(value):
if value is None:
else:
return self.process(alias.original, **kwargs)
+ def returning_clause(self, stmt):
+ returning_cols = stmt._returning
+
+ def flatten_columnlist(collist):
+ for c in collist:
+ if isinstance(c, expression.Selectable):
+ for co in c.columns:
+ yield co
+ else:
+ yield c
+
+ def create_out_param(col, i):
+ bindparam = sql.outparam("ret_%d" % i, type_=col.type)
+ self.binds[bindparam.key] = bindparam
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
+
+ # within_columns_clause =False so that labels (foo AS bar) don't render
+ columns = [self.process(c, within_columns_clause=False) for c in flatten_columnlist(returning_cols)]
+
+ binds = [create_out_param(c, i) for i, c in enumerate(flatten_columnlist(returning_cols))]
+
+ return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
+
def _TODO_visit_compound_select(self, select):
"""Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
pass
class OracleDefaultRunner(base.DefaultRunner):
def visit_sequence(self, seq):
- return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {})
+ return self.execute_string("SELECT " +
+ self.dialect.identifier_preparer.format_sequence(seq) +
+ ".nextval FROM DUAL", {})
class OracleIdentifierPreparer(compiler.IdentifierPreparer):
type_code = column[1]
if type_code in self.dialect.ORACLE_BINARY_TYPES:
return base.BufferedColumnResultProxy(self)
+
+ if hasattr(self, 'out_parameters') and \
+ (self.isinsert or self.isupdate or self.isdelete) and \
+ self.compiled.statement._returning:
+
+ return ReturningResultProxy(self)
+ else:
+ return base.ResultProxy(self)
- return base.ResultProxy(self)
-
+class ReturningResultProxy(base.FullyBufferedResultProxy):
+ """Result proxy which stuffs the _returning clause + outparams into the fetch."""
+
+ def _cursor_description(self):
+ returning = self.context.compiled.statement._returning
+
+ ret = []
+ for c in returning:
+ if hasattr(c, 'key'):
+ ret.append((c.key, c.type))
+ else:
+ ret.append((c.anon_label, c.type))
+ return ret
+
+ def _buffer_rows(self):
+ returning = self.context.compiled.statement._returning
+ return [tuple(self.context.out_parameters["ret_%d" % i] for i, c in enumerate(returning))]
class Oracle_cx_oracle(OracleDialect):
execution_ctx_cls = Oracle_cx_oracleExecutionContext
else:
return super(PGCompiler, self).for_update_clause(select)
- def _append_returning(self, text, stmt):
- try:
- returning_cols = stmt.kwargs['postgresql_returning']
- except KeyError:
- returning_cols = stmt.kwargs['postgres_returning']
- util.warn_deprecated("The 'postgres_returning' argument has been renamed 'postgresql_returning'")
-
+ def returning_clause(self, stmt):
+ returning_cols = stmt._returning
+
def flatten_columnlist(collist):
for c in collist:
if isinstance(c, expression.Selectable):
yield co
else:
yield c
- columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)]
- text += ' RETURNING ' + ', '.join(columns)
- return text
-
- def visit_update(self, update_stmt):
- text = super(PGCompiler, self).visit_update(update_stmt)
- if 'postgresql_returning' in update_stmt.kwargs or 'postgres_returning' in update_stmt.kwargs:
- return self._append_returning(text, update_stmt)
- else:
- return text
-
- def visit_insert(self, insert_stmt):
- text = super(PGCompiler, self).visit_insert(insert_stmt)
- if 'postgresql_returning' in insert_stmt.kwargs or 'postgres_returning' in insert_stmt.kwargs:
- return self._append_returning(text, insert_stmt)
- else:
- return text
+
+ columns = [
+ self.process(c, within_columns_clause=True, result_map=self.result_map)
+ for c in flatten_columnlist(returning_cols)
+ ]
+
+ return 'RETURNING ' + ', '.join(columns)
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context)
if context.compiled:
context.post_exec()
+
if context.should_autocommit and not self.in_transaction():
self._commit_impl()
return context.get_result_proxy()
"""
_process_row = RowProxy
-
+
def __init__(self, context):
self.context = context
self.dialect = context.dialect
@property
def out_parameters(self):
return self.context.out_parameters
-
- def _init_metadata(self):
+
+ def _cursor_description(self):
metadata = self.cursor.description
if metadata is None:
- # no results, get rowcount (which requires open cursor on some DB's such as firebird),
- # then close
+ return
+ else:
+ return [(r[0], r[1]) for r in metadata]
+
+ def _init_metadata(self):
+
+ metadata = self._cursor_description()
+ if metadata is None:
+ # no results, get rowcount
+ # (which requires open cursor on some DB's such as firebird),
self.rowcount
- self.close()
+ self.close() # autoclose
return
self._props = util.populate_column_dict(None)
typemap = self.dialect.dbapi_type_map
- for i, item in enumerate(metadata):
- colname = item[0]
+ for i, (colname, coltype) in enumerate(metadata):
if self.dialect.description_encoding:
colname = colname.decode(self.dialect.description_encoding)
try:
(name, obj, type_) = self.context.result_map[colname.lower()]
except KeyError:
- (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
+ (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE))
else:
- (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
+ (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE))
rec = (type_, type_.dialect_impl(self.dialect).result_processor(self.dialect), i)
return result
def _fetchall_impl(self):
- return self.__rowbuffer + list(self.cursor.fetchall())
+ ret = self.__rowbuffer + list(self.cursor.fetchall())
+ self.__rowbuffer[:] = []
+ return ret
+
+class FullyBufferedResultProxy(ResultProxy):
+ """A result proxy that buffers rows fully upon creation.
+
+ Used for operations where a result is to be delivered
+ after the database conversation can not be continued,
+ such as MSSQL INSERT...OUTPUT after an autocommit.
+
+ """
+ def _init_metadata(self):
+ self.__rowbuffer = self._buffer_rows()
+ super(FullyBufferedResultProxy, self)._init_metadata()
+
+ def _buffer_rows(self):
+ return self.cursor.fetchall()
+
+ def _fetchone_impl(self):
+ if self.__rowbuffer:
+ return self.__rowbuffer.pop(0)
+ else:
+ return None
+ def _fetchmany_impl(self, size=None):
+ result = []
+ for x in range(0, size):
+ row = self._fetchone_impl()
+ if row is None:
+ break
+ result.append(row)
+ return result
+
+ def _fetchall_impl(self):
+ ret = self.__rowbuffer
+ self.__rowbuffer = []
+ return ret
class BufferedColumnResultProxy(ResultProxy):
"""A ResultProxy with column buffering behavior.
self.statement = unicode(compiled).encode(self.dialect.encoding)
else:
self.statement = unicode(compiled)
- self.isinsert = self.isupdate = self.executemany = False
+ self.isinsert = self.isupdate = self.isdelete = self.executemany = False
self.should_autocommit = True
self.result_map = None
self.cursor = self.create_cursor()
self.isinsert = compiled.isinsert
self.isupdate = compiled.isupdate
+ self.isdelete = compiled.isdelete
self.should_autocommit = compiled.statement._autocommit
if isinstance(compiled.statement, expression._TextClause):
self.should_autocommit = self.should_autocommit or self.should_autocommit_text(self.statement)
self.statement = statement.encode(self.dialect.encoding)
else:
self.statement = statement
- self.isinsert = self.isupdate = False
+ self.isinsert = self.isupdate = self.isdelete = False
self.cursor = self.create_cursor()
self.should_autocommit = self.should_autocommit_text(statement)
else:
# no statement. used for standalone ColumnDefault execution.
self.statement = self.compiled = None
- self.isinsert = self.isupdate = self.executemany = self.should_autocommit = False
+ self.isinsert = self.isupdate = self.isdelete = self.executemany = self.should_autocommit = False
self.cursor = self.create_cursor()
@property
def visit_insert(self, insert_stmt):
self.isinsert = True
colparams = self._get_colparams(insert_stmt)
- preparer = self.preparer
-
- insert = ' '.join(["INSERT"] +
- [self.process(x) for x in insert_stmt._prefixes])
if not colparams and \
not self.dialect.supports_default_values and \
not self.dialect.supports_empty_insert:
raise exc.CompileError(
"The version of %s you are using does not support empty inserts." % self.dialect.name)
- elif not colparams and self.dialect.supports_default_values:
- return (insert + " INTO %s DEFAULT VALUES" % (
- (preparer.format_table(insert_stmt.table),)))
- else:
- return (insert + " INTO %s (%s) VALUES (%s)" %
- (preparer.format_table(insert_stmt.table),
- ', '.join([preparer.format_column(c[0])
- for c in colparams]),
- ', '.join([c[1] for c in colparams])))
+ preparer = self.preparer
+ supports_default_values = self.dialect.supports_default_values
+
+ text = "INSERT"
+
+ prefixes = [self.process(x) for x in insert_stmt._prefixes]
+ if prefixes:
+ text += " " + " ".join(prefixes)
+
+ text += " INTO " + preparer.format_table(insert_stmt.table)
+
+ if not colparams and supports_default_values:
+ text += " DEFAULT VALUES"
+ else:
+ text += " (%s)" % ', '.join([preparer.format_column(c[0])
+ for c in colparams])
+
+ if insert_stmt._returning:
+ returning_clause = self.returning_clause(insert_stmt)
+
+ # cheating
+ if returning_clause.startswith("OUTPUT"):
+ text += " " + returning_clause
+ returning_clause = None
+
+ if colparams or not supports_default_values:
+ text += " VALUES (%s)" % \
+ ', '.join([c[1] for c in colparams])
+
+ if insert_stmt._returning and returning_clause:
+ text += " " + returning_clause
+
+ return text
+
def visit_update(self, update_stmt):
self.stack.append({'from': set([update_stmt.table])})
self.isupdate = True
colparams = self._get_colparams(update_stmt)
- text = ' '.join((
- "UPDATE",
- self.preparer.format_table(update_stmt.table),
- 'SET',
- ', '.join(self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
- for c in colparams)
- ))
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table)
+
+ text += ' SET ' + \
+ ', '.join(
+ self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
+ for c in colparams
+ )
+ if update_stmt._returning:
+ returning_clause = self.returning_clause(update_stmt)
+ if returning_clause.startswith("OUTPUT"):
+ text += " " + returning_clause
+ returning_clause = None
+
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
+ if update_stmt._returning and returning_clause:
+ text += " " + returning_clause
+
self.stack.pop(-1)
return text
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
+ if delete_stmt._returning:
+ returning_clause = self.returning_clause(delete_stmt)
+ if returning_clause.startswith("OUTPUT"):
+ text += " " + returning_clause
+ returning_clause = None
+
if delete_stmt._whereclause:
text += " WHERE " + self.process(delete_stmt._whereclause)
+ if delete_stmt._returning and returning_clause:
+ text += " " + returning_clause
+
self.stack.pop(-1)
return text
supports_execution = True
_autocommit = True
-
+
def _generate(self):
s = self.__class__.__new__(self.__class__)
s.__dict__ = self.__dict__.copy()
self._bind = bind
bind = property(bind, _set_bind)
+ _returning_re = re.compile(r'(?:firebird|postgres(?:ql)?)_returning')
+ def _process_deprecated_kw(self, kwargs):
+ for k in list(kwargs):
+ m = self._returning_re.match(k)
+ if m:
+ self._returning = kwargs.pop(k)
+ util.warn_deprecated(
+ "The %r argument is deprecated. Please use statement.returning(col1, col2, ...)" % k
+ )
+ return kwargs
+
+ @_generative
+ def returning(self, *cols):
+ """Add a RETURNING or equivalent clause to this statement.
+
+ The given list of columns represent columns within the table
+ that is the target of the INSERT, UPDATE, or DELETE. Each
+ element can be any column expression. ``Table`` objects
+ will be expanded into their individual columns.
+
+ Upon compilation, a RETURNING clause, or database equivalent,
+ will be rendered within the statement. For INSERT and UPDATE,
+ the values are the newly inserted/updated values. For DELETE,
+ the values are those of the rows which were deleted.
+
+ Upon execution, the values of the columns to be returned
+ are made available via the result set and can be iterated
+ using ``fetchone()`` and similar. For DBAPIs which do not
+ natively support returning values (i.e. cx_oracle),
+ SQLAlchemy will approximate this behavior at the result level
+ so that a reasonable amount of behavioral neutrality is
+ provided.
+
+ Note that not all databases/DBAPIs
+ support RETURNING. For those backends with no support,
+ an exception is raised upon compilation and/or execution.
+ For those who do support it, the functionality across backends
+ varies greatly, including restrictions on executemany()
+ and other statements which return multiple rows. Please
+ read the documentation notes for the database in use in
+ order to determine the availability of RETURNING.
+
+ """
+ self._returning = cols
+
class _ValuesBase(_UpdateBase):
__visit_name__ = 'values_base'
inline=False,
bind=None,
prefixes=None,
+ returning=None,
**kwargs):
_ValuesBase.__init__(self, table, values)
self._bind = bind
self.select = None
self.inline = inline
+ self._returning = returning
if prefixes:
self._prefixes = [_literal_as_text(p) for p in prefixes]
else:
self._prefixes = []
- self.kwargs = kwargs
+
+ self.kwargs = self._process_deprecated_kw(kwargs)
def get_children(self, **kwargs):
if self.select is not None:
values=None,
inline=False,
bind=None,
+ returning=None,
**kwargs):
_ValuesBase.__init__(self, table, values)
self._bind = bind
+ self._returning = returning
if whereclause:
self._whereclause = _literal_as_text(whereclause)
else:
self._whereclause = None
self.inline = inline
- self.kwargs = kwargs
+
+ self.kwargs = self._process_deprecated_kw(kwargs)
def get_children(self, **kwargs):
if self._whereclause is not None:
__visit_name__ = 'delete'
- def __init__(self, table, whereclause, bind=None, **kwargs):
+ def __init__(self,
+ table,
+ whereclause,
+ bind=None,
+ returning =None,
+ **kwargs):
self._bind = bind
self.table = table
+ self._returning = returning
+
if whereclause:
self._whereclause = _literal_as_text(whereclause)
else:
self._whereclause = None
- self.kwargs = kwargs
+ self.kwargs = self._process_deprecated_kw(kwargs)
def get_children(self, **kwargs):
if self._whereclause is not None:
column('description', String(128)),
)
- u = update(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name])
+ u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING mytable.myid, mytable.name")
- u = update(table1, values=dict(name='foo'), firebird_returning=[table1])
+ u = update(table1, values=dict(name='foo')).returning(table1)
self.assert_compile(u, "UPDATE mytable SET name=:name "\
"RETURNING mytable.myid, mytable.name, mytable.description")
- u = update(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)])
+ u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name)")
def test_insert_returning(self):
column('description', String(128)),
)
- i = insert(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name])
+ i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING mytable.myid, mytable.name")
- i = insert(table1, values=dict(name='foo'), firebird_returning=[table1])
+ i = insert(table1, values=dict(name='foo')).returning(table1)
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) "\
"RETURNING mytable.myid, mytable.name, mytable.description")
- i = insert(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)])
+ i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name)")
-class ReturningTest(TestBase, AssertsExecutionResults):
- __only_on__ = 'firebird'
-
- @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
- def test_update_returning(self):
- meta = MetaData(testing.db)
- table = Table('tables', meta,
- Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
- Column('persons', Integer),
- Column('full', Boolean)
- )
- table.create()
- try:
- table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
-
- result = table.update(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute()
- eq_(result.fetchall(), [(1,)])
-
- result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
- eq_(result2.fetchall(), [(1,True),(2,False)])
- finally:
- table.drop()
-
- @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
- def test_insert_returning(self):
- meta = MetaData(testing.db)
- table = Table('tables', meta,
- Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
- Column('persons', Integer),
- Column('full', Boolean)
- )
- table.create()
- try:
- result = table.insert(firebird_returning=[table.c.id]).execute({'persons': 1, 'full': False})
-
- eq_(result.fetchall(), [(1,)])
-
- # Multiple inserts only return the last row
- result2 = table.insert(firebird_returning=[table]).execute(
- [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
-
- eq_(result2.fetchall(), [(3,3,True)])
-
- result3 = table.insert(firebird_returning=[table.c.id]).execute({'persons': 4, 'full': False})
- eq_([dict(row) for row in result3], [{'id': 4}])
-
- result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, 1) returning persons')
- eq_([dict(row) for row in result4], [{'persons': 10}])
- finally:
- table.drop()
-
- @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
- def test_delete_returning(self):
- meta = MetaData(testing.db)
- table = Table('tables', meta,
- Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
- Column('persons', Integer),
- Column('full', Boolean)
- )
- table.create()
- try:
- table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
-
- result = table.delete(table.c.persons > 4, firebird_returning=[table.c.id]).execute()
- eq_(result.fetchall(), [(1,)])
-
- result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
- eq_(result2.fetchall(), [(2,False),])
- finally:
- table.drop()
class MiscTest(TestBase):
select([extract(field, t.c.col1)]),
'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % field)
+ def test_update_returning(self):
+ table1 = table('mytable',
+ column('myid', Integer),
+ column('name', String(128)),
+ column('description', String(128)),
+ )
+
+ u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
+ self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, inserted.name")
+
+ u = update(table1, values=dict(name='foo')).returning(table1)
+ self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, "
+ "inserted.name, inserted.description")
+
+ u = update(table1, values=dict(name='foo')).returning(table1).where(table1.c.name=='bar')
+ self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, "
+ "inserted.name, inserted.description WHERE mytable.name = :name_1")
+
+ u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+ self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name)")
+
+ def test_insert_returning(self):
+ table1 = table('mytable',
+ column('myid', Integer),
+ column('name', String(128)),
+ column('description', String(128)),
+ )
+
+ i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
+ self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, inserted.name VALUES (:name)")
+
+ i = insert(table1, values=dict(name='foo')).returning(table1)
+ self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, "
+ "inserted.name, inserted.description VALUES (:name)")
+
+ i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+ self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) VALUES (:name)")
+
+
class IdentityInsertTest(TestBase, AssertsCompiledSQL):
__only_on__ = 'mssql'
column('description', String(128)),
)
- u = update(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name])
+ u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
- u = update(table1, values=dict(name='foo'), postgresql_returning=[table1])
+ u = update(table1, values=dict(name='foo')).returning(table1)
self.assert_compile(u, "UPDATE mytable SET name=%(name)s "\
"RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
- u = update(table1, values=dict(name='foo'), postgresql_returning=[func.length(table1.c.name)])
+ u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect)
column('description', String(128)),
)
- i = insert(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name])
+ i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect)
- i = insert(table1, values=dict(name='foo'), postgresql_returning=[table1])
+ i = insert(table1, values=dict(name='foo')).returning(table1)
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) "\
"RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
- i = insert(table1, values=dict(name='foo'), postgresql_returning=[func.length(table1.c.name)])
+ i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect)
- @testing.uses_deprecated(r".*'postgres_returning' argument has been renamed.*")
+ @testing.uses_deprecated(r".*argument is deprecated. Please use statement.returning.*")
def test_old_returning_names(self):
dialect = postgresql.dialect()
table1 = table('mytable',
u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
+ u = update(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name])
+ self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
+
i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect)
"SELECT EXTRACT(%s FROM t.col1::timestamp) AS anon_1 "
"FROM t" % field)
-class ReturningTest(TestBase, AssertsExecutionResults):
- __only_on__ = 'postgresql'
-
- @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
- def test_update_returning(self):
- meta = MetaData(testing.db)
- table = Table('tables', meta,
- Column('id', Integer, primary_key=True),
- Column('persons', Integer),
- Column('full', Boolean)
- )
- table.create()
- try:
- table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
-
- result = table.update(table.c.persons > 4, dict(full=True), postgresql_returning=[table.c.id]).execute()
- eq_(result.fetchall(), [(1,)])
-
- result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
- eq_(result2.fetchall(), [(1,True),(2,False)])
- finally:
- table.drop()
-
- @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
- def test_insert_returning(self):
- meta = MetaData(testing.db)
- table = Table('tables', meta,
- Column('id', Integer, primary_key=True),
- Column('persons', Integer),
- Column('full', Boolean)
- )
- table.create()
- try:
- result = table.insert(postgresql_returning=[table.c.id]).execute({'persons': 1, 'full': False})
-
- eq_(result.fetchall(), [(1,)])
-
- @testing.fails_on('postgresql', 'Known limitation of psycopg2')
- def test_executemany():
- # return value is documented as failing with psycopg2/executemany
- result2 = table.insert(postgresql_returning=[table]).execute(
- [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
- eq_(result2.fetchall(), [(2, 2, False), (3,3,True)])
-
- test_executemany()
-
- result3 = table.insert(postgresql_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False})
- eq_([dict(row) for row in result3], [{'double_id':8}])
-
- result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons')
- eq_([dict(row) for row in result4], [{'persons': 10}])
- finally:
- table.drop()
-
class InsertTest(TestBase, AssertsExecutionResults):
__only_on__ = 'postgresql'
--- /dev/null
+from sqlalchemy.test.testing import eq_
+from sqlalchemy import *
+from sqlalchemy.test import *
+from sqlalchemy.test.schema import Table, Column
+from sqlalchemy.types import TypeDecorator
+
+class ReturningTest(TestBase, AssertsExecutionResults):
+ __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access')
+
+ def setup(self):
+ meta = MetaData(testing.db)
+ global table, GoofyType
+
+ class GoofyType(TypeDecorator):
+ impl = String
+
+ def process_bind_param(self, value, dialect):
+ if value is None:
+ return None
+ return "FOO" + value
+
+ def process_result_value(self, value, dialect):
+ if value is None:
+ return None
+ return value + "BAR"
+
+ table = Table('tables', meta,
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('persons', Integer),
+ Column('full', Boolean),
+ Column('goofy', GoofyType(50))
+ )
+ table.create()
+
+ def teardown(self):
+ table.drop()
+
+ @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ def test_column_targeting(self):
+ result = table.insert().returning(table.c.id, table.c.full).execute({'persons': 1, 'full': False})
+
+ row = result.first()
+ assert row[table.c.id] == row['id'] == 1
+ assert row[table.c.full] == row['full'] == False
+
+ result = table.insert().values(persons=5, full=True, goofy="somegoofy").\
+ returning(table.c.persons, table.c.full, table.c.goofy).execute()
+ row = result.first()
+ assert row[table.c.persons] == row['persons'] == 5
+ assert row[table.c.full] == row['full'] == True
+ assert row[table.c.goofy] == row['goofy'] == "FOOsomegoofyBAR"
+
+ @testing.fails_on('firebird', "fb can't handle returning x AS y")
+ @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ def test_labeling(self):
+ result = table.insert().values(persons=6).\
+ returning(table.c.persons.label('lala')).execute()
+ row = result.first()
+ assert row['lala'] == 6
+
+ @testing.fails_on('firebird', "fb/kintersbasdb can't handle the bind params")
+ @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ def test_anon_expressions(self):
+ result = table.insert().values(goofy="someOTHERgoofy").\
+ returning(func.lower(table.c.goofy, type_=GoofyType)).execute()
+ row = result.first()
+ assert row[0] == "foosomeothergoofyBAR"
+
+ result = table.insert().values(persons=12).\
+ returning(table.c.persons + 18).execute()
+ row = result.first()
+ assert row[0] == 30
+
+ @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ def test_update_returning(self):
+ table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
+
+ result = table.update(table.c.persons > 4, dict(full=True)).returning(table.c.id).execute()
+ eq_(result.fetchall(), [(1,)])
+
+ result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+ eq_(result2.fetchall(), [(1,True),(2,False)])
+
+ @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ def test_insert_returning(self):
+ result = table.insert().returning(table.c.id).execute({'persons': 1, 'full': False})
+
+ eq_(result.fetchall(), [(1,)])
+
+ @testing.fails_on('postgresql', '')
+ @testing.fails_on('oracle', '')
+ def test_executemany():
+ # return value is documented as failing with psycopg2/executemany
+ result2 = table.insert().returning(table).execute(
+ [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
+
+ if testing.against('firebird', 'mssql'):
+ # Multiple inserts only return the last row
+ eq_(result2.fetchall(), [(3,3,True, None)])
+ else:
+ # nobody does this as far as we know (pg8000?)
+ eq_(result2.fetchall(), [(2, 2, False, None), (3,3,True, None)])
+
+ test_executemany()
+
+ result3 = table.insert().returning(table.c.id).execute({'persons': 4, 'full': False})
+ eq_([dict(row) for row in result3], [{'id': 4}])
+
+ @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ @testing.fails_on_everything_except('postgresql', 'firebird')
+ def test_literal_returning(self):
+ if testing.against("postgresql"):
+ literal_true = "true"
+ else:
+ literal_true = "1"
+
+ result4 = testing.db.execute('insert into tables (id, persons, "full") '
+ 'values (5, 10, %s) returning persons' % literal_true)
+ eq_([dict(row) for row in result4], [{'persons': 10}])
+
+ @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ def test_delete_returning(self):
+ table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
+
+ result = table.delete(table.c.persons > 4).returning(table.c.id).execute()
+ eq_(result.fetchall(), [(1,)])
+
+ result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+ eq_(result2.fetchall(), [(2,False),])