]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The 'default' compiler is automatically copied over
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Jul 2010 18:07:42 +0000 (14:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Jul 2010 18:07:42 +0000 (14:07 -0400)
when overriding the compilation of a built in
clause construct, so no KeyError is raised if the
user-defined compiler is specific to certain
backends and compilation for a different backend
is invoked. [ticket:1838]

CHANGES
lib/sqlalchemy/ext/compiler.py
test/ext/test_compiler.py

diff --git a/CHANGES b/CHANGES
index 22badc790f30ce50d540753fed5c5f75900782d4..c7e4251db2a15b8a61d1f9d5ffa2dd8e4cbddd8d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -170,7 +170,15 @@ CHANGES
      subclass.  It cannot, however, define one that is 
      not present in the __table__, and the error message
      here now works.  [ticket:1821]
-     
+
+- compiler extension
+  - The 'default' compiler is automatically copied over
+    when overriding the compilation of a built in
+    clause construct, so no KeyError is raised if the
+    user-defined compiler is specific to certain 
+    backends and compilation for a different backend
+    is invoked. [ticket:1838]
+    
 - documentation
   - Added documentation for the Inspector. [ticket:1820]
 
index 68c434fd916819a0c6be6ec02231eca2f3a11bb0..12f1e443d0afa2ae1d8dc5ece23754416074d383 100644 (file)
@@ -198,9 +198,13 @@ A big part of using the compiler extension is subclassing SQLAlchemy expression
 def compiles(class_, *specs):
     def decorate(fn):
         existing = class_.__dict__.get('_compiler_dispatcher', None)
+        existing_dispatch = class_.__dict__.get('_compiler_dispatch')
         if not existing:
             existing = _dispatcher()
-
+            
+            if existing_dispatch:
+                existing.specs['default'] = existing_dispatch
+                
             # TODO: why is the lambda needed ?
             setattr(class_, '_compiler_dispatch', lambda *arg, **kw: existing(*arg, **kw))
             setattr(class_, '_compiler_dispatcher', existing)
@@ -208,6 +212,7 @@ def compiles(class_, *specs):
         if specs:
             for s in specs:
                 existing.specs[s] = fn
+
         else:
             existing.specs['default'] = fn
         return fn
index fa1e3c1623e5efbaf115b5379e418d3f13d469f0..3ed84fe61ccd946286414a24d398ff44c999bc5f 100644 (file)
@@ -125,7 +125,39 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
             )
         finally:
             Select._compiler_dispatch = dispatch
+            if hasattr(Select, '_compiler_dispatcher'):
+                del Select._compiler_dispatcher
             
+    def test_default_on_existing(self):
+        """test that the existing compiler function remains
+        as 'default' when overriding the compilation of an
+        existing construct."""
+        
+
+        t1 = table('t1', column('c1'), column('c2'))
+        
+        dispatch = Select._compiler_dispatch
+        try:
+            
+            @compiles(Select, 'sqlite')
+            def compile(element, compiler, **kw):
+                return "OVERRIDE"
+            
+            s1 = select([t1])
+            self.assert_compile(
+                s1, "SELECT t1.c1, t1.c2 FROM t1",
+            )
+
+            from sqlalchemy.dialects.sqlite import base as sqlite
+            self.assert_compile(
+                s1, "OVERRIDE",
+                dialect=sqlite.dialect()
+            )
+        finally:
+            Select._compiler_dispatch = dispatch
+            if hasattr(Select, '_compiler_dispatcher'):
+                del Select._compiler_dispatcher
+        
     def test_dialect_specific(self):
         class AddThingy(DDLElement):
             __visit_name__ = 'add_thingy'