From eec0260e124d4d001e3bc8d1e638d92ac3949a9b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 17 Jan 2009 07:03:36 +0000 Subject: [PATCH] - added a compiler extension that allows easy creation of user-defined compilers, which register themselves with custom ClauseElement subclasses such that the compiler is invoked along with the primary compiler. The compilers can also be registered on a per-dialect basis. This provides a supported path for SQLAlchemy extensions such as ALTER TABLE extensions and other SQL constructs. --- doc/build/reference/ext/compiler.rst | 5 + doc/build/reference/ext/index.rst | 1 + lib/sqlalchemy/connectors/pyodbc.py | 6 +- lib/sqlalchemy/dialects/mysql/pyodbc.py | 2 + lib/sqlalchemy/engine/default.py | 1 + lib/sqlalchemy/ext/compiler.py | 163 ++++++++++++++++++++++++ test/ext/alltests.py | 1 + test/ext/compiler.py | 133 +++++++++++++++++++ 8 files changed, 311 insertions(+), 1 deletion(-) create mode 100644 doc/build/reference/ext/compiler.rst create mode 100644 lib/sqlalchemy/ext/compiler.py create mode 100644 test/ext/compiler.py diff --git a/doc/build/reference/ext/compiler.rst b/doc/build/reference/ext/compiler.rst new file mode 100644 index 0000000000..95ce639b09 --- /dev/null +++ b/doc/build/reference/ext/compiler.rst @@ -0,0 +1,5 @@ +compiler +======== + +.. automodule:: sqlalchemy.ext.compiler + :members: \ No newline at end of file diff --git a/doc/build/reference/ext/index.rst b/doc/build/reference/ext/index.rst index 6dc6444225..b15253ec59 100644 --- a/doc/build/reference/ext/index.rst +++ b/doc/build/reference/ext/index.rst @@ -16,4 +16,5 @@ core behavior. orderinglist serializer sqlsoup + compiler diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 1204cfbfd3..10725f45a8 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -12,6 +12,10 @@ class PyODBCConnector(Connector): supports_unicode_statements = supports_unicode default_paramstyle = 'named' + # for non-DSN connections, this should + # hold the desired driver name + pyodbc_driver_name = None + @classmethod def dbapi(cls): return __import__('pyodbc') @@ -34,7 +38,7 @@ class PyODBCConnector(Connector): if 'port' in keys and not 'port' in query: port = ',%d' % int(keys.pop('port')) - connectors = ["DRIVER={%s}" % keys.pop('driver'), + connectors = ["DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name), 'Server=%s%s' % (keys.pop('host', ''), port), 'Database=%s' % keys.pop('database', '') ] diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index 4eb7657073..3b9b373610 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -12,6 +12,8 @@ class MySQL_pyodbcExecutionContext(MySQLExecutionContext): class MySQL_pyodbc(PyODBCConnector, MySQLDialect): supports_unicode_statements = False execution_ctx_cls = MySQL_pyodbcExecutionContext + + pyodbc_driver_name = "MySQL" def __init__(self, **kw): # deal with http://code.google.com/p/pyodbc/issues/detail?id=25 diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 8be0a2d85f..1f602eb6d3 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -34,6 +34,7 @@ class DefaultDialect(base.Dialect): supports_unicode_statements = False supports_unicode_binds = False + name = 'default' max_identifier_length = 9999 supports_sane_rowcount = True supports_sane_multi_rowcount = True diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py new file mode 100644 index 0000000000..365cc70bdf --- /dev/null +++ b/lib/sqlalchemy/ext/compiler.py @@ -0,0 +1,163 @@ +"""Provides an API for creation of custom ClauseElements and compilers. + +Synopsis +======== + +Usage involves the creation of one or more :class:`~sqlalchemy.sql.expression.ClauseElement` +subclasses and a :class:`~UserDefinedCompiler` class:: + + from sqlalchemy.ext.compiler import UserDefinedCompiler + from sqlalchemy.sql.expression import ColumnClause + + class MyColumn(ColumnClause): + __visit_name__ = 'mycolumn' + + def __init__(self, text): + ColumnClause.__init__(self, text) + + class MyCompiler(UserDefinedCompiler): + compile_elements = [MyColumn] + + def visit_mycolumn(self, element, **kw): + return "[%s]" % element.name + +Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`, the +base expression element for column objects. The ``MyCompiler`` class registers +itself with the ``MyColumn`` class so that it is invoked when the object +is compiled to a string:: + + from sqlalchemy import select + + s = select([MyColumn('x'), MyColumn('y')]) + print str(s) + +Produces:: + + SELECT [x], [y] + +User defined compilers are associated with the :class:`~sqlalchemy.engine.Compiled` +object that is responsible for the current compile, and can compile sub elements using +the :meth:`UserDefinedCompiler.process` method:: + + class InsertFromSelect(ClauseElement): + __visit_name__ = 'insert_from_select' + def __init__(self, table, select): + self.table = table + self.select = select + + class MyCompiler(UserDefinedCompiler): + compile_elements = [InsertFromSelect] + + def visit_insert_from_select(self, element, **kw): + return "INSERT INTO %s (%s)" % ( + self.process(element.table, asfrom=True), + self.process(element.select) + ) + +A single compiler can be made to service any number of elements as in this DDL example:: + + from sqlalchemy.schema import DDLElement + class AlterTable(DDLElement): + __visit_name__ = 'alter_table' + + def __init__(self, table, cmd): + self.table = table + self.cmd = cmd + + class AlterColumn(DDLElement): + __visit_name__ = 'alter_column' + + def __init__(self, column, cmd): + self.column = column + self.cmd = cmd + + class AlterCompiler(UserDefinedCompiler): + compile_elements = [AlterTable, AlterColumn] + + def visit_alter_table(self, element, **kw): + return "ALTER TABLE %s ..." % element.table.name + + def visit_alter_column(self, element, **kw): + return "ALTER COLUMN %s ..." % element.column.name + +Compilers can also be made dialect-specific. The appropriate compiler will be invoked +for the dialect in use:: + + class PGAlterCompiler(AlterCompiler): + compile_elements = [AlterTable, AlterColumn] + dialect = 'postgres' + + def visit_alter_table(self, element, **kw): + return "ALTER PG TABLE %s ..." % element.table.name + +The above compiler will be invoked when any ``postgres`` dialect is used. Note +that it extends the ``AlterCompiler`` so that the ``AlterColumn`` construct +will be serviced by the generic ``AlterCompiler.visit_alter_column()`` method. +Subclassing is not required for dialect-specific compilers, but is recommended. + +""" +from sqlalchemy import util +from sqlalchemy.engine.base import Compiled +import weakref + +def _spawn_compiler(clauseelement, compiler): + if not hasattr(compiler, '_user_compilers'): + compiler._user_compilers = {} + try: + return compiler._user_compilers[clauseelement._user_compiler_registry] + except KeyError: + registry = clauseelement._user_compiler_registry + cls = registry.get_compiler_cls(compiler.dialect) + compiler._user_compilers[registry] = user_compiler = cls(compiler) + return user_compiler + +class _CompilerRegistry(object): + def __init__(self): + self.user_compilers = {} + + def get_compiler_cls(self, dialect): + if dialect.name in self.user_compilers: + return self.user_compilers[dialect.name] + else: + return self.user_compilers['*'] + +class _UserDefinedMeta(type): + def __init__(cls, classname, bases, dict_): + if cls.compile_elements: + if not hasattr(cls.compile_elements[0], '_user_compiler_registry'): + registry = _CompilerRegistry() + def compiler_dispatch(element, visitor, **kw): + compiler = _spawn_compiler(element, visitor) + return getattr(compiler, 'visit_%s' % element.__visit_name__)(element, **kw) + + for elem in cls.compile_elements: + if hasattr(elem, '_user_compiler_registry'): + raise exceptions.InvalidRequestError("Detected an existing UserDefinedCompiler registry on class %r" % elem) + elem._user_compiler_registry = registry + elem._compiler_dispatch = compiler_dispatch + else: + registry = cls.compile_elements[0]._user_compiler_registry + + if hasattr(cls, 'dialect'): + registry.user_compilers[cls.dialect] = cls + else: + registry.user_compilers['*'] = cls + return type.__init__(cls, classname, bases, dict_) + +class UserDefinedCompiler(Compiled): + __metaclass__ = _UserDefinedMeta + compile_elements = [] + + def __init__(self, parent_compiler): + Compiled.__init__(self, parent_compiler.dialect, parent_compiler.statement, parent_compiler.bind) + self.compiler = weakref.ref(parent_compiler) + + def compile(self): + raise NotImplementedError() + + def process(self, obj, **kwargs): + return obj._compiler_dispatch(self.compiler(), **kwargs) + + def __str__(self): + return self.compiler().string or '' + \ No newline at end of file diff --git a/test/ext/alltests.py b/test/ext/alltests.py index 4733292483..a1f1be60d3 100644 --- a/test/ext/alltests.py +++ b/test/ext/alltests.py @@ -9,6 +9,7 @@ def suite(): 'ext.orderinglist', 'ext.associationproxy', 'ext.serializer', + 'ext.compiler', ) if sys.version_info < (2, 4): diff --git a/test/ext/compiler.py b/test/ext/compiler.py new file mode 100644 index 0000000000..79e1041c10 --- /dev/null +++ b/test/ext/compiler.py @@ -0,0 +1,133 @@ +import testenv; testenv.configure_for_tests() +from sqlalchemy import * +from sqlalchemy.sql.expression import ClauseElement, ColumnClause +from sqlalchemy.schema import DDLElement +from sqlalchemy.ext.compiler import UserDefinedCompiler +from sqlalchemy.ext import compiler +from sqlalchemy.sql import table, column +from testlib import * +import gc + +class UserDefinedTest(TestBase, AssertsCompiledSQL): + + def test_column(self): + + class MyThingy(ColumnClause): + __visit_name__ = 'thingy' + + def __init__(self, arg= None): + super(MyThingy, self).__init__(arg or 'MYTHINGY!') + + class MyCompiler(UserDefinedCompiler): + compile_elements = [MyThingy] + + def visit_thingy(self, thingy, **kw): + return ">>%s<<" % thingy.name + + + self.assert_compile( + select([column('foo'), MyThingy()]), + "SELECT foo, >>MYTHINGY!<<" + ) + + self.assert_compile( + select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5), + "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1" + ) + + def test_stateful(self): + class MyThingy(ColumnClause): + __visit_name__ = 'thingy' + + def __init__(self): + super(MyThingy, self).__init__('MYTHINGY!') + + class MyCompiler(UserDefinedCompiler): + compile_elements = [MyThingy] + + def __init__(self, parent_compiler): + UserDefinedCompiler.__init__(self, parent_compiler) + self.counter = 0 + + def visit_thingy(self, thingy, **kw): + self.counter += 1 + return str(self.counter) + + self.assert_compile( + select([column('foo'), MyThingy()]).order_by(desc(MyThingy())), + "SELECT foo, 1 ORDER BY 2 DESC" + ) + + self.assert_compile( + select([MyThingy(), MyThingy()]).where(MyThingy() == 5), + "SELECT 1, 2 WHERE 3 = :MYTHINGY!_1" + ) + + def test_callout_to_compiler(self): + class InsertFromSelect(ClauseElement): + __visit_name__ = 'insert_from_select' + def __init__(self, table, select): + self.table = table + self.select = select + + class MyCompiler(UserDefinedCompiler): + compile_elements = [InsertFromSelect] + + def visit_insert_from_select(self, element): + return "INSERT INTO %s (%s)" % ( + self.process(element.table, asfrom=True), + self.process(element.select) + ) + + t1 = table("mytable", column('x'), column('y'), column('z')) + self.assert_compile( + InsertFromSelect( + t1, + select([t1]).where(t1.c.x>5) + ), + "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z FROM mytable WHERE mytable.x > :x_1)" + ) + + def test_ddl(self): + class AddThingy(DDLElement): + __visit_name__ = 'add_thingy' + + class DropThingy(DDLElement): + __visit_name__ = 'drop_thingy' + + class MyCompiler(UserDefinedCompiler): + compile_elements = [AddThingy, DropThingy] + + def visit_add_thingy(self, thingy, **kw): + return "ADD THINGY" + + def visit_drop_thingy(self, thingy, **kw): + return "DROP THINGY" + + class MyPGCompiler(MyCompiler): + dialect = 'postgres' + + def visit_add_thingy(self, thingy, **kw): + return "ADD SPECIAL PG THINGY" + + self.assert_compile(AddThingy(), + "ADD THINGY" + ) + + self.assert_compile(DropThingy(), + "DROP THINGY" + ) + + self.assert_compile(AddThingy(), + "ADD SPECIAL PG THINGY", + dialect=create_engine('postgres://').dialect + ) + + self.assert_compile(DropThingy(), + "DROP THINGY", + dialect=create_engine('postgres://').dialect + ) + + +if __name__ == '__main__': + testenv.main() \ No newline at end of file -- 2.47.3