]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
compiler: adjust _get_colparams to return the columns and parameters in separate...
authorIdan Kamara <idankk86@gmail.com>
Wed, 5 Dec 2012 21:45:49 +0000 (23:45 +0200)
committerIdan Kamara <idankk86@gmail.com>
Wed, 5 Dec 2012 21:45:49 +0000 (23:45 +0200)
lib/sqlalchemy/sql/compiler.py

index 102b44a7e3c3bf0a3e311129d0a94a4983a4856a..6f7f1dadd21f2e0d12af5bc78fa6df7bb2fcc5c7 100644 (file)
@@ -1275,9 +1275,9 @@ class SQLCompiler(engine.Compiled):
 
     def visit_insert(self, insert_stmt, **kw):
         self.isinsert = True
-        colparams = self._get_colparams(insert_stmt)
+        cols, params = self._get_colparams(insert_stmt)
 
-        if not colparams and \
+        if not cols and \
                 not self.dialect.supports_default_values and \
                 not self.dialect.supports_empty_insert:
             raise exc.CompileError("The version of %s you are using does "
@@ -1313,9 +1313,9 @@ class SQLCompiler(engine.Compiled):
 
         text += table_text
 
-        if colparams or not supports_default_values:
-            text += " (%s)" % ', '.join([preparer.format_column(c[0])
-                       for c in colparams])
+        if cols or not supports_default_values:
+            text += " (%s)" % ', '.join([preparer.format_column(c)
+                       for c in cols])
 
         if self.returning or insert_stmt._returning:
             self.returning = self.returning or insert_stmt._returning
@@ -1325,11 +1325,11 @@ class SQLCompiler(engine.Compiled):
             if self.returning_precedes_values:
                 text += " " + returning_clause
 
-        if not colparams and supports_default_values:
+        if not cols and supports_default_values:
             text += " DEFAULT VALUES"
         else:
             text += " VALUES (%s)" % \
-                     ', '.join([c[1] for c in colparams])
+                     ', '.join(params[0])
 
         if self.returning and not self.returning_precedes_values:
             text += " " + returning_clause
@@ -1373,7 +1373,7 @@ class SQLCompiler(engine.Compiled):
 
         extra_froms = update_stmt._extra_froms
 
-        colparams = self._get_colparams(update_stmt, extra_froms)
+        cols, params = self._get_colparams(update_stmt, extra_froms)
 
         text = "UPDATE "
 
@@ -1406,10 +1406,13 @@ class SQLCompiler(engine.Compiled):
         text += ' SET '
         include_table = extra_froms and \
                         self.render_table_with_column_in_update_from
+        colparams = []
+        if params:
+            colparams = zip(cols, params[0])
         text += ', '.join(
-                        c[0]._compiler_dispatch(self,
+                        c._compiler_dispatch(self,
                             include_table=include_table) +
-                        '=' + c[1] for c in colparams
+                        '=' + p for c, p in colparams
                         )
 
         if update_stmt._returning:
@@ -1467,11 +1470,9 @@ class SQLCompiler(engine.Compiled):
         # no parameters in the statement, no parameters in the
         # compiled params - return binds for all columns
         if self.column_keys is None and stmt.parameters is None:
-            return [
-                        (c, self._create_crud_bind_param(c,
-                                    None, required=True))
-                        for c in stmt.table.columns
-                    ]
+            values = [self._create_crud_bind_param(c, None, required=True)
+                      for c in stmt.table.columns]
+            return list(stmt.table.columns), [values]
 
         required = object()
 
@@ -1486,6 +1487,7 @@ class SQLCompiler(engine.Compiled):
                               key not in stmt.parameters)
 
         # create a list of column assignment clauses as tuples
+        columns = []
         values = []
 
         if stmt.parameters is not None:
@@ -1502,7 +1504,8 @@ class SQLCompiler(engine.Compiled):
                     else:
                         v = self.process(v.self_group())
 
-                    values.append((k, v))
+                    columns.append(k)
+                    values.append(v)
 
         need_pks = self.isinsert and \
                         not self.inline and \
@@ -1536,7 +1539,8 @@ class SQLCompiler(engine.Compiled):
                         else:
                             self.postfetch.append(c)
                             value = self.process(value.self_group())
-                        values.append((c, value))
+                        columns.append(c)
+                        values.append(value)
             # determine tables which are actually
             # to be updated - process onupdate and
             # server_onupdate for these
@@ -1546,14 +1550,12 @@ class SQLCompiler(engine.Compiled):
                         continue
                     elif c.onupdate is not None and not c.onupdate.is_sequence:
                         if c.onupdate.is_clause_element:
-                            values.append(
-                                (c, self.process(c.onupdate.arg.self_group()))
-                            )
+                            columns.apppend(c)
+                            values.append(self.process(c.onupdate.arg.self_group()))
                             self.postfetch.append(c)
                         else:
-                            values.append(
-                                (c, self._create_crud_bind_param(c, None))
-                            )
+                            columns.append(c)
+                            values.append(self._create_crud_bind_param(c, None))
                             self.prefetch.append(c)
                     elif c.server_onupdate is not None:
                         self.postfetch.append(c)
@@ -1573,7 +1575,8 @@ class SQLCompiler(engine.Compiled):
                 else:
                     self.postfetch.append(c)
                     value = self.process(value.self_group())
-                values.append((c, value))
+                columns.append(c)
+                values.append(value)
 
             elif self.isinsert:
                 if c.primary_key and \
@@ -1591,18 +1594,16 @@ class SQLCompiler(engine.Compiled):
                                     (not c.default.optional or \
                                     not self.dialect.sequences_optional):
                                     proc = self.process(c.default)
-                                    values.append((c, proc))
+                                    columns.append(c)
+                                    values.append(proc)
                                 self.returning.append(c)
                             elif c.default.is_clause_element:
-                                values.append(
-                                    (c,
-                                    self.process(c.default.arg.self_group()))
-                                )
+                                columns.append(c)
+                                values.append(self.process(c.default.arg.self_group()))
                                 self.returning.append(c)
                             else:
-                                values.append(
-                                    (c, self._create_crud_bind_param(c, None))
-                                )
+                                columns.append(c)
+                                values.append(self._create_crud_bind_param(c, None))
                                 self.prefetch.append(c)
                         else:
                             self.returning.append(c)
@@ -1613,10 +1614,8 @@ class SQLCompiler(engine.Compiled):
                                 self.dialect.preexecute_autoincrement_sequences
                             ):
 
-                            values.append(
-                                (c, self._create_crud_bind_param(c, None))
-                            )
-
+                            columns.append(c)
+                            values.append(self._create_crud_bind_param(c, None))
                             self.prefetch.append(c)
 
                 elif c.default is not None:
@@ -1625,21 +1624,20 @@ class SQLCompiler(engine.Compiled):
                             (not c.default.optional or \
                             not self.dialect.sequences_optional):
                             proc = self.process(c.default)
-                            values.append((c, proc))
+                            columns.append(c)
+                            values.append(proc)
                             if not c.primary_key:
                                 self.postfetch.append(c)
                     elif c.default.is_clause_element:
-                        values.append(
-                            (c, self.process(c.default.arg.self_group()))
-                        )
+                        columns.append(c)
+                        values.append(self.process(c.default.arg.self_group()))
 
                         if not c.primary_key:
                             # dont add primary key column to postfetch
                             self.postfetch.append(c)
                     else:
-                        values.append(
-                            (c, self._create_crud_bind_param(c, None))
-                        )
+                        columns.append(c)
+                        values.append(self._create_crud_bind_param(c, None))
                         self.prefetch.append(c)
                 elif c.server_default is not None:
                     if not c.primary_key:
@@ -1648,14 +1646,12 @@ class SQLCompiler(engine.Compiled):
             elif self.isupdate:
                 if c.onupdate is not None and not c.onupdate.is_sequence:
                     if c.onupdate.is_clause_element:
-                        values.append(
-                            (c, self.process(c.onupdate.arg.self_group()))
-                        )
+                        columns.append(c)
+                        values.append(self.process(c.onupdate.arg.self_group()))
                         self.postfetch.append(c)
                     else:
-                        values.append(
-                            (c, self._create_crud_bind_param(c, None))
-                        )
+                        columns.append(c)
+                        values.append(self._create_crud_bind_param(c, None))
                         self.prefetch.append(c)
                 elif c.server_onupdate is not None:
                     self.postfetch.append(c)
@@ -1670,7 +1666,10 @@ class SQLCompiler(engine.Compiled):
                     (", ".join(check))
                 )
 
-        return values
+        if values:
+            values = [values]
+
+        return columns, values
 
     def visit_delete(self, delete_stmt, **kw):
         self.stack.append({'from': set([delete_stmt.table])})