mssql gets extra label stuff to deal with column adaption (not sure if column adaption should
blow away labels like that...). fixes potential column targeting issues on all platforms
+ fixes mssql failures
def returning_clause(self, stmt):
returning_cols = stmt._returning
- def flatten_columnlist(collist):
- for c in collist:
- if isinstance(c, expression.Selectable):
- for co in c.columns:
- yield co
- else:
- yield c
-
columns = [
- self.process(c, within_columns_clause=True, result_map=self.result_map)
- for c in flatten_columnlist(returning_cols)
+ self.process(
+ self.label_select_column(None, c, asfrom=False),
+ within_columns_clause=True,
+ result_map=self.result_map
+ )
+ for c in expression._select_iterables(returning_cols)
]
return 'RETURNING ' + ', '.join(columns)
"""
import datetime, decimal, inspect, operator, sys, re
+import itertools
from sqlalchemy import sql, schema as sa_schema, exc, util
from sqlalchemy.sql import select, compiler, expression, \
def returning_clause(self, stmt):
returning_cols = stmt._returning
- def flatten_columnlist(collist):
- for c in collist:
- if isinstance(c, expression.Selectable):
- for co in c.columns:
- yield co
- else:
- yield c
-
if self.isinsert or self.isupdate:
target = stmt.table.alias("inserted")
else:
target = stmt.table.alias("deleted")
adapter = sql_util.ClauseAdapter(target)
+ def col_label(col):
+ adapted = adapter.traverse(c)
+ if isinstance(c, expression._Label):
+ return adapted.label(c.key)
+ else:
+ return self.label_select_column(None, adapted, asfrom=False)
+
columns = [
- self.process(adapter.traverse(c), within_columns_clause=True, result_map=self.result_map)
- for c in flatten_columnlist(returning_cols)
+ self.process(
+ col_label(c),
+ within_columns_clause=True,
+ result_map=self.result_map
+ )
+ for c in expression._select_iterables(returning_cols)
]
-
return 'OUTPUT ' + ', '.join(columns)
def label_select_column(self, select, column, asfrom):
def returning_clause(self, stmt):
returning_cols = stmt._returning
- def flatten_columnlist(collist):
- for c in collist:
- if isinstance(c, expression.Selectable):
- for co in c.columns:
- yield co
- else:
- yield c
-
def create_out_param(col, i):
bindparam = sql.outparam("ret_%d" % i, type_=col.type)
self.binds[bindparam.key] = bindparam
return self.bindparam_string(self._truncate_bindparam(bindparam))
+ columnlist = list(expression._select_iterables(returning_cols))
+
# within_columns_clause =False so that labels (foo AS bar) don't render
- columns = [self.process(c, within_columns_clause=False) for c in flatten_columnlist(returning_cols)]
+ columns = [self.process(c, within_columns_clause=False) for c in columnlist]
- binds = [create_out_param(c, i) for i, c in enumerate(flatten_columnlist(returning_cols))]
+ binds = [create_out_param(c, i) for i, c in enumerate(columnlist)]
return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
def returning_clause(self, stmt):
returning_cols = stmt._returning
- def flatten_columnlist(collist):
- for c in collist:
- if isinstance(c, expression.Selectable):
- for co in c.columns:
- yield co
- else:
- yield c
-
columns = [
- self.process(c, within_columns_clause=True, result_map=self.result_map)
- for c in flatten_columnlist(returning_cols)
+ self.process(
+ self.label_select_column(None, c, asfrom=False),
+ within_columns_clause=True,
+ result_map=self.result_map)
+ for c in expression._select_iterables(returning_cols)
]
return 'RETURNING ' + ', '.join(columns)
def visit_sequence(self, seq):
if not seq.optional:
- return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)))
+ return self.execute_string(("select nextval('%s')" % \
+ self.dialect.identifier_preparer.format_sequence(seq)))
else:
return None
self.rowcount
self.close() # autoclose
return
-
+
self._props = util.populate_column_dict(None)
self._props.creator = self.__key_fallback()
self.keys = []
if isinstance(column, sql._Label):
return column
- if select.use_labels and column._label:
+ if select and select.use_labels and column._label:
return _CompileLabel(column, column._label)
if \
"""
return itertools.chain(*[x._cloned_set for x in elements])
+def _select_iterables(elements):
+ """expand tables into individual columns in the
+ given list of column expressions.
+
+ """
+ return itertools.chain(*[c._select_iterable for c in elements])
+
def _cloned_intersection(a, b):
"""return the intersection of sets a and b, counting
any overlap between 'cloned' predecessors.
be rendered into the columns clause of the resulting SELECT statement.
"""
- return itertools.chain(*[c._select_iterable for c in self._raw_columns])
+ return _select_iterables(self._raw_columns)
def is_derived_from(self, fromclause):
if self in fromclause._cloned_set:
"RETURNING mytable.myid, mytable.name, mytable.description")
u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
- self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name)")
+ self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name) AS length_1")
def test_insert_returning(self):
table1 = table('mytable',
"RETURNING mytable.myid, mytable.name, mytable.description")
i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
- self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name)")
+ self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name) AS length_1")
"inserted.name, inserted.description WHERE mytable.name = :name_1")
u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
- self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name)")
+ self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name) AS length_1")
def test_insert_returning(self):
table1 = table('mytable',
"inserted.name, inserted.description VALUES (:name)")
i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
- self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) VALUES (:name)")
+ self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) AS length_1 VALUES (:name)")
"RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
- self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect)
+ self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name) AS length_1", dialect=dialect)
def test_insert_returning(self):
"RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
- self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect)
+ self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name) AS length_1", dialect=dialect)
@testing.uses_deprecated(r".*argument is deprecated. Please use statement.returning.*")
def test_old_returning_names(self):