From: Mike Bayer Date: Wed, 9 Feb 2011 20:45:15 +0000 (-0500) Subject: - The compiler extension now supports overriding the default X-Git-Tag: rel_0_7b1~25 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3f30fb065c3b6baa8d7870bb0682d20ad37a62a2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - The compiler extension now supports overriding the default compilation of expression._BindParamClause including that the auto-generated binds within the VALUES/SET clause of an insert()/update() statement will also use the new compilation rules. [ticket:2042] --- diff --git a/CHANGES b/CHANGES index dec011438b..ceba445c54 100644 --- a/CHANGES +++ b/CHANGES @@ -231,6 +231,12 @@ CHANGES the extension compiles and runs on Python 2.4. [ticket:2023] + - The compiler extension now supports overriding the default + compilation of expression._BindParamClause including that + the auto-generated binds within the VALUES/SET clause + of an insert()/update() statement will also use the new + compilation rules. [ticket:2042] + - postgresql - When explicit sequence execution derives the name of the auto-generated sequence of a SERIAL column, diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d906bf5d46..1ab0ba4054 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -579,8 +579,9 @@ class SQLCompiler(engine.Compiled): else: return fn(" " + operator + " ") - def visit_bindparam(self, bindparam, within_columns_clause=False, + def visit_bindparam(self, bindparam, within_columns_clause=False, literal_binds=False, **kwargs): + if literal_binds or \ (within_columns_clause and \ self.ansi_bind_rules): @@ -591,6 +592,7 @@ class SQLCompiler(engine.Compiled): within_columns_clause=True, **kwargs) name = self._truncate_bindparam(bindparam) + if name in self.binds: existing = self.binds[name] if existing is not bindparam: @@ -600,7 +602,8 @@ class SQLCompiler(engine.Compiled): "unique bind parameter of the same name" % bindparam.key ) - elif getattr(existing, '_is_crud', False): + elif getattr(existing, '_is_crud', False) or \ + getattr(bindparam, '_is_crud', False): raise exc.CompileError( "bindparam() name '%s' is reserved " "for automatic usage in the VALUES or SET " @@ -992,18 +995,8 @@ class SQLCompiler(engine.Compiled): bindparam = sql.bindparam(col.key, value, type_=col.type, required=required) bindparam._is_crud = True - if col.key in self.binds: - raise exc.CompileError( - "bindparam() name '%s' is reserved " - "for automatic usage in the VALUES or SET clause of this " - "insert/update statement. Please use a " - "name other than column name when using bindparam() " - "with insert() or update() (for example, 'b_%s')." - % (col.key, col.key) - ) + return bindparam._compiler_dispatch(self) - self.binds[col.key] = bindparam - return self.bindparam_string(self._truncate_bindparam(bindparam)) def _get_colparams(self, stmt): """create a set of tuples representing column/string pairs for use diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 8011aa109d..0c6be97d78 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -44,22 +44,25 @@ class VisitableType(type): super(VisitableType, cls).__init__(clsname, bases, clsdict) return - # set up an optimized visit dispatch function - # for use by the compiler - if '__visit_name__' in cls.__dict__: - visit_name = cls.__visit_name__ - if isinstance(visit_name, str): - getter = operator.attrgetter("visit_%s" % visit_name) - def _compiler_dispatch(self, visitor, **kw): - return getter(visitor)(self, **kw) - else: - def _compiler_dispatch(self, visitor, **kw): - return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) - - cls._compiler_dispatch = _compiler_dispatch + _generate_dispatch(cls) super(VisitableType, cls).__init__(clsname, bases, clsdict) +def _generate_dispatch(cls): + # set up an optimized visit dispatch function + # for use by the compiler + if '__visit_name__' in cls.__dict__: + visit_name = cls.__visit_name__ + if isinstance(visit_name, str): + getter = operator.attrgetter("visit_%s" % visit_name) + def _compiler_dispatch(self, visitor, **kw): + return getter(visitor)(self, **kw) + else: + def _compiler_dispatch(self, visitor, **kw): + return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) + + cls._compiler_dispatch = _compiler_dispatch + class Visitable(object): """Base class for visitable objects, applies the ``VisitableType`` metaclass. diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py index 18d0118f14..697ab5952c 100644 --- a/test/aaa_profiling/test_compiler.py +++ b/test/aaa_profiling/test_compiler.py @@ -32,12 +32,12 @@ class CompileTest(TestBase, AssertsExecutionResults): cls.dialect = default.DefaultDialect() - @profiling.function_call_count(versions={'2.7':58, '2.6':58, - '3':64}) + @profiling.function_call_count(versions={'2.7':62, '2.6':62, + '3':68}) def test_insert(self): t1.insert().compile(dialect=self.dialect) - @profiling.function_call_count(versions={'2.6':49, '2.7':49}) + @profiling.function_call_count(versions={'2.6':53, '2.7':53}) def test_update(self): t1.update().compile(dialect=self.dialect) diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index d9bb778db8..116b0f2293 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -1,10 +1,12 @@ from sqlalchemy import * from sqlalchemy.types import TypeEngine from sqlalchemy.sql.expression import ClauseElement, ColumnClause,\ - FunctionElement, Select + FunctionElement, Select, \ + _BindParamClause + from sqlalchemy.schema import DDLElement from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql import table, column +from sqlalchemy.sql import table, column, visitors from test.lib import * class UserDefinedTest(TestBase, AssertsCompiledSQL): @@ -128,36 +130,6 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): if hasattr(Select, '_compiler_dispatcher'): del Select._compiler_dispatcher - def test_default_on_existing(self): - """test that the existing compiler function remains - as 'default' when overriding the compilation of an - existing construct.""" - - - t1 = table('t1', column('c1'), column('c2')) - - dispatch = Select._compiler_dispatch - try: - - @compiles(Select, 'sqlite') - def compile(element, compiler, **kw): - return "OVERRIDE" - - s1 = select([t1]) - self.assert_compile( - s1, "SELECT t1.c1, t1.c2 FROM t1", - ) - - from sqlalchemy.dialects.sqlite import base as sqlite - self.assert_compile( - s1, "OVERRIDE", - dialect=sqlite.dialect() - ) - finally: - Select._compiler_dispatch = dispatch - if hasattr(Select, '_compiler_dispatcher'): - del Select._compiler_dispatcher - def test_dialect_specific(self): class AddThingy(DDLElement): __visit_name__ = 'add_thingy' @@ -290,3 +262,66 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): 'SELECT FOOsub1, sub2, FOOsubsub1', use_default_dialect=True ) + + +class DefaultOnExistingTest(TestBase, AssertsCompiledSQL): + """Test replacement of default compilation on existing constructs.""" + + def teardown(self): + for cls in (Select, _BindParamClause): + if hasattr(cls, '_compiler_dispatcher'): + visitors._generate_dispatch(cls) + del cls._compiler_dispatcher + + def test_select(self): + t1 = table('t1', column('c1'), column('c2')) + + @compiles(Select, 'sqlite') + def compile(element, compiler, **kw): + return "OVERRIDE" + + s1 = select([t1]) + self.assert_compile( + s1, "SELECT t1.c1, t1.c2 FROM t1", + ) + + from sqlalchemy.dialects.sqlite import base as sqlite + self.assert_compile( + s1, "OVERRIDE", + dialect=sqlite.dialect() + ) + + def test_binds_in_select(self): + t = table('t', + column('a'), + column('b'), + column('c') + ) + + @compiles(_BindParamClause) + def gen_bind(element, compiler, **kw): + return "BIND(%s)" % compiler.visit_bindparam(element, **kw) + + self.assert_compile( + t.select().where(t.c.c == 5), + "SELECT t.a, t.b, t.c FROM t WHERE t.c = BIND(:c_1)", + use_default_dialect=True + ) + + def test_binds_in_dml(self): + t = table('t', + column('a'), + column('b'), + column('c') + ) + + @compiles(_BindParamClause) + def gen_bind(element, compiler, **kw): + return "BIND(%s)" % compiler.visit_bindparam(element, **kw) + + self.assert_compile( + t.insert(), + "INSERT INTO t (a, b) VALUES (BIND(:a), BIND(:b))", + {'a':1, 'b':2}, + use_default_dialect=True + ) diff --git a/test/lib/testing.py b/test/lib/testing.py index cdd5ee258f..36a8c8d1a4 100644 --- a/test/lib/testing.py +++ b/test/lib/testing.py @@ -626,7 +626,9 @@ class TestBase(object): assert val, msg class AssertsCompiledSQL(object): - def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, use_default_dialect=False): + def assert_compile(self, clause, result, params=None, + checkparams=None, dialect=None, + use_default_dialect=False): if use_default_dialect: dialect = default.DefaultDialect()