]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- turn __visit_name__ into an explicit member.
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Dec 2008 23:28:01 +0000 (23:28 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Dec 2008 23:28:01 +0000 (23:28 +0000)
[ticket:1244]

lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
test/sql/generative.py

index da534a2fa0289c830feabb5a935c1cb3e0c636c0..dc523b36c698334ad2997f1df30d4e6c9c4339d0 100644 (file)
@@ -43,6 +43,7 @@ __all__.sort()
 class SchemaItem(visitors.Visitable):
     """Base class for items that define a database schema."""
 
+    __visit_name__ = 'schema_item'
     quote = None
 
     def _init_items(self, *args):
@@ -121,6 +122,8 @@ class Table(SchemaItem, expression.TableClause):
 
     __metaclass__ = _TableSingleton
 
+    __visit_name__ = 'table'
+
     ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop')
 
     def __init__(self, name, metadata, *args, **kwargs):
@@ -408,6 +411,8 @@ class Table(SchemaItem, expression.TableClause):
 class Column(SchemaItem, expression.ColumnClause):
     """Represents a column in a database table."""
 
+    __visit_name__ = 'column'
+
     def __init__(self, *args, **kwargs):
         """
         Construct a new ``Column`` object.
@@ -761,6 +766,9 @@ class ForeignKey(SchemaItem):
     Further examples of foreign key configuration are in :ref:`metadata_foreignkeys`.
 
     """
+
+    __visit_name__ = 'foreign_key'
+
     def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None, deferrable=None, initially=None):
         """
         Construct a column-level FOREIGN KEY.
@@ -921,6 +929,8 @@ class ForeignKey(SchemaItem):
 class DefaultGenerator(SchemaItem):
     """Base class for column *default* values."""
 
+    __visit_name__ = 'default_generator'
+
     def __init__(self, for_update=False, metadata=None):
         self.for_update = for_update
         self.metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata')
@@ -1001,6 +1011,8 @@ class ColumnDefault(DefaultGenerator):
 class Sequence(DefaultGenerator):
     """Represents a named database sequence."""
 
+    __visit_name__ = 'sequence'
+
     def __init__(self, name, start=None, increment=None, schema=None,
                  optional=False, quote=None, **kwargs):
         super(Sequence, self).__init__(**kwargs)
@@ -1078,6 +1090,8 @@ class Constraint(SchemaItem):
     underying columns.
     """
 
+    __visit_name__ = 'constraint'
+
     def __init__(self, name=None, deferrable=None, initially=None):
         """Create a SQL constraint.
 
@@ -1175,6 +1189,7 @@ class ForeignKeyConstraint(Constraint):
     Examples of foreign key configuration are in :ref:`metadata_foreignkeys`.
     
     """
+    __visit_name__ = 'foreign_key_constraint'
 
     def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False, deferrable=None, initially=None):
         """Construct a composite-capable FOREIGN KEY.
@@ -1252,6 +1267,8 @@ class PrimaryKeyConstraint(Constraint):
     multiple-column PrimaryKeyConstraint.
     """
 
+    __visit_name__ = 'primary_key_constraint'
+
     def __init__(self, *columns, **kwargs):
         """Construct a composite-capable PRIMARY KEY.
 
@@ -1315,6 +1332,8 @@ class UniqueConstraint(Constraint):
     UniqueConstraint.
     """
 
+    __visit_name__ = 'unique_constraint'
+
     def __init__(self, *columns, **kwargs):
         """Construct a UNIQUE constraint.
 
@@ -1365,6 +1384,8 @@ class Index(SchemaItem):
     a shorthand equivalent for an unnamed, single column Index.
     """
 
+    __visit_name__ = 'index'
+
     def __init__(self, name, *columns, **kwargs):
         """Construct an index object.
 
index 49629069cd31bc24d1ebb2e7d88d27721bdb827b..6bda4c82c3978f7534f944a42f19b98a5563e114 100644 (file)
@@ -961,6 +961,8 @@ def is_column(col):
 class ClauseElement(Visitable):
     """Base class for elements of a programmatically constructed SQL expression."""
 
+    __visit_name__ = 'clause'
+
     _annotations = {}
     supports_execution = False
     _from_objects = []
@@ -1567,6 +1569,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
 
     """
 
+    __visit_name__ = 'column'
     primary_key = False
     foreign_keys = []
     quote = None
@@ -1734,6 +1737,7 @@ class ColumnSet(util.OrderedSet):
 
 class Selectable(ClauseElement):
     """mark a class as being selectable"""
+    __visit_name__ = 'selectable'
 
 class FromClause(Selectable):
     """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement."""
@@ -2062,6 +2066,8 @@ class _Null(ColumnElement):
 
     """
 
+    __visit_name__ = 'null'
+
     def __init__(self):
         self.type = sqltypes.NULLTYPE
 
@@ -2211,6 +2217,8 @@ class _Function(_CalculatedClause, FromClause):
 
     """
 
+    __visit_name__ = 'function'
+
     def __init__(self, name, *clauses, **kwargs):
         self.packagenames = kwargs.get('packagenames', None) or []
         self.name = name
@@ -2237,6 +2245,8 @@ class _Function(_CalculatedClause, FromClause):
 
 class _Cast(ColumnElement):
 
+    __visit_name__ = 'cast'
+
     def __init__(self, clause, totype, **kwargs):
         self.type = sqltypes.to_instance(totype)
         self.clause = _literal_as_binds(clause, None)
@@ -2255,6 +2265,9 @@ class _Cast(ColumnElement):
 
 
 class _UnaryExpression(ColumnElement):
+
+    __visit_name__ = 'unary'
+
     def __init__(self, element, operator=None, modifier=None, type_=None, negate=None):
         self.operator = operator
         self.modifier = modifier
@@ -2304,6 +2317,8 @@ class _UnaryExpression(ColumnElement):
 class _BinaryExpression(ColumnElement):
     """Represent an expression that is ``LEFT <operator> RIGHT``."""
 
+    __visit_name__ = 'binary'
+
     def __init__(self, left, right, operator, type_=None, negate=None, modifiers=None):
         self.left = _literal_as_text(left).self_group(against=operator)
         self.right = _literal_as_text(right).self_group(against=operator)
@@ -2408,6 +2423,7 @@ class Join(FromClause):
     off all ``FromClause`` subclasses.
 
     """
+    __visit_name__ = 'join'
 
     def __init__(self, left, right, onclause=None, isouter=False):
         self.left = _selectable(left)
@@ -2528,6 +2544,7 @@ class Alias(FromClause):
 
     """
 
+    __visit_name__ = 'alias'
     named_with_column = True
 
     def __init__(self, selectable, alias=None):
@@ -2584,6 +2601,8 @@ class Alias(FromClause):
 class _Grouping(ColumnElement):
     """Represent a grouping within a column expression"""
 
+    __visit_name__ = 'grouping'
+
     def __init__(self, element):
         self.element = element
         self.type = getattr(element, 'type', None)
@@ -2656,6 +2675,8 @@ class _Label(ColumnElement):
 
     """
 
+    __visit_name__ = 'label'
+
     def __init__(self, name, element, type_=None):
         while isinstance(element, _Label):
             element = element.element
@@ -2729,6 +2750,8 @@ class ColumnClause(_Immutable, ColumnElement):
       ``ColumnClause``.
 
     """
+    __visit_name__ = 'column'
+
     def __init__(self, text, selectable=None, type_=None, is_literal=False):
         self.key = self.name = text
         self.table = selectable
@@ -2801,6 +2824,8 @@ class TableClause(_Immutable, FromClause):
 
     """
 
+    __visit_name__ = 'table'
+
     named_with_column = True
 
     def __init__(self, name, *columns):
@@ -2994,7 +3019,6 @@ class _SelectBaseMixin(object):
 
 
 class _ScalarSelect(_Grouping):
-    __visit_name__ = 'grouping'
     _from_objects = []
 
     def __init__(self, element):
@@ -3020,6 +3044,8 @@ class _ScalarSelect(_Grouping):
 class CompoundSelect(_SelectBaseMixin, FromClause):
     """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations."""
 
+    __visit_name__ = 'compound_select'
+
     def __init__(self, keyword, *selects, **kwargs):
         self._should_correlate = kwargs.pop('correlate', False)
         self.keyword = keyword
@@ -3087,6 +3113,9 @@ class Select(_SelectBaseMixin, FromClause):
     ability to execute themselves and return a result set.
 
     """
+
+    __visit_name__ = 'select'
+
     def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, prefixes=None, **kwargs):
         """Construct a Select object.
 
@@ -3444,6 +3473,8 @@ class Select(_SelectBaseMixin, FromClause):
 class _UpdateBase(ClauseElement):
     """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
 
+    __visit_name__ = 'update_base'
+
     supports_execution = True
     _autocommit = True
 
@@ -3477,6 +3508,9 @@ class _UpdateBase(ClauseElement):
     bind = property(bind, _set_bind)
 
 class _ValuesBase(_UpdateBase):
+
+    __visit_name__ = 'values_base'
+
     def __init__(self, table, values):
         self.table = table
         self.parameters = self._process_colparams(values)
@@ -3512,6 +3546,8 @@ class Insert(_ValuesBase):
     The ``Insert`` object is created using the :func:`insert()` function.
 
     """
+    __visit_name__ = 'insert'
+
     def __init__(self, table, values=None, inline=False, bind=None, prefixes=None, **kwargs):
         _ValuesBase.__init__(self, table, values)
         self._bind = bind
@@ -3550,6 +3586,8 @@ class Update(_ValuesBase):
     The ``Update`` object is created using the :func:`update()` function.
 
     """
+    __visit_name__ = 'update'
+
     def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs):
         _ValuesBase.__init__(self, table, values)
         self._bind = bind
@@ -3589,6 +3627,8 @@ class Delete(_UpdateBase):
 
     """
 
+    __visit_name__ = 'delete'
+
     def __init__(self, table, whereclause, bind=None, **kwargs):
         self._bind = bind
         self.table = table
@@ -3619,6 +3659,7 @@ class Delete(_UpdateBase):
         self._whereclause = clone(self._whereclause)
 
 class _IdentifiedClause(ClauseElement):
+    __visit_name__ = 'identified'
     supports_execution = True
     _autocommit = False
     quote = None
@@ -3627,10 +3668,10 @@ class _IdentifiedClause(ClauseElement):
         self.ident = ident
 
 class SavepointClause(_IdentifiedClause):
-    pass
+    __visit_name__ = 'savepoint'
 
 class RollbackToSavepointClause(_IdentifiedClause):
-    pass
+    __visit_name__ = 'rollback_to_savepoint'
 
 class ReleaseSavepointClause(_IdentifiedClause):
-    pass
+    __visit_name__ = 'release_savepoint'
index 8b62910ace19b8ba700a9daa388c9bdb343c5f02..b57b242f521e4edd267da52d0f20108d609118c9 100644 (file)
@@ -6,10 +6,6 @@ from sqlalchemy.sql import operators
 from sqlalchemy.sql.visitors import VisitableType
 
 class _GenericMeta(VisitableType):
-    def __init__(cls, clsname, bases, dict):
-        cls.__visit_name__ = 'function'
-        type.__init__(cls, clsname, bases, dict)
-
     def __call__(self, *args, **kwargs):
         args = [_literal_as_binds(c) for c in args]
         return type.__call__(self, *args, **kwargs)
index 547753b283072a3635f4596b1252012c71a3cd9f..9b9b9ec094b65e59fe05c9f0cabb3639a2d24781 100644 (file)
@@ -121,7 +121,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False):
     else:
         return sql.and_(*crit)
 
-    
+
 class Annotated(object):
     """clones a ClauseElement and applies an 'annotations' dictionary.
     
@@ -146,7 +146,9 @@ class Annotated(object):
             try:
                 cls = annotated_classes[element.__class__]
             except KeyError:
-                raise KeyError("Class %s has not defined an Annotated subclass" % element.__class__)
+                cls = annotated_classes[element.__class__] = type.__new__(type, 
+                        "Annotated%s" % element.__class__.__name__, 
+                        (Annotated, element.__class__), {})
             return object.__new__(cls)
 
     def __init__(self, element, values):
index 5d1d53cf8e59e7ebe6ab3825726b79492148f685..17b9c59d56de0f0e1b1c74f45628edbc36b9e33d 100644 (file)
@@ -28,21 +28,22 @@ __all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
     'cloned_traverse', 'replacement_traverse']
     
 class VisitableType(type):
-    """Metaclass which applies a `__visit_name__` attribute and 
-    `_compiler_dispatch` method to classes.
+    """Metaclass which checks for a `__visit_name__` attribute and
+    applies `_compiler_dispatch` method to classes.
     
     """
     
-    def __init__(cls, clsname, bases, dict):
-        if not '__visit_name__' in cls.__dict__:
-            m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
-            x = m.group(1)
-            x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
-            cls.__visit_name__ = x.lower()
+    def __init__(cls, clsname, bases, clsdict):
+        if cls.__name__ == 'Visitable':
+            super(VisitableType, cls).__init__(clsname, bases, clsdict)
+            return
+        
+        assert hasattr(cls, '__visit_name__'), "`Visitable` descendants " \
+                                               "should define `__visit_name__`"
         
         # set up an optimized visit dispatch function
         # for use by the compiler
-        visit_name = cls.__dict__["__visit_name__"]
+        visit_name = cls.__visit_name__
         if isinstance(visit_name, str):
             getter = operator.attrgetter("visit_%s" % visit_name)
             def _compiler_dispatch(self, visitor, **kw):
@@ -53,7 +54,7 @@ class VisitableType(type):
     
         cls._compiler_dispatch = _compiler_dispatch
         
-        super(VisitableType, cls).__init__(clsname, bases, dict)
+        super(VisitableType, cls).__init__(clsname, bases, clsdict)
 
 class Visitable(object):
     """Base class for visitable objects, applies the
index 4edf334f667db5220baf54047e45712fc60db91f..daa2432da3346a6d971ff330b8cdc80555571895 100644 (file)
@@ -18,6 +18,8 @@ class TraversalTest(TestBase, AssertsExecutionResults):
         # establish two ficticious ClauseElements.
         # define deep equality semantics as well as deep identity semantics.
         class A(ClauseElement):
+            __visit_name__ = 'a'
+
             def __init__(self, expr):
                 self.expr = expr
 
@@ -34,6 +36,8 @@ class TraversalTest(TestBase, AssertsExecutionResults):
                 return "A(%s)" % repr(self.expr)
 
         class B(ClauseElement):
+            __visit_name__ = 'b'
+
             def __init__(self, *items):
                 self.items = items
 
@@ -137,6 +141,19 @@ class TraversalTest(TestBase, AssertsExecutionResults):
         assert struct != s3
         assert struct3 == s3
 
+    def test_visit_name(self):
+        # override fns in testlib/schema.py
+        from sqlalchemy import Column
+
+        class CustomObj(Column):
+            pass
+            
+        assert CustomObj.__visit_name__ == Column.__visit_name__ == 'column'
+        
+        foo, bar = CustomObj('foo', String), CustomObj('bar', String)
+        bin = foo == bar
+        s = set(ClauseVisitor().iterate(bin))
+        assert set(ClauseVisitor().iterate(bin)) == set([foo, bar, bin])
 
 class ClauseTest(TestBase, AssertsCompiledSQL):
     """test copy-in-place behavior of various ClauseElements."""