]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- apply pep8 to compiler.py
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Dec 2010 21:34:00 +0000 (16:34 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Dec 2010 21:34:00 +0000 (16:34 -0500)
- deprecate Compiled.compile() - have __init__ do compilation
if statement is present.

lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/dialect/test_mysql.py
test/engine/test_ddlevents.py

index df8bfd4bdea9248f05fd27af1b747ffef325fd56..08c747f38e7079df63f43fd7c0a55271bda1a04d 100644 (file)
@@ -677,8 +677,8 @@ class MSSQLCompiler(compiler.SQLCompiler):
     })
 
     def __init__(self, *args, **kwargs):
-        super(MSSQLCompiler, self).__init__(*args, **kwargs)
         self.tablealiases = {}
+        super(MSSQLCompiler, self).__init__(*args, **kwargs)
 
     def visit_now_func(self, fn, **kw):
         return "CURRENT_TIMESTAMP"
index 4e11117f7b11c8d941381ca7d7ae214e36bd50cd..1e5285b3552a939e52a0765608a621782b8e5e6a 100644 (file)
@@ -675,14 +675,17 @@ class Compiled(object):
         """
 
         self.dialect = dialect
-        self.statement = statement
         self.bind = bind
-        self.can_execute = statement.supports_execution
+        if statement is not None:
+            self.statement = statement
+            self.can_execute = statement.supports_execution
+            self.string = self.process(self.statement)
 
+    @util.deprecated("0.7", ":class:`.Compiled` objects now compile "
+                        "within the constructor.")
     def compile(self):
         """Produce the internal string representation of this element."""
-
-        self.string = self.process(self.statement)
+        pass
 
     @property
     def sql_compiler(self):
index 8474ebaccb2146e73241627755e3f88853140755..3b43386d5445f0f18c18e10121fc21db46b40a5f 100644 (file)
@@ -24,7 +24,8 @@ To generate user-defined SQL strings, see
 
 import re
 from sqlalchemy import schema, engine, util, exc
-from sqlalchemy.sql import operators, functions, util as sql_util, visitors
+from sqlalchemy.sql import operators, functions, util as sql_util, \
+    visitors
 from sqlalchemy.sql import expression as sql
 import decimal
 
@@ -197,7 +198,8 @@ class SQLCompiler(engine.Compiled):
     # driver/DB enforces this
     ansi_bind_rules = False
     
-    def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
+    def __init__(self, dialect, statement, column_keys=None, 
+                    inline=False, **kwargs):
         """Construct a new ``DefaultCompiler`` object.
 
         dialect
@@ -211,47 +213,49 @@ class SQLCompiler(engine.Compiled):
           statement.
 
         """
-        engine.Compiled.__init__(self, dialect, statement, **kwargs)
-
         self.column_keys = column_keys
 
-        # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute)
+        # compile INSERT/UPDATE defaults/sequences inlined (no pre-
+        # execute)
         self.inline = inline or getattr(statement, 'inline', False)
 
-        # a dictionary of bind parameter keys to _BindParamClause instances.
+        # a dictionary of bind parameter keys to _BindParamClause
+        # instances.
         self.binds = {}
 
-        # a dictionary of _BindParamClause instances to "compiled" names that are
-        # actually present in the generated SQL
+        # a dictionary of _BindParamClause instances to "compiled" names
+        # that are actually present in the generated SQL
         self.bind_names = util.column_dict()
 
         # stack which keeps track of nested SELECT statements
         self.stack = []
 
-        # relates label names in the final SQL to
-        # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine.
-        # ResultProxy uses this for type processing and column targeting
+        # relates label names in the final SQL to a tuple of local
+        # column/label name, ColumnElement object (if any) and
+        # TypeEngine. ResultProxy uses this for type processing and
+        # column targeting
         self.result_map = {}
 
         # true if the paramstyle is positional
-        self.positional = self.dialect.positional
+        self.positional = dialect.positional
         if self.positional:
             self.positiontup = []
-
-        self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle]
+        self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
 
         # an IdentifierPreparer that formats the quoting of identifiers
-        self.preparer = self.dialect.identifier_preparer
+        self.preparer = dialect.identifier_preparer
+        self.label_length = dialect.label_length \
+            or dialect.max_identifier_length
 
-        self.label_length = self.dialect.label_length or self.dialect.max_identifier_length
-        
-        # a map which tracks "anonymous" identifiers that are
-        # created on the fly here
+        # a map which tracks "anonymous" identifiers that are created on
+        # the fly here
         self.anon_map = util.PopulateDict(self._process_anon)
 
-        # a map which tracks "truncated" names based on dialect.label_length
-        # or dialect.max_identifier_length
+        # a map which tracks "truncated" names based on
+        # dialect.label_length or dialect.max_identifier_length
         self.truncated_names = {}
+        engine.Compiled.__init__(self, dialect, statement, **kwargs)
+
         
 
     @util.memoized_property
@@ -284,13 +288,13 @@ class SQLCompiler(engine.Compiled):
                 elif bindparam.required:
                     if _group_number:
                         raise exc.InvalidRequestError(
-                                        "A value is required for bind parameter %r, "
-                                        "in parameter group %d" % 
-                                        (bindparam.key, _group_number))
+                                "A value is required for bind parameter %r, "
+                                "in parameter group %d" % 
+                                (bindparam.key, _group_number))
                     else:
                         raise exc.InvalidRequestError(
-                                        "A value is required for bind parameter %r" 
-                                        % bindparam.key)
+                                "A value is required for bind parameter %r" 
+                                % bindparam.key)
                 elif bindparam.callable:
                     pd[name] = bindparam.callable()
                 else:
@@ -311,7 +315,8 @@ class SQLCompiler(engine.Compiled):
     """)
 
     def default_from(self):
-        """Called when a SELECT statement has no froms, and no FROM clause is to be appended.
+        """Called when a SELECT statement has no froms, and no FROM clause is
+        to be appended.
 
         Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output.
 
@@ -328,12 +333,15 @@ class SQLCompiler(engine.Compiled):
         # or ORDER BY clause of a select.  dialect-specific compilers
         # can modify this behavior.
         if within_columns_clause and not within_label_clause:
-            labelname = isinstance(label.name, sql._generated_label) and \
-                    self._truncated_identifier("colident", label.name) or label.name
+            if isinstance(label.name, sql._generated_label):
+                labelname = self._truncated_identifier("colident", label.name)
+            else:
+                labelname = label.name
 
             if result_map is not None:
                 result_map[labelname.lower()] = \
-                        (label.name, (label, label.element, labelname), label.type)
+                        (label.name, (label, label.element, labelname),\
+                        label.type)
 
             return self.process(label.element, 
                                     within_columns_clause=True,
@@ -373,11 +381,12 @@ class SQLCompiler(engine.Compiled):
             else:
                 schema_prefix = ''
             tablename = column.table.name
-            tablename = isinstance(tablename, sql._generated_label) and \
-                            self._truncated_identifier("alias", tablename) or tablename
+            if isinstance(tablename, sql._generated_label):
+                tablename = self._truncated_identifier("alias", tablename)
             
             return schema_prefix + \
-                    self.preparer.quote(tablename, column.table.quote) + "." + name
+                    self.preparer.quote(tablename, column.table.quote) + \
+                    "." + name
 
     def escape_literal_column(self, text):
         """provide escaping for the literal_column() construct."""
@@ -411,7 +420,8 @@ class SQLCompiler(engine.Compiled):
 
         # un-escape any \:params
         return BIND_PARAMS_ESC.sub(lambda m: m.group(1),
-            BIND_PARAMS.sub(do_bindparam, self.post_process_text(textclause.text))
+            BIND_PARAMS.sub(do_bindparam,
+             self.post_process_text(textclause.text))
         )
 
     def visit_null(self, null, **kwargs):
@@ -423,8 +433,11 @@ class SQLCompiler(engine.Compiled):
             sep = " "
         else:
             sep = OPERATORS[clauselist.operator]
-        return sep.join(s for s in (self.process(c, **kwargs) for c in clauselist.clauses)
-                        if s is not None)
+        return sep.join(
+                    s for s in 
+                    (self.process(c, **kwargs) 
+                    for c in clauselist.clauses)
+                    if s is not None)
 
     def visit_case(self, clause, **kwargs):
         x = "CASE "
@@ -440,11 +453,13 @@ class SQLCompiler(engine.Compiled):
 
     def visit_cast(self, cast, **kwargs):
         return "CAST(%s AS %s)" % \
-                    (self.process(cast.clause, **kwargs), self.process(cast.typeclause, **kwargs))
+                    (self.process(cast.clause, **kwargs),
+                    self.process(cast.typeclause, **kwargs))
 
     def visit_extract(self, extract, **kwargs):
         field = self.extract_map.get(extract.field, extract.field)
-        return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr, **kwargs))
+        return "EXTRACT(%s FROM %s)" % (field, 
+                            self.process(extract.expr, **kwargs))
 
     def visit_function(self, func, result_map=None, **kwargs):
         if result_map is not None:
@@ -461,7 +476,8 @@ class SQLCompiler(engine.Compiled):
     def function_argspec(self, func, **kwargs):
         return self.process(func.clause_expr, **kwargs)
 
-    def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=1, **kwargs):
+    def visit_compound_select(self, cs, asfrom=False, 
+                            parens=True, compound_index=1, **kwargs):
         entry = self.stack and self.stack[-1] or {}
         self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
 
@@ -478,7 +494,8 @@ class SQLCompiler(engine.Compiled):
             text += " GROUP BY " + group_by
 
         text += self.order_by_clause(cs, **kwargs)
-        text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or ""
+        text += (cs._limit is not None or cs._offset is not None) and \
+                        self.limit_clause(cs) or ""
 
         self.stack.pop(-1)
         if asfrom and parens:
@@ -530,8 +547,8 @@ class SQLCompiler(engine.Compiled):
     def visit_ilike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
         return 'lower(%s) LIKE lower(%s)' % (
-                                            self.process(binary.left, **kw), 
-                                            self.process(binary.right, **kw)) \
+                                        self.process(binary.left, **kw), 
+                                        self.process(binary.right, **kw)) \
             + (escape and 
                     (' ESCAPE ' + self.render_literal_value(escape, None))
                     or '')
@@ -539,8 +556,8 @@ class SQLCompiler(engine.Compiled):
     def visit_notilike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
         return 'lower(%s) NOT LIKE lower(%s)' % (
-                                            self.process(binary.left, **kw), 
-                                            self.process(binary.right, **kw)) \
+                                    self.process(binary.left, **kw), 
+                                    self.process(binary.right, **kw)) \
             + (escape and 
                     (' ESCAPE ' + self.render_literal_value(escape, None))
                     or '')
@@ -563,7 +580,8 @@ class SQLCompiler(engine.Compiled):
             if bindparam.value is None:
                 raise exc.CompileError("Bind parameter without a "
                                         "renderable value not allowed here.")
-            return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs)
+            return self.render_literal_bindparam(bindparam,
+                            within_columns_clause=True, **kwargs)
             
         name = self._truncate_bindparam(bindparam)
         if name in self.binds:
@@ -572,17 +590,19 @@ class SQLCompiler(engine.Compiled):
                 if existing.unique or bindparam.unique:
                     raise exc.CompileError(
                             "Bind parameter '%s' conflicts with "
-                            "unique bind parameter of the same name" % bindparam.key
+                            "unique bind parameter of the same name" %
+                            bindparam.key
                         )
                 elif getattr(existing, '_is_crud', False):
                     raise exc.CompileError(
-                            "bindparam() name '%s' is reserved "
-                            "for automatic usage in the VALUES or SET clause of this "
-                            "insert/update statement.   Please use a " 
-                            "name other than column name when using bindparam() "
-                            "with insert() or update() (for example, 'b_%s')."
-                            % (bindparam.key, bindparam.key)
-                        )
+                        "bindparam() name '%s' is reserved "
+                        "for automatic usage in the VALUES or SET "
+                        "clause of this "
+                        "insert/update statement.   Please use a " 
+                        "name other than column name when using bindparam() "
+                        "with insert() or update() (for example, 'b_%s')."
+                        % (bindparam.key, bindparam.key)
+                    )
                     
         self.binds[bindparam.key] = self.binds[name] = bindparam
         return self.bindparam_string(name)
@@ -614,15 +634,17 @@ class SQLCompiler(engine.Compiled):
         elif isinstance(value, decimal.Decimal):
             return str(value)
         else:
-            raise NotImplementedError("Don't know how to literal-quote value %r" % value)
+            raise NotImplementedError(
+                        "Don't know how to literal-quote value %r" % value)
         
     def _truncate_bindparam(self, bindparam):
         if bindparam in self.bind_names:
             return self.bind_names[bindparam]
 
         bind_name = bindparam.key
-        bind_name = isinstance(bind_name, sql._generated_label) and \
-                        self._truncated_identifier("bindparam", bind_name) or bind_name
+        if isinstance(bind_name, sql._generated_label):
+            bind_name = self._truncated_identifier("bindparam", bind_name)
+
         # add to bind_names for translation
         self.bind_names[bindparam] = bind_name
 
@@ -636,7 +658,8 @@ class SQLCompiler(engine.Compiled):
 
         if len(anonname) > self.label_length:
             counter = self.truncated_names.get(ident_class, 1)
-            truncname = anonname[0:max(self.label_length - 6, 0)] + "_" + hex(counter)[2:]
+            truncname = anonname[0:max(self.label_length - 6, 0)] + \
+                                "_" + hex(counter)[2:]
             self.truncated_names[ident_class] = counter + 1
         else:
             truncname = anonname
@@ -659,14 +682,19 @@ class SQLCompiler(engine.Compiled):
         else:
             return self.bindtemplate % {'name':name}
 
-    def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs):
+    def visit_alias(self, alias, asfrom=False, ashint=False, 
+                                fromhints=None, **kwargs):
         if asfrom or ashint:
-            alias_name = isinstance(alias.name, sql._generated_label) and \
-                            self._truncated_identifier("alias", alias.name) or alias.name
+            if isinstance(alias.name, sql._generated_label):
+                alias_name = self._truncated_identifier("alias", alias.name)
+            else:
+                alias_name = alias.name
+
         if ashint:
             return self.preparer.format_alias(alias, alias_name)
         elif asfrom:
-            ret = self.process(alias.original, asfrom=True, **kwargs) + " AS " + \
+            ret = self.process(alias.original, asfrom=True, **kwargs) + \
+                                " AS " + \
                     self.preparer.format_alias(alias, alias_name)
                     
             if fromhints and alias in fromhints:
@@ -695,8 +723,10 @@ class SQLCompiler(engine.Compiled):
             not isinstance(column.table, sql.Select):
             return _CompileLabel(column, sql._generated_label(column.name))
         elif not isinstance(column, 
-                    (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \
-                and (not hasattr(column, 'name') or isinstance(column, sql.Function)):
+                    (sql._UnaryExpression, sql._TextClause,
+                        sql._BindParamClause)) \
+                and (not hasattr(column, 'name') or \
+                        isinstance(column, sql.Function)):
             return _CompileLabel(column, column.anon_label)
         else:
             return column
@@ -719,12 +749,13 @@ class SQLCompiler(engine.Compiled):
 
         correlate_froms = set(sql._from_objects(*froms))
 
-        # TODO: might want to propagate existing froms for select(select(select))
-        # where innermost select should correlate to outermost
-        # if existingfroms:
-        #     correlate_froms = correlate_froms.union(existingfroms)
+        # TODO: might want to propagate existing froms for
+        # select(select(select)) where innermost select should correlate
+        # to outermost if existingfroms: correlate_froms =
+        # correlate_froms.union(existingfroms)
 
-        self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper})
+        self.stack.append({'from': correlate_froms, 'iswrapper'
+                          : iswrapper})
 
         if compound_index==1 and not entry or entry.get('iswrapper', False):
             column_clause_args = {'result_map':self.result_map}
@@ -747,7 +778,8 @@ class SQLCompiler(engine.Compiled):
 
         if select._hints:
             byfrom = dict([
-                            (from_, hinttext % {'name':self.process(from_, ashint=True)}) 
+                            (from_, hinttext % {
+                                'name':self.process(from_, ashint=True)}) 
                             for (from_, dialect), hinttext in 
                             select._hints.iteritems() 
                             if dialect in ('*', self.dialect.name)
@@ -757,7 +789,9 @@ class SQLCompiler(engine.Compiled):
                 text += hint_text + " "
                 
         if select._prefixes:
-            text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " "
+            text += " ".join(
+                            self.process(x, **kwargs) 
+                            for x in select._prefixes) + " "
         text += self.get_select_precolumns(select)
         text += ', '.join(inner_columns)
 
@@ -806,8 +840,8 @@ class SQLCompiler(engine.Compiled):
             return text
 
     def get_select_precolumns(self, select):
-        """Called when building a ``SELECT`` statement, position is just before 
-        column list.
+        """Called when building a ``SELECT`` statement, position is just
+        before column list.
         
         """
         return select._distinct and "DISTINCT " or ""
@@ -835,11 +869,14 @@ class SQLCompiler(engine.Compiled):
             text += " OFFSET " + self.process(sql.literal(select._offset))
         return text
 
-    def visit_table(self, table, asfrom=False, ashint=False, fromhints=None, **kwargs):
+    def visit_table(self, table, asfrom=False, ashint=False, 
+                        fromhints=None, **kwargs):
         if asfrom or ashint:
             if getattr(table, "schema", None):
-                ret = self.preparer.quote_schema(table.schema, table.quote_schema) + \
-                                "." + self.preparer.quote(table.name, table.quote)
+                ret = self.preparer.quote_schema(table.schema,
+                                table.quote_schema) + \
+                                "." + self.preparer.quote(table.name,
+                                                table.quote)
             else:
                 ret = self.preparer.quote(table.name, table.quote)
             if fromhints and table in fromhints:
@@ -887,7 +924,8 @@ class SQLCompiler(engine.Compiled):
 
         if self.returning or insert_stmt._returning:
             self.returning = self.returning or insert_stmt._returning
-            returning_clause = self.returning_clause(insert_stmt, self.returning)
+            returning_clause = self.returning_clause(
+                                    insert_stmt, self.returning)
             
             if self.returning_precedes_values:
                 text += " " + returning_clause
@@ -913,27 +951,31 @@ class SQLCompiler(engine.Compiled):
         
         text += ' SET ' + \
                 ', '.join(
-                        self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
+                        self.preparer.quote(c[0].name, c[0].quote) + 
+                        '=' + c[1]
                       for c in colparams
                 )
 
         if update_stmt._returning:
             self.returning = update_stmt._returning
             if self.returning_precedes_values:
-                text += " " + self.returning_clause(update_stmt, update_stmt._returning)
+                text += " " + self.returning_clause(
+                                    update_stmt, update_stmt._returning)
                 
         if update_stmt._whereclause is not None:
             text += " WHERE " + self.process(update_stmt._whereclause)
 
         if self.returning and not self.returning_precedes_values:
-            text += " " + self.returning_clause(update_stmt, update_stmt._returning)
+            text += " " + self.returning_clause(
+                                    update_stmt, update_stmt._returning)
             
         self.stack.pop(-1)
 
         return text
 
     def _create_crud_bind_param(self, col, value, required=False):
-        bindparam = sql.bindparam(col.key, value, type_=col.type, required=required)
+        bindparam = sql.bindparam(col.key, value, 
+                            type_=col.type, required=required)
         bindparam._is_crud = True
         if col.key in self.binds:
             raise exc.CompileError(
@@ -952,8 +994,8 @@ class SQLCompiler(engine.Compiled):
         """create a set of tuples representing column/string pairs for use
         in an INSERT or UPDATE statement.
 
-        Also generates the Compiled object's postfetch, prefetch, and returning
-        column collections, used for default handling and ultimately
+        Also generates the Compiled object's postfetch, prefetch, and
+        returning column collections, used for default handling and ultimately
         populating the ResultProxy's prefetch_cols() and postfetch_cols()
         collections.
 
@@ -967,7 +1009,8 @@ class SQLCompiler(engine.Compiled):
         # compiled params - return binds for all columns
         if self.column_keys is None and stmt.parameters is None:
             return [
-                        (c, self._create_crud_bind_param(c, None, required=True)) 
+                        (c, self._create_crud_bind_param(c, 
+                                    None, required=True)) 
                         for c in stmt.table.columns
                     ]
 
@@ -980,7 +1023,8 @@ class SQLCompiler(engine.Compiled):
         else:
             parameters = dict((sql._column_as_key(key), required)
                               for key in self.column_keys 
-                              if not stmt.parameters or key not in stmt.parameters)
+                              if not stmt.parameters or 
+                              key not in stmt.parameters)
 
         if stmt.parameters is not None:
             for k, v in stmt.parameters.iteritems():
@@ -1006,7 +1050,8 @@ class SQLCompiler(engine.Compiled):
             if c.key in parameters:
                 value = parameters[c.key]
                 if sql._is_literal(value):
-                    value = self._create_crud_bind_param(c, value, required=value is required)
+                    value = self._create_crud_bind_param(
+                                    c, value, required=value is required)
                 else:
                     self.postfetch.append(c)
                     value = self.process(value.self_group())
@@ -1029,10 +1074,15 @@ class SQLCompiler(engine.Compiled):
                                     values.append((c, proc))
                                 self.returning.append(c)
                             elif c.default.is_clause_element:
-                                values.append((c, self.process(c.default.arg.self_group())))
+                                values.append(
+                                    (c,
+                                    self.process(c.default.arg.self_group()))
+                                )
                                 self.returning.append(c)
                             else:
-                                values.append((c, self._create_crud_bind_param(c, None)))
+                                values.append(
+                                    (c, self._create_crud_bind_param(c, None))
+                                )
                                 self.prefetch.append(c)
                         else:
                             self.returning.append(c)
@@ -1043,9 +1093,12 @@ class SQLCompiler(engine.Compiled):
                                     self.dialect.supports_sequences or 
                                     not c.default.is_sequence
                                 )
-                            ) or self.dialect.preexecute_autoincrement_sequences:
+                            ) or \
+                             self.dialect.preexecute_autoincrement_sequences:
 
-                            values.append((c, self._create_crud_bind_param(c, None)))
+                            values.append(
+                                (c, self._create_crud_bind_param(c, None))
+                            )
                             self.prefetch.append(c)
                 
                 elif c.default is not None:
@@ -1056,13 +1109,17 @@ class SQLCompiler(engine.Compiled):
                             if not c.primary_key:
                                 self.postfetch.append(c)
                     elif c.default.is_clause_element:
-                        values.append((c, self.process(c.default.arg.self_group())))
+                        values.append(
+                            (c, self.process(c.default.arg.self_group()))
+                        )
                     
                         if not c.primary_key:
                             # dont add primary key column to postfetch
                             self.postfetch.append(c)
                     else:
-                        values.append((c, self._create_crud_bind_param(c, None)))
+                        values.append(
+                            (c, self._create_crud_bind_param(c, None))
+                        )
                         self.prefetch.append(c)
                 elif c.server_default is not None:
                     if not c.primary_key:
@@ -1071,10 +1128,14 @@ class SQLCompiler(engine.Compiled):
             elif self.isupdate:
                 if c.onupdate is not None and not c.onupdate.is_sequence:
                     if c.onupdate.is_clause_element:
-                        values.append((c, self.process(c.onupdate.arg.self_group())))
+                        values.append(
+                            (c, self.process(c.onupdate.arg.self_group()))
+                        )
                         self.postfetch.append(c)
                     else:
-                        values.append((c, self._create_crud_bind_param(c, None)))
+                        values.append(
+                            (c, self._create_crud_bind_param(c, None))
+                        )
                         self.prefetch.append(c)
                 elif c.server_onupdate is not None:
                     self.postfetch.append(c)
@@ -1089,13 +1150,15 @@ class SQLCompiler(engine.Compiled):
         if delete_stmt._returning:
             self.returning = delete_stmt._returning
             if self.returning_precedes_values:
-                text += " " + self.returning_clause(delete_stmt, delete_stmt._returning)
+                text += " " + self.returning_clause(
+                                delete_stmt, delete_stmt._returning)
                 
         if delete_stmt._whereclause is not None:
             text += " WHERE " + self.process(delete_stmt._whereclause)
 
         if self.returning and not self.returning_precedes_values:
-            text += " " + self.returning_clause(delete_stmt, delete_stmt._returning)
+            text += " " + self.returning_clause(
+                                delete_stmt, delete_stmt._returning)
             
         self.stack.pop(-1)
 
@@ -1105,17 +1168,19 @@ class SQLCompiler(engine.Compiled):
         return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
 
     def visit_rollback_to_savepoint(self, savepoint_stmt):
-        return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
+        return "ROLLBACK TO SAVEPOINT %s" % \
+                self.preparer.format_savepoint(savepoint_stmt)
 
     def visit_release_savepoint(self, savepoint_stmt):
-        return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
+        return "RELEASE SAVEPOINT %s" % \
+                self.preparer.format_savepoint(savepoint_stmt)
 
 
 class DDLCompiler(engine.Compiled):
     
     @util.memoized_property
     def sql_compiler(self):
-        return self.dialect.statement_compiler(self.dialect, self.statement)
+        return self.dialect.statement_compiler(self.dialect, None)
         
     @property
     def preparer(self):
@@ -1161,11 +1226,13 @@ class DDLCompiler(engine.Compiled):
             separator = ", \n"
             text += "\t" + self.get_column_specification(
                                             column, 
-                                            first_pk=column.primary_key and not first_pk
+                                            first_pk=column.primary_key and \
+                                            not first_pk
                                         )
             if column.primary_key:
                 first_pk = True
-            const = " ".join(self.process(constraint) for constraint in column.constraints)
+            const = " ".join(self.process(constraint) \
+                            for constraint in column.constraints)
             if const:
                 text += " " + const
 
@@ -1184,10 +1251,12 @@ class DDLCompiler(engine.Compiled):
         if table.primary_key:
             constraints.append(table.primary_key)
             
-        constraints.extend([c for c in table.constraints if c is not table.primary_key])
+        constraints.extend([c for c in table.constraints 
+                                if c is not table.primary_key])
         
         return ", \n\t".join(p for p in
-                        (self.process(constraint) for constraint in constraints 
+                        (self.process(constraint) 
+                        for constraint in constraints 
                         if (
                             constraint._create_rule is None or
                             constraint._create_rule(self))
@@ -1230,7 +1299,8 @@ class DDLCompiler(engine.Compiled):
     def visit_drop_index(self, drop):
         index = drop.element
         return "\nDROP INDEX " + \
-                    self.preparer.quote(self._index_identifier(index.name), index.quote)
+                    self.preparer.quote(
+                            self._index_identifier(index.name), index.quote)
 
     def visit_add_constraint(self, create):
         preparer = self.preparer
@@ -1240,7 +1310,8 @@ class DDLCompiler(engine.Compiled):
         )
 
     def visit_create_sequence(self, create):
-        text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
+        text = "CREATE SEQUENCE %s" % \
+                self.preparer.format_sequence(create.element)
         if create.element.increment is not None:
             text += " INCREMENT BY %d" % create.element.increment
         if create.element.start is not None:
@@ -1248,7 +1319,8 @@ class DDLCompiler(engine.Compiled):
         return text
         
     def visit_drop_sequence(self, drop):
-        return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+        return "DROP SEQUENCE %s" % \
+                self.preparer.format_sequence(drop.element)
 
     def visit_drop_constraint(self, drop):
         preparer = self.preparer
@@ -1301,7 +1373,8 @@ class DDLCompiler(engine.Compiled):
             return ''
         text = ""
         if constraint.name is not None:
-            text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
+            text += "CONSTRAINT %s " % \
+                    self.preparer.format_constraint(constraint)
         text += "PRIMARY KEY "
         text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
                                        for c in constraint)
@@ -1318,7 +1391,8 @@ class DDLCompiler(engine.Compiled):
         text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
             ', '.join(preparer.quote(f.parent.name, f.parent.quote)
                       for f in constraint._elements.values()),
-            self.define_constraint_remote_table(constraint, remote_table, preparer),
+            self.define_constraint_remote_table(
+                            constraint, remote_table, preparer),
             ', '.join(preparer.quote(f.column.name, f.column.quote)
                       for f in constraint._elements.values())
         )
@@ -1334,8 +1408,11 @@ class DDLCompiler(engine.Compiled):
     def visit_unique_constraint(self, constraint):
         text = ""
         if constraint.name is not None:
-            text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
-        text += "UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))
+            text += "CONSTRAINT %s " % \
+                    self.preparer.format_constraint(constraint)
+        text += "UNIQUE (%s)" % (
+                    ', '.join(self.preparer.quote(c.name, c.quote) 
+                            for c in constraint))
         text += self.define_constraint_deferrability(constraint)
         return text
 
@@ -1373,9 +1450,12 @@ class GenericTypeCompiler(engine.TypeCompiler):
         if type_.precision is None:
             return "NUMERIC"
         elif type_.scale is None:
-            return "NUMERIC(%(precision)s)" % {'precision': type_.precision}
+            return "NUMERIC(%(precision)s)" % \
+                        {'precision': type_.precision}
         else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}
+            return "NUMERIC(%(precision)s, %(scale)s)" % \
+                        {'precision': type_.precision, 
+                        'scale' : type_.scale}
 
     def visit_DECIMAL(self, type_):
         return "DECIMAL"
@@ -1499,7 +1579,8 @@ class IdentifierPreparer(object):
           Character that begins a delimited identifier.
 
         final_quote
-          Character that ends a delimited identifier. Defaults to `initial_quote`.
+          Character that ends a delimited identifier. Defaults to
+          `initial_quote`.
 
         omit_schema
           Prevent prepending schema name. Useful for databases that do
@@ -1539,7 +1620,9 @@ class IdentifierPreparer(object):
         quoting behavior.
         """
 
-        return self.initial_quote + self._escape_identifier(value) + self.final_quote
+        return self.initial_quote + \
+                    self._escape_identifier(value) + \
+                    self.final_quote
 
     def _requires_quotes(self, value):
         """Return True if the given identifier requires quoting."""
@@ -1574,8 +1657,10 @@ class IdentifierPreparer(object):
 
     def format_sequence(self, sequence, use_schema=True):
         name = self.quote(sequence.name, sequence.quote)
-        if not self.omit_schema and use_schema and sequence.schema is not None:
-            name = self.quote_schema(sequence.schema, sequence.quote) + "." + name
+        if not self.omit_schema and use_schema and \
+            sequence.schema is not None:
+            name = self.quote_schema(sequence.schema, sequence.quote) + \
+                        "." + name
         return name
 
     def format_label(self, label, name=None):
@@ -1596,24 +1681,33 @@ class IdentifierPreparer(object):
         if name is None:
             name = table.name
         result = self.quote(name, table.quote)
-        if not self.omit_schema and use_schema and getattr(table, "schema", None):
-            result = self.quote_schema(table.schema, table.quote_schema) + "." + result
+        if not self.omit_schema and use_schema \
+            and getattr(table, "schema", None):
+            result = self.quote_schema(table.schema, table.quote_schema) + \
+                                "." + result
         return result
 
-    def format_column(self, column, use_table=False, name=None, table_name=None):
+    def format_column(self, column, use_table=False, 
+                            name=None, table_name=None):
         """Prepare a quoted column name."""
 
         if name is None:
             name = column.name
         if not getattr(column, 'is_literal', False):
             if use_table:
-                return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote)
+                return self.format_table(
+                            column.table, use_schema=False, 
+                            name=table_name) + "." + \
+                            self.quote(name, column.quote)
             else:
                 return self.quote(name, column.quote)
         else:
-            # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
+            # literal textual elements get stuck into ColumnClause alot,
+            # which shouldnt get quoted
+
             if use_table:
-                return self.format_table(column.table, use_schema=False, name=table_name) + "." + name
+                return self.format_table(column.table,
+                        use_schema=False, name=table_name) + '.' + name
             else:
                 return name
 
@@ -1624,7 +1718,8 @@ class IdentifierPreparer(object):
         # ('database', 'owner', etc.) could override this and return
         # a longer sequence.
 
-        if not self.omit_schema and use_schema and getattr(table, 'schema', None):
+        if not self.omit_schema and use_schema and \
+                getattr(table, 'schema', None):
             return (self.quote_schema(table.schema, table.quote_schema),
                     self.format_table(table, use_schema=False))
         else:
index b8c06cb08131b973982fd070d085236e273f0778..bc36e888c62e49c65432b0753e17768add402171 100644 (file)
@@ -1451,10 +1451,10 @@ class ClauseElement(Visitable):
                 bind = self.bind
             else:
                 dialect = default.DefaultDialect()
-        compiler = self._compiler(dialect, bind=bind, **kw)
-        compiler.compile()
-        return compiler
-    
+        c= self._compiler(dialect, bind=bind, **kw)
+        #c.string = c.process(c.statement)
+        return c
+        
     def _compiler(self, dialect, **kw):
         """Return a compiler appropriate for this ClauseElement, given a
         Dialect."""
index 02b888fed93ca395af6726bcb81d75dc7cab3202..0e0f92d3c43f6087938a7d8e64844bb35bb04984 100644 (file)
@@ -188,8 +188,7 @@ class TypesTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
             table_args.append(Column('c%s' % index, type_(*args, **kw)))
 
         numeric_table = Table(*table_args)
-        gen = testing.db.dialect.ddl_compiler(
-                testing.db.dialect, numeric_table)
+        gen = testing.db.dialect.ddl_compiler(testing.db.dialect, None)
 
         for col in numeric_table.c:
             index = int(col.name[1:])
@@ -277,8 +276,7 @@ class TypesTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
             table_args.append(Column('c%s' % index, type_(*args, **kw)))
 
         charset_table = Table(*table_args)
-        gen = testing.db.dialect.ddl_compiler(testing.db.dialect,
-                charset_table)
+        gen = testing.db.dialect.ddl_compiler(testing.db.dialect, None)
 
         for col in charset_table.c:
             index = int(col.name[1:])
@@ -1471,5 +1469,6 @@ class MatchTest(TestBase, AssertsCompiledSQL):
 
 
 def colspec(c):
-    return testing.db.dialect.ddl_compiler(testing.db.dialect, c.table).get_column_specification(c)
+    return testing.db.dialect.ddl_compiler(
+                    testing.db.dialect, None).get_column_specification(c)
 
index e0e121242eae918b0ed378c84522246e14155e1f..790bc23bc767963a539e25545facb16ac09de973 100644 (file)
@@ -408,7 +408,7 @@ class DDLExecutionTest(TestBase):
         """test the escaping of % characters in the DDL construct."""
         
         default_from = testing.db.dialect.statement_compiler(
-                            testing.db.dialect, DDL("")).default_from()
+                            testing.db.dialect, None).default_from()
         
         eq_(
             testing.db.execute(