from operator import attrgetter
from sqlalchemy import util, exc
-from sqlalchemy.sql import operators, visitors
+from sqlalchemy.sql import operators
+from sqlalchemy.sql.visitors import Visitable, cloned_traverse
from sqlalchemy import types as sqltypes
functions, schema, sql_util = None, None, None
return CompoundSelect(keyword, *selects, **kwargs)
def _is_literal(element):
- return not isinstance(element, ClauseElement)
+ return not isinstance(element, Visitable) and not hasattr(element, '__clause_element__')
def _from_objects(*elements, **kwargs):
return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements])
def _literal_as_text(element):
if hasattr(element, '__clause_element__'):
return element.__clause_element__()
- elif not isinstance(element, ClauseElement):
+ elif not isinstance(element, Visitable):
return _TextClause(unicode(element))
else:
return element
def _literal_as_column(element):
if hasattr(element, '__clause_element__'):
return element.__clause_element__()
- elif not isinstance(element, ClauseElement):
+ elif not isinstance(element, Visitable):
return literal_column(str(element))
else:
return element
def _literal_as_binds(element, name=None, type_=None):
if hasattr(element, '__clause_element__'):
return element.__clause_element__()
- elif not isinstance(element, ClauseElement):
+ elif not isinstance(element, Visitable):
if element is None:
return null()
else:
def _no_literals(element):
if hasattr(element, '__clause_element__'):
return element.__clause_element__()
- elif not isinstance(element, ClauseElement):
- raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
+ elif not isinstance(element, Visitable):
+ raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function "
+ "to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
else:
return element
def _corresponding_column_or_error(fromclause, column, require_embedded=False):
c = fromclause.corresponding_column(column, require_embedded=require_embedded)
if not c:
- raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
+ raise exc.InvalidRequestError("Given column '%s', attached to table '%s', "
+ "failed to locate a corresponding column from table '%s'"
+ % (column, getattr(column, 'table', None), fromclause.description))
return c
def _selectable(element):
elif isinstance(element, Selectable):
return element
else:
- raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
+ raise exc.ArgumentError("Object %r is not a Selectable and does not implement `__selectable__()`" % element)
def is_column(col):
"""True if ``col`` is an instance of ``ColumnElement``."""
return isinstance(col, ColumnElement)
-class _FigureVisitName(type):
- def __init__(cls, clsname, bases, dict):
- if not '__visit_name__' in cls.__dict__:
- m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
- x = m.group(1)
- x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
- cls.__visit_name__ = x.lower()
-
- # set up an optimized visit dispatch function
- # for use by the compiler
- visit_name = cls.__dict__["__visit_name__"]
- if isinstance(visit_name, str):
- func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
- " return visitor.visit_%s(self, **kw)" % visit_name
- else:
- func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
- " return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)"
- env = locals().copy()
- exec func_text in env
- cls._compiler_dispatch = env['_compiler_dispatch']
-
- super(_FigureVisitName, cls).__init__(clsname, bases, dict)
-
-class ClauseElement(object):
+class ClauseElement(Visitable):
"""Base class for elements of a programmatically constructed SQL expression."""
- __metaclass__ = _FigureVisitName
_annotations = {}
supports_execution = False
This method may be used by a generative API. Its also used as
part of the "deep" copy afforded by a traversal that combines
the _copy_internals() method.
+
"""
c = self.__class__.__new__(self.__class__)
c.__dict__ = self.__dict__.copy()
should be added to the ``FROM`` list of a query, when this
``ClauseElement`` is placed in the column clause of a
``Select`` statement.
+
"""
-
raise NotImplementedError(repr(self))
def _annotate(self, values):
bind.value = kwargs[bind.key]
if unique:
bind._convert_to_unique()
- return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam})
+ return cloned_traverse(self, {}, {'bindparam':visit_bindparam})
def compare(self, other):
"""Compare this ClauseElement to the given ClauseElement.
rules applied regardless of case sensitive settings. the
``literal_column()`` function is usually used to create such a
``_ColumnClause``.
+
"""
-
def __init__(self, text, selectable=None, type_=None, is_literal=False):
ColumnElement.__init__(self)
self.key = self.name = text
self.table = selectable
self.type = sqltypes.to_instance(type_)
- self.__label = None
self.is_literal = is_literal
@util.memoized_property
def _label(self):
if self.is_literal:
return None
- if not self.__label:
- if self.table and self.table.named_with_column:
- if getattr(self.table, 'schema', None):
- self.__label = self.table.schema + "_" + self.table.name + "_" + self.name
- else:
- self.__label = self.table.name + "_" + self.name
-
- if self.__label in self.table.c:
- label = self.__label
- counter = 1
- while label in self.table.c:
- label = self.__label + "_" + str(counter)
- counter += 1
- self.__label = label
+
+ elif self.table and self.table.named_with_column:
+ if getattr(self.table, 'schema', None):
+ label = self.table.schema + "_" + self.table.name + "_" + self.name
else:
- self.__label = self.name
- return self.__label
+ label = self.table.name + "_" + self.name
+
+ if label in self.table.c:
+ # TODO: coverage does not seem to be present for this
+ _label = label
+ counter = 1
+ while _label in self.table.c:
+ _label = label + "_" + str(counter)
+ counter += 1
+ label = _label
+ return label
+
+ else:
+ return self.name
def label(self, name):
if name is None:
def _export_columns(self):
raise NotImplementedError()
- @property
+ @util.memoized_property
def description(self):
return self.name.encode('ascii', 'backslashreplace')
def _get_from_objects(self, **modifiers):
return [self]
+@util.decorator
+def _generative(fn, *args, **kw):
+ """Mark a method as generative."""
+
+ self = args[0]._generate()
+ fn(self, *args[1:], **kw)
+ return self
+
class _SelectBaseMixin(object):
"""Base class for ``Select`` and ``CompoundSelects``."""
"""
return _ScalarSelect(self)
+ @_generative
def apply_labels(self):
"""return a new selectable with the 'use_labels' flag set to True.
among the individual FROM clauses.
"""
- s = self._generate()
- s.use_labels = True
- return s
+ self.use_labels = True
def label(self, name):
"""return a 'scalar' representation of this selectable, embedded as a subquery
"""
return self.as_scalar().label(name)
+ @_generative
def autocommit(self):
"""return a new selectable with the 'autocommit' flag set to True."""
- s = self._generate()
- s._autocommit = True
- return s
+ self._autocommit = True
def _generate(self):
s = self.__class__.__new__(self.__class__)
s._reset_exported()
return s
+ @_generative
def limit(self, limit):
"""return a new selectable with the given LIMIT criterion applied."""
- s = self._generate()
- s._limit = limit
- return s
+ self._limit = limit
+ @_generative
def offset(self, offset):
"""return a new selectable with the given OFFSET criterion applied."""
- s = self._generate()
- s._offset = offset
- return s
+ self._offset = offset
+ @_generative
def order_by(self, *clauses):
"""return a new selectable with the given list of ORDER BY criterion applied.
The criterion will be appended to any pre-existing ORDER BY criterion.
"""
- s = self._generate()
- s.append_order_by(*clauses)
- return s
+ self.append_order_by(*clauses)
+ @_generative
def group_by(self, *clauses):
"""return a new selectable with the given list of GROUP BY criterion applied.
The criterion will be appended to any pre-existing GROUP BY criterion.
"""
- s = self._generate()
- s.append_group_by(*clauses)
- return s
+ self.append_group_by(*clauses)
def append_order_by(self, *clauses):
"""Append the given ORDER BY criterion applied to this selectable.
self._raw_columns + list(self._froms) + \
[x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
+ @_generative
def column(self, column):
"""return a new select() construct with the given column expression added to its columns clause."""
- s = self._generate()
column = _literal_as_column(column)
if isinstance(column, _ScalarSelect):
column = column.self_group(against=operators.comma_op)
- s._raw_columns = s._raw_columns + [column]
- s._froms = s._froms.union(_from_objects(column))
- return s
+ self._raw_columns = self._raw_columns + [column]
+ self._froms = self._froms.union(_from_objects(column))
+ @_generative
def with_only_columns(self, columns):
"""return a new select() construct with its columns clause replaced with the given columns."""
- s = self._generate()
- s._raw_columns = [
+
+ self._raw_columns = [
isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c
for c in
[_literal_as_column(c) for c in columns]
]
- return s
+ @_generative
def where(self, whereclause):
"""return a new select() construct with the given expression added to its WHERE clause, joined
to the existing clause via AND, if any."""
- s = self._generate()
- s.append_whereclause(whereclause)
- return s
+ self.append_whereclause(whereclause)
+ @_generative
def having(self, having):
"""return a new select() construct with the given expression added to its HAVING clause, joined
to the existing clause via AND, if any."""
- s = self._generate()
- s.append_having(having)
- return s
+ self.append_having(having)
+ @_generative
def distinct(self):
"""return a new select() construct which will apply DISTINCT to its columns clause."""
- s = self._generate()
- s._distinct = True
- return s
+ self._distinct = True
+ @_generative
def prefix_with(self, clause):
"""return a new select() construct which will apply the given expression to the start of its
columns clause, not using any commas."""
- s = self._generate()
clause = _literal_as_text(clause)
- s._prefixes = s._prefixes + [clause]
- return s
+ self._prefixes = self._prefixes + [clause]
+ @_generative
def select_from(self, fromclause):
"""return a new select() construct with the given FROM expression applied to its list of
FROM objects."""
- s = self._generate()
if _is_literal(fromclause):
fromclause = _TextClause(fromclause)
- s._froms = s._froms.union([fromclause])
- return s
+ self._froms = self._froms.union([fromclause])
+ @_generative
def correlate(self, *fromclauses):
"""return a new select() construct which will correlate the given FROM clauses to that
of an enclosing select(), if a match is found.
If the fromclause is None, correlation is disabled for the returned select().
"""
- s = self._generate()
- s._should_correlate = False
+ self._should_correlate = False
if fromclauses == (None,):
- s._correlate = set()
+ self._correlate = set()
else:
- s._correlate = s._correlate.union(fromclauses)
- return s
+ self._correlate = self._correlate.union(fromclauses)
def append_correlation(self, fromclause):
"""append the given correlation expression to this select() construct."""
def _copy_internals(self, clone=_clone):
self.parameters = self.parameters.copy()
+ @_generative
def prefix_with(self, clause):
"""Add a word or expression between INSERT and INTO. Generative.
If multiple prefixes are supplied, they will be separated with
spaces.
"""
- gen = self._generate()
clause = _literal_as_text(clause)
- gen._prefixes = self._prefixes + [clause]
- return gen
+ self._prefixes = self._prefixes + [clause]
class Update(_ValuesBase):
def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs):
self._whereclause = clone(self._whereclause)
self.parameters = self.parameters.copy()
+ @_generative
def where(self, whereclause):
"""return a new update() construct with the given expression added to its WHERE clause, joined
to the existing clause via AND, if any."""
- s = self._generate()
- if s._whereclause is not None:
- s._whereclause = and_(s._whereclause, _literal_as_text(whereclause))
+ if self._whereclause is not None:
+ self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
else:
- s._whereclause = _literal_as_text(whereclause)
- return s
+ self._whereclause = _literal_as_text(whereclause)
class Delete(_UpdateBase):
else:
return ()
+ @_generative
def where(self, whereclause):
"""return a new delete() construct with the given expression added to its WHERE clause, joined
to the existing clause via AND, if any."""
- s = self._generate()
- if s._whereclause is not None:
- s._whereclause = and_(s._whereclause, _literal_as_text(whereclause))
+ if self._whereclause is not None:
+ self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
else:
- s._whereclause = _literal_as_text(whereclause)
- return s
+ self._whereclause = _literal_as_text(whereclause)
def _copy_internals(self, clone=_clone):
self._whereclause = clone(self._whereclause)
from collections import deque
+import re
+from sqlalchemy import util
+
+class VisitableType(type):
+ def __init__(cls, clsname, bases, dict):
+ if not '__visit_name__' in cls.__dict__:
+ m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
+ x = m.group(1)
+ x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
+ cls.__visit_name__ = x.lower()
+
+ # set up an optimized visit dispatch function
+ # for use by the compiler
+ visit_name = cls.__dict__["__visit_name__"]
+ if isinstance(visit_name, str):
+ func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
+ " return visitor.visit_%s(self, **kw)" % visit_name
+ else:
+ func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
+ " return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)"
+
+ env = locals().copy()
+ exec func_text in env
+ cls._compiler_dispatch = env['_compiler_dispatch']
+
+ super(VisitableType, cls).__init__(clsname, bases, dict)
+
+class Visitable(object):
+ __metaclass__ = VisitableType
class ClauseVisitor(object):
__traverse_options__ = {}
def traverse_single(self, obj):
- for v in self._iterate_visitors:
+ for v in self._visitor_iterator:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth:
return meth(obj)
def traverse(self, obj):
"""traverse and visit the given expression structure."""
+ return traverse(obj, self.__traverse_options__, self._visitor_dict)
+
+ @util.memoized_property
+ def _visitor_dict(self):
visitors = {}
for name in dir(self):
if name.startswith('visit_'):
visitors[name[6:]] = getattr(self, name)
-
- return traverse(obj, self.__traverse_options__, visitors)
-
- def _iterate_visitors(self):
+ return visitors
+
+ @property
+ def _visitor_iterator(self):
"""iterate through this visitor and each 'chained' visitor."""
v = self
while v:
yield v
v = getattr(v, '_next', None)
- _iterate_visitors = property(_iterate_visitors)
def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
the chained visitor will receive all visit events after this one.
+
"""
- tail = list(self._iterate_visitors)[-1]
+ tail = list(self._visitor_iterator)[-1]
tail._next = visitor
return self
def traverse(self, obj):
"""traverse and visit the given expression structure."""
- visitors = {}
-
- for name in dir(self):
- if name.startswith('visit_'):
- visitors[name[6:]] = getattr(self, name)
-
- return cloned_traverse(obj, self.__traverse_options__, visitors)
+ return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict)
class ReplacingCloningVisitor(CloningVisitor):
def replace(self, elem):
"""traverse and visit the given expression structure."""
def replace(elem):
- for v in self._iterate_visitors:
+ for v in self._visitor_iterator:
e = v.replace(elem)
if e:
return e