From: Mike Bayer Date: Fri, 2 Jul 2010 18:07:42 +0000 (-0400) Subject: - The 'default' compiler is automatically copied over X-Git-Tag: rel_0_6_2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0025a6a50eabe323c353681a1dd3949c8e57bb9b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - The 'default' compiler is automatically copied over when overriding the compilation of a built in clause construct, so no KeyError is raised if the user-defined compiler is specific to certain backends and compilation for a different backend is invoked. [ticket:1838] --- diff --git a/CHANGES b/CHANGES index 22badc790f..c7e4251db2 100644 --- a/CHANGES +++ b/CHANGES @@ -170,7 +170,15 @@ CHANGES subclass. It cannot, however, define one that is not present in the __table__, and the error message here now works. [ticket:1821] - + +- compiler extension + - The 'default' compiler is automatically copied over + when overriding the compilation of a built in + clause construct, so no KeyError is raised if the + user-defined compiler is specific to certain + backends and compilation for a different backend + is invoked. [ticket:1838] + - documentation - Added documentation for the Inspector. [ticket:1820] diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 68c434fd91..12f1e443d0 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -198,9 +198,13 @@ A big part of using the compiler extension is subclassing SQLAlchemy expression def compiles(class_, *specs): def decorate(fn): existing = class_.__dict__.get('_compiler_dispatcher', None) + existing_dispatch = class_.__dict__.get('_compiler_dispatch') if not existing: existing = _dispatcher() - + + if existing_dispatch: + existing.specs['default'] = existing_dispatch + # TODO: why is the lambda needed ? setattr(class_, '_compiler_dispatch', lambda *arg, **kw: existing(*arg, **kw)) setattr(class_, '_compiler_dispatcher', existing) @@ -208,6 +212,7 @@ def compiles(class_, *specs): if specs: for s in specs: existing.specs[s] = fn + else: existing.specs['default'] = fn return fn diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index fa1e3c1623..3ed84fe61c 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -125,7 +125,39 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): ) finally: Select._compiler_dispatch = dispatch + if hasattr(Select, '_compiler_dispatcher'): + del Select._compiler_dispatcher + def test_default_on_existing(self): + """test that the existing compiler function remains + as 'default' when overriding the compilation of an + existing construct.""" + + + t1 = table('t1', column('c1'), column('c2')) + + dispatch = Select._compiler_dispatch + try: + + @compiles(Select, 'sqlite') + def compile(element, compiler, **kw): + return "OVERRIDE" + + s1 = select([t1]) + self.assert_compile( + s1, "SELECT t1.c1, t1.c2 FROM t1", + ) + + from sqlalchemy.dialects.sqlite import base as sqlite + self.assert_compile( + s1, "OVERRIDE", + dialect=sqlite.dialect() + ) + finally: + Select._compiler_dispatch = dispatch + if hasattr(Select, '_compiler_dispatcher'): + del Select._compiler_dispatcher + def test_dialect_specific(self): class AddThingy(DDLElement): __visit_name__ = 'add_thingy'