]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the compiler extension now allows @compiles decorators
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 30 Mar 2010 14:39:36 +0000 (10:39 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 30 Mar 2010 14:39:36 +0000 (10:39 -0400)
on base classes that extend to child classes, @compiles
decorators on child classes that aren't broken by a
@compiles decorator on the base class.

CHANGES
lib/sqlalchemy/ext/compiler.py
lib/sqlalchemy/sql/visitors.py
test/ext/test_compiler.py

diff --git a/CHANGES b/CHANGES
index dd6211fc1d994b8a74483663958175c065853d21..cca3298f072a0cbbd11d8c0e705899c2a08b0709 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -26,7 +26,13 @@ CHANGES
     will expunge the object if the cascade also includes
     "delete-orphan", or will simply detach it otherwise.
     [ticket:1754]
-    
+
+- ext
+   - the compiler extension now allows @compiles decorators
+     on base classes that extend to child classes, @compiles
+     decorators on child classes that aren't broken by a 
+     @compiles decorator on the base class.
+   
 0.6beta3
 ========
 
index 3226b0efd8ae803b009d4b534be54c6a2720e4c8..dde49e232e51d762c8cad149361d53bd804a1d77 100644 (file)
@@ -165,7 +165,7 @@ A big part of using the compiler extension is subclassing SQLAlchemy expression
 
 def compiles(class_, *specs):
     def decorate(fn):
-        existing = getattr(class_, '_compiler_dispatcher', None)
+        existing = class_.__dict__.get('_compiler_dispatcher', None)
         if not existing:
             existing = _dispatcher()
 
index 4a54375f8babd38610e0a38f25a00af58335a184..799486c02f7346059ab419c540d4fe8d2edc0be9 100644 (file)
@@ -40,16 +40,17 @@ class VisitableType(type):
         
         # set up an optimized visit dispatch function
         # for use by the compiler
-        visit_name = cls.__visit_name__
-        if isinstance(visit_name, str):
-            getter = operator.attrgetter("visit_%s" % visit_name)
-            def _compiler_dispatch(self, visitor, **kw):
-                return getter(visitor)(self, **kw)
-        else:
-            def _compiler_dispatch(self, visitor, **kw):
-                return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
-    
-        cls._compiler_dispatch = _compiler_dispatch
+        if '__visit_name__' in cls.__dict__:
+            visit_name = cls.__visit_name__
+            if isinstance(visit_name, str):
+                getter = operator.attrgetter("visit_%s" % visit_name)
+                def _compiler_dispatch(self, visitor, **kw):
+                    return getter(visitor)(self, **kw)
+            else:
+                def _compiler_dispatch(self, visitor, **kw):
+                    return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
+
+            cls._compiler_dispatch = _compiler_dispatch
         
         super(VisitableType, cls).__init__(clsname, bases, clsdict)
 
index d625ae0ca07d00439f8dd0651229eb6aaef7e30c..2d33b382afededc9059c64491b0bc3a02f724356 100644 (file)
@@ -176,6 +176,61 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
             "timezone('utc', current_timestamp)",
             dialect=postgresql.dialect()
         )
-            
+    
+    def test_subclasses_one(self):
+        class Base(FunctionElement):
+            name = 'base'
+        
+        class Sub1(Base):
+            name = 'sub1'
+
+        class Sub2(Base):
+            name = 'sub2'
+        
+        @compiles(Base)
+        def visit_base(element, compiler, **kw):
+            return element.name
+
+        @compiles(Sub1)
+        def visit_base(element, compiler, **kw):
+            return "FOO" + element.name
+
+        self.assert_compile(
+            select([Sub1(), Sub2()]),
+            'SELECT FOOsub1, sub2',
+            use_default_dialect=True
+        )
+    
+    def test_subclasses_two(self):
+        class Base(FunctionElement):
+            name = 'base'
         
+        class Sub1(Base):
+            name = 'sub1'
+
+        @compiles(Base)
+        def visit_base(element, compiler, **kw):
+            return element.name
+
+        class Sub2(Base):
+            name = 'sub2'
+        
+        class SubSub1(Sub1):
+            name = 'subsub1'
+            
+        self.assert_compile(
+            select([Sub1(), Sub2(), SubSub1()]),
+            'SELECT sub1, sub2, subsub1',
+            use_default_dialect=True
+        )
+
+        @compiles(Sub1)
+        def visit_base(element, compiler, **kw):
+            return "FOO" + element.name
+
+        self.assert_compile(
+            select([Sub1(), Sub2(), SubSub1()]),
+            'SELECT FOOsub1, sub2, FOOsubsub1',
+            use_default_dialect=True
+        )
         
\ No newline at end of file