]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure @compiles calls down to the original compilation scheme
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 29 Jun 2016 15:11:17 +0000 (11:11 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 29 Jun 2016 15:14:28 +0000 (11:14 -0400)
Made a slight behavioral change in the ``sqlalchemy.ext.compiler``
extension, whereby the existing compilation schemes for an established
construct would be removed if that construct was itself didn't already
have its own dedicated ``__visit_name__``.  This was a
rare occurrence in 1.0, however in 1.1 :class:`.postgresql.ARRAY`
subclasses :class:`.sqltypes.ARRAY` and has this behavior.
As a result, setting up a compilation handler for another dialect
such as SQLite would render the main :class:`.postgresql.ARRAY`
object no longer compilable.

Fixes: #3732
Change-Id: If2c1ada4eeb09157885888e41f529173902f2b49

doc/build/changelog/changelog_11.rst
lib/sqlalchemy/ext/compiler.py
test/ext/test_compiler.py

index 9ae17fde090d90ad62191c8b7f44a041b7b80ca1..7fea05d4d952b46cff9a7bcda150c9e3d1110d8f 100644 (file)
 .. changelog::
     :version: 1.1.0b2
 
+    .. change::
+        :tags: bug, ext, postgresql
+        :tickets: 3732
+
+        Made a slight behavioral change in the ``sqlalchemy.ext.compiler``
+        extension, whereby the existing compilation schemes for an established
+        construct would be removed if that construct was itself didn't already
+        have its own dedicated ``__visit_name__``.  This was a
+        rare occurrence in 1.0, however in 1.1 :class:`.postgresql.ARRAY`
+        subclasses :class:`.sqltypes.ARRAY` and has this behavior.
+        As a result, setting up a compilation handler for another dialect
+        such as SQLite would render the main :class:`.postgresql.ARRAY`
+        object no longer compilable.
+
     .. change::
         :tags: bug, sql
         :tickets: 3730
index 86156be1fba97f44079b8935caa3936d6b64030a..5ef4e1d2a0e13ecfac718aac5ccea0301ef43364 100644 (file)
@@ -410,13 +410,25 @@ def compiles(class_, *specs):
     given :class:`.ClauseElement` type."""
 
     def decorate(fn):
+        # get an existing @compiles handler
         existing = class_.__dict__.get('_compiler_dispatcher', None)
-        existing_dispatch = class_.__dict__.get('_compiler_dispatch')
+
+        # get the original handler.  All ClauseElement classes have one
+        # of these, but some TypeEngine classes will not.
+        existing_dispatch = getattr(class_, '_compiler_dispatch', None)
+
         if not existing:
             existing = _dispatcher()
 
             if existing_dispatch:
-                existing.specs['default'] = existing_dispatch
+                def _wrap_existing_dispatch(element, compiler, **kw):
+                    try:
+                        return existing_dispatch(element, compiler, **kw)
+                    except exc.UnsupportedCompilationError:
+                        raise exc.CompileError(
+                            "%s construct has no default "
+                            "compilation handler." % type(element))
+                existing.specs['default'] = _wrap_existing_dispatch
 
             # TODO: why is the lambda needed ?
             setattr(class_, '_compiler_dispatch',
@@ -458,4 +470,5 @@ class _dispatcher(object):
                 raise exc.CompileError(
                     "%s construct has no default "
                     "compilation handler." % type(element))
+
         return fn(element, compiler, **kw)
index f381ca185b1fa562f41d2b78a9b4bb6c550d4b38..02b9f3a433cdd155987ef7fe89238ffab38e2845 100644 (file)
@@ -2,15 +2,17 @@ from sqlalchemy import *
 from sqlalchemy.types import TypeEngine
 from sqlalchemy.sql.expression import ClauseElement, ColumnClause,\
                                     FunctionElement, Select, \
-                                    BindParameter
+                                    BindParameter, ColumnElement
 
 from sqlalchemy.schema import DDLElement, CreateColumn, CreateTable
 from sqlalchemy.ext.compiler import compiles, deregister
 from sqlalchemy import exc
-from sqlalchemy.sql import table, column, visitors
+from sqlalchemy.testing import eq_
+from sqlalchemy.sql import table, column
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import fixtures, AssertsCompiledSQL
 
+
 class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = 'default'
 
@@ -123,7 +125,7 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
             "FROM mytable WHERE mytable.x > :x_1)"
         )
 
-    def test_no_default_message(self):
+    def test_no_default_but_has_a_visit(self):
         class MyThingy(ColumnClause):
             pass
 
@@ -131,11 +133,52 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
         def visit_thingy(thingy, compiler, **kw):
             return "mythingy"
 
+        eq_(str(MyThingy('x')), "x")
+
+    def test_no_default_has_no_visit(self):
+        class MyThingy(TypeEngine):
+            pass
+
+        @compiles(MyThingy, 'postgresql')
+        def visit_thingy(thingy, compiler, **kw):
+            return "mythingy"
+
         assert_raises_message(
             exc.CompileError,
             "<class 'test.ext.test_compiler..*MyThingy'> "
             "construct has no default compilation handler.",
-            str, MyThingy('x')
+            str, MyThingy()
+        )
+
+    def test_no_default_message(self):
+        class MyThingy(ClauseElement):
+            pass
+
+        @compiles(MyThingy, 'postgresql')
+        def visit_thingy(thingy, compiler, **kw):
+            return "mythingy"
+
+        assert_raises_message(
+            exc.CompileError,
+            "<class 'test.ext.test_compiler..*MyThingy'> "
+            "construct has no default compilation handler.",
+            str, MyThingy()
+        )
+
+    def test_default_subclass(self):
+        from sqlalchemy.types import ARRAY
+
+        class MyArray(ARRAY):
+            pass
+
+        @compiles(MyArray, "sqlite")
+        def sl_array(elem, compiler, **kw):
+            return "array"
+
+        self.assert_compile(
+            MyArray(Integer),
+            "INTEGER[]",
+            dialect="postgresql"
         )
 
     def test_annotations(self):