]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The compiler extension now supports overriding the default
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Feb 2011 21:04:29 +0000 (16:04 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Feb 2011 21:04:29 +0000 (16:04 -0500)
compilation of expression._BindParamClause including that
the auto-generated binds within the VALUES/SET clause
of an insert()/update() statement will also use the new
compilation rules. [ticket:2042]

CHANGES
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/test/testing.py
test/aaa_profiling/test_compiler.py
test/ext/test_compiler.py

diff --git a/CHANGES b/CHANGES
index 8d38759a33902fa1e4a51edd0c2f5d964e85af56..380a4881ff319f13e9e36d1af86bef186196e052 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -34,6 +34,12 @@ CHANGES
     the extension compiles and runs on Python 2.4.
     [ticket:2023]
 
+  - The compiler extension now supports overriding the default
+    compilation of expression._BindParamClause including that
+    the auto-generated binds within the VALUES/SET clause
+    of an insert()/update() statement will also use the new
+    compilation rules. [ticket:2042]
+
 - postgresql
   - When explicit sequence execution derives the name 
     of the auto-generated sequence of a SERIAL column, 
index 8363bab5421c5604703a3a2ace759c124bc0b966..ed06e32c90d05b8e6eab3bb427bd0efb74ce4e40 100644 (file)
@@ -543,8 +543,9 @@ class SQLCompiler(engine.Compiled):
         else:
             return fn(" " + operator + " ")
 
-    def visit_bindparam(self, bindparam, within_columns_clause=False, 
+    def visit_bindparam(self, bindparam, within_columns_clause=False,
                                             literal_binds=False, **kwargs):
+
         if literal_binds or \
             (within_columns_clause and \
                 self.ansi_bind_rules):
@@ -554,6 +555,7 @@ class SQLCompiler(engine.Compiled):
             return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs)
 
         name = self._truncate_bindparam(bindparam)
+
         if name in self.binds:
             existing = self.binds[name]
             if existing is not bindparam:
@@ -562,7 +564,8 @@ class SQLCompiler(engine.Compiled):
                             "Bind parameter '%s' conflicts with "
                             "unique bind parameter of the same name" % bindparam.key
                         )
-                elif getattr(existing, '_is_crud', False):
+                elif getattr(existing, '_is_crud', False) or \
+                    getattr(bindparam, '_is_crud', False):
                     raise exc.CompileError(
                             "bindparam() name '%s' is reserved "
                             "for automatic usage in the VALUES or SET clause of this "
@@ -923,18 +926,8 @@ class SQLCompiler(engine.Compiled):
     def _create_crud_bind_param(self, col, value, required=False):
         bindparam = sql.bindparam(col.key, value, type_=col.type, required=required)
         bindparam._is_crud = True
-        if col.key in self.binds:
-            raise exc.CompileError(
-                    "bindparam() name '%s' is reserved "
-                    "for automatic usage in the VALUES or SET clause of this "
-                    "insert/update statement.   Please use a " 
-                    "name other than column name when using bindparam() "
-                    "with insert() or update() (for example, 'b_%s')."
-                    % (col.key, col.key)
-                )
+        return bindparam._compiler_dispatch(self)
 
-        self.binds[col.key] = bindparam
-        return self.bindparam_string(self._truncate_bindparam(bindparam))
 
     def _get_colparams(self, stmt):
         """create a set of tuples representing column/string pairs for use
index 8011aa109d960748a9a5f5c46361cb9f9ba61d65..0c6be97d78a72d06824ac039cc5812795c8f963d 100644 (file)
@@ -44,22 +44,25 @@ class VisitableType(type):
             super(VisitableType, cls).__init__(clsname, bases, clsdict)
             return
 
-        # set up an optimized visit dispatch function
-        # for use by the compiler
-        if '__visit_name__' in cls.__dict__:
-            visit_name = cls.__visit_name__
-            if isinstance(visit_name, str):
-                getter = operator.attrgetter("visit_%s" % visit_name)
-                def _compiler_dispatch(self, visitor, **kw):
-                    return getter(visitor)(self, **kw)
-            else:
-                def _compiler_dispatch(self, visitor, **kw):
-                    return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
-
-            cls._compiler_dispatch = _compiler_dispatch
+        _generate_dispatch(cls)
 
         super(VisitableType, cls).__init__(clsname, bases, clsdict)
 
+def _generate_dispatch(cls):
+    # set up an optimized visit dispatch function
+    # for use by the compiler
+    if '__visit_name__' in cls.__dict__:
+        visit_name = cls.__visit_name__
+        if isinstance(visit_name, str):
+            getter = operator.attrgetter("visit_%s" % visit_name)
+            def _compiler_dispatch(self, visitor, **kw):
+                return getter(visitor)(self, **kw)
+        else:
+            def _compiler_dispatch(self, visitor, **kw):
+                return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
+
+        cls._compiler_dispatch = _compiler_dispatch
+
 class Visitable(object):
     """Base class for visitable objects, applies the
     ``VisitableType`` metaclass.
index a6e02d5a86a5dd1013ea36d5a5000def9b761b9c..f6b6ec450c38bc713c7b402df23a196e0c95f881 100644 (file)
@@ -634,7 +634,9 @@ class TestBase(object):
         assert val, msg
 
 class AssertsCompiledSQL(object):
-    def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, use_default_dialect=False):
+    def assert_compile(self, clause, result, params=None, 
+                            checkparams=None, dialect=None, 
+                            use_default_dialect=False):
         if use_default_dialect:
             dialect = default.DefaultDialect()
 
index c8ae6cf171bf7051299f6948eb3185b0370f735a..83cbcaa56e9e22fb0ddc85c737a5478555ff663d 100644 (file)
@@ -30,12 +30,12 @@ class CompileTest(TestBase, AssertsExecutionResults):
         for t in types.type_map.values():
             t._type_affinity
 
-    @profiling.function_call_count(69, {'2.4': 44, 
+    @profiling.function_call_count(73, {'2.4': 44, 
                                             '3.0':77, '3.1':77})
     def test_insert(self):
         t1.insert().compile()
 
-    @profiling.function_call_count(69, {'2.4': 45})
+    @profiling.function_call_count(73, {'2.4': 45})
     def test_update(self):
         t1.update().compile()
 
index ff6ad51f677f2f45e5cdf61122710e05fb376862..0e9b31da10327ddb0d8dd0d73dbe4a210656d12b 100644 (file)
@@ -1,10 +1,11 @@
 from sqlalchemy import *
 from sqlalchemy.types import TypeEngine
 from sqlalchemy.sql.expression import ClauseElement, ColumnClause,\
-                                    FunctionElement, Select
+                                    FunctionElement, Select,\
+                                    _BindParamClause
 from sqlalchemy.schema import DDLElement
 from sqlalchemy.ext.compiler import compiles
-from sqlalchemy.sql import table, column
+from sqlalchemy.sql import table, column, visitors
 from sqlalchemy.test import *
 
 class UserDefinedTest(TestBase, AssertsCompiledSQL):
@@ -128,36 +129,6 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
             if hasattr(Select, '_compiler_dispatcher'):
                 del Select._compiler_dispatcher
 
-    def test_default_on_existing(self):
-        """test that the existing compiler function remains
-        as 'default' when overriding the compilation of an
-        existing construct."""
-
-
-        t1 = table('t1', column('c1'), column('c2'))
-
-        dispatch = Select._compiler_dispatch
-        try:
-
-            @compiles(Select, 'sqlite')
-            def compile(element, compiler, **kw):
-                return "OVERRIDE"
-
-            s1 = select([t1])
-            self.assert_compile(
-                s1, "SELECT t1.c1, t1.c2 FROM t1",
-            )
-
-            from sqlalchemy.dialects.sqlite import base as sqlite
-            self.assert_compile(
-                s1, "OVERRIDE",
-                dialect=sqlite.dialect()
-            )
-        finally:
-            Select._compiler_dispatch = dispatch
-            if hasattr(Select, '_compiler_dispatcher'):
-                del Select._compiler_dispatcher
-
     def test_dialect_specific(self):
         class AddThingy(DDLElement):
             __visit_name__ = 'add_thingy'
@@ -290,3 +261,66 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
             'SELECT FOOsub1, sub2, FOOsubsub1',
             use_default_dialect=True
         )
+
+
+class DefaultOnExistingTest(TestBase, AssertsCompiledSQL):
+    """Test replacement of default compilation on existing constructs."""
+
+    def teardown(self):
+        for cls in (Select, _BindParamClause):
+            if hasattr(cls, '_compiler_dispatcher'):
+                visitors._generate_dispatch(cls)
+                del cls._compiler_dispatcher
+
+    def test_select(self):
+        t1 = table('t1', column('c1'), column('c2'))
+
+        @compiles(Select, 'sqlite')
+        def compile(element, compiler, **kw):
+            return "OVERRIDE"
+
+        s1 = select([t1])
+        self.assert_compile(
+            s1, "SELECT t1.c1, t1.c2 FROM t1",
+        )
+
+        from sqlalchemy.dialects.sqlite import base as sqlite
+        self.assert_compile(
+            s1, "OVERRIDE",
+            dialect=sqlite.dialect()
+        )
+
+    def test_binds_in_select(self):
+        t = table('t',
+            column('a'),
+            column('b'),
+            column('c')
+        )
+
+        @compiles(_BindParamClause)
+        def gen_bind(element, compiler, **kw):
+            return "BIND(%s)" % compiler.visit_bindparam(element, **kw)
+
+        self.assert_compile(
+            t.select().where(t.c.c == 5), 
+            "SELECT t.a, t.b, t.c FROM t WHERE t.c = BIND(:c_1)",
+            use_default_dialect=True
+        )
+
+    def test_binds_in_dml(self):
+        t = table('t',
+            column('a'),
+            column('b'),
+            column('c')
+        )
+
+        @compiles(_BindParamClause)
+        def gen_bind(element, compiler, **kw):
+            return "BIND(%s)" % compiler.visit_bindparam(element, **kw)
+
+        self.assert_compile(
+            t.insert(), 
+            "INSERT INTO t (a, b) VALUES (BIND(:a), BIND(:b))",
+            {'a':1, 'b':2},
+            use_default_dialect=True
+        )