]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
a pared down ext.compiler with minimal boilerplate.
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 31 Mar 2009 18:57:22 +0000 (18:57 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 31 Mar 2009 18:57:22 +0000 (18:57 +0000)
lib/sqlalchemy/ext/compiler.py
test/ext/compiler.py

index 365cc70bdf054f19b7cfcccc8f6665bbd141b2bb..f97dfd5377ca0d3b5bdceb8de6b29ebd4933b12f 100644 (file)
@@ -4,25 +4,20 @@ Synopsis
 ========
 
 Usage involves the creation of one or more :class:`~sqlalchemy.sql.expression.ClauseElement`
-subclasses and a :class:`~UserDefinedCompiler` class::
+subclasses and one or more callables defining its compilation::
 
-    from sqlalchemy.ext.compiler import UserDefinedCompiler
+    from sqlalchemy.ext.compiler import compiles
     from sqlalchemy.sql.expression import ColumnClause
     
     class MyColumn(ColumnClause):
-        __visit_name__ = 'mycolumn'
-        
-        def __init__(self, text):
-            ColumnClause.__init__(self, text)
+        pass
             
-    class MyCompiler(UserDefinedCompiler):
-        compile_elements = [MyColumn]
+    @compiles(MyColumn)
+    def compile_mycolumn(element, compiler, **kw):
+        return "[%s]" % element.name
         
-        def visit_mycolumn(self, element, **kw):
-            return "[%s]" % element.name
-            
 Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`, the
-base expression element for column objects.  The ``MyCompiler`` class registers
+base expression element for column objects.  The ``compiles`` decorator registers
 itself with the ``MyColumn`` class so that it is invoked when the object 
 is compiled to a string::
 
@@ -35,129 +30,81 @@ Produces::
 
     SELECT [x], [y]
 
-User defined compilers are associated with the :class:`~sqlalchemy.engine.Compiled`
-object that is responsible for the current compile, and can compile sub elements using
-the :meth:`UserDefinedCompiler.process` method::
-
-    class InsertFromSelect(ClauseElement):
-        __visit_name__ = 'insert_from_select'
-        def __init__(self, table, select):
-            self.table = table
-            self.select = select
-
-    class MyCompiler(UserDefinedCompiler):
-        compile_elements = [InsertFromSelect]
-    
-        def visit_insert_from_select(self, element, **kw):
-            return "INSERT INTO %s (%s)" % (
-                self.process(element.table, asfrom=True),
-                self.process(element.select)
-            )
-
-A single compiler can be made to service any number of elements as in this DDL example::
+Compilers can also be made dialect-specific.  The appropriate compiler will be invoked
+for the dialect in use::
 
     from sqlalchemy.schema import DDLElement
-    class AlterTable(DDLElement):
-        __visit_name__ = 'alter_table'
-        
-        def __init__(self, table, cmd):
-            self.table = table
-            self.cmd = cmd
 
     class AlterColumn(DDLElement):
-        __visit_name__ = 'alter_column'
 
         def __init__(self, column, cmd):
             self.column = column
             self.cmd = cmd
 
-    class AlterCompiler(UserDefinedCompiler):
-        compile_elements = [AlterTable, AlterColumn]
-        
-        def visit_alter_table(self, element, **kw):
-            return "ALTER TABLE %s ..." % element.table.name
+    @compiles(AlterColumn)
+    def visit_alter_column(element, compiler, **kw):
+        return "ALTER COLUMN %s ..." % element.column.name
 
-        def visit_alter_column(self, element, **kw):
-            return "ALTER COLUMN %s ..." % element.column.name
+    @compiles(AlterColumn, 'postgres')
+    def visit_alter_column(element, compiler, **kw):
+        return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, element.column.name)
 
-Compilers can also be made dialect-specific.  The appropriate compiler will be invoked
-for the dialect in use::
+The second ``visit_alter_table`` will be invoked when any ``postgres`` dialect is used.
+
+The ``compiler`` argument is the :class:`~sqlalchemy.engine.base.Compiled` object
+in use.  This object can be inspected for any information about the in-progress 
+compilation, including ``compiler.dialect``, ``compiler.statement`` etc.
+The :class:`~sqlalchemy.sql.compiler.SQLCompiler` and :class:`~sqlalchemy.sql.compiler.DDLCompiler`
+both include a ``process()`` method which can be used for compilation of embedded attributes::
+
+    class InsertFromSelect(ClauseElement):
+        def __init__(self, table, select):
+            self.table = table
+            self.select = select
+
+    @compiles(InsertFromSelect)
+    def visit_insert_from_select(element, compiler, **kw):
+        return "INSERT INTO %s (%s)" % (
+            compiler.process(element.table, asfrom=True),
+            compiler.process(element.select)
+        )
+
+    insert = InsertFromSelect(t1, select([t1]).where(t1.c.x>5))
+    print insert
     
-    class PGAlterCompiler(AlterCompiler):
-        compile_elements = [AlterTable, AlterColumn]
-        dialect = 'postgres'
-        
-        def visit_alter_table(self, element, **kw):
-            return "ALTER PG TABLE %s ..." % element.table.name
+Produces::
+
+    "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)"
 
-The above compiler will be invoked when any ``postgres`` dialect is used. Note
-that it extends the ``AlterCompiler`` so that the ``AlterColumn`` construct
-will be serviced by the generic ``AlterCompiler.visit_alter_column()`` method. 
-Subclassing is not required for dialect-specific compilers, but is recommended.
 
 """
-from sqlalchemy import util
-from sqlalchemy.engine.base import Compiled
-import weakref
-
-def _spawn_compiler(clauseelement, compiler):
-    if not hasattr(compiler, '_user_compilers'):
-        compiler._user_compilers = {}
-    try:
-        return compiler._user_compilers[clauseelement._user_compiler_registry]
-    except KeyError:
-        registry = clauseelement._user_compiler_registry
-        cls = registry.get_compiler_cls(compiler.dialect)
-        compiler._user_compilers[registry] = user_compiler = cls(compiler)
-        return user_compiler
-
-class _CompilerRegistry(object):
-    def __init__(self):
-        self.user_compilers = {}
+
+def compiles(class_, *specs):
+    def decorate(fn):
+        existing = getattr(class_, '_compiler_dispatcher', None)
+        if not existing:
+            existing = _dispatcher()
+
+            # TODO: why is the lambda needed ?
+            setattr(class_, '_compiler_dispatch', lambda *arg, **kw: existing(*arg, **kw))
+            setattr(class_, '_compiler_dispatcher', existing)
         
-    def get_compiler_cls(self, dialect):
-        if dialect.name in self.user_compilers:
-            return self.user_compilers[dialect.name]
+        if specs:
+            for s in specs:
+                existing.specs[s] = fn
         else:
-            return self.user_compilers['*']
-
-class _UserDefinedMeta(type):
-    def __init__(cls, classname, bases, dict_):
-        if cls.compile_elements:
-            if not hasattr(cls.compile_elements[0], '_user_compiler_registry'):
-                registry = _CompilerRegistry()
-                def compiler_dispatch(element, visitor, **kw):
-                    compiler = _spawn_compiler(element, visitor)
-                    return getattr(compiler, 'visit_%s' % element.__visit_name__)(element, **kw)
-                
-                for elem in cls.compile_elements:
-                    if hasattr(elem, '_user_compiler_registry'):
-                        raise exceptions.InvalidRequestError("Detected an existing UserDefinedCompiler registry on class %r" % elem)
-                    elem._user_compiler_registry = registry
-                    elem._compiler_dispatch = compiler_dispatch
-            else:
-                registry = cls.compile_elements[0]._user_compiler_registry
-        
-            if hasattr(cls, 'dialect'):
-                registry.user_compilers[cls.dialect] = cls
-            else:
-                registry.user_compilers['*'] = cls
-        return type.__init__(cls, classname, bases, dict_)
-
-class UserDefinedCompiler(Compiled):
-    __metaclass__ = _UserDefinedMeta
-    compile_elements = []
+            existing.specs['default'] = fn
+        return fn
+    return decorate
     
-    def __init__(self, parent_compiler):
-        Compiled.__init__(self, parent_compiler.dialect, parent_compiler.statement, parent_compiler.bind)
-        self.compiler = weakref.ref(parent_compiler)
+class _dispatcher(object):
+    def __init__(self):
+        self.specs = {}
+    
+    def __call__(self, element, compiler, **kw):
+        # TODO: yes, this could also switch off of DBAPI in use.
+        fn = self.specs.get(compiler.dialect.name, None)
+        if not fn:
+            fn = self.specs['default']
+        return fn(element, compiler, **kw)
         
-    def compile(self):
-        raise NotImplementedError()
-
-    def process(self, obj, **kwargs):
-        return obj._compiler_dispatch(self.compiler(), **kwargs)
-
-    def __str__(self):
-        return self.compiler().string or ''
-    
\ No newline at end of file
index d3965bb1529910becd04b3f02edf817e7ba2ae39..c8fefdf06f48db76b9ddeaa07907413d1a2f231a 100644 (file)
@@ -2,28 +2,21 @@ import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.sql.expression import ClauseElement, ColumnClause
 from sqlalchemy.schema import DDLElement
-from sqlalchemy.ext.compiler import UserDefinedCompiler
-from sqlalchemy.ext import compiler
+from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.sql import table, column
 from testlib import *
-import gc
 
 class UserDefinedTest(TestBase, AssertsCompiledSQL):
 
     def test_column(self):
 
         class MyThingy(ColumnClause):
-            __visit_name__ = 'thingy'
-
             def __init__(self, arg= None):
                 super(MyThingy, self).__init__(arg or 'MYTHINGY!')
 
-        class MyCompiler(UserDefinedCompiler):
-            compile_elements = [MyThingy]
-
-            def visit_thingy(self, thingy, **kw):
-                return ">>%s<<" % thingy.name
-
+        @compiles(MyThingy)
+        def visit_thingy(thingy, compiler, **kw):
+            return ">>%s<<" % thingy.name
 
         self.assert_compile(
             select([column('foo'), MyThingy()]),
@@ -37,21 +30,15 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
 
     def test_stateful(self):
         class MyThingy(ColumnClause):
-            __visit_name__ = 'thingy'
-
             def __init__(self):
                 super(MyThingy, self).__init__('MYTHINGY!')
 
-        class MyCompiler(UserDefinedCompiler):
-            compile_elements = [MyThingy]
-
-            def __init__(self, parent_compiler):
-                UserDefinedCompiler.__init__(self, parent_compiler)
-                self.counter = 0
-
-            def visit_thingy(self, thingy, **kw):
-                self.counter += 1
-                return str(self.counter)
+        @compiles(MyThingy)
+        def visit_thingy(thingy, compiler, **kw):
+            if not hasattr(compiler, 'counter'):
+                compiler.counter = 0
+            compiler.counter += 1
+            return str(compiler.counter)
 
         self.assert_compile(
             select([column('foo'), MyThingy()]).order_by(desc(MyThingy())),
@@ -65,19 +52,16 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
 
     def test_callout_to_compiler(self):
         class InsertFromSelect(ClauseElement):
-            __visit_name__ = 'insert_from_select'
             def __init__(self, table, select):
                 self.table = table
                 self.select = select
 
-        class MyCompiler(UserDefinedCompiler):
-            compile_elements = [InsertFromSelect]
-
-            def visit_insert_from_select(self, element):
-                return "INSERT INTO %s (%s)" % (
-                    self.process(element.table, asfrom=True),
-                    self.process(element.select)
-                )
+        @compiles(InsertFromSelect)
+        def visit_insert_from_select(element, compiler, **kw):
+            return "INSERT INTO %s (%s)" % (
+                compiler.process(element.table, asfrom=True),
+                compiler.process(element.select)
+            )
 
         t1 = table("mytable", column('x'), column('y'), column('z'))
         self.assert_compile(
@@ -88,27 +72,24 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
             "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)"
         )
 
-    def test_ddl(self):
+    def test_dialect_specific(self):
         class AddThingy(DDLElement):
             __visit_name__ = 'add_thingy'
 
         class DropThingy(DDLElement):
             __visit_name__ = 'drop_thingy'
 
-        class MyCompiler(UserDefinedCompiler):
-            compile_elements = [AddThingy, DropThingy]
-
-            def visit_add_thingy(self, thingy, **kw):
-                return "ADD THINGY"
-
-            def visit_drop_thingy(self, thingy, **kw):
-                return "DROP THINGY"
+        @compiles(AddThingy, 'sqlite')
+        def visit_add_thingy(thingy, compiler, **kw):
+            return "ADD SPECIAL SL THINGY"
 
-        class MySqliteCompiler(MyCompiler):
-            dialect = 'sqlite'
+        @compiles(AddThingy)
+        def visit_add_thingy(thingy, compiler, **kw):
+            return "ADD THINGY"
 
-            def visit_add_thingy(self, thingy, **kw):
-                return "ADD SPECIAL SL THINGY"
+        @compiles(DropThingy)
+        def visit_drop_thingy(thingy, compiler, **kw):
+            return "DROP THINGY"
 
         self.assert_compile(AddThingy(),
             "ADD THINGY"
@@ -129,6 +110,18 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
             dialect=base.dialect()
         )
 
+        @compiles(DropThingy, 'sqlite')
+        def visit_drop_thingy(thingy, compiler, **kw):
+            return "DROP SPECIAL SL THINGY"
+
+        self.assert_compile(DropThingy(),
+            "DROP SPECIAL SL THINGY",
+            dialect=base.dialect()
+        )
+
+        self.assert_compile(DropThingy(),
+            "DROP THINGY",
+        )
 
 if __name__ == '__main__':
     testenv.main()