From: Mike Bayer Date: Tue, 21 Dec 2010 21:34:00 +0000 (-0500) Subject: - apply pep8 to compiler.py X-Git-Tag: rel_0_7b1~127 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=dff4e0591eee3def7c4c38666c8fc581c327aa87;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - apply pep8 to compiler.py - deprecate Compiled.compile() - have __init__ do compilation if statement is present. --- diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index df8bfd4bde..08c747f38e 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -677,8 +677,8 @@ class MSSQLCompiler(compiler.SQLCompiler): }) def __init__(self, *args, **kwargs): - super(MSSQLCompiler, self).__init__(*args, **kwargs) self.tablealiases = {} + super(MSSQLCompiler, self).__init__(*args, **kwargs) def visit_now_func(self, fn, **kw): return "CURRENT_TIMESTAMP" diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4e11117f7b..1e5285b355 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -675,14 +675,17 @@ class Compiled(object): """ self.dialect = dialect - self.statement = statement self.bind = bind - self.can_execute = statement.supports_execution + if statement is not None: + self.statement = statement + self.can_execute = statement.supports_execution + self.string = self.process(self.statement) + @util.deprecated("0.7", ":class:`.Compiled` objects now compile " + "within the constructor.") def compile(self): """Produce the internal string representation of this element.""" - - self.string = self.process(self.statement) + pass @property def sql_compiler(self): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8474ebaccb..3b43386d54 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -24,7 +24,8 @@ To generate user-defined SQL strings, see import re from sqlalchemy import schema, engine, util, exc -from sqlalchemy.sql import operators, functions, util as sql_util, visitors +from sqlalchemy.sql import operators, functions, util as sql_util, \ + visitors from sqlalchemy.sql import expression as sql import decimal @@ -197,7 +198,8 @@ class SQLCompiler(engine.Compiled): # driver/DB enforces this ansi_bind_rules = False - def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): + def __init__(self, dialect, statement, column_keys=None, + inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. dialect @@ -211,47 +213,49 @@ class SQLCompiler(engine.Compiled): statement. """ - engine.Compiled.__init__(self, dialect, statement, **kwargs) - self.column_keys = column_keys - # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) + # compile INSERT/UPDATE defaults/sequences inlined (no pre- + # execute) self.inline = inline or getattr(statement, 'inline', False) - # a dictionary of bind parameter keys to _BindParamClause instances. + # a dictionary of bind parameter keys to _BindParamClause + # instances. self.binds = {} - # a dictionary of _BindParamClause instances to "compiled" names that are - # actually present in the generated SQL + # a dictionary of _BindParamClause instances to "compiled" names + # that are actually present in the generated SQL self.bind_names = util.column_dict() # stack which keeps track of nested SELECT statements self.stack = [] - # relates label names in the final SQL to - # a tuple of local column/label name, ColumnElement object (if any) and TypeEngine. - # ResultProxy uses this for type processing and column targeting + # relates label names in the final SQL to a tuple of local + # column/label name, ColumnElement object (if any) and + # TypeEngine. ResultProxy uses this for type processing and + # column targeting self.result_map = {} # true if the paramstyle is positional - self.positional = self.dialect.positional + self.positional = dialect.positional if self.positional: self.positiontup = [] - - self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle] + self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] # an IdentifierPreparer that formats the quoting of identifiers - self.preparer = self.dialect.identifier_preparer + self.preparer = dialect.identifier_preparer + self.label_length = dialect.label_length \ + or dialect.max_identifier_length - self.label_length = self.dialect.label_length or self.dialect.max_identifier_length - - # a map which tracks "anonymous" identifiers that are - # created on the fly here + # a map which tracks "anonymous" identifiers that are created on + # the fly here self.anon_map = util.PopulateDict(self._process_anon) - # a map which tracks "truncated" names based on dialect.label_length - # or dialect.max_identifier_length + # a map which tracks "truncated" names based on + # dialect.label_length or dialect.max_identifier_length self.truncated_names = {} + engine.Compiled.__init__(self, dialect, statement, **kwargs) + @util.memoized_property @@ -284,13 +288,13 @@ class SQLCompiler(engine.Compiled): elif bindparam.required: if _group_number: raise exc.InvalidRequestError( - "A value is required for bind parameter %r, " - "in parameter group %d" % - (bindparam.key, _group_number)) + "A value is required for bind parameter %r, " + "in parameter group %d" % + (bindparam.key, _group_number)) else: raise exc.InvalidRequestError( - "A value is required for bind parameter %r" - % bindparam.key) + "A value is required for bind parameter %r" + % bindparam.key) elif bindparam.callable: pd[name] = bindparam.callable() else: @@ -311,7 +315,8 @@ class SQLCompiler(engine.Compiled): """) def default_from(self): - """Called when a SELECT statement has no froms, and no FROM clause is to be appended. + """Called when a SELECT statement has no froms, and no FROM clause is + to be appended. Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. @@ -328,12 +333,15 @@ class SQLCompiler(engine.Compiled): # or ORDER BY clause of a select. dialect-specific compilers # can modify this behavior. if within_columns_clause and not within_label_clause: - labelname = isinstance(label.name, sql._generated_label) and \ - self._truncated_identifier("colident", label.name) or label.name + if isinstance(label.name, sql._generated_label): + labelname = self._truncated_identifier("colident", label.name) + else: + labelname = label.name if result_map is not None: result_map[labelname.lower()] = \ - (label.name, (label, label.element, labelname), label.type) + (label.name, (label, label.element, labelname),\ + label.type) return self.process(label.element, within_columns_clause=True, @@ -373,11 +381,12 @@ class SQLCompiler(engine.Compiled): else: schema_prefix = '' tablename = column.table.name - tablename = isinstance(tablename, sql._generated_label) and \ - self._truncated_identifier("alias", tablename) or tablename + if isinstance(tablename, sql._generated_label): + tablename = self._truncated_identifier("alias", tablename) return schema_prefix + \ - self.preparer.quote(tablename, column.table.quote) + "." + name + self.preparer.quote(tablename, column.table.quote) + \ + "." + name def escape_literal_column(self, text): """provide escaping for the literal_column() construct.""" @@ -411,7 +420,8 @@ class SQLCompiler(engine.Compiled): # un-escape any \:params return BIND_PARAMS_ESC.sub(lambda m: m.group(1), - BIND_PARAMS.sub(do_bindparam, self.post_process_text(textclause.text)) + BIND_PARAMS.sub(do_bindparam, + self.post_process_text(textclause.text)) ) def visit_null(self, null, **kwargs): @@ -423,8 +433,11 @@ class SQLCompiler(engine.Compiled): sep = " " else: sep = OPERATORS[clauselist.operator] - return sep.join(s for s in (self.process(c, **kwargs) for c in clauselist.clauses) - if s is not None) + return sep.join( + s for s in + (self.process(c, **kwargs) + for c in clauselist.clauses) + if s is not None) def visit_case(self, clause, **kwargs): x = "CASE " @@ -440,11 +453,13 @@ class SQLCompiler(engine.Compiled): def visit_cast(self, cast, **kwargs): return "CAST(%s AS %s)" % \ - (self.process(cast.clause, **kwargs), self.process(cast.typeclause, **kwargs)) + (self.process(cast.clause, **kwargs), + self.process(cast.typeclause, **kwargs)) def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) - return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr, **kwargs)) + return "EXTRACT(%s FROM %s)" % (field, + self.process(extract.expr, **kwargs)) def visit_function(self, func, result_map=None, **kwargs): if result_map is not None: @@ -461,7 +476,8 @@ class SQLCompiler(engine.Compiled): def function_argspec(self, func, **kwargs): return self.process(func.clause_expr, **kwargs) - def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=1, **kwargs): + def visit_compound_select(self, cs, asfrom=False, + parens=True, compound_index=1, **kwargs): entry = self.stack and self.stack[-1] or {} self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) @@ -478,7 +494,8 @@ class SQLCompiler(engine.Compiled): text += " GROUP BY " + group_by text += self.order_by_clause(cs, **kwargs) - text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or "" + text += (cs._limit is not None or cs._offset is not None) and \ + self.limit_clause(cs) or "" self.stack.pop(-1) if asfrom and parens: @@ -530,8 +547,8 @@ class SQLCompiler(engine.Compiled): def visit_ilike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) LIKE lower(%s)' % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw)) \ + self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') @@ -539,8 +556,8 @@ class SQLCompiler(engine.Compiled): def visit_notilike_op(self, binary, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) NOT LIKE lower(%s)' % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw)) \ + self.process(binary.left, **kw), + self.process(binary.right, **kw)) \ + (escape and (' ESCAPE ' + self.render_literal_value(escape, None)) or '') @@ -563,7 +580,8 @@ class SQLCompiler(engine.Compiled): if bindparam.value is None: raise exc.CompileError("Bind parameter without a " "renderable value not allowed here.") - return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs) + return self.render_literal_bindparam(bindparam, + within_columns_clause=True, **kwargs) name = self._truncate_bindparam(bindparam) if name in self.binds: @@ -572,17 +590,19 @@ class SQLCompiler(engine.Compiled): if existing.unique or bindparam.unique: raise exc.CompileError( "Bind parameter '%s' conflicts with " - "unique bind parameter of the same name" % bindparam.key + "unique bind parameter of the same name" % + bindparam.key ) elif getattr(existing, '_is_crud', False): raise exc.CompileError( - "bindparam() name '%s' is reserved " - "for automatic usage in the VALUES or SET clause of this " - "insert/update statement. Please use a " - "name other than column name when using bindparam() " - "with insert() or update() (for example, 'b_%s')." - % (bindparam.key, bindparam.key) - ) + "bindparam() name '%s' is reserved " + "for automatic usage in the VALUES or SET " + "clause of this " + "insert/update statement. Please use a " + "name other than column name when using bindparam() " + "with insert() or update() (for example, 'b_%s')." + % (bindparam.key, bindparam.key) + ) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) @@ -614,15 +634,17 @@ class SQLCompiler(engine.Compiled): elif isinstance(value, decimal.Decimal): return str(value) else: - raise NotImplementedError("Don't know how to literal-quote value %r" % value) + raise NotImplementedError( + "Don't know how to literal-quote value %r" % value) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: return self.bind_names[bindparam] bind_name = bindparam.key - bind_name = isinstance(bind_name, sql._generated_label) and \ - self._truncated_identifier("bindparam", bind_name) or bind_name + if isinstance(bind_name, sql._generated_label): + bind_name = self._truncated_identifier("bindparam", bind_name) + # add to bind_names for translation self.bind_names[bindparam] = bind_name @@ -636,7 +658,8 @@ class SQLCompiler(engine.Compiled): if len(anonname) > self.label_length: counter = self.truncated_names.get(ident_class, 1) - truncname = anonname[0:max(self.label_length - 6, 0)] + "_" + hex(counter)[2:] + truncname = anonname[0:max(self.label_length - 6, 0)] + \ + "_" + hex(counter)[2:] self.truncated_names[ident_class] = counter + 1 else: truncname = anonname @@ -659,14 +682,19 @@ class SQLCompiler(engine.Compiled): else: return self.bindtemplate % {'name':name} - def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs): + def visit_alias(self, alias, asfrom=False, ashint=False, + fromhints=None, **kwargs): if asfrom or ashint: - alias_name = isinstance(alias.name, sql._generated_label) and \ - self._truncated_identifier("alias", alias.name) or alias.name + if isinstance(alias.name, sql._generated_label): + alias_name = self._truncated_identifier("alias", alias.name) + else: + alias_name = alias.name + if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: - ret = self.process(alias.original, asfrom=True, **kwargs) + " AS " + \ + ret = self.process(alias.original, asfrom=True, **kwargs) + \ + " AS " + \ self.preparer.format_alias(alias, alias_name) if fromhints and alias in fromhints: @@ -695,8 +723,10 @@ class SQLCompiler(engine.Compiled): not isinstance(column.table, sql.Select): return _CompileLabel(column, sql._generated_label(column.name)) elif not isinstance(column, - (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \ - and (not hasattr(column, 'name') or isinstance(column, sql.Function)): + (sql._UnaryExpression, sql._TextClause, + sql._BindParamClause)) \ + and (not hasattr(column, 'name') or \ + isinstance(column, sql.Function)): return _CompileLabel(column, column.anon_label) else: return column @@ -719,12 +749,13 @@ class SQLCompiler(engine.Compiled): correlate_froms = set(sql._from_objects(*froms)) - # TODO: might want to propagate existing froms for select(select(select)) - # where innermost select should correlate to outermost - # if existingfroms: - # correlate_froms = correlate_froms.union(existingfroms) + # TODO: might want to propagate existing froms for + # select(select(select)) where innermost select should correlate + # to outermost if existingfroms: correlate_froms = + # correlate_froms.union(existingfroms) - self.stack.append({'from':correlate_froms, 'iswrapper':iswrapper}) + self.stack.append({'from': correlate_froms, 'iswrapper' + : iswrapper}) if compound_index==1 and not entry or entry.get('iswrapper', False): column_clause_args = {'result_map':self.result_map} @@ -747,7 +778,8 @@ class SQLCompiler(engine.Compiled): if select._hints: byfrom = dict([ - (from_, hinttext % {'name':self.process(from_, ashint=True)}) + (from_, hinttext % { + 'name':self.process(from_, ashint=True)}) for (from_, dialect), hinttext in select._hints.iteritems() if dialect in ('*', self.dialect.name) @@ -757,7 +789,9 @@ class SQLCompiler(engine.Compiled): text += hint_text + " " if select._prefixes: - text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " " + text += " ".join( + self.process(x, **kwargs) + for x in select._prefixes) + " " text += self.get_select_precolumns(select) text += ', '.join(inner_columns) @@ -806,8 +840,8 @@ class SQLCompiler(engine.Compiled): return text def get_select_precolumns(self, select): - """Called when building a ``SELECT`` statement, position is just before - column list. + """Called when building a ``SELECT`` statement, position is just + before column list. """ return select._distinct and "DISTINCT " or "" @@ -835,11 +869,14 @@ class SQLCompiler(engine.Compiled): text += " OFFSET " + self.process(sql.literal(select._offset)) return text - def visit_table(self, table, asfrom=False, ashint=False, fromhints=None, **kwargs): + def visit_table(self, table, asfrom=False, ashint=False, + fromhints=None, **kwargs): if asfrom or ashint: if getattr(table, "schema", None): - ret = self.preparer.quote_schema(table.schema, table.quote_schema) + \ - "." + self.preparer.quote(table.name, table.quote) + ret = self.preparer.quote_schema(table.schema, + table.quote_schema) + \ + "." + self.preparer.quote(table.name, + table.quote) else: ret = self.preparer.quote(table.name, table.quote) if fromhints and table in fromhints: @@ -887,7 +924,8 @@ class SQLCompiler(engine.Compiled): if self.returning or insert_stmt._returning: self.returning = self.returning or insert_stmt._returning - returning_clause = self.returning_clause(insert_stmt, self.returning) + returning_clause = self.returning_clause( + insert_stmt, self.returning) if self.returning_precedes_values: text += " " + returning_clause @@ -913,27 +951,31 @@ class SQLCompiler(engine.Compiled): text += ' SET ' + \ ', '.join( - self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] + self.preparer.quote(c[0].name, c[0].quote) + + '=' + c[1] for c in colparams ) if update_stmt._returning: self.returning = update_stmt._returning if self.returning_precedes_values: - text += " " + self.returning_clause(update_stmt, update_stmt._returning) + text += " " + self.returning_clause( + update_stmt, update_stmt._returning) if update_stmt._whereclause is not None: text += " WHERE " + self.process(update_stmt._whereclause) if self.returning and not self.returning_precedes_values: - text += " " + self.returning_clause(update_stmt, update_stmt._returning) + text += " " + self.returning_clause( + update_stmt, update_stmt._returning) self.stack.pop(-1) return text def _create_crud_bind_param(self, col, value, required=False): - bindparam = sql.bindparam(col.key, value, type_=col.type, required=required) + bindparam = sql.bindparam(col.key, value, + type_=col.type, required=required) bindparam._is_crud = True if col.key in self.binds: raise exc.CompileError( @@ -952,8 +994,8 @@ class SQLCompiler(engine.Compiled): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. - Also generates the Compiled object's postfetch, prefetch, and returning - column collections, used for default handling and ultimately + Also generates the Compiled object's postfetch, prefetch, and + returning column collections, used for default handling and ultimately populating the ResultProxy's prefetch_cols() and postfetch_cols() collections. @@ -967,7 +1009,8 @@ class SQLCompiler(engine.Compiled): # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: return [ - (c, self._create_crud_bind_param(c, None, required=True)) + (c, self._create_crud_bind_param(c, + None, required=True)) for c in stmt.table.columns ] @@ -980,7 +1023,8 @@ class SQLCompiler(engine.Compiled): else: parameters = dict((sql._column_as_key(key), required) for key in self.column_keys - if not stmt.parameters or key not in stmt.parameters) + if not stmt.parameters or + key not in stmt.parameters) if stmt.parameters is not None: for k, v in stmt.parameters.iteritems(): @@ -1006,7 +1050,8 @@ class SQLCompiler(engine.Compiled): if c.key in parameters: value = parameters[c.key] if sql._is_literal(value): - value = self._create_crud_bind_param(c, value, required=value is required) + value = self._create_crud_bind_param( + c, value, required=value is required) else: self.postfetch.append(c) value = self.process(value.self_group()) @@ -1029,10 +1074,15 @@ class SQLCompiler(engine.Compiled): values.append((c, proc)) self.returning.append(c) elif c.default.is_clause_element: - values.append((c, self.process(c.default.arg.self_group()))) + values.append( + (c, + self.process(c.default.arg.self_group())) + ) self.returning.append(c) else: - values.append((c, self._create_crud_bind_param(c, None))) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) self.prefetch.append(c) else: self.returning.append(c) @@ -1043,9 +1093,12 @@ class SQLCompiler(engine.Compiled): self.dialect.supports_sequences or not c.default.is_sequence ) - ) or self.dialect.preexecute_autoincrement_sequences: + ) or \ + self.dialect.preexecute_autoincrement_sequences: - values.append((c, self._create_crud_bind_param(c, None))) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) self.prefetch.append(c) elif c.default is not None: @@ -1056,13 +1109,17 @@ class SQLCompiler(engine.Compiled): if not c.primary_key: self.postfetch.append(c) elif c.default.is_clause_element: - values.append((c, self.process(c.default.arg.self_group()))) + values.append( + (c, self.process(c.default.arg.self_group())) + ) if not c.primary_key: # dont add primary key column to postfetch self.postfetch.append(c) else: - values.append((c, self._create_crud_bind_param(c, None))) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) self.prefetch.append(c) elif c.server_default is not None: if not c.primary_key: @@ -1071,10 +1128,14 @@ class SQLCompiler(engine.Compiled): elif self.isupdate: if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: - values.append((c, self.process(c.onupdate.arg.self_group()))) + values.append( + (c, self.process(c.onupdate.arg.self_group())) + ) self.postfetch.append(c) else: - values.append((c, self._create_crud_bind_param(c, None))) + values.append( + (c, self._create_crud_bind_param(c, None)) + ) self.prefetch.append(c) elif c.server_onupdate is not None: self.postfetch.append(c) @@ -1089,13 +1150,15 @@ class SQLCompiler(engine.Compiled): if delete_stmt._returning: self.returning = delete_stmt._returning if self.returning_precedes_values: - text += " " + self.returning_clause(delete_stmt, delete_stmt._returning) + text += " " + self.returning_clause( + delete_stmt, delete_stmt._returning) if delete_stmt._whereclause is not None: text += " WHERE " + self.process(delete_stmt._whereclause) if self.returning and not self.returning_precedes_values: - text += " " + self.returning_clause(delete_stmt, delete_stmt._returning) + text += " " + self.returning_clause( + delete_stmt, delete_stmt._returning) self.stack.pop(-1) @@ -1105,17 +1168,19 @@ class SQLCompiler(engine.Compiled): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_rollback_to_savepoint(self, savepoint_stmt): - return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + return "ROLLBACK TO SAVEPOINT %s" % \ + self.preparer.format_savepoint(savepoint_stmt) def visit_release_savepoint(self, savepoint_stmt): - return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + return "RELEASE SAVEPOINT %s" % \ + self.preparer.format_savepoint(savepoint_stmt) class DDLCompiler(engine.Compiled): @util.memoized_property def sql_compiler(self): - return self.dialect.statement_compiler(self.dialect, self.statement) + return self.dialect.statement_compiler(self.dialect, None) @property def preparer(self): @@ -1161,11 +1226,13 @@ class DDLCompiler(engine.Compiled): separator = ", \n" text += "\t" + self.get_column_specification( column, - first_pk=column.primary_key and not first_pk + first_pk=column.primary_key and \ + not first_pk ) if column.primary_key: first_pk = True - const = " ".join(self.process(constraint) for constraint in column.constraints) + const = " ".join(self.process(constraint) \ + for constraint in column.constraints) if const: text += " " + const @@ -1184,10 +1251,12 @@ class DDLCompiler(engine.Compiled): if table.primary_key: constraints.append(table.primary_key) - constraints.extend([c for c in table.constraints if c is not table.primary_key]) + constraints.extend([c for c in table.constraints + if c is not table.primary_key]) return ", \n\t".join(p for p in - (self.process(constraint) for constraint in constraints + (self.process(constraint) + for constraint in constraints if ( constraint._create_rule is None or constraint._create_rule(self)) @@ -1230,7 +1299,8 @@ class DDLCompiler(engine.Compiled): def visit_drop_index(self, drop): index = drop.element return "\nDROP INDEX " + \ - self.preparer.quote(self._index_identifier(index.name), index.quote) + self.preparer.quote( + self._index_identifier(index.name), index.quote) def visit_add_constraint(self, create): preparer = self.preparer @@ -1240,7 +1310,8 @@ class DDLCompiler(engine.Compiled): ) def visit_create_sequence(self, create): - text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) + text = "CREATE SEQUENCE %s" % \ + self.preparer.format_sequence(create.element) if create.element.increment is not None: text += " INCREMENT BY %d" % create.element.increment if create.element.start is not None: @@ -1248,7 +1319,8 @@ class DDLCompiler(engine.Compiled): return text def visit_drop_sequence(self, drop): - return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + return "DROP SEQUENCE %s" % \ + self.preparer.format_sequence(drop.element) def visit_drop_constraint(self, drop): preparer = self.preparer @@ -1301,7 +1373,8 @@ class DDLCompiler(engine.Compiled): return '' text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) text += "PRIMARY KEY " text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) for c in constraint) @@ -1318,7 +1391,8 @@ class DDLCompiler(engine.Compiled): text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( ', '.join(preparer.quote(f.parent.name, f.parent.quote) for f in constraint._elements.values()), - self.define_constraint_remote_table(constraint, remote_table, preparer), + self.define_constraint_remote_table( + constraint, remote_table, preparer), ', '.join(preparer.quote(f.column.name, f.column.quote) for f in constraint._elements.values()) ) @@ -1334,8 +1408,11 @@ class DDLCompiler(engine.Compiled): def visit_unique_constraint(self, constraint): text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) - text += "UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)) + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + text += "UNIQUE (%s)" % ( + ', '.join(self.preparer.quote(c.name, c.quote) + for c in constraint)) text += self.define_constraint_deferrability(constraint) return text @@ -1373,9 +1450,12 @@ class GenericTypeCompiler(engine.TypeCompiler): if type_.precision is None: return "NUMERIC" elif type_.scale is None: - return "NUMERIC(%(precision)s)" % {'precision': type_.precision} + return "NUMERIC(%(precision)s)" % \ + {'precision': type_.precision} else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale} + return "NUMERIC(%(precision)s, %(scale)s)" % \ + {'precision': type_.precision, + 'scale' : type_.scale} def visit_DECIMAL(self, type_): return "DECIMAL" @@ -1499,7 +1579,8 @@ class IdentifierPreparer(object): Character that begins a delimited identifier. final_quote - Character that ends a delimited identifier. Defaults to `initial_quote`. + Character that ends a delimited identifier. Defaults to + `initial_quote`. omit_schema Prevent prepending schema name. Useful for databases that do @@ -1539,7 +1620,9 @@ class IdentifierPreparer(object): quoting behavior. """ - return self.initial_quote + self._escape_identifier(value) + self.final_quote + return self.initial_quote + \ + self._escape_identifier(value) + \ + self.final_quote def _requires_quotes(self, value): """Return True if the given identifier requires quoting.""" @@ -1574,8 +1657,10 @@ class IdentifierPreparer(object): def format_sequence(self, sequence, use_schema=True): name = self.quote(sequence.name, sequence.quote) - if not self.omit_schema and use_schema and sequence.schema is not None: - name = self.quote_schema(sequence.schema, sequence.quote) + "." + name + if not self.omit_schema and use_schema and \ + sequence.schema is not None: + name = self.quote_schema(sequence.schema, sequence.quote) + \ + "." + name return name def format_label(self, label, name=None): @@ -1596,24 +1681,33 @@ class IdentifierPreparer(object): if name is None: name = table.name result = self.quote(name, table.quote) - if not self.omit_schema and use_schema and getattr(table, "schema", None): - result = self.quote_schema(table.schema, table.quote_schema) + "." + result + if not self.omit_schema and use_schema \ + and getattr(table, "schema", None): + result = self.quote_schema(table.schema, table.quote_schema) + \ + "." + result return result - def format_column(self, column, use_table=False, name=None, table_name=None): + def format_column(self, column, use_table=False, + name=None, table_name=None): """Prepare a quoted column name.""" if name is None: name = column.name if not getattr(column, 'is_literal', False): if use_table: - return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote) + return self.format_table( + column.table, use_schema=False, + name=table_name) + "." + \ + self.quote(name, column.quote) else: return self.quote(name, column.quote) else: - # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted + # literal textual elements get stuck into ColumnClause alot, + # which shouldnt get quoted + if use_table: - return self.format_table(column.table, use_schema=False, name=table_name) + "." + name + return self.format_table(column.table, + use_schema=False, name=table_name) + '.' + name else: return name @@ -1624,7 +1718,8 @@ class IdentifierPreparer(object): # ('database', 'owner', etc.) could override this and return # a longer sequence. - if not self.omit_schema and use_schema and getattr(table, 'schema', None): + if not self.omit_schema and use_schema and \ + getattr(table, 'schema', None): return (self.quote_schema(table.schema, table.quote_schema), self.format_table(table, use_schema=False)) else: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index b8c06cb081..bc36e888c6 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1451,10 +1451,10 @@ class ClauseElement(Visitable): bind = self.bind else: dialect = default.DefaultDialect() - compiler = self._compiler(dialect, bind=bind, **kw) - compiler.compile() - return compiler - + c= self._compiler(dialect, bind=bind, **kw) + #c.string = c.process(c.statement) + return c + def _compiler(self, dialect, **kw): """Return a compiler appropriate for this ClauseElement, given a Dialect.""" diff --git a/test/dialect/test_mysql.py b/test/dialect/test_mysql.py index 02b888fed9..0e0f92d3c4 100644 --- a/test/dialect/test_mysql.py +++ b/test/dialect/test_mysql.py @@ -188,8 +188,7 @@ class TypesTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): table_args.append(Column('c%s' % index, type_(*args, **kw))) numeric_table = Table(*table_args) - gen = testing.db.dialect.ddl_compiler( - testing.db.dialect, numeric_table) + gen = testing.db.dialect.ddl_compiler(testing.db.dialect, None) for col in numeric_table.c: index = int(col.name[1:]) @@ -277,8 +276,7 @@ class TypesTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): table_args.append(Column('c%s' % index, type_(*args, **kw))) charset_table = Table(*table_args) - gen = testing.db.dialect.ddl_compiler(testing.db.dialect, - charset_table) + gen = testing.db.dialect.ddl_compiler(testing.db.dialect, None) for col in charset_table.c: index = int(col.name[1:]) @@ -1471,5 +1469,6 @@ class MatchTest(TestBase, AssertsCompiledSQL): def colspec(c): - return testing.db.dialect.ddl_compiler(testing.db.dialect, c.table).get_column_specification(c) + return testing.db.dialect.ddl_compiler( + testing.db.dialect, None).get_column_specification(c) diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py index e0e121242e..790bc23bc7 100644 --- a/test/engine/test_ddlevents.py +++ b/test/engine/test_ddlevents.py @@ -408,7 +408,7 @@ class DDLExecutionTest(TestBase): """test the escaping of % characters in the DDL construct.""" default_from = testing.db.dialect.statement_compiler( - testing.db.dialect, DDL("")).default_from() + testing.db.dialect, None).default_from() eq_( testing.db.execute(