From: Idan Kamara Date: Wed, 5 Dec 2012 21:45:49 +0000 (+0200) Subject: compiler: adjust _get_colparams to return the columns and parameters in separate... X-Git-Tag: rel_0_8_0b2~19 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=51839352a4a9d4b87bdca6c148ec0fd847b8630b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git compiler: adjust _get_colparams to return the columns and parameters in separate lists --- diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 102b44a7e3..6f7f1dadd2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1275,9 +1275,9 @@ class SQLCompiler(engine.Compiled): def visit_insert(self, insert_stmt, **kw): self.isinsert = True - colparams = self._get_colparams(insert_stmt) + cols, params = self._get_colparams(insert_stmt) - if not colparams and \ + if not cols and \ not self.dialect.supports_default_values and \ not self.dialect.supports_empty_insert: raise exc.CompileError("The version of %s you are using does " @@ -1313,9 +1313,9 @@ class SQLCompiler(engine.Compiled): text += table_text - if colparams or not supports_default_values: - text += " (%s)" % ', '.join([preparer.format_column(c[0]) - for c in colparams]) + if cols or not supports_default_values: + text += " (%s)" % ', '.join([preparer.format_column(c) + for c in cols]) if self.returning or insert_stmt._returning: self.returning = self.returning or insert_stmt._returning @@ -1325,11 +1325,11 @@ class SQLCompiler(engine.Compiled): if self.returning_precedes_values: text += " " + returning_clause - if not colparams and supports_default_values: + if not cols and supports_default_values: text += " DEFAULT VALUES" else: text += " VALUES (%s)" % \ - ', '.join([c[1] for c in colparams]) + ', '.join(params[0]) if self.returning and not self.returning_precedes_values: text += " " + returning_clause @@ -1373,7 +1373,7 @@ class SQLCompiler(engine.Compiled): extra_froms = update_stmt._extra_froms - colparams = self._get_colparams(update_stmt, extra_froms) + cols, params = self._get_colparams(update_stmt, extra_froms) text = "UPDATE " @@ -1406,10 +1406,13 @@ class SQLCompiler(engine.Compiled): text += ' SET ' include_table = extra_froms and \ self.render_table_with_column_in_update_from + colparams = [] + if params: + colparams = zip(cols, params[0]) text += ', '.join( - c[0]._compiler_dispatch(self, + c._compiler_dispatch(self, include_table=include_table) + - '=' + c[1] for c in colparams + '=' + p for c, p in colparams ) if update_stmt._returning: @@ -1467,11 +1470,9 @@ class SQLCompiler(engine.Compiled): # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: - return [ - (c, self._create_crud_bind_param(c, - None, required=True)) - for c in stmt.table.columns - ] + values = [self._create_crud_bind_param(c, None, required=True) + for c in stmt.table.columns] + return list(stmt.table.columns), [values] required = object() @@ -1486,6 +1487,7 @@ class SQLCompiler(engine.Compiled): key not in stmt.parameters) # create a list of column assignment clauses as tuples + columns = [] values = [] if stmt.parameters is not None: @@ -1502,7 +1504,8 @@ class SQLCompiler(engine.Compiled): else: v = self.process(v.self_group()) - values.append((k, v)) + columns.append(k) + values.append(v) need_pks = self.isinsert and \ not self.inline and \ @@ -1536,7 +1539,8 @@ class SQLCompiler(engine.Compiled): else: self.postfetch.append(c) value = self.process(value.self_group()) - values.append((c, value)) + columns.append(c) + values.append(value) # determine tables which are actually # to be updated - process onupdate and # server_onupdate for these @@ -1546,14 +1550,12 @@ class SQLCompiler(engine.Compiled): continue elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: - values.append( - (c, self.process(c.onupdate.arg.self_group())) - ) + columns.apppend(c) + values.append(self.process(c.onupdate.arg.self_group())) self.postfetch.append(c) else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) + columns.append(c) + values.append(self._create_crud_bind_param(c, None)) self.prefetch.append(c) elif c.server_onupdate is not None: self.postfetch.append(c) @@ -1573,7 +1575,8 @@ class SQLCompiler(engine.Compiled): else: self.postfetch.append(c) value = self.process(value.self_group()) - values.append((c, value)) + columns.append(c) + values.append(value) elif self.isinsert: if c.primary_key and \ @@ -1591,18 +1594,16 @@ class SQLCompiler(engine.Compiled): (not c.default.optional or \ not self.dialect.sequences_optional): proc = self.process(c.default) - values.append((c, proc)) + columns.append(c) + values.append(proc) self.returning.append(c) elif c.default.is_clause_element: - values.append( - (c, - self.process(c.default.arg.self_group())) - ) + columns.append(c) + values.append(self.process(c.default.arg.self_group())) self.returning.append(c) else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) + columns.append(c) + values.append(self._create_crud_bind_param(c, None)) self.prefetch.append(c) else: self.returning.append(c) @@ -1613,10 +1614,8 @@ class SQLCompiler(engine.Compiled): self.dialect.preexecute_autoincrement_sequences ): - values.append( - (c, self._create_crud_bind_param(c, None)) - ) - + columns.append(c) + values.append(self._create_crud_bind_param(c, None)) self.prefetch.append(c) elif c.default is not None: @@ -1625,21 +1624,20 @@ class SQLCompiler(engine.Compiled): (not c.default.optional or \ not self.dialect.sequences_optional): proc = self.process(c.default) - values.append((c, proc)) + columns.append(c) + values.append(proc) if not c.primary_key: self.postfetch.append(c) elif c.default.is_clause_element: - values.append( - (c, self.process(c.default.arg.self_group())) - ) + columns.append(c) + values.append(self.process(c.default.arg.self_group())) if not c.primary_key: # dont add primary key column to postfetch self.postfetch.append(c) else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) + columns.append(c) + values.append(self._create_crud_bind_param(c, None)) self.prefetch.append(c) elif c.server_default is not None: if not c.primary_key: @@ -1648,14 +1646,12 @@ class SQLCompiler(engine.Compiled): elif self.isupdate: if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: - values.append( - (c, self.process(c.onupdate.arg.self_group())) - ) + columns.append(c) + values.append(self.process(c.onupdate.arg.self_group())) self.postfetch.append(c) else: - values.append( - (c, self._create_crud_bind_param(c, None)) - ) + columns.append(c) + values.append(self._create_crud_bind_param(c, None)) self.prefetch.append(c) elif c.server_onupdate is not None: self.postfetch.append(c) @@ -1670,7 +1666,10 @@ class SQLCompiler(engine.Compiled): (", ".join(check)) ) - return values + if values: + values = [values] + + return columns, values def visit_delete(self, delete_stmt, **kw): self.stack.append({'from': set([delete_stmt.table])})