]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- moved _FigureVisitName into visitiors.VisitorType, added Visitor base class to...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Oct 2008 19:44:21 +0000 (19:44 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Oct 2008 19:44:21 +0000 (19:44 +0000)
- implemented _generative decorator for select/update/insert/delete constructs
- other minutiae

lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/visitors.py
test/profiling/zoomark.py

index 567b1d43ae4ba9e7f2e911c2b237cf44a0da755d..9e8df78ac660c328903157b47e7588de026d46eb 100644 (file)
@@ -40,10 +40,9 @@ __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
            'DefaulClause', 'FetchedValue', 'ColumnDefault', 'DDL']
 
 
-class SchemaItem(object):
+class SchemaItem(visitors.Visitable):
     """Base class for items that define a database schema."""
 
-    __metaclass__ = expression._FigureVisitName
     quote = None
 
     def _init_items(self, *args):
@@ -87,7 +86,7 @@ def _get_table_key(name, schema):
     else:
         return schema + "." + name
 
-class _TableSingleton(expression._FigureVisitName):
+class _TableSingleton(visitors.VisitableType):
     """A metaclass used by the ``Table`` object to provide singleton behavior."""
 
     def __call__(self, name, metadata, *args, **kwargs):
@@ -2050,3 +2049,4 @@ def _bind_or_error(schemaitem):
                'assign %s to enable implicit execution.') % (item, bindable)
         raise exc.UnboundExecutionError(msg)
     return bind
+
index b4069e6fd6d2ec2f688d309a649fd89d129fb473..63557f24b53af9f64862edf258eb7f5e57c39bb1 100644 (file)
@@ -20,7 +20,7 @@ is otherwise internal to SQLAlchemy.
 
 import string, re
 from sqlalchemy import schema, engine, util, exc
-from sqlalchemy.sql import operators, functions, util as sql_util
+from sqlalchemy.sql import operators, functions, util as sql_util, visitors
 from sqlalchemy.sql import expression as sql
 
 RESERVED_WORDS = set([
@@ -110,10 +110,9 @@ FUNCTIONS = {
 }
 
 
-class _CompileLabel(object):
+class _CompileLabel(visitors.Visitable):
     """lightweight label object which acts as an expression._Label."""
 
-    __metaclass__ = sql._FigureVisitName
     __visit_name__ = 'label'
     __slots__ = 'element', 'name'
     
index d937d05079e23f324f6909e412edce72cc25b28e..52291c487e81cae7294650b36177a78b76c987b4 100644 (file)
@@ -29,7 +29,8 @@ import itertools, re
 from operator import attrgetter
 
 from sqlalchemy import util, exc
-from sqlalchemy.sql import operators, visitors
+from sqlalchemy.sql import operators
+from sqlalchemy.sql.visitors import Visitable, cloned_traverse
 from sqlalchemy import types as sqltypes
 
 functions, schema, sql_util = None, None, None
@@ -876,7 +877,7 @@ def _compound_select(keyword, *selects, **kwargs):
     return CompoundSelect(keyword, *selects, **kwargs)
 
 def _is_literal(element):
-    return not isinstance(element, ClauseElement)
+    return not isinstance(element, Visitable) and not hasattr(element, '__clause_element__')
 
 def _from_objects(*elements, **kwargs):
     return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements])
@@ -890,7 +891,7 @@ def _labeled(element):
 def _literal_as_text(element):
     if hasattr(element, '__clause_element__'):
         return element.__clause_element__()
-    elif not isinstance(element, ClauseElement):
+    elif not isinstance(element, Visitable):
         return _TextClause(unicode(element))
     else:
         return element
@@ -898,7 +899,7 @@ def _literal_as_text(element):
 def _literal_as_column(element):
     if hasattr(element, '__clause_element__'):
         return element.__clause_element__()
-    elif not isinstance(element, ClauseElement):
+    elif not isinstance(element, Visitable):
         return literal_column(str(element))
     else:
         return element
@@ -906,7 +907,7 @@ def _literal_as_column(element):
 def _literal_as_binds(element, name=None, type_=None):
     if hasattr(element, '__clause_element__'):
         return element.__clause_element__()
-    elif not isinstance(element, ClauseElement):
+    elif not isinstance(element, Visitable):
         if element is None:
             return null()
         else:
@@ -917,15 +918,18 @@ def _literal_as_binds(element, name=None, type_=None):
 def _no_literals(element):
     if hasattr(element, '__clause_element__'):
         return element.__clause_element__()
-    elif not isinstance(element, ClauseElement):
-        raise exc.ArgumentError("Ambiguous literal: %r.  Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
+    elif not isinstance(element, Visitable):
+        raise exc.ArgumentError("Ambiguous literal: %r.  Use the 'text()' function "
+                "to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
     else:
         return element
     
 def _corresponding_column_or_error(fromclause, column, require_embedded=False):
     c = fromclause.corresponding_column(column, require_embedded=require_embedded)
     if not c:
-        raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
+        raise exc.InvalidRequestError("Given column '%s', attached to table '%s', "
+                "failed to locate a corresponding column from table '%s'" 
+                % (column, getattr(column, 'table', None), fromclause.description))
     return c
 
 def _selectable(element):
@@ -934,39 +938,15 @@ def _selectable(element):
     elif isinstance(element, Selectable):
         return element
     else:
-        raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
+        raise exc.ArgumentError("Object %r is not a Selectable and does not implement `__selectable__()`" % element)
 
 def is_column(col):
     """True if ``col`` is an instance of ``ColumnElement``."""
     return isinstance(col, ColumnElement)
 
-class _FigureVisitName(type):
-    def __init__(cls, clsname, bases, dict):
-        if not '__visit_name__' in cls.__dict__:
-            m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
-            x = m.group(1)
-            x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
-            cls.__visit_name__ = x.lower()
-        
-        # set up an optimized visit dispatch function
-        # for use by the compiler
-        visit_name = cls.__dict__["__visit_name__"]
-        if isinstance(visit_name, str):
-            func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
-            "    return visitor.visit_%s(self, **kw)" % visit_name
-        else:
-            func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
-            "    return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)"
     
-        env = locals().copy()
-        exec func_text in env
-        cls._compiler_dispatch = env['_compiler_dispatch']
-        
-        super(_FigureVisitName, cls).__init__(clsname, bases, dict)
-
-class ClauseElement(object):
+class ClauseElement(Visitable):
     """Base class for elements of a programmatically constructed SQL expression."""
-    __metaclass__ = _FigureVisitName
     _annotations = {}
     supports_execution = False
     
@@ -976,6 +956,7 @@ class ClauseElement(object):
         This method may be used by a generative API.  Its also used as
         part of the "deep" copy afforded by a traversal that combines
         the _copy_internals() method.
+        
         """
         c = self.__class__.__new__(self.__class__)
         c.__dict__ = self.__dict__.copy()
@@ -1001,8 +982,8 @@ class ClauseElement(object):
         should be added to the ``FROM`` list of a query, when this
         ``ClauseElement`` is placed in the column clause of a
         ``Select`` statement.
+        
         """
-
         raise NotImplementedError(repr(self))
     
     def _annotate(self, values):
@@ -1049,7 +1030,7 @@ class ClauseElement(object):
                 bind.value = kwargs[bind.key]
             if unique:
                 bind._convert_to_unique()
-        return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam})
+        return cloned_traverse(self, {}, {'bindparam':visit_bindparam})
 
     def compare(self, other):
         """Compare this ClauseElement to the given ClauseElement.
@@ -2637,14 +2618,13 @@ class _ColumnClause(_Immutable, ColumnElement):
       rules applied regardless of case sensitive settings.  the
       ``literal_column()`` function is usually used to create such a
       ``_ColumnClause``.
+      
     """
-
     def __init__(self, text, selectable=None, type_=None, is_literal=False):
         ColumnElement.__init__(self)
         self.key = self.name = text
         self.table = selectable
         self.type = sqltypes.to_instance(type_)
-        self.__label = None
         self.is_literal = is_literal
 
     @util.memoized_property
@@ -2655,23 +2635,25 @@ class _ColumnClause(_Immutable, ColumnElement):
     def _label(self):
         if self.is_literal:
             return None
-        if not self.__label:
-            if self.table and self.table.named_with_column:
-                if getattr(self.table, 'schema', None):
-                    self.__label = self.table.schema + "_" + self.table.name + "_" + self.name
-                else:
-                    self.__label = self.table.name + "_" + self.name
-                    
-                if self.__label in self.table.c:
-                    label = self.__label
-                    counter = 1
-                    while label in self.table.c:
-                        label = self.__label + "_" + str(counter)
-                        counter += 1
-                    self.__label = label
+            
+        elif self.table and self.table.named_with_column:
+            if getattr(self.table, 'schema', None):
+                label = self.table.schema + "_" + self.table.name + "_" + self.name
             else:
-                self.__label = self.name
-        return self.__label
+                label = self.table.name + "_" + self.name
+                
+            if label in self.table.c:
+                # TODO: coverage does not seem to be present for this
+                _label = label
+                counter = 1
+                while _label in self.table.c:
+                    _label = label + "_" + str(counter)
+                    counter += 1
+                label = _label
+            return label
+            
+        else:
+            return self.name
 
     def label(self, name):
         if name is None:
@@ -2723,7 +2705,7 @@ class TableClause(_Immutable, FromClause):
     def _export_columns(self):
         raise NotImplementedError()
 
-    @property
+    @util.memoized_property
     def description(self):
         return self.name.encode('ascii', 'backslashreplace')
 
@@ -2756,6 +2738,14 @@ class TableClause(_Immutable, FromClause):
     def _get_from_objects(self, **modifiers):
         return [self]
 
+@util.decorator
+def _generative(fn, *args, **kw):
+    """Mark a method as generative."""
+
+    self = args[0]._generate()
+    fn(self, *args[1:], **kw)
+    return self
+
 class _SelectBaseMixin(object):
     """Base class for ``Select`` and ``CompoundSelects``."""
 
@@ -2784,6 +2774,7 @@ class _SelectBaseMixin(object):
         """
         return _ScalarSelect(self)
 
+    @_generative
     def apply_labels(self):
         """return a new selectable with the 'use_labels' flag set to True.
 
@@ -2793,9 +2784,7 @@ class _SelectBaseMixin(object):
         among the individual FROM clauses.
 
         """
-        s = self._generate()
-        s.use_labels = True
-        return s
+        self.use_labels = True
 
     def label(self, name):
         """return a 'scalar' representation of this selectable, embedded as a subquery
@@ -2806,12 +2795,11 @@ class _SelectBaseMixin(object):
         """
         return self.as_scalar().label(name)
 
+    @_generative
     def autocommit(self):
         """return a new selectable with the 'autocommit' flag set to True."""
 
-        s = self._generate()
-        s._autocommit = True
-        return s
+        self._autocommit = True
 
     def _generate(self):
         s = self.__class__.__new__(self.__class__)
@@ -2819,39 +2807,35 @@ class _SelectBaseMixin(object):
         s._reset_exported()
         return s
 
+    @_generative
     def limit(self, limit):
         """return a new selectable with the given LIMIT criterion applied."""
 
-        s = self._generate()
-        s._limit = limit
-        return s
+        self._limit = limit
 
+    @_generative
     def offset(self, offset):
         """return a new selectable with the given OFFSET criterion applied."""
 
-        s = self._generate()
-        s._offset = offset
-        return s
+        self._offset = offset
 
+    @_generative
     def order_by(self, *clauses):
         """return a new selectable with the given list of ORDER BY criterion applied.
 
         The criterion will be appended to any pre-existing ORDER BY criterion.
 
         """
-        s = self._generate()
-        s.append_order_by(*clauses)
-        return s
+        self.append_order_by(*clauses)
 
+    @_generative
     def group_by(self, *clauses):
         """return a new selectable with the given list of GROUP BY criterion applied.
 
         The criterion will be appended to any pre-existing GROUP BY criterion.
 
         """
-        s = self._generate()
-        s.append_group_by(*clauses)
-        return s
+        self.append_group_by(*clauses)
 
     def append_order_by(self, *clauses):
         """Append the given ORDER BY criterion applied to this selectable.
@@ -3112,72 +3096,67 @@ class Select(_SelectBaseMixin, FromClause):
             self._raw_columns + list(self._froms) + \
             [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
 
+    @_generative
     def column(self, column):
         """return a new select() construct with the given column expression added to its columns clause."""
 
-        s = self._generate()
         column = _literal_as_column(column)
 
         if isinstance(column, _ScalarSelect):
             column = column.self_group(against=operators.comma_op)
 
-        s._raw_columns = s._raw_columns + [column]
-        s._froms = s._froms.union(_from_objects(column))
-        return s
+        self._raw_columns = self._raw_columns + [column]
+        self._froms = self._froms.union(_from_objects(column))
 
+    @_generative
     def with_only_columns(self, columns):
         """return a new select() construct with its columns clause replaced with the given columns."""
-        s = self._generate()
-        s._raw_columns = [
+
+        self._raw_columns = [
                 isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c
                 for c in
                 [_literal_as_column(c) for c in columns]
             ]
-        return s
 
+    @_generative
     def where(self, whereclause):
         """return a new select() construct with the given expression added to its WHERE clause, joined
         to the existing clause via AND, if any."""
 
-        s = self._generate()
-        s.append_whereclause(whereclause)
-        return s
+        self.append_whereclause(whereclause)
 
+    @_generative
     def having(self, having):
         """return a new select() construct with the given expression added to its HAVING clause, joined
         to the existing clause via AND, if any."""
 
-        s = self._generate()
-        s.append_having(having)
-        return s
+        self.append_having(having)
 
+    @_generative
     def distinct(self):
         """return a new select() construct which will apply DISTINCT to its columns clause."""
 
-        s = self._generate()
-        s._distinct = True
-        return s
+        self._distinct = True
 
+    @_generative
     def prefix_with(self, clause):
         """return a new select() construct which will apply the given expression to the start of its
         columns clause, not using any commas."""
 
-        s = self._generate()
         clause = _literal_as_text(clause)
-        s._prefixes = s._prefixes + [clause]
-        return s
+        self._prefixes = self._prefixes + [clause]
 
+    @_generative
     def select_from(self, fromclause):
         """return a new select() construct with the given FROM expression applied to its list of
         FROM objects."""
 
-        s = self._generate()
         if _is_literal(fromclause):
             fromclause = _TextClause(fromclause)
 
-        s._froms = s._froms.union([fromclause])
-        return s
+        self._froms = self._froms.union([fromclause])
 
+    @_generative
     def correlate(self, *fromclauses):
         """return a new select() construct which will correlate the given FROM clauses to that
         of an enclosing select(), if a match is found.
@@ -3192,13 +3171,11 @@ class Select(_SelectBaseMixin, FromClause):
         If the fromclause is None, correlation is disabled for the returned select().
         
         """
-        s = self._generate()
-        s._should_correlate = False
+        self._should_correlate = False
         if fromclauses == (None,):
-            s._correlate = set()
+            self._correlate = set()
         else:
-            s._correlate = s._correlate.union(fromclauses)
-        return s
+            self._correlate = self._correlate.union(fromclauses)
 
     def append_correlation(self, fromclause):
         """append the given correlation expression to this select() construct."""
@@ -3416,16 +3393,15 @@ class Insert(_ValuesBase):
     def _copy_internals(self, clone=_clone):
         self.parameters = self.parameters.copy()
 
+    @_generative
     def prefix_with(self, clause):
         """Add a word or expression between INSERT and INTO. Generative.
 
         If multiple prefixes are supplied, they will be separated with
         spaces.
         """
-        gen = self._generate()
         clause = _literal_as_text(clause)
-        gen._prefixes = self._prefixes + [clause]
-        return gen
+        self._prefixes = self._prefixes + [clause]
 
 class Update(_ValuesBase):
     def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs):
@@ -3450,16 +3426,15 @@ class Update(_ValuesBase):
         self._whereclause = clone(self._whereclause)
         self.parameters = self.parameters.copy()
 
+    @_generative
     def where(self, whereclause):
         """return a new update() construct with the given expression added to its WHERE clause, joined
         to the existing clause via AND, if any."""
         
-        s = self._generate()
-        if s._whereclause is not None:
-            s._whereclause = and_(s._whereclause, _literal_as_text(whereclause))
+        if self._whereclause is not None:
+            self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
         else:
-            s._whereclause = _literal_as_text(whereclause)
-        return s
+            self._whereclause = _literal_as_text(whereclause)
 
 
 class Delete(_UpdateBase):
@@ -3479,16 +3454,15 @@ class Delete(_UpdateBase):
         else:
             return ()
 
+    @_generative
     def where(self, whereclause):
         """return a new delete() construct with the given expression added to its WHERE clause, joined
         to the existing clause via AND, if any."""
         
-        s = self._generate()
-        if s._whereclause is not None:
-            s._whereclause = and_(s._whereclause, _literal_as_text(whereclause))
+        if self._whereclause is not None:
+            self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
         else:
-            s._whereclause = _literal_as_text(whereclause)
-        return s
+            self._whereclause = _literal_as_text(whereclause)
         
     def _copy_internals(self, clone=_clone):
         self._whereclause = clone(self._whereclause)
index c7a0f142d0b9e4996597e40796e97e4821a3a5d8..87901b266ca913121d302a29a4cff8629e72ac2f 100644 (file)
@@ -1,11 +1,11 @@
 from sqlalchemy import types as sqltypes
 from sqlalchemy.sql.expression import (
-    ClauseList, _FigureVisitName, _Function, _literal_as_binds, text
+    ClauseList, _Function, _literal_as_binds, text
     )
 from sqlalchemy.sql import operators
+from sqlalchemy.sql.visitors import VisitableType
 
-
-class _GenericMeta(_FigureVisitName):
+class _GenericMeta(VisitableType):
     def __init__(cls, clsname, bases, dict):
         cls.__visit_name__ = 'function'
         type.__init__(cls, clsname, bases, dict)
index 3f06535de74e69b7a26e6e3ddaee9e9da04cbe51..371c6ec4befb7aa2d577c1011b56e8a51de6394b 100644 (file)
@@ -1,10 +1,39 @@
 from collections import deque
+import re
+from sqlalchemy import util
+
+class VisitableType(type):
+    def __init__(cls, clsname, bases, dict):
+        if not '__visit_name__' in cls.__dict__:
+            m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
+            x = m.group(1)
+            x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
+            cls.__visit_name__ = x.lower()
+        
+        # set up an optimized visit dispatch function
+        # for use by the compiler
+        visit_name = cls.__dict__["__visit_name__"]
+        if isinstance(visit_name, str):
+            func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
+            "    return visitor.visit_%s(self, **kw)" % visit_name
+        else:
+            func_text = "def _compiler_dispatch(self, visitor, **kw):\n"\
+            "    return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)"
+    
+        env = locals().copy()
+        exec func_text in env
+        cls._compiler_dispatch = env['_compiler_dispatch']
+        
+        super(VisitableType, cls).__init__(clsname, bases, dict)
+
+class Visitable(object):
+    __metaclass__ = VisitableType
 
 class ClauseVisitor(object):
     __traverse_options__ = {}
     
     def traverse_single(self, obj):
-        for v in self._iterate_visitors:
+        for v in self._visitor_iterator:
             meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
             if meth:
                 return meth(obj)
@@ -17,29 +46,33 @@ class ClauseVisitor(object):
     def traverse(self, obj):
         """traverse and visit the given expression structure."""
 
+        return traverse(obj, self.__traverse_options__, self._visitor_dict)
+    
+    @util.memoized_property
+    def _visitor_dict(self):
         visitors = {}
 
         for name in dir(self):
             if name.startswith('visit_'):
                 visitors[name[6:]] = getattr(self, name)
-            
-        return traverse(obj, self.__traverse_options__, visitors)
-
-    def _iterate_visitors(self):
+        return visitors
+        
+    @property
+    def _visitor_iterator(self):
         """iterate through this visitor and each 'chained' visitor."""
         
         v = self
         while v:
             yield v
             v = getattr(v, '_next', None)
-    _iterate_visitors = property(_iterate_visitors)
 
     def chain(self, visitor):
         """'chain' an additional ClauseVisitor onto this ClauseVisitor.
         
         the chained visitor will receive all visit events after this one.
+        
         """
-        tail = list(self._iterate_visitors)[-1]
+        tail = list(self._visitor_iterator)[-1]
         tail._next = visitor
         return self
 
@@ -52,13 +85,7 @@ class CloningVisitor(ClauseVisitor):
     def traverse(self, obj):
         """traverse and visit the given expression structure."""
 
-        visitors = {}
-
-        for name in dir(self):
-            if name.startswith('visit_'):
-                visitors[name[6:]] = getattr(self, name)
-            
-        return cloned_traverse(obj, self.__traverse_options__, visitors)
+        return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict)
 
 class ReplacingCloningVisitor(CloningVisitor):
     def replace(self, elem):
@@ -74,7 +101,7 @@ class ReplacingCloningVisitor(CloningVisitor):
         """traverse and visit the given expression structure."""
 
         def replace(elem):
-            for v in self._iterate_visitors:
+            for v in self._visitor_iterator:
                 e = v.replace(elem)
                 if e:
                     return e
index 6cf0771c4d0289b07c08775d361ee636563020ef..41f6cdc5dfd7c53547113798619c24a29dae0673 100644 (file)
@@ -324,7 +324,7 @@ class ZooMarkTest(TestBase):
     def test_profile_1_create_tables(self):
         self.test_baseline_1_create_tables()
 
-    @profiling.function_call_count(5726, {'2.4': 3844})
+    @profiling.function_call_count(5726, {'2.4': 3650})
     def test_profile_1a_populate(self):
         self.test_baseline_1a_populate()