]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- some cleanups in compiler.py
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Aug 2015 21:25:05 +0000 (17:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Aug 2015 21:25:05 +0000 (17:25 -0400)
lib/sqlalchemy/sql/compiler.py

index a036dcc4238db4693a812cd599fe7670bd50dcb9..d3c46e6437db706d4777d59adb5cb44900ac7fac 100644 (file)
@@ -1809,6 +1809,22 @@ class SQLCompiler(Compiled):
             join.onclause._compiler_dispatch(self, **kwargs)
         )
 
+    def _setup_crud_hints(self, stmt, table_text):
+        dialect_hints = dict([
+            (table, hint_text)
+            for (table, dialect), hint_text in
+            stmt._hints.items()
+            if dialect in ('*', self.dialect.name)
+        ])
+        if stmt.table in dialect_hints:
+            table_text = self.format_from_hint_text(
+                table_text,
+                stmt.table,
+                dialect_hints[stmt.table],
+                True
+            )
+        return dialect_hints, table_text
+
     def visit_insert(self, insert_stmt, **kw):
         self.stack.append(
             {'correlate_froms': set(),
@@ -1850,19 +1866,10 @@ class SQLCompiler(Compiled):
         table_text = preparer.format_table(insert_stmt.table)
 
         if insert_stmt._hints:
-            dialect_hints = dict([
-                (table, hint_text)
-                for (table, dialect), hint_text in
-                insert_stmt._hints.items()
-                if dialect in ('*', self.dialect.name)
-            ])
-            if insert_stmt.table in dialect_hints:
-                table_text = self.format_from_hint_text(
-                    table_text,
-                    insert_stmt.table,
-                    dialect_hints[insert_stmt.table],
-                    True
-                )
+            dialect_hints, table_text = self._setup_crud_hints(
+                insert_stmt, table_text)
+        else:
+            dialect_hints = None
 
         text += table_text
 
@@ -1954,19 +1961,8 @@ class SQLCompiler(Compiled):
         crud_params = crud._get_crud_params(self, update_stmt, **kw)
 
         if update_stmt._hints:
-            dialect_hints = dict([
-                (table, hint_text)
-                for (table, dialect), hint_text in
-                update_stmt._hints.items()
-                if dialect in ('*', self.dialect.name)
-            ])
-            if update_stmt.table in dialect_hints:
-                table_text = self.format_from_hint_text(
-                    table_text,
-                    update_stmt.table,
-                    dialect_hints[update_stmt.table],
-                    True
-                )
+            dialect_hints, table_text = self._setup_crud_hints(
+                update_stmt, table_text)
         else:
             dialect_hints = None
 
@@ -2035,22 +2031,8 @@ class SQLCompiler(Compiled):
             self, asfrom=True, iscrud=True)
 
         if delete_stmt._hints:
-            dialect_hints = dict([
-                (table, hint_text)
-                for (table, dialect), hint_text in
-                delete_stmt._hints.items()
-                if dialect in ('*', self.dialect.name)
-            ])
-            if delete_stmt.table in dialect_hints:
-                table_text = self.format_from_hint_text(
-                    table_text,
-                    delete_stmt.table,
-                    dialect_hints[delete_stmt.table],
-                    True
-                )
-
-        else:
-            dialect_hints = None
+            dialect_hints, table_text = self._setup_crud_hints(
+                delete_stmt, table_text)
 
         text += table_text
 
@@ -2136,11 +2118,11 @@ class DDLCompiler(Compiled):
         table = create.element
         preparer = self.dialect.identifier_preparer
 
-        text = "\n" + " ".join(['CREATE'] +
-                               table._prefixes +
-                               ['TABLE',
-                                preparer.format_table(table),
-                                "("])
+        text = "\nCREATE "
+        if table._prefixes:
+            text += " ".join(table._prefixes) + " "
+        text += "TABLE " + preparer.format_table(table) + " ("
+
         separator = "\n"
 
         # if only one primary key, specify it along with the column
@@ -2165,8 +2147,8 @@ class DDLCompiler(Compiled):
                     ))
 
         const = self.create_table_constraints(
-            table, _include_foreign_key_constraints=
-            create.include_foreign_key_constraints)
+            table, _include_foreign_key_constraints=  # noqa
+                create.include_foreign_key_constraints)
         if const:
             text += ", \n\t" + const
 
@@ -2220,7 +2202,7 @@ class DDLCompiler(Compiled):
                 and (
                     not self.dialect.supports_alter or
                     not getattr(constraint, 'use_alter', False)
-                )) if p is not None
+            )) if p is not None
         )
 
     def visit_drop_table(self, drop):