From: Mike Bayer Date: Tue, 30 Mar 2010 14:39:36 +0000 (-0400) Subject: - the compiler extension now allows @compiles decorators X-Git-Tag: rel_0_6_0~82^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=43b9f0d116580474ac56c532a1427a4cdeb3748b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - the compiler extension now allows @compiles decorators on base classes that extend to child classes, @compiles decorators on child classes that aren't broken by a @compiles decorator on the base class. --- diff --git a/CHANGES b/CHANGES index dd6211fc1d..cca3298f07 100644 --- a/CHANGES +++ b/CHANGES @@ -26,7 +26,13 @@ CHANGES will expunge the object if the cascade also includes "delete-orphan", or will simply detach it otherwise. [ticket:1754] - + +- ext + - the compiler extension now allows @compiles decorators + on base classes that extend to child classes, @compiles + decorators on child classes that aren't broken by a + @compiles decorator on the base class. + 0.6beta3 ======== diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 3226b0efd8..dde49e232e 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -165,7 +165,7 @@ A big part of using the compiler extension is subclassing SQLAlchemy expression def compiles(class_, *specs): def decorate(fn): - existing = getattr(class_, '_compiler_dispatcher', None) + existing = class_.__dict__.get('_compiler_dispatcher', None) if not existing: existing = _dispatcher() diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 4a54375f8b..799486c02f 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -40,16 +40,17 @@ class VisitableType(type): # set up an optimized visit dispatch function # for use by the compiler - visit_name = cls.__visit_name__ - if isinstance(visit_name, str): - getter = operator.attrgetter("visit_%s" % visit_name) - def _compiler_dispatch(self, visitor, **kw): - return getter(visitor)(self, **kw) - else: - def _compiler_dispatch(self, visitor, **kw): - return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) - - cls._compiler_dispatch = _compiler_dispatch + if '__visit_name__' in cls.__dict__: + visit_name = cls.__visit_name__ + if isinstance(visit_name, str): + getter = operator.attrgetter("visit_%s" % visit_name) + def _compiler_dispatch(self, visitor, **kw): + return getter(visitor)(self, **kw) + else: + def _compiler_dispatch(self, visitor, **kw): + return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) + + cls._compiler_dispatch = _compiler_dispatch super(VisitableType, cls).__init__(clsname, bases, clsdict) diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index d625ae0ca0..2d33b382af 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -176,6 +176,61 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): "timezone('utc', current_timestamp)", dialect=postgresql.dialect() ) - + + def test_subclasses_one(self): + class Base(FunctionElement): + name = 'base' + + class Sub1(Base): + name = 'sub1' + + class Sub2(Base): + name = 'sub2' + + @compiles(Base) + def visit_base(element, compiler, **kw): + return element.name + + @compiles(Sub1) + def visit_base(element, compiler, **kw): + return "FOO" + element.name + + self.assert_compile( + select([Sub1(), Sub2()]), + 'SELECT FOOsub1, sub2', + use_default_dialect=True + ) + + def test_subclasses_two(self): + class Base(FunctionElement): + name = 'base' + class Sub1(Base): + name = 'sub1' + + @compiles(Base) + def visit_base(element, compiler, **kw): + return element.name + + class Sub2(Base): + name = 'sub2' + + class SubSub1(Sub1): + name = 'subsub1' + + self.assert_compile( + select([Sub1(), Sub2(), SubSub1()]), + 'SELECT sub1, sub2, subsub1', + use_default_dialect=True + ) + + @compiles(Sub1) + def visit_base(element, compiler, **kw): + return "FOO" + element.name + + self.assert_compile( + select([Sub1(), Sub2(), SubSub1()]), + 'SELECT FOOsub1, sub2, FOOsubsub1', + use_default_dialect=True + ) \ No newline at end of file