From 6480e75a7c53db5fc31bbe87a1c68535caf61143 Mon Sep 17 00:00:00 2001 From: Diana Clarke Date: Mon, 19 Nov 2012 11:32:00 -0500 Subject: [PATCH] just a pep8 pass of lib/sqlalchemy/sql/ --- lib/sqlalchemy/sql/__init__.py | 1 - lib/sqlalchemy/sql/compiler.py | 127 +++++++++++++-------------- lib/sqlalchemy/sql/expression.py | 146 +++++++++++++++++++++++++------ lib/sqlalchemy/sql/functions.py | 28 +++++- lib/sqlalchemy/sql/operators.py | 45 ++++++++-- lib/sqlalchemy/sql/util.py | 57 ++++++++---- lib/sqlalchemy/sql/visitors.py | 44 +++++++--- 7 files changed, 322 insertions(+), 126 deletions(-) diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index e25bb31600..d0ffd8076e 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -66,4 +66,3 @@ from .visitors import ClauseVisitor __tmp = locals().keys() __all__ = sorted([i for i in __tmp if not i.startswith('__')]) - diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 74127c86a0..102b44a7e3 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -57,59 +57,59 @@ BIND_PARAMS = re.compile(r'(? ', - operators.ge : ' >= ', - operators.eq : ' = ', - operators.concat_op : ' || ', - operators.between_op : ' BETWEEN ', - operators.match_op : ' MATCH ', - operators.in_op : ' IN ', - operators.notin_op : ' NOT IN ', - operators.comma_op : ', ', - operators.from_ : ' FROM ', - operators.as_ : ' AS ', - operators.is_ : ' IS ', - operators.isnot : ' IS NOT ', - operators.collate : ' COLLATE ', + operators.mod: ' % ', + operators.truediv: ' / ', + operators.neg: '-', + operators.lt: ' < ', + operators.le: ' <= ', + operators.ne: ' != ', + operators.gt: ' > ', + operators.ge: ' >= ', + operators.eq: ' = ', + operators.concat_op: ' || ', + operators.between_op: ' BETWEEN ', + operators.match_op: ' MATCH ', + operators.in_op: ' IN ', + operators.notin_op: ' NOT IN ', + operators.comma_op: ', ', + operators.from_: ' FROM ', + operators.as_: ' AS ', + operators.is_: ' IS ', + operators.isnot: ' IS NOT ', + operators.collate: ' COLLATE ', # unary - operators.exists : 'EXISTS ', - operators.distinct_op : 'DISTINCT ', - operators.inv : 'NOT ', + operators.exists: 'EXISTS ', + operators.distinct_op: 'DISTINCT ', + operators.inv: 'NOT ', # modifiers - operators.desc_op : ' DESC', - operators.asc_op : ' ASC', - operators.nullsfirst_op : ' NULLS FIRST', - operators.nullslast_op : ' NULLS LAST', + operators.desc_op: ' DESC', + operators.asc_op: ' ASC', + operators.nullsfirst_op: ' NULLS FIRST', + operators.nullslast_op: ' NULLS LAST', } FUNCTIONS = { - functions.coalesce : 'coalesce%(expr)s', + functions.coalesce: 'coalesce%(expr)s', functions.current_date: 'CURRENT_DATE', functions.current_time: 'CURRENT_TIME', functions.current_timestamp: 'CURRENT_TIMESTAMP', @@ -118,7 +118,7 @@ FUNCTIONS = { functions.localtimestamp: 'LOCALTIMESTAMP', functions.random: 'random%(expr)s', functions.sysdate: 'sysdate', - functions.session_user :'SESSION_USER', + functions.session_user: 'SESSION_USER', functions.user: 'USER' } @@ -141,14 +141,15 @@ EXTRACT_MAP = { } COMPOUND_KEYWORDS = { - sql.CompoundSelect.UNION : 'UNION', - sql.CompoundSelect.UNION_ALL : 'UNION ALL', - sql.CompoundSelect.EXCEPT : 'EXCEPT', - sql.CompoundSelect.EXCEPT_ALL : 'EXCEPT ALL', - sql.CompoundSelect.INTERSECT : 'INTERSECT', - sql.CompoundSelect.INTERSECT_ALL : 'INTERSECT ALL' + sql.CompoundSelect.UNION: 'UNION', + sql.CompoundSelect.UNION_ALL: 'UNION ALL', + sql.CompoundSelect.EXCEPT: 'EXCEPT', + sql.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL', + sql.CompoundSelect.INTERSECT: 'INTERSECT', + sql.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL' } + class _CompileLabel(visitors.Visitable): """lightweight label object which acts as an expression.Label.""" @@ -297,16 +298,16 @@ class SQLCompiler(engine.Compiled): poscount = itertools.count(1) self.string = re.sub( r'\[_POSITION\]', - lambda m:str(util.next(poscount)), + lambda m: str(util.next(poscount)), self.string) @util.memoized_property def _bind_processors(self): return dict( (key, value) for key, value in - ( (self.bind_names[bindparam], + ((self.bind_names[bindparam], bindparam.type._cached_bind_processor(self.dialect)) - for bindparam in self.bind_names ) + for bindparam in self.bind_names) if value is not None ) @@ -750,7 +751,6 @@ class SQLCompiler(engine.Compiled): (' ESCAPE ' + self.render_literal_value(escape, None)) or '') - def visit_bindparam(self, bindparam, within_columns_clause=False, literal_binds=False, skip_bind_expression=False, @@ -873,7 +873,7 @@ class SQLCompiler(engine.Compiled): positional_names.append(name) else: self.positiontup.append(name) - return self.bindtemplate % {'name':name} + return self.bindtemplate % {'name': name} def visit_cte(self, cte, asfrom=False, ashint=False, fromhints=None, @@ -1240,7 +1240,7 @@ class SQLCompiler(engine.Compiled): def limit_clause(self, select): text = "" if select._limit is not None: - text += "\n LIMIT " + self.process(sql.literal(select._limit)) + text += "\n LIMIT " + self.process(sql.literal(select._limit)) if select._offset is not None: if select._limit is None: text += "\n LIMIT -1" @@ -1449,7 +1449,6 @@ class SQLCompiler(engine.Compiled): bindparam._is_crud = True return bindparam._compiler_dispatch(self) - def _get_colparams(self, stmt, extra_tables=None): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -1505,7 +1504,6 @@ class SQLCompiler(engine.Compiled): values.append((k, v)) - need_pks = self.isinsert and \ not self.inline and \ not stmt._returning @@ -1534,7 +1532,7 @@ class SQLCompiler(engine.Compiled): value = normalized_params[c] if sql._is_literal(value): value = self._create_crud_bind_param( - c, value, required=value is required) + c, value, required=value is required) else: self.postfetch.append(c) value = self.process(value.self_group()) @@ -1775,10 +1773,12 @@ class DDLCompiler(engine.Compiled): return self.sql_compiler.post_process_text(ddl.statement % context) def visit_create_schema(self, create): - return "CREATE SCHEMA " + self.preparer.format_schema(create.element, create.quote) + schema = self.preparer.format_schema(create.element, create.quote) + return "CREATE SCHEMA " + schema def visit_drop_schema(self, drop): - text = "DROP SCHEMA " + self.preparer.format_schema(drop.element, drop.quote) + schema = self.preparer.format_schema(drop.element, drop.quote) + text = "DROP SCHEMA " + schema if drop.cascade: text += " CASCADE" return text @@ -1921,9 +1921,7 @@ class DDLCompiler(engine.Compiled): index_name = schema_name + "." + index_name return index_name - def visit_add_constraint(self, create): - preparer = self.preparer return "ALTER TABLE %s ADD %s" % ( self.preparer.format_table(create.element.table), self.process(create.element) @@ -1943,7 +1941,6 @@ class DDLCompiler(engine.Compiled): self.preparer.format_sequence(drop.element) def visit_drop_constraint(self, drop): - preparer = self.preparer return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( self.preparer.format_table(drop.element.table), self.preparer.format_constraint(drop.element), @@ -2084,7 +2081,7 @@ class GenericTypeCompiler(engine.TypeCompiler): else: return "NUMERIC(%(precision)s, %(scale)s)" % \ {'precision': type_.precision, - 'scale' : type_.scale} + 'scale': type_.scale} def visit_DECIMAL(self, type_): return "DECIMAL" @@ -2152,7 +2149,6 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_BOOLEAN(self, type_): return "BOOLEAN" - def visit_large_binary(self, type_): return self.visit_BLOB(type_) @@ -2210,6 +2206,7 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_user_defined(self, type_): return type_.get_col_spec() + class IdentifierPreparer(object): """Handle quoting and case-folding of identifiers based on options.""" @@ -2388,9 +2385,9 @@ class IdentifierPreparer(object): r'(?:' r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' r'|([^\.]+))(?=\.|$))+' % - { 'initial': initial, - 'final': final, - 'escaped': escaped_final }) + {'initial': initial, + 'final': final, + 'escaped': escaped_final}) return r def unformat_identifiers(self, identifiers): diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 1d3be7de17..3dc8dfea49 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -57,6 +57,7 @@ __all__ = [ PARSE_AUTOCOMMIT = util.symbol('PARSE_AUTOCOMMIT') NO_ARG = util.symbol('NO_ARG') + def nullsfirst(column): """Return a NULLS FIRST ``ORDER BY`` clause element. @@ -71,6 +72,7 @@ def nullsfirst(column): """ return UnaryExpression(column, modifier=operators.nullsfirst_op) + def nullslast(column): """Return a NULLS LAST ``ORDER BY`` clause element. @@ -85,6 +87,7 @@ def nullslast(column): """ return UnaryExpression(column, modifier=operators.nullslast_op) + def desc(column): """Return a descending ``ORDER BY`` clause element. @@ -99,6 +102,7 @@ def desc(column): """ return UnaryExpression(column, modifier=operators.desc_op) + def asc(column): """Return an ascending ``ORDER BY`` clause element. @@ -113,6 +117,7 @@ def asc(column): """ return UnaryExpression(column, modifier=operators.asc_op) + def outerjoin(left, right, onclause=None): """Return an ``OUTER JOIN`` clause element. @@ -137,6 +142,7 @@ def outerjoin(left, right, onclause=None): """ return Join(left, right, onclause, isouter=True) + def join(left, right, onclause=None, isouter=False): """Return a ``JOIN`` clause element (regular inner join). @@ -162,6 +168,7 @@ def join(left, right, onclause=None, isouter=False): """ return Join(left, right, onclause, isouter) + def select(columns=None, whereclause=None, from_obj=[], **kwargs): """Returns a ``SELECT`` clause element. @@ -297,6 +304,7 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs): return Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs) + def subquery(alias, *args, **kwargs): """Return an :class:`.Alias` object derived from a :class:`.Select`. @@ -312,6 +320,7 @@ def subquery(alias, *args, **kwargs): """ return Select(*args, **kwargs).alias(alias) + def insert(table, values=None, inline=False, **kwargs): """Represent an ``INSERT`` statement via the :class:`.Insert` SQL construct. @@ -358,6 +367,7 @@ def insert(table, values=None, inline=False, **kwargs): """ return Insert(table, values, inline=inline, **kwargs) + def update(table, whereclause=None, values=None, inline=False, **kwargs): """Represent an ``UPDATE`` statement via the :class:`.Update` SQL construct. @@ -470,6 +480,7 @@ def update(table, whereclause=None, values=None, inline=False, **kwargs): inline=inline, **kwargs) + def delete(table, whereclause=None, **kwargs): """Represent a ``DELETE`` statement via the :class:`.Delete` SQL construct. @@ -491,6 +502,7 @@ def delete(table, whereclause=None, **kwargs): """ return Delete(table, whereclause, **kwargs) + def and_(*clauses): """Join a list of clauses together using the ``AND`` operator. @@ -503,6 +515,7 @@ def and_(*clauses): return clauses[0] return BooleanClauseList(operator=operators.and_, *clauses) + def or_(*clauses): """Join a list of clauses together using the ``OR`` operator. @@ -515,6 +528,7 @@ def or_(*clauses): return clauses[0] return BooleanClauseList(operator=operators.or_, *clauses) + def not_(clause): """Return a negation of the given clause, i.e. ``NOT(clause)``. @@ -525,6 +539,7 @@ def not_(clause): """ return operators.inv(_literal_as_binds(clause)) + def distinct(expr): """Return a ``DISTINCT`` clause. @@ -541,6 +556,7 @@ def distinct(expr): return UnaryExpression(expr, operator=operators.distinct_op, type_=expr.type) + def between(ctest, cleft, cright): """Return a ``BETWEEN`` predicate clause. @@ -554,6 +570,7 @@ def between(ctest, cleft, cright): ctest = _literal_as_binds(ctest) return ctest.between(cleft, cright) + def case(whens, value=None, else_=None): """Produce a ``CASE`` statement. @@ -608,6 +625,7 @@ def case(whens, value=None, else_=None): return Case(whens, value=value, else_=else_) + def cast(clause, totype, **kwargs): """Return a ``CAST`` function. @@ -624,11 +642,13 @@ def cast(clause, totype, **kwargs): """ return Cast(clause, totype, **kwargs) + def extract(field, expr): """Return the clause ``extract(field FROM expr)``.""" return Extract(field, expr) + def collate(expression, collation): """Return the clause ``expression COLLATE collation``. @@ -648,6 +668,7 @@ def collate(expression, collation): _literal_as_text(collation), operators.collate, type_=expr.type) + def exists(*args, **kwargs): """Return an ``EXISTS`` clause as applied to a :class:`.Select` object. @@ -667,6 +688,7 @@ def exists(*args, **kwargs): """ return Exists(*args, **kwargs) + def union(*selects, **kwargs): """Return a ``UNION`` of multiple selectables. @@ -686,6 +708,7 @@ def union(*selects, **kwargs): """ return CompoundSelect(CompoundSelect.UNION, *selects, **kwargs) + def union_all(*selects, **kwargs): """Return a ``UNION ALL`` of multiple selectables. @@ -705,6 +728,7 @@ def union_all(*selects, **kwargs): """ return CompoundSelect(CompoundSelect.UNION_ALL, *selects, **kwargs) + def except_(*selects, **kwargs): """Return an ``EXCEPT`` of multiple selectables. @@ -721,6 +745,7 @@ def except_(*selects, **kwargs): """ return CompoundSelect(CompoundSelect.EXCEPT, *selects, **kwargs) + def except_all(*selects, **kwargs): """Return an ``EXCEPT ALL`` of multiple selectables. @@ -737,6 +762,7 @@ def except_all(*selects, **kwargs): """ return CompoundSelect(CompoundSelect.EXCEPT_ALL, *selects, **kwargs) + def intersect(*selects, **kwargs): """Return an ``INTERSECT`` of multiple selectables. @@ -753,6 +779,7 @@ def intersect(*selects, **kwargs): """ return CompoundSelect(CompoundSelect.INTERSECT, *selects, **kwargs) + def intersect_all(*selects, **kwargs): """Return an ``INTERSECT ALL`` of multiple selectables. @@ -769,6 +796,7 @@ def intersect_all(*selects, **kwargs): """ return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs) + def alias(selectable, name=None): """Return an :class:`.Alias` object. @@ -826,6 +854,7 @@ def literal(value, type_=None): """ return BindParameter(None, value, type_=type_, unique=True) + def tuple_(*expr): """Return a SQL tuple. @@ -846,6 +875,7 @@ def tuple_(*expr): """ return Tuple(*expr) + def type_coerce(expr, type_): """Coerce the given expression into the given type, on the Python side only. @@ -919,6 +949,7 @@ def label(name, obj): """ return Label(name, obj) + def column(text, type_=None): """Return a textual column clause, as would be in the columns clause of a ``SELECT`` statement. @@ -947,6 +978,7 @@ def column(text, type_=None): """ return ColumnClause(text, type_=type_) + def literal_column(text, type_=None): """Return a textual column expression, as would be in the columns clause of a ``SELECT`` statement. @@ -970,15 +1002,18 @@ def literal_column(text, type_=None): """ return ColumnClause(text, type_=type_, is_literal=True) + def table(name, *columns): """Represent a textual table clause. - The object returned is an instance of :class:`.TableClause`, which represents the - "syntactical" portion of the schema-level :class:`~.schema.Table` object. + The object returned is an instance of :class:`.TableClause`, which + represents the "syntactical" portion of the schema-level + :class:`~.schema.Table` object. It may be used to construct lightweight table constructs. Note that the :func:`~.expression.table` function is not part of - the ``sqlalchemy`` namespace. It must be imported from the ``sql`` package:: + the ``sqlalchemy`` namespace. It must be imported from the + ``sql`` package:: from sqlalchemy.sql import table, column @@ -991,6 +1026,7 @@ def table(name, *columns): """ return TableClause(name, *columns) + def bindparam(key, value=NO_ARG, type_=None, unique=False, required=NO_ARG, quote=None, callable_=None): """Create a bind parameter clause with the given key. @@ -1009,8 +1045,8 @@ def bindparam(key, value=NO_ARG, type_=None, unique=False, required=NO_ARG, compilation/execution. Defaults to ``None``, however if neither ``value`` nor - ``callable`` are passed explicitly, the ``required`` flag will be set to - ``True`` which has the effect of requiring a value be present + ``callable`` are passed explicitly, the ``required`` flag will be + set to ``True`` which has the effect of requiring a value be present when the statement is actually executed. .. versionchanged:: 0.8 The ``required`` flag is set to ``True`` @@ -1062,6 +1098,7 @@ def bindparam(key, value=NO_ARG, type_=None, unique=False, required=NO_ARG, unique=unique, required=required, quote=quote) + def outparam(key, type_=None): """Create an 'OUT' parameter for usage in functions (stored procedures), for databases which support them. @@ -1075,6 +1112,7 @@ def outparam(key, type_=None): return BindParameter( key, None, type_=type_, unique=False, isoutparam=True) + def text(text, bind=None, *args, **kwargs): """Create a SQL construct that is represented by a literal string. @@ -1171,6 +1209,7 @@ def text(text, bind=None, *args, **kwargs): """ return TextClause(text, bind=bind, *args, **kwargs) + def over(func, partition_by=None, order_by=None): """Produce an OVER clause against a function. @@ -1201,12 +1240,14 @@ def over(func, partition_by=None, order_by=None): """ return Over(func, partition_by=partition_by, order_by=order_by) + def null(): """Return a :class:`Null` object, which compiles to ``NULL``. """ return Null() + def true(): """Return a :class:`True_` object, which compiles to ``true``, or the boolean equivalent for the target dialect. @@ -1214,6 +1255,7 @@ def true(): """ return True_() + def false(): """Return a :class:`False_` object, which compiles to ``false``, or the boolean equivalent for the target dialect. @@ -1221,6 +1263,7 @@ def false(): """ return False_() + class _FunctionGenerator(object): """Generate :class:`.Function` objects based on getattr calls.""" @@ -1333,6 +1376,7 @@ func = _FunctionGenerator() # TODO: use UnaryExpression for this instead ? modifier = _FunctionGenerator(group=False) + class _truncated_label(unicode): """A unicode subclass used to identify symbolic " "names that may require truncation.""" @@ -1346,6 +1390,7 @@ class _truncated_label(unicode): # compiler _generated_label = _truncated_label + class _anonymous_label(_truncated_label): """A unicode subclass used to identify anonymously generated names.""" @@ -1363,6 +1408,7 @@ class _anonymous_label(_truncated_label): def apply_map(self, map_): return self % map_ + def _as_truncated(value): """coerce the given value to :class:`._truncated_label`. @@ -1376,6 +1422,7 @@ def _as_truncated(value): else: return _truncated_label(value) + def _string_or_unprintable(element): if isinstance(element, basestring): return element @@ -1385,9 +1432,11 @@ def _string_or_unprintable(element): except: return "unprintable element %r" % element + def _clone(element, **kw): return element._clone() + def _expand_cloned(elements): """expand the given set of ClauseElements to be the set of all 'cloned' predecessors. @@ -1395,6 +1444,7 @@ def _expand_cloned(elements): """ return itertools.chain(*[x._cloned_set for x in elements]) + def _select_iterables(elements): """expand tables into individual columns in the given list of column expressions. @@ -1402,6 +1452,7 @@ def _select_iterables(elements): """ return itertools.chain(*[c._select_iterable for c in elements]) + def _cloned_intersection(a, b): """return the intersection of sets a and b, counting any overlap between 'cloned' predecessors. @@ -1413,15 +1464,18 @@ def _cloned_intersection(a, b): return set(elem for elem in a if all_overlap.intersection(elem._cloned_set)) + def _from_objects(*elements): return itertools.chain(*[element._from_objects for element in elements]) + def _labeled(element): if not hasattr(element, 'name'): return element.label(None) else: return element + # there is some inconsistency here between the usage of # inspect() vs. checking for Visitable and __clause_element__. # Ideally all functions here would derive from inspect(), @@ -1432,6 +1486,7 @@ def _labeled(element): # _interpret_as_from() where we'd like to be able to receive ORM entities # that have no defined namespace, hence inspect() is needed there. + def _column_as_key(element): if isinstance(element, basestring): return element @@ -1442,12 +1497,14 @@ def _column_as_key(element): except AttributeError: return None + def _clause_element_as_expr(element): if hasattr(element, '__clause_element__'): return element.__clause_element__() else: return element + def _literal_as_text(element): if isinstance(element, Visitable): return element @@ -1462,6 +1519,7 @@ def _literal_as_text(element): "SQL expression object or string expected." ) + def _no_literals(element): if hasattr(element, '__clause_element__'): return element.__clause_element__() @@ -1473,16 +1531,19 @@ def _no_literals(element): else: return element + def _is_literal(element): return not isinstance(element, Visitable) and \ not hasattr(element, '__clause_element__') + def _only_column_elements_or_none(element, name): if element is None: return None else: return _only_column_elements(element, name) + def _only_column_elements(element, name): if hasattr(element, '__clause_element__'): element = element.__clause_element__() @@ -1504,6 +1565,7 @@ def _literal_as_binds(element, name=None, type_=None): else: return element + def _interpret_as_column_or_from(element): if isinstance(element, Visitable): return element @@ -1519,6 +1581,7 @@ def _interpret_as_column_or_from(element): return literal_column(str(element)) + def _interpret_as_from(element): insp = inspection.inspect(element, raiseerr=False) if insp is None: @@ -1528,6 +1591,7 @@ def _interpret_as_from(element): return insp.selectable raise exc.ArgumentError("FROM expression expected") + def _const_expr(element): if element is None: return null() @@ -1564,6 +1628,7 @@ def _corresponding_column_or_error(fromclause, column, ) return c + @util.decorator def _generative(fn, *args, **kw): """Mark a method as generative.""" @@ -1798,7 +1863,6 @@ class ClauseElement(Visitable): """ return self - def compile(self, bind=None, dialect=None, **kw): """Compile this SQL expression. @@ -2007,15 +2071,14 @@ class _DefaultColumnComparator(operators.ColumnOperators): # as_scalar() to produce a multi- column selectable that # does not export itself as a FROM clause - return self._boolean_compare(expr, op, seq_or_selectable.as_scalar(), - negate=negate_op, **kw) + return self._boolean_compare( + expr, op, seq_or_selectable.as_scalar(), + negate=negate_op, **kw) elif isinstance(seq_or_selectable, (Selectable, TextClause)): return self._boolean_compare(expr, op, seq_or_selectable, negate=negate_op, **kw) - # Handle non selectable arguments as sequences - args = [] for o in seq_or_selectable: if not _is_literal(o): @@ -2120,7 +2183,6 @@ class _DefaultColumnComparator(operators.ColumnOperators): "rshift": (_unsupported_impl,), } - def _check_literal(self, expr, operator, other): if isinstance(other, (ColumnElement, TextClause)): if isinstance(other, BindParameter) and \ @@ -2152,15 +2214,15 @@ class ColumnElement(ClauseElement, ColumnOperators): :class:`.Column` object, :class:`.ColumnElement` serves as the basis for any unit that may be present in a SQL expression, including the expressions themselves, SQL functions, bound parameters, - literal expressions, keywords such as ``NULL``, etc. :class:`.ColumnElement` - is the ultimate base class for all such elements. + literal expressions, keywords such as ``NULL``, etc. + :class:`.ColumnElement` is the ultimate base class for all such elements. A :class:`.ColumnElement` provides the ability to generate new :class:`.ColumnElement` objects using Python expressions. This means that Python operators such as ``==``, ``!=`` and ``<`` are overloaded to mimic SQL operations, - and allow the instantiation of further :class:`.ColumnElement` instances which - are composed from other, more fundamental :class:`.ColumnElement` + and allow the instantiation of further :class:`.ColumnElement` instances + which are composed from other, more fundamental :class:`.ColumnElement` objects. For example, two :class:`.ColumnClause` objects can be added together with the addition operator ``+`` to produce a :class:`.BinaryExpression`. @@ -2181,7 +2243,6 @@ class ColumnElement(ClauseElement, ColumnOperators): discussion of this concept can be found at `Expression Transformations `_. - """ __visit_name__ = 'column' @@ -2338,6 +2399,7 @@ class ColumnElement(ClauseElement, ColumnOperators): return _anonymous_label('%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon'))) + class ColumnCollection(util.OrderedProperties): """An ordered dictionary that stores a list of ColumnElement instances. @@ -2459,6 +2521,7 @@ class ColumnCollection(util.OrderedProperties): def as_immutable(self): return ImmutableColumnCollection(self._data, self._all_cols) + class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): def __init__(self, data, colset): util.ImmutableProperties.__init__(self, data) @@ -2489,6 +2552,7 @@ class ColumnSet(util.ordered_column_set): def __hash__(self): return hash(tuple(x for x in self)) + class Selectable(ClauseElement): """mark a class as being selectable""" __visit_name__ = 'selectable' @@ -2499,6 +2563,7 @@ class Selectable(ClauseElement): def selectable(self): return self + class FromClause(Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -2790,6 +2855,7 @@ class FromClause(Selectable): else: return None + class BindParameter(ColumnElement): """Represent a bind parameter. @@ -2938,6 +3004,7 @@ class BindParameter(ColumnElement): return 'BindParameter(%r, %r, type_=%r)' % (self.key, self.value, self.type) + class TypeClause(ClauseElement): """Handle a type keyword in a SQL statement. @@ -2983,8 +3050,9 @@ class Executable(Generative): Execution options can be set on a per-statement or per :class:`.Connection` basis. Additionally, the - :class:`.Engine` and ORM :class:`~.orm.query.Query` objects provide access - to execution options which they in turn configure upon connections. + :class:`.Engine` and ORM :class:`~.orm.query.Query` objects provide + access to execution options which they in turn configure upon + connections. The :meth:`execution_options` method is generative. A new instance of this statement is returned that contains the options:: @@ -3064,6 +3132,7 @@ class Executable(Generative): # legacy, some outside users may be calling this _Executable = Executable + class TextClause(Executable, ClauseElement): """Represent a literal SQL text fragment. @@ -3162,6 +3231,7 @@ class Null(ColumnElement): def compare(self, other): return isinstance(other, Null) + class False_(ColumnElement): """Represent the ``false`` keyword in a SQL statement. @@ -3174,6 +3244,7 @@ class False_(ColumnElement): def __init__(self): self.type = sqltypes.BOOLEANTYPE + class True_(ColumnElement): """Represent the ``true`` keyword in a SQL statement. @@ -3262,6 +3333,7 @@ class ClauseList(ClauseElement): else: return False + class BooleanClauseList(ClauseList, ColumnElement): __visit_name__ = 'clauselist' @@ -3280,6 +3352,7 @@ class BooleanClauseList(ClauseList, ColumnElement): else: return super(BooleanClauseList, self).self_group(against=against) + class Tuple(ClauseList, ColumnElement): def __init__(self, *clauses, **kw): @@ -3360,6 +3433,7 @@ class Case(ColumnElement): return list(itertools.chain(*[x._from_objects for x in self.get_children()])) + class FunctionElement(Executable, ColumnElement, FromClause): """Base for SQL function-oriented constructs. @@ -3717,6 +3791,7 @@ class BinaryExpression(ColumnElement): else: return super(BinaryExpression, self)._negate() + class Exists(UnaryExpression): __visit_name__ = UnaryExpression.__visit_name__ _from_objects = [] @@ -3746,9 +3821,9 @@ class Exists(UnaryExpression): return e def select_from(self, clause): - """return a new :class:`.Exists` construct, applying the given expression - to the :meth:`.Select.select_from` method of the select statement - contained. + """return a new :class:`.Exists` construct, applying the given + expression to the :meth:`.Select.select_from` method of the select + statement contained. """ e = self._clone() @@ -3764,6 +3839,7 @@ class Exists(UnaryExpression): e.element = self.element.where(clause).self_group() return e + class Join(FromClause): """represent a ``JOIN`` construct between two :class:`.FromClause` elements. @@ -3916,6 +3992,7 @@ class Join(FromClause): self.left._from_objects + \ self.right._from_objects + class Alias(FromClause): """Represents an table or selectable alias (AS). @@ -4009,6 +4086,7 @@ class Alias(FromClause): def bind(self): return self.element.bind + class CTE(Alias): """Represent a Common Table Expression. @@ -4093,6 +4171,7 @@ class Grouping(ColumnElement): return isinstance(other, Grouping) and \ self.element.compare(other.element) + class FromGrouping(FromClause): """Represent a grouping of a FROM clause""" __visit_name__ = 'grouping' @@ -4141,6 +4220,7 @@ class FromGrouping(FromClause): def __setstate__(self, state): self.element = state['element'] + class Over(ColumnElement): """Represent an OVER clause. @@ -4187,6 +4267,7 @@ class Over(ColumnElement): if c is not None] )) + class Label(ColumnElement): """Represents a column label (AS). @@ -4260,6 +4341,7 @@ class Label(ColumnElement): e.type = self._type return e + class ColumnClause(Immutable, ColumnElement): """Represents a generic column expression from any textual string. @@ -4403,7 +4485,6 @@ class ColumnClause(Immutable, ColumnElement): else: return name - def _bind_param(self, operator, obj): return BindParameter(self.name, obj, _compared_to_operator=operator, @@ -4431,6 +4512,7 @@ class ColumnClause(Immutable, ColumnElement): selectable._columns[c.key] = c return c + class TableClause(Immutable, FromClause): """Represents a minimal "table" construct. @@ -4564,6 +4646,7 @@ class TableClause(Immutable, FromClause): def _from_objects(self): return [self] + class SelectBase(Executable, FromClause): """Base class for :class:`.Select` and ``CompoundSelects``.""" @@ -4871,6 +4954,7 @@ class ScalarSelect(Generative, Grouping): def self_group(self, **kwargs): return self + class CompoundSelect(SelectBase): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations.""" @@ -4984,6 +5068,7 @@ class CompoundSelect(SelectBase): self._bind = bind bind = property(bind, _set_bind) + class HasPrefixes(object): _prefixes = () @@ -5020,6 +5105,7 @@ class HasPrefixes(object): self._prefixes = self._prefixes + tuple( [(_literal_as_text(p), dialect) for p in prefixes]) + class Select(HasPrefixes, SelectBase): """Represents a ``SELECT`` statement. @@ -5332,8 +5418,8 @@ class Select(HasPrefixes, SelectBase): other either based on foreign key, or via a simple equality comparison in the WHERE clause of the statement. The primary purpose of this method is to automatically construct a select statement - with all uniquely-named columns, without the need to use table-qualified - labels as :meth:`.apply_labels` does. + with all uniquely-named columns, without the need to use + table-qualified labels as :meth:`.apply_labels` does. When columns are omitted based on foreign key, the referred-to column is the one that's kept. When columns are omitted based on @@ -5488,7 +5574,6 @@ class Select(HasPrefixes, SelectBase): else: self._distinct = True - @_generative def select_from(self, fromclause): """return a new :func:`.select` construct with the @@ -5733,6 +5818,7 @@ class Select(HasPrefixes, SelectBase): self._bind = bind bind = property(bind, _set_bind) + class UpdateBase(HasPrefixes, Executable, ClauseElement): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements. @@ -5779,7 +5865,6 @@ class UpdateBase(HasPrefixes, Executable, ClauseElement): self._bind = bind bind = property(bind, _set_bind) - _returning_re = re.compile(r'(?:firebird|postgres(?:ql)?)_returning') def _process_deprecated_kw(self, kwargs): @@ -5866,6 +5951,7 @@ class UpdateBase(HasPrefixes, Executable, ClauseElement): self._hints = self._hints.union( {(selectable, dialect_name): text}) + class ValuesBase(UpdateBase): """Supplies support for :meth:`.ValuesBase.values` to INSERT and UPDATE constructs.""" @@ -5923,6 +6009,7 @@ class ValuesBase(UpdateBase): self.parameters.update(self._process_colparams(v)) self.parameters.update(kwargs) + class Insert(ValuesBase): """Represent an INSERT construct. @@ -5936,7 +6023,6 @@ class Insert(ValuesBase): """ __visit_name__ = 'insert' - def __init__(self, table, values=None, @@ -6032,6 +6118,7 @@ class Update(ValuesBase): return froms + class Delete(UpdateBase): """Represent a DELETE construct. @@ -6083,6 +6170,7 @@ class Delete(UpdateBase): # TODO: coverage self._whereclause = clone(self._whereclause, **kw) + class _IdentifiedClause(Executable, ClauseElement): __visit_name__ = 'identified' @@ -6093,12 +6181,15 @@ class _IdentifiedClause(Executable, ClauseElement): def __init__(self, ident): self.ident = ident + class SavepointClause(_IdentifiedClause): __visit_name__ = 'savepoint' + class RollbackToSavepointClause(_IdentifiedClause): __visit_name__ = 'rollback_to_savepoint' + class ReleaseSavepointClause(_IdentifiedClause): __visit_name__ = 'release_savepoint' @@ -6123,4 +6214,3 @@ _Exists = Exists _Grouping = Grouping _FromGrouping = FromGrouping _ScalarSelect = ScalarSelect - diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index d26589bd9d..fd6607be04 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -15,6 +15,7 @@ from .. import util _registry = util.defaultdict(dict) + def register_function(identifier, fn, package="_default"): """Associate a callable with a particular func. name. @@ -39,6 +40,7 @@ class _GenericMeta(VisitableType): register_function(identifier, cls, package) super(_GenericMeta, cls).__init__(clsname, bases, clsdict) + class GenericFunction(Function): """Define a 'generic' function. @@ -113,6 +115,7 @@ class GenericFunction(Function): __metaclass__ = _GenericMeta coerce_arguments = True + def __init__(self, *args, **kwargs): parsed_args = kwargs.pop('_parsed_args', None) if parsed_args is None: @@ -129,6 +132,7 @@ class GenericFunction(Function): register_function("cast", cast) register_function("extract", extract) + class next_value(GenericFunction): """Represent the 'next value', given a :class:`.Sequence` as it's single argument. @@ -151,10 +155,12 @@ class next_value(GenericFunction): def _from_objects(self): return [] + class AnsiFunction(GenericFunction): def __init__(self, **kwargs): GenericFunction.__init__(self, **kwargs) + class ReturnTypeFromArgs(GenericFunction): """Define a function whose return type is the same as its arguments.""" @@ -164,15 +170,19 @@ class ReturnTypeFromArgs(GenericFunction): kwargs['_parsed_args'] = args GenericFunction.__init__(self, *args, **kwargs) + class coalesce(ReturnTypeFromArgs): pass + class max(ReturnTypeFromArgs): pass + class min(ReturnTypeFromArgs): pass + class sum(ReturnTypeFromArgs): pass @@ -180,21 +190,27 @@ class sum(ReturnTypeFromArgs): class now(GenericFunction): type = sqltypes.DateTime + class concat(GenericFunction): type = sqltypes.String + class char_length(GenericFunction): type = sqltypes.Integer def __init__(self, arg, **kwargs): GenericFunction.__init__(self, arg, **kwargs) + class random(GenericFunction): pass + class count(GenericFunction): - """The ANSI COUNT aggregate function. With no arguments, emits COUNT \*.""" + """The ANSI COUNT aggregate function. With no arguments, + emits COUNT \*. + """ type = sqltypes.Integer def __init__(self, expression=None, **kwargs): @@ -202,30 +218,38 @@ class count(GenericFunction): expression = literal_column('*') GenericFunction.__init__(self, expression, **kwargs) + class current_date(AnsiFunction): type = sqltypes.Date + class current_time(AnsiFunction): type = sqltypes.Time + class current_timestamp(AnsiFunction): type = sqltypes.DateTime + class current_user(AnsiFunction): type = sqltypes.String + class localtime(AnsiFunction): type = sqltypes.DateTime + class localtimestamp(AnsiFunction): type = sqltypes.DateTime + class session_user(AnsiFunction): type = sqltypes.String + class sysdate(AnsiFunction): type = sqltypes.DateTime + class user(AnsiFunction): type = sqltypes.String - diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 7513c0b82c..0f90f50ab1 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -395,8 +395,8 @@ class ColumnOperators(Operators): def notlike(self, other, escape=None): """implement the ``NOT LIKE`` operator. - This is equivalent to using negation with :meth:`.ColumnOperators.like`, - i.e. ``~x.like(y)``. + This is equivalent to using negation with + :meth:`.ColumnOperators.like`, i.e. ``~x.like(y)``. .. versionadded:: 0.8 @@ -410,8 +410,8 @@ class ColumnOperators(Operators): def notilike(self, other, escape=None): """implement the ``NOT ILIKE`` operator. - This is equivalent to using negation with :meth:`.ColumnOperators.ilike`, - i.e. ``~x.ilike(y)``. + This is equivalent to using negation with + :meth:`.ColumnOperators.ilike`, i.e. ``~x.ilike(y)``. .. versionadded:: 0.8 @@ -549,7 +549,10 @@ class ColumnOperators(Operators): return self.operate(between_op, cleft, cright) def distinct(self): - """Produce a :func:`~.expression.distinct` clause against the parent object.""" + """Produce a :func:`~.expression.distinct` clause against the + parent object. + + """ return self.operate(distinct_op) def __add__(self, other): @@ -612,100 +615,132 @@ class ColumnOperators(Operators): """ return self.reverse_operate(truediv, other) + def from_(): raise NotImplementedError() + def as_(): raise NotImplementedError() + def exists(): raise NotImplementedError() + def is_(a, b): return a.is_(b) + def isnot(a, b): return a.isnot(b) + def collate(a, b): return a.collate(b) + def op(a, opstring, b): return a.op(opstring)(b) + def like_op(a, b, escape=None): return a.like(b, escape=escape) + def notlike_op(a, b, escape=None): return a.notlike(b, escape=escape) + def ilike_op(a, b, escape=None): return a.ilike(b, escape=escape) + def notilike_op(a, b, escape=None): return a.notilike(b, escape=escape) + def between_op(a, b, c): return a.between(b, c) + def in_op(a, b): return a.in_(b) + def notin_op(a, b): return a.notin_(b) + def distinct_op(a): return a.distinct() + def startswith_op(a, b, escape=None): return a.startswith(b, escape=escape) + def notstartswith_op(a, b, escape=None): return ~a.startswith(b, escape=escape) + def endswith_op(a, b, escape=None): return a.endswith(b, escape=escape) + def notendswith_op(a, b, escape=None): return ~a.endswith(b, escape=escape) + def contains_op(a, b, escape=None): return a.contains(b, escape=escape) + def notcontains_op(a, b, escape=None): return ~a.contains(b, escape=escape) + def match_op(a, b): return a.match(b) + def comma_op(a, b): raise NotImplementedError() + def concat_op(a, b): return a.concat(b) + def desc_op(a): return a.desc() + def asc_op(a): return a.asc() + def nullsfirst_op(a): return a.nullsfirst() + def nullslast_op(a): return a.nullslast() + _commutative = set([eq, ne, add, mul]) _comparison = set([eq, ne, lt, gt, ge, le]) + def is_comparison(op): return op in _comparison + def is_commutative(op): return op in _commutative + def is_ordering_modifier(op): return op in (asc_op, desc_op, nullsfirst_op, nullslast_op) diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 2c07690122..29504cd711 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -12,12 +12,14 @@ from collections import deque """Utility functions that build upon SQL and Schema constructs.""" + def sort_tables(tables, skip_fn=None): """sort a collection of Table objects in order of their foreign-key dependency.""" tables = list(tables) tuples = [] + def visit_foreign_key(fkey): if fkey.use_alter: return @@ -40,6 +42,7 @@ def sort_tables(tables, skip_fn=None): return list(topological.sort(tuples, tables)) + def find_join_source(clauses, join_to): """Given a list of FROM clauses and a selectable, return the first index and element from the list of @@ -102,6 +105,7 @@ def visit_binary_product(fn, expr): """ stack = [] + def visit(element): if isinstance(element, (expression.ScalarSelect)): # we dont want to dig into correlated subqueries, @@ -124,6 +128,7 @@ def visit_binary_product(fn, expr): yield e list(visit(expr)) + def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False, include_crud=False): @@ -139,7 +144,7 @@ def find_tables(clause, check_columns=False, _visitors['join'] = tables.append if include_aliases: - _visitors['alias'] = tables.append + _visitors['alias'] = tables.append if include_crud: _visitors['insert'] = _visitors['update'] = \ @@ -152,16 +157,18 @@ def find_tables(clause, check_columns=False, _visitors['table'] = tables.append - visitors.traverse(clause, {'column_collections':False}, _visitors) + visitors.traverse(clause, {'column_collections': False}, _visitors) return tables + def find_columns(clause): """locate Column objects within the given expression.""" cols = util.column_set() - visitors.traverse(clause, {}, {'column':cols.add}) + visitors.traverse(clause, {}, {'column': cols.add}) return cols + def unwrap_order_by(clause): """Break up an 'order by' expression into individual column-expressions, without DESC/ASC/NULLS FIRST/NULLS LAST""" @@ -181,6 +188,7 @@ def unwrap_order_by(clause): stack.append(c) return cols + def clause_is_present(clause, search): """Given a target clause and a second to search within, return True if the target is plainly present in the search without any @@ -213,12 +221,14 @@ def bind_values(clause): """ v = [] + def visit_bindparam(bind): v.append(bind.effective_value) - visitors.traverse(clause, {}, {'bindparam':visit_bindparam}) + visitors.traverse(clause, {}, {'bindparam': visit_bindparam}) return v + def _quote_ddl_expr(element): if isinstance(element, basestring): element = element.replace("'", "''") @@ -226,6 +236,7 @@ def _quote_ddl_expr(element): else: return repr(element) + class _repr_params(object): """A string view of bound parameters, truncating display to the given number of 'multi' parameter sets. @@ -239,9 +250,10 @@ class _repr_params(object): if isinstance(self.params, (list, tuple)) and \ len(self.params) > self.batches and \ isinstance(self.params[0], (list, dict, tuple)): + msg = " ... displaying %i of %i total bound parameter sets ... " return ' '.join(( repr(self.params[:self.batches - 2])[0:-1], - " ... displaying %i of %i total bound parameter sets ... " % (self.batches, len(self.params)), + msg % (self.batches, len(self.params)), repr(self.params[-2:])[1:] )) else: @@ -268,8 +280,12 @@ def expression_as_ddl(clause): return visitors.replacement_traverse(clause, {}, repl) + def adapt_criterion_to_null(crit, nulls): - """given criterion containing bind params, convert selected elements to IS NULL.""" + """given criterion containing bind params, convert selected elements + to IS NULL. + + """ def visit_binary(binary): if isinstance(binary.left, expression.BindParameter) \ @@ -285,7 +301,7 @@ def adapt_criterion_to_null(crit, nulls): binary.operator = operators.is_ binary.negate = operators.isnot - return visitors.cloned_traverse(crit, {}, {'binary':visit_binary}) + return visitors.cloned_traverse(crit, {}, {'binary': visit_binary}) def join_condition(a, b, ignore_nonexistent_tables=False, @@ -325,7 +341,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False, continue for fk in sorted( b.foreign_keys, - key=lambda fk:fk.parent._creation_order): + key=lambda fk: fk.parent._creation_order): if consider_as_foreign_keys is not None and \ fk.parent not in consider_as_foreign_keys: continue @@ -343,7 +359,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False, if left is not b: for fk in sorted( left.foreign_keys, - key=lambda fk:fk.parent._creation_order): + key=lambda fk: fk.parent._creation_order): if consider_as_foreign_keys is not None and \ fk.parent not in consider_as_foreign_keys: continue @@ -473,6 +489,7 @@ class Annotated(object): else: return hash(other) == hash(self) + class AnnotatedColumnElement(Annotated): def __init__(self, element, values): Annotated.__init__(self, element, values) @@ -506,6 +523,7 @@ for cls in expression.__dict__.values() + [schema.Column, schema.Table]: " pass" % (cls.__name__, annotation_cls) in locals() exec "annotated_classes[cls] = Annotated%s" % (cls.__name__,) + def _deep_annotate(element, annotations, exclude=None): """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary. @@ -529,6 +547,7 @@ def _deep_annotate(element, annotations, exclude=None): element = clone(element) return element + def _deep_deannotate(element, values=None): """Deep copy the given element, removing annotations.""" @@ -554,6 +573,7 @@ def _deep_deannotate(element, values=None): element = clone(element) return element + def _shallow_annotate(element, annotations): """Annotate the given ClauseElement and copy its internals so that internal objects refer to the new annotated object. @@ -566,6 +586,7 @@ def _shallow_annotate(element, annotations): element._copy_internals() return element + def splice_joins(left, right, stop_on=None): if left is None: return right @@ -590,12 +611,15 @@ def splice_joins(left, right, stop_on=None): return ret + def reduce_columns(columns, *clauses, **kw): - """given a list of columns, return a 'reduced' set based on natural equivalents. + """given a list of columns, return a 'reduced' set based on natural + equivalents. the set is reduced to the smallest list of columns which have no natural - equivalent present in the list. A "natural equivalent" means that two columns - will ultimately represent the same value because they are related by a foreign key. + equivalent present in the list. A "natural equivalent" means that two + columns will ultimately represent the same value because they are related + by a foreign key. \*clauses is an optional list of join clauses which will be traversed to further identify columns that are "equivalent". @@ -659,6 +683,7 @@ def reduce_columns(columns, *clauses, **kw): return expression.ColumnSet(columns.difference(omit)) + def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_referenced_keys=None, any_operator=False): """traverse an expression and locate binary criterion pairs.""" @@ -705,7 +730,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) pairs = [] - visitors.traverse(expression, {}, {'binary':visit_binary}) + visitors.traverse(expression, {}, {'binary': visit_binary}) return pairs @@ -768,7 +793,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): include=None, exclude=None, include_fn=None, exclude_fn=None, adapt_on_names=False): - self.__traverse_options__ = {'stop_on':[selectable]} + self.__traverse_options__ = {'stop_on': [selectable]} self.selectable = selectable if include: assert not include_fn @@ -783,7 +808,8 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): self.equivalents = util.column_dict(equivalents or {}) self.adapt_on_names = adapt_on_names - def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET): + def _corresponding_column(self, col, require_embedded, + _seen=util.EMPTY_SET): newcol = self.selectable.corresponding_column( col, require_embedded=require_embedded) @@ -811,6 +837,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): else: return self._corresponding_column(col, True) + class ColumnAdapter(ClauseAdapter): """Extends ClauseAdapter with extra utility functions. diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 6f2c829921..09c50a9348 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -24,7 +24,6 @@ http://techspot.zzzeek.org/2008/01/23/expression-transformations/ """ from collections import deque -import re from .. import util import operator @@ -33,6 +32,7 @@ __all__ = ['VisitableType', 'Visitable', 'ClauseVisitor', 'iterate_depthfirst', 'traverse_using', 'traverse', 'cloned_traverse', 'replacement_traverse'] + class VisitableType(type): """Metaclass which assigns a `_compiler_dispatch` method to classes having a `__visit_name__` attribute. @@ -43,7 +43,8 @@ class VisitableType(type): def _compiler_dispatch (self, visitor, **kw): '''Look for an attribute named "visit_" + self.__visit_name__ on the visitor, and call it with the same kw params.''' - return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) + visit_attr = 'visit_%s' % self.__visit_name__ + return getattr(visitor, visit_attr)(self, **kw) Classes having no __visit_name__ attribute will remain unaffected. """ @@ -68,6 +69,7 @@ def _generate_dispatch(cls): # the string name of the class's __visit_name__ is known at # this early stage (import time) so it can be pre-constructed. getter = operator.attrgetter("visit_%s" % visit_name) + def _compiler_dispatch(self, visitor, **kw): return getter(visitor)(self, **kw) else: @@ -75,14 +77,16 @@ def _generate_dispatch(cls): # __visit_name__ is not yet a string. As a result, the visit # string has to be recalculated with each compilation. def _compiler_dispatch(self, visitor, **kw): - return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) + visit_attr = 'visit_%s' % self.__visit_name__ + return getattr(visitor, visit_attr)(self, **kw) - _compiler_dispatch.__doc__ = \ + _compiler_dispatch.__doc__ = \ """Look for an attribute named "visit_" + self.__visit_name__ on the visitor, and call it with the same kw params. """ cls._compiler_dispatch = _compiler_dispatch + class Visitable(object): """Base class for visitable objects, applies the ``VisitableType`` metaclass. @@ -91,6 +95,7 @@ class Visitable(object): __metaclass__ = VisitableType + class ClauseVisitor(object): """Base class for visitor objects which can traverse using the traverse() function. @@ -106,8 +111,10 @@ class ClauseVisitor(object): return meth(obj, **kw) def iterate(self, obj): - """traverse the given expression structure, returning an iterator of all elements.""" + """traverse the given expression structure, returning an iterator + of all elements. + """ return iterate(obj, self.__traverse_options__) def traverse(self, obj): @@ -143,6 +150,7 @@ class ClauseVisitor(object): tail._next = visitor return self + class CloningVisitor(ClauseVisitor): """Base class for visitor objects which can traverse using the cloned_traverse() function. @@ -150,14 +158,18 @@ class CloningVisitor(ClauseVisitor): """ def copy_and_process(self, list_): - """Apply cloned traversal to the given list of elements, and return the new list.""" + """Apply cloned traversal to the given list of elements, and return + the new list. + """ return [self.traverse(x) for x in list_] def traverse(self, obj): """traverse and visit the given expression structure.""" - return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict) + return cloned_traverse( + obj, self.__traverse_options__, self._visitor_dict) + class ReplacingCloningVisitor(CloningVisitor): """Base class for visitor objects which can traverse using @@ -184,6 +196,7 @@ class ReplacingCloningVisitor(CloningVisitor): return e return replacement_traverse(obj, self.__traverse_options__, replace) + def iterate(obj, opts): """traverse the given expression structure, returning an iterator. @@ -197,6 +210,7 @@ def iterate(obj, opts): for c in t.get_children(**opts): stack.append(c) + def iterate_depthfirst(obj, opts): """traverse the given expression structure, returning an iterator. @@ -212,25 +226,35 @@ def iterate_depthfirst(obj, opts): stack.append(c) return iter(traversal) + def traverse_using(iterator, obj, visitors): - """visit the given expression structure using the given iterator of objects.""" + """visit the given expression structure using the given iterator of + objects. + """ for target in iterator: meth = visitors.get(target.__visit_name__, None) if meth: meth(target) return obj + def traverse(obj, opts, visitors): - """traverse and visit the given expression structure using the default iterator.""" + """traverse and visit the given expression structure using the default + iterator. + """ return traverse_using(iterate(obj, opts), obj, visitors) + def traverse_depthfirst(obj, opts, visitors): - """traverse and visit the given expression structure using the depth-first iterator.""" + """traverse and visit the given expression structure using the + depth-first iterator. + """ return traverse_using(iterate_depthfirst(obj, opts), obj, visitors) + def cloned_traverse(obj, opts, visitors): """clone the given expression structure, allowing modifications by visitors.""" -- 2.47.2