:class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders
type specification strings.
-To generate user-defined SQL strings, see
+To generate user-defined SQL strings, see
:module:`~sqlalchemy.ext.compiler`.
"""
driver/DB enforces this
"""
- def __init__(self, dialect, statement, column_keys=None,
+ def __init__(self, dialect, statement, column_keys=None,
inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
self.positiontup = []
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
- # collect CTEs to tack on top of a SELECT
- self.ctes = util.OrderedDict()
- self.ctes_recursive = False
- if self.positional:
- self.cte_positional = []
+ self.ctes = None
# an IdentifierPreparer that formats the quoting of identifiers
self.preparer = dialect.identifier_preparer
if self.positional and dialect.paramstyle == 'numeric':
self._apply_numbered_params()
+ @util.memoized_instancemethod
+ def _init_cte_state(self):
+ """Initialize collections related to CTEs only if
+ a CTE is located, to save on the overhead of
+ these collections otherwise.
+
+ """
+ # collect CTEs to tack on top of a SELECT
+ self.ctes = util.OrderedDict()
+ self.ctes_by_name = {}
+ self.ctes_recursive = False
+ if self.positional:
+ self.cte_positional = []
+
def _apply_numbered_params(self):
poscount = itertools.count(1)
self.string = re.sub(
- r'\[_POSITION\]',
- lambda m:str(util.next(poscount)),
+ r'\[_POSITION\]',
+ lambda m:str(util.next(poscount)),
self.string)
@util.memoized_property
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
- "in parameter group %d" %
+ "in parameter group %d" %
(bindparam.key, _group_number))
else:
raise exc.InvalidRequestError(
- "A value is required for bind parameter %r"
+ "A value is required for bind parameter %r"
% bindparam.key)
else:
pd[name] = bindparam.effective_value
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
- "in parameter group %d" %
+ "in parameter group %d" %
(bindparam.key, _group_number))
else:
raise exc.InvalidRequestError(
- "A value is required for bind parameter %r"
+ "A value is required for bind parameter %r"
% bindparam.key)
pd[self.bind_names[bindparam]] = bindparam.effective_value
return pd
@property
def params(self):
- """Return the bind param dictionary embedded into this
+ """Return the bind param dictionary embedded into this
compiled object, for those values that are present."""
return self.construct_params(_check=False)
def visit_grouping(self, grouping, asfrom=False, **kwargs):
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
- def visit_label(self, label, result_map=None,
- within_label_clause=False,
+ def visit_label(self, label, result_map=None,
+ within_label_clause=False,
within_columns_clause=False, **kw):
# only render labels within the columns clause
# or ORDER BY clause of a select. dialect-specific compilers
if result_map is not None:
result_map[labelname.lower()] = (
- label.name,
- (label, label.element, labelname, ) +
+ label.name,
+ (label, label.element, labelname, ) +
label._alt_names,
label.type)
- return label.element._compiler_dispatch(self,
+ return label.element._compiler_dispatch(self,
within_columns_clause=True,
- within_label_clause=True,
+ within_label_clause=True,
**kw) + \
OPERATORS[operators.as_] + \
self.preparer.format_label(label, labelname)
else:
- return label.element._compiler_dispatch(self,
- within_columns_clause=False,
+ return label.element._compiler_dispatch(self,
+ within_columns_clause=False,
**kw)
def visit_column(self, column, result_map=None, **kwargs):
name = self._truncated_identifier("colident", name)
if result_map is not None:
- result_map[name.lower()] = (orig_name,
- (column, name, column.key),
+ result_map[name.lower()] = (orig_name,
+ (column, name, column.key),
column.type)
if is_literal:
else:
if table.schema:
schema_prefix = self.preparer.quote_schema(
- table.schema,
+ table.schema,
table.quote_schema) + '.'
else:
schema_prefix = ''
else:
sep = OPERATORS[clauselist.operator]
return sep.join(
- s for s in
- (c._compiler_dispatch(self, **kwargs)
+ s for s in
+ (c._compiler_dispatch(self, **kwargs)
for c in clauselist.clauses)
if s)
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
- return "EXTRACT(%s FROM %s)" % (field,
+ return "EXTRACT(%s FROM %s)" % (field,
extract.expr._compiler_dispatch(self, **kwargs))
def visit_function(self, func, result_map=None, **kwargs):
def function_argspec(self, func, **kwargs):
return func.clause_expr._compiler_dispatch(self, **kwargs)
- def visit_compound_select(self, cs, asfrom=False,
+ def visit_compound_select(self, cs, asfrom=False,
parens=True, compound_index=1, **kwargs):
entry = self.stack and self.stack[-1] or {}
self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
keyword = self.compound_keywords.get(cs.keyword)
text = (" " + keyword + " ").join(
- (c._compiler_dispatch(self,
- asfrom=asfrom, parens=False,
+ (c._compiler_dispatch(self,
+ asfrom=asfrom, parens=False,
compound_index=i, **kwargs)
for i, c in enumerate(cs.selects))
)
return self._operator_dispatch(binary.operator,
binary,
- lambda opstr: binary.left._compiler_dispatch(self, **kw) +
- opstr +
+ lambda opstr: binary.left._compiler_dispatch(self, **kw) +
+ opstr +
binary.right._compiler_dispatch(
self, **kw),
**kw
def visit_like_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return '%s LIKE %s' % (
- binary.left._compiler_dispatch(self, **kw),
+ binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
- + (escape and
+ + (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
def visit_notlike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return '%s NOT LIKE %s' % (
- binary.left._compiler_dispatch(self, **kw),
+ binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
- + (escape and
+ + (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
def visit_ilike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) LIKE lower(%s)' % (
- binary.left._compiler_dispatch(self, **kw),
+ binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
- + (escape and
+ + (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
def visit_notilike_op(self, binary, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) NOT LIKE lower(%s)' % (
- binary.left._compiler_dispatch(self, **kw),
+ binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
- + (escape and
+ + (escape and
(' ESCAPE ' + self.render_literal_value(escape, None))
or '')
"bindparam() name '%s' is reserved "
"for automatic usage in the VALUES or SET "
"clause of this "
- "insert/update statement. Please use a "
+ "insert/update statement. Please use a "
"name other than column name when using bindparam() "
"with insert() or update() (for example, 'b_%s')."
% (bindparam.key, bindparam.key)
self.positiontup.append(name)
return self.bindtemplate % {'name':name}
- def visit_cte(self, cte, asfrom=False, ashint=False,
- fromhints=None, **kwargs):
+ def visit_cte(self, cte, asfrom=False, ashint=False,
+ fromhints=None,
+ **kwargs):
+ self._init_cte_state()
if self.positional:
kwargs['positional_names'] = self.cte_positional
cte_name = self._truncated_identifier("alias", cte.name)
else:
cte_name = cte.name
+
+ if cte_name in self.ctes_by_name:
+ existing_cte = self.ctes_by_name[cte_name]
+ # we've generated a same-named CTE that we are enclosed in,
+ # or this is the same CTE. just return the name.
+ if cte in existing_cte._restates or cte is existing_cte:
+ return cte_name
+ elif existing_cte in cte._restates:
+ # we've generated a same-named CTE that is
+ # enclosed in us - we take precedence, so
+ # discard the text for the "inner".
+ del self.ctes[existing_cte]
+ else:
+ raise exc.CompileError(
+ "Multiple, unrelated CTEs found with "
+ "the same name: %r" %
+ cte_name)
+
+ self.ctes_by_name[cte_name] = cte
+
if cte.cte_alias:
if isinstance(cte.cte_alias, sql._truncated_label):
cte_alias = self._truncated_identifier("alias", cte.cte_alias)
col_source = cte.original.selects[0]
else:
assert False
- recur_cols = [c for c in
+ recur_cols = [c for c in
util.unique_list(col_source.inner_columns)
if c is not None]
text += "(%s)" % (", ".join(
- self.preparer.format_column(ident)
+ self.preparer.format_column(ident)
for ident in recur_cols))
text += " AS \n" + \
cte.original._compiler_dispatch(
return self.preparer.format_alias(cte, cte_name)
return text
- def visit_alias(self, alias, asfrom=False, ashint=False,
+ def visit_alias(self, alias, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
if isinstance(alias.name, sql._truncated_label):
if ashint:
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
- ret = alias.original._compiler_dispatch(self,
+ ret = alias.original._compiler_dispatch(self,
asfrom=True, **kwargs) + \
" AS " + \
self.preparer.format_alias(alias, alias_name)
select.use_labels and \
column._label:
return _CompileLabel(
- column,
- column._label,
+ column,
+ column._label,
alt_names=(column._key_label, )
)
not column.is_literal and \
column.table is not None and \
not isinstance(column.table, sql.Select):
- return _CompileLabel(column, sql._as_truncated(column.name),
+ return _CompileLabel(column, sql._as_truncated(column.name),
alt_names=(column.key,))
- elif not isinstance(column,
+ elif not isinstance(column,
(sql._UnaryExpression, sql._TextClause)) \
and (not hasattr(column, 'name') or \
isinstance(column, sql.Function)):
def get_crud_hint_text(self, table, text):
return None
- def visit_select(self, select, asfrom=False, parens=True,
- iswrapper=False, fromhints=None,
- compound_index=1,
+ def visit_select(self, select, asfrom=False, parens=True,
+ iswrapper=False, fromhints=None,
+ compound_index=1,
positional_names=None, **kwargs):
entry = self.stack and self.stack[-1] or {}
: iswrapper})
if compound_index==1 and not entry or entry.get('iswrapper', False):
- column_clause_args = {'result_map':self.result_map,
+ column_clause_args = {'result_map':self.result_map,
'positional_names':positional_names}
else:
column_clause_args = {'positional_names':positional_names}
self.label_select_column(select, co, asfrom=asfrom).\
_compiler_dispatch(self,
within_columns_clause=True,
- **column_clause_args)
+ **column_clause_args)
for co in util.unique_list(select.inner_columns)
]
if c is not None
(from_, hinttext % {
'name':from_._compiler_dispatch(
self, ashint=True)
- })
- for (from_, dialect), hinttext in
- select._hints.iteritems()
+ })
+ for (from_, dialect), hinttext in
+ select._hints.iteritems()
if dialect in ('*', self.dialect.name)
])
hint_text = self.get_select_hint_text(byfrom)
if select._prefixes:
text += " ".join(
- x._compiler_dispatch(self, **kwargs)
+ x._compiler_dispatch(self, **kwargs)
for x in select._prefixes) + " "
text += self.get_select_precolumns(select)
text += ', '.join(inner_columns)
text += " \nFROM "
if select._hints:
- text += ', '.join([f._compiler_dispatch(self,
- asfrom=True, fromhints=byfrom,
- **kwargs)
+ text += ', '.join([f._compiler_dispatch(self,
+ asfrom=True, fromhints=byfrom,
+ **kwargs)
for f in froms])
else:
- text += ', '.join([f._compiler_dispatch(self,
- asfrom=True, **kwargs)
+ text += ', '.join([f._compiler_dispatch(self,
+ asfrom=True, **kwargs)
for f in froms])
else:
text += self.default_from()
text += " OFFSET " + self.process(sql.literal(select._offset))
return text
- def visit_table(self, table, asfrom=False, ashint=False,
+ def visit_table(self, table, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
if getattr(table, "schema", None):
def visit_join(self, join, asfrom=False, **kwargs):
return (
- join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
- (join.isouter and " LEFT OUTER JOIN " or " JOIN ") +
- join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
- " ON " +
+ join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
+ (join.isouter and " LEFT OUTER JOIN " or " JOIN ") +
+ join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
+ " ON " +
join.onclause._compiler_dispatch(self, **kwargs)
)
not self.dialect.supports_default_values and \
not self.dialect.supports_empty_insert:
raise exc.CompileError("The version of %s you are using does "
- "not support empty inserts." %
+ "not support empty inserts." %
self.dialect.name)
preparer = self.preparer
if insert_stmt._hints:
dialect_hints = dict([
(table, hint_text)
- for (table, dialect), hint_text in
+ for (table, dialect), hint_text in
insert_stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if insert_stmt.table in dialect_hints:
text += " " + self.get_crud_hint_text(
- insert_stmt.table,
+ insert_stmt.table,
dialect_hints[insert_stmt.table]
)
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
return None
- def update_tables_clause(self, update_stmt, from_table,
+ def update_tables_clause(self, update_stmt, from_table,
extra_froms, **kw):
"""Provide a hook to override the initial table clause
in an UPDATE statement.
"""
return self.preparer.format_table(from_table)
- def update_from_clause(self, update_stmt,
- from_table, extra_froms,
+ def update_from_clause(self, update_stmt,
+ from_table, extra_froms,
from_hints,
**kw):
- """Provide a hook to override the generation of an
+ """Provide a hook to override the generation of an
UPDATE..FROM clause.
MySQL and MSSQL override this.
"""
return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
+ t._compiler_dispatch(self, asfrom=True,
+ fromhints=from_hints, **kw)
for t in extra_froms)
def visit_update(self, update_stmt, **kw):
colparams = self._get_colparams(update_stmt, extra_froms)
text = "UPDATE " + self.update_tables_clause(
- update_stmt,
- update_stmt.table,
+ update_stmt,
+ update_stmt.table,
extra_froms, **kw)
if update_stmt._hints:
dialect_hints = dict([
(table, hint_text)
- for (table, dialect), hint_text in
+ for (table, dialect), hint_text in
update_stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if update_stmt.table in dialect_hints:
text += " " + self.get_crud_hint_text(
- update_stmt.table,
+ update_stmt.table,
dialect_hints[update_stmt.table]
)
else:
text += ' SET '
if extra_froms and self.render_table_with_column_in_update_from:
text += ', '.join(
- self.visit_column(c[0]) +
+ self.visit_column(c[0]) +
'=' + c[1] for c in colparams
)
else:
text += ', '.join(
- self.preparer.quote(c[0].name, c[0].quote) +
+ self.preparer.quote(c[0].name, c[0].quote) +
'=' + c[1] for c in colparams
)
if extra_froms:
extra_from_text = self.update_from_clause(
- update_stmt,
- update_stmt.table,
- extra_froms,
+ update_stmt,
+ update_stmt.table,
+ extra_froms,
dialect_hints, **kw)
if extra_from_text:
text += " " + extra_from_text
return text
def _create_crud_bind_param(self, col, value, required=False):
- bindparam = sql.bindparam(col.key, value,
+ bindparam = sql.bindparam(col.key, value,
type_=col.type, required=required)
bindparam._is_crud = True
return bindparam._compiler_dispatch(self)
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
return [
- (c, self._create_crud_bind_param(c,
- None, required=True))
+ (c, self._create_crud_bind_param(c,
+ None, required=True))
for c in stmt.table.columns
]
parameters = {}
else:
parameters = dict((sql._column_as_key(key), required)
- for key in self.column_keys
- if not stmt.parameters or
+ for key in self.column_keys
+ if not stmt.parameters or
key not in stmt.parameters)
if stmt.parameters is not None:
postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
check_columns = {}
- # special logic that only occurs for multi-table UPDATE
+ # special logic that only occurs for multi-table UPDATE
# statements
if extra_tables and stmt.parameters:
assert self.isupdate
value = self.process(value.self_group())
values.append((c, value))
# determine tables which are actually
- # to be updated - process onupdate and
+ # to be updated - process onupdate and
# server_onupdate for these
for t in affected_tables:
for c in t.c:
self.postfetch.append(c)
# iterating through columns at the top to maintain ordering.
- # otherwise we might iterate through individual sets of
+ # otherwise we might iterate through individual sets of
# "defaults", "primary key cols", etc.
for c in stmt.table.columns:
if c.key in parameters and c.key not in check_columns:
if c.primary_key and \
need_pks and \
(
- implicit_returning or
- not postfetch_lastrowid or
+ implicit_returning or
+ not postfetch_lastrowid or
c is not stmt.table._autoincrement_column
):
).difference(check_columns)
if check:
util.warn(
- "Unconsumed column names: %s" %
+ "Unconsumed column names: %s" %
(", ".join(check))
)
if delete_stmt._hints:
dialect_hints = dict([
(table, hint_text)
- for (table, dialect), hint_text in
+ for (table, dialect), hint_text in
delete_stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if delete_stmt.table in dialect_hints:
text += " " + self.get_crud_hint_text(
- delete_stmt.table,
+ delete_stmt.table,
dialect_hints[delete_stmt.table]
)
else:
text += separator
separator = ", \n"
text += "\t" + self.get_column_specification(
- column,
+ column,
first_pk=column.primary_key and \
not first_pk
)
text += " " + const
except exc.CompileError, ce:
# Py3K
- #raise exc.CompileError("(in table '%s', column '%s'): %s"
+ #raise exc.CompileError("(in table '%s', column '%s'): %s"
# % (
- # table.description,
- # column.name,
+ # table.description,
+ # column.name,
# ce.args[0]
# )) from ce
# Py2K
- raise exc.CompileError("(in table '%s', column '%s'): %s"
+ raise exc.CompileError("(in table '%s', column '%s'): %s"
% (
- table.description,
+ table.description,
column.name,
ce.args[0]
)), None, sys.exc_info()[2]
if table.primary_key:
constraints.append(table.primary_key)
- constraints.extend([c for c in table._sorted_constraints
+ constraints.extend([c for c in table._sorted_constraints
if c is not table.primary_key])
return ", \n\t".join(p for p in
- (self.process(constraint)
- for constraint in constraints
+ (self.process(constraint)
+ for constraint in constraints
if (
constraint._create_rule is None or
constraint._create_rule(self))
and (
- not self.dialect.supports_alter or
+ not self.dialect.supports_alter or
not getattr(constraint, 'use_alter', False)
)) if p is not None
)
if index.unique:
text += "UNIQUE "
text += "INDEX %s ON %s (%s)" \
- % (preparer.quote(self._index_identifier(index.name),
+ % (preparer.quote(self._index_identifier(index.name),
index.quote),
preparer.format_table(index.table),
', '.join(preparer.quote(c.name, c.quote)
text += "CONSTRAINT %s " % \
self.preparer.format_constraint(constraint)
text += "UNIQUE (%s)" % (
- ', '.join(self.preparer.quote(c.name, c.quote)
+ ', '.join(self.preparer.quote(c.name, c.quote)
for c in constraint))
text += self.define_constraint_deferrability(constraint)
return text
{'precision': type_.precision}
else:
return "NUMERIC(%(precision)s, %(scale)s)" % \
- {'precision': type_.precision,
+ {'precision': type_.precision,
'scale' : type_.scale}
def visit_DECIMAL(self, type_):
def visit_large_binary(self, type_):
return self.visit_BLOB(type_)
- def visit_boolean(self, type_):
+ def visit_boolean(self, type_):
return self.visit_BOOLEAN(type_)
- def visit_time(self, type_):
+ def visit_time(self, type_):
return self.visit_TIME(type_)
- def visit_datetime(self, type_):
+ def visit_datetime(self, type_):
return self.visit_DATETIME(type_)
- def visit_date(self, type_):
+ def visit_date(self, type_):
return self.visit_DATE(type_)
- def visit_big_integer(self, type_):
+ def visit_big_integer(self, type_):
return self.visit_BIGINT(type_)
- def visit_small_integer(self, type_):
+ def visit_small_integer(self, type_):
return self.visit_SMALLINT(type_)
- def visit_integer(self, type_):
+ def visit_integer(self, type_):
return self.visit_INTEGER(type_)
def visit_real(self, type_):
def visit_float(self, type_):
return self.visit_FLOAT(type_)
- def visit_numeric(self, type_):
+ def visit_numeric(self, type_):
return self.visit_NUMERIC(type_)
- def visit_string(self, type_):
+ def visit_string(self, type_):
return self.visit_VARCHAR(type_)
- def visit_unicode(self, type_):
+ def visit_unicode(self, type_):
return self.visit_VARCHAR(type_)
- def visit_text(self, type_):
+ def visit_text(self, type_):
return self.visit_TEXT(type_)
- def visit_unicode_text(self, type_):
+ def visit_unicode_text(self, type_):
return self.visit_TEXT(type_)
def visit_enum(self, type_):
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
- def __init__(self, dialect, initial_quote='"',
+ def __init__(self, dialect, initial_quote='"',
final_quote=None, escape_quote='"', omit_schema=False):
"""Construct a new ``IdentifierPreparer`` object.
def quote_schema(self, schema, force):
"""Quote a schema.
- Subclasses should override this to provide database-dependent
+ Subclasses should override this to provide database-dependent
quoting behavior.
"""
return self.quote(schema, force)
return self.quote(name, quote)
- def format_column(self, column, use_table=False,
+ def format_column(self, column, use_table=False,
name=None, table_name=None):
"""Prepare a quoted column name."""
if not getattr(column, 'is_literal', False):
if use_table:
return self.format_table(
- column.table, use_schema=False,
+ column.table, use_schema=False,
name=table_name) + "." + \
self.quote(name, column.quote)
else: