]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added a compiler extension that allows easy creation of user-defined compilers,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Jan 2009 07:03:36 +0000 (07:03 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Jan 2009 07:03:36 +0000 (07:03 +0000)
which register themselves with custom ClauseElement subclasses such that the compiler
is invoked along with the primary compiler.  The compilers can also be registered
on a per-dialect basis.

This provides a supported path for SQLAlchemy extensions such as ALTER TABLE
extensions and other SQL constructs.

doc/build/reference/ext/compiler.rst [new file with mode: 0644]
doc/build/reference/ext/index.rst
lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/dialects/mysql/pyodbc.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/ext/compiler.py [new file with mode: 0644]
test/ext/alltests.py
test/ext/compiler.py [new file with mode: 0644]

diff --git a/doc/build/reference/ext/compiler.rst b/doc/build/reference/ext/compiler.rst
new file mode 100644 (file)
index 0000000..95ce639
--- /dev/null
@@ -0,0 +1,5 @@
+compiler
+========
+
+.. automodule:: sqlalchemy.ext.compiler
+    :members:
\ No newline at end of file
index 6dc644422523bb2e3d382c314b708cdcacd25196..b15253ec5901439351645692aca50512b75d2b67 100644 (file)
@@ -16,4 +16,5 @@ core behavior.
     orderinglist
     serializer
     sqlsoup
+    compiler
 
index 1204cfbfd3dda33bd8549c0154a1bca879c1a82f..10725f45a87c46f89b349dd08ffa93bbbdd23c9d 100644 (file)
@@ -12,6 +12,10 @@ class PyODBCConnector(Connector):
     supports_unicode_statements = supports_unicode
     default_paramstyle = 'named'
     
+    # for non-DSN connections, this should
+    # hold the desired driver name
+    pyodbc_driver_name = None
+    
     @classmethod
     def dbapi(cls):
         return __import__('pyodbc')
@@ -34,7 +38,7 @@ class PyODBCConnector(Connector):
                 if 'port' in keys and not 'port' in query:
                     port = ',%d' % int(keys.pop('port'))
 
-                connectors = ["DRIVER={%s}" % keys.pop('driver'),
+                connectors = ["DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name),
                               'Server=%s%s' % (keys.pop('host', ''), port),
                               'Database=%s' % keys.pop('database', '') ]
 
index 4eb7657073f129bfa39553995bab5f16881acccd..3b9b373610d8055a08f77160b79ec7d0c90a5eb5 100644 (file)
@@ -12,6 +12,8 @@ class MySQL_pyodbcExecutionContext(MySQLExecutionContext):
 class MySQL_pyodbc(PyODBCConnector, MySQLDialect):
     supports_unicode_statements = False
     execution_ctx_cls = MySQL_pyodbcExecutionContext
+
+    pyodbc_driver_name = "MySQL"
     
     def __init__(self, **kw):
         # deal with http://code.google.com/p/pyodbc/issues/detail?id=25
index 8be0a2d85fd9b80400438411afeb1e6c6cde747a..1f602eb6d3e8fcf0eafe666bd69a92dfa02a74ff 100644 (file)
@@ -34,6 +34,7 @@ class DefaultDialect(base.Dialect):
     supports_unicode_statements = False
     supports_unicode_binds = False
     
+    name = 'default'
     max_identifier_length = 9999
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = True
diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py
new file mode 100644 (file)
index 0000000..365cc70
--- /dev/null
@@ -0,0 +1,163 @@
+"""Provides an API for creation of custom ClauseElements and compilers.
+
+Synopsis
+========
+
+Usage involves the creation of one or more :class:`~sqlalchemy.sql.expression.ClauseElement`
+subclasses and a :class:`~UserDefinedCompiler` class::
+
+    from sqlalchemy.ext.compiler import UserDefinedCompiler
+    from sqlalchemy.sql.expression import ColumnClause
+    
+    class MyColumn(ColumnClause):
+        __visit_name__ = 'mycolumn'
+        
+        def __init__(self, text):
+            ColumnClause.__init__(self, text)
+            
+    class MyCompiler(UserDefinedCompiler):
+        compile_elements = [MyColumn]
+        
+        def visit_mycolumn(self, element, **kw):
+            return "[%s]" % element.name
+            
+Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`, the
+base expression element for column objects.  The ``MyCompiler`` class registers
+itself with the ``MyColumn`` class so that it is invoked when the object 
+is compiled to a string::
+
+    from sqlalchemy import select
+    
+    s = select([MyColumn('x'), MyColumn('y')])
+    print str(s)
+    
+Produces::
+
+    SELECT [x], [y]
+
+User defined compilers are associated with the :class:`~sqlalchemy.engine.Compiled`
+object that is responsible for the current compile, and can compile sub elements using
+the :meth:`UserDefinedCompiler.process` method::
+
+    class InsertFromSelect(ClauseElement):
+        __visit_name__ = 'insert_from_select'
+        def __init__(self, table, select):
+            self.table = table
+            self.select = select
+
+    class MyCompiler(UserDefinedCompiler):
+        compile_elements = [InsertFromSelect]
+    
+        def visit_insert_from_select(self, element, **kw):
+            return "INSERT INTO %s (%s)" % (
+                self.process(element.table, asfrom=True),
+                self.process(element.select)
+            )
+
+A single compiler can be made to service any number of elements as in this DDL example::
+
+    from sqlalchemy.schema import DDLElement
+    class AlterTable(DDLElement):
+        __visit_name__ = 'alter_table'
+        
+        def __init__(self, table, cmd):
+            self.table = table
+            self.cmd = cmd
+
+    class AlterColumn(DDLElement):
+        __visit_name__ = 'alter_column'
+
+        def __init__(self, column, cmd):
+            self.column = column
+            self.cmd = cmd
+
+    class AlterCompiler(UserDefinedCompiler):
+        compile_elements = [AlterTable, AlterColumn]
+        
+        def visit_alter_table(self, element, **kw):
+            return "ALTER TABLE %s ..." % element.table.name
+
+        def visit_alter_column(self, element, **kw):
+            return "ALTER COLUMN %s ..." % element.column.name
+
+Compilers can also be made dialect-specific.  The appropriate compiler will be invoked
+for the dialect in use::
+    
+    class PGAlterCompiler(AlterCompiler):
+        compile_elements = [AlterTable, AlterColumn]
+        dialect = 'postgres'
+        
+        def visit_alter_table(self, element, **kw):
+            return "ALTER PG TABLE %s ..." % element.table.name
+
+The above compiler will be invoked when any ``postgres`` dialect is used. Note
+that it extends the ``AlterCompiler`` so that the ``AlterColumn`` construct
+will be serviced by the generic ``AlterCompiler.visit_alter_column()`` method. 
+Subclassing is not required for dialect-specific compilers, but is recommended.
+
+"""
+from sqlalchemy import util
+from sqlalchemy.engine.base import Compiled
+import weakref
+
+def _spawn_compiler(clauseelement, compiler):
+    if not hasattr(compiler, '_user_compilers'):
+        compiler._user_compilers = {}
+    try:
+        return compiler._user_compilers[clauseelement._user_compiler_registry]
+    except KeyError:
+        registry = clauseelement._user_compiler_registry
+        cls = registry.get_compiler_cls(compiler.dialect)
+        compiler._user_compilers[registry] = user_compiler = cls(compiler)
+        return user_compiler
+
+class _CompilerRegistry(object):
+    def __init__(self):
+        self.user_compilers = {}
+        
+    def get_compiler_cls(self, dialect):
+        if dialect.name in self.user_compilers:
+            return self.user_compilers[dialect.name]
+        else:
+            return self.user_compilers['*']
+
+class _UserDefinedMeta(type):
+    def __init__(cls, classname, bases, dict_):
+        if cls.compile_elements:
+            if not hasattr(cls.compile_elements[0], '_user_compiler_registry'):
+                registry = _CompilerRegistry()
+                def compiler_dispatch(element, visitor, **kw):
+                    compiler = _spawn_compiler(element, visitor)
+                    return getattr(compiler, 'visit_%s' % element.__visit_name__)(element, **kw)
+                
+                for elem in cls.compile_elements:
+                    if hasattr(elem, '_user_compiler_registry'):
+                        raise exceptions.InvalidRequestError("Detected an existing UserDefinedCompiler registry on class %r" % elem)
+                    elem._user_compiler_registry = registry
+                    elem._compiler_dispatch = compiler_dispatch
+            else:
+                registry = cls.compile_elements[0]._user_compiler_registry
+        
+            if hasattr(cls, 'dialect'):
+                registry.user_compilers[cls.dialect] = cls
+            else:
+                registry.user_compilers['*'] = cls
+        return type.__init__(cls, classname, bases, dict_)
+
+class UserDefinedCompiler(Compiled):
+    __metaclass__ = _UserDefinedMeta
+    compile_elements = []
+    
+    def __init__(self, parent_compiler):
+        Compiled.__init__(self, parent_compiler.dialect, parent_compiler.statement, parent_compiler.bind)
+        self.compiler = weakref.ref(parent_compiler)
+        
+    def compile(self):
+        raise NotImplementedError()
+
+    def process(self, obj, **kwargs):
+        return obj._compiler_dispatch(self.compiler(), **kwargs)
+
+    def __str__(self):
+        return self.compiler().string or ''
+    
\ No newline at end of file
index 4733292483dc3adbbdb5ddec36b6aefc4798b59d..a1f1be60d31f60f59796cd1980371b50d10f5a03 100644 (file)
@@ -9,6 +9,7 @@ def suite():
         'ext.orderinglist',
         'ext.associationproxy',
         'ext.serializer',
+        'ext.compiler',
         )
 
     if sys.version_info < (2, 4):
diff --git a/test/ext/compiler.py b/test/ext/compiler.py
new file mode 100644 (file)
index 0000000..79e1041
--- /dev/null
@@ -0,0 +1,133 @@
+import testenv; testenv.configure_for_tests()
+from sqlalchemy import *
+from sqlalchemy.sql.expression import ClauseElement, ColumnClause
+from sqlalchemy.schema import DDLElement
+from sqlalchemy.ext.compiler import UserDefinedCompiler
+from sqlalchemy.ext import compiler
+from sqlalchemy.sql import table, column
+from testlib import *
+import gc
+
+class UserDefinedTest(TestBase, AssertsCompiledSQL):
+    
+    def test_column(self):
+        
+        class MyThingy(ColumnClause):
+            __visit_name__ = 'thingy'
+            
+            def __init__(self, arg= None):
+                super(MyThingy, self).__init__(arg or 'MYTHINGY!')
+                
+        class MyCompiler(UserDefinedCompiler):
+            compile_elements = [MyThingy]
+            
+            def visit_thingy(self, thingy, **kw):
+                return ">>%s<<" % thingy.name
+                
+        
+        self.assert_compile(
+            select([column('foo'), MyThingy()]),
+            "SELECT foo, >>MYTHINGY!<<"
+        )
+
+        self.assert_compile(
+            select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5),
+            "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1"
+        )
+
+    def test_stateful(self):
+        class MyThingy(ColumnClause):
+            __visit_name__ = 'thingy'
+            
+            def __init__(self):
+                super(MyThingy, self).__init__('MYTHINGY!')
+                
+        class MyCompiler(UserDefinedCompiler):
+            compile_elements = [MyThingy]
+            
+            def __init__(self, parent_compiler):
+                UserDefinedCompiler.__init__(self, parent_compiler)
+                self.counter = 0
+                
+            def visit_thingy(self, thingy, **kw):
+                self.counter += 1
+                return str(self.counter)
+
+        self.assert_compile(
+            select([column('foo'), MyThingy()]).order_by(desc(MyThingy())),
+            "SELECT foo, 1 ORDER BY 2 DESC"
+        )
+
+        self.assert_compile(
+            select([MyThingy(), MyThingy()]).where(MyThingy() == 5),
+            "SELECT 1, 2 WHERE 3 = :MYTHINGY!_1"
+        )
+
+    def test_callout_to_compiler(self):
+        class InsertFromSelect(ClauseElement):
+            __visit_name__ = 'insert_from_select'
+            def __init__(self, table, select):
+                self.table = table
+                self.select = select
+        
+        class MyCompiler(UserDefinedCompiler):
+            compile_elements = [InsertFromSelect]
+            
+            def visit_insert_from_select(self, element):
+                return "INSERT INTO %s (%s)" % (
+                    self.process(element.table, asfrom=True),
+                    self.process(element.select)
+                )
+        
+        t1 = table("mytable", column('x'), column('y'), column('z'))
+        self.assert_compile(
+            InsertFromSelect(
+                t1,
+                select([t1]).where(t1.c.x>5)
+            ),
+            "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)"
+        )
+
+    def test_ddl(self):
+        class AddThingy(DDLElement):
+            __visit_name__ = 'add_thingy'
+
+        class DropThingy(DDLElement):
+            __visit_name__ = 'drop_thingy'
+
+        class MyCompiler(UserDefinedCompiler):
+            compile_elements = [AddThingy, DropThingy]
+
+            def visit_add_thingy(self, thingy, **kw):
+                return "ADD THINGY"
+
+            def visit_drop_thingy(self, thingy, **kw):
+                return "DROP THINGY"
+
+        class MyPGCompiler(MyCompiler):
+            dialect = 'postgres'
+            
+            def visit_add_thingy(self, thingy, **kw):
+                return "ADD SPECIAL PG THINGY"
+
+        self.assert_compile(AddThingy(),
+            "ADD THINGY"
+        )
+
+        self.assert_compile(DropThingy(),
+            "DROP THINGY"
+        )
+
+        self.assert_compile(AddThingy(),
+            "ADD SPECIAL PG THINGY",
+            dialect=create_engine('postgres://').dialect
+        )
+
+        self.assert_compile(DropThingy(),
+            "DROP THINGY",
+            dialect=create_engine('postgres://').dialect
+        )
+        
+        
+if __name__ == '__main__':
+    testenv.main()
\ No newline at end of file