the extension compiles and runs on Python 2.4.
[ticket:2023]
+ - The compiler extension now supports overriding the default
+ compilation of expression._BindParamClause including that
+ the auto-generated binds within the VALUES/SET clause
+ of an insert()/update() statement will also use the new
+ compilation rules. [ticket:2042]
+
- postgresql
- When explicit sequence execution derives the name
of the auto-generated sequence of a SERIAL column,
else:
return fn(" " + operator + " ")
- def visit_bindparam(self, bindparam, within_columns_clause=False,
+ def visit_bindparam(self, bindparam, within_columns_clause=False,
literal_binds=False, **kwargs):
+
if literal_binds or \
(within_columns_clause and \
self.ansi_bind_rules):
return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs)
name = self._truncate_bindparam(bindparam)
+
if name in self.binds:
existing = self.binds[name]
if existing is not bindparam:
"Bind parameter '%s' conflicts with "
"unique bind parameter of the same name" % bindparam.key
)
- elif getattr(existing, '_is_crud', False):
+ elif getattr(existing, '_is_crud', False) or \
+ getattr(bindparam, '_is_crud', False):
raise exc.CompileError(
"bindparam() name '%s' is reserved "
"for automatic usage in the VALUES or SET clause of this "
def _create_crud_bind_param(self, col, value, required=False):
bindparam = sql.bindparam(col.key, value, type_=col.type, required=required)
bindparam._is_crud = True
- if col.key in self.binds:
- raise exc.CompileError(
- "bindparam() name '%s' is reserved "
- "for automatic usage in the VALUES or SET clause of this "
- "insert/update statement. Please use a "
- "name other than column name when using bindparam() "
- "with insert() or update() (for example, 'b_%s')."
- % (col.key, col.key)
- )
+ return bindparam._compiler_dispatch(self)
- self.binds[col.key] = bindparam
- return self.bindparam_string(self._truncate_bindparam(bindparam))
def _get_colparams(self, stmt):
"""create a set of tuples representing column/string pairs for use
super(VisitableType, cls).__init__(clsname, bases, clsdict)
return
- # set up an optimized visit dispatch function
- # for use by the compiler
- if '__visit_name__' in cls.__dict__:
- visit_name = cls.__visit_name__
- if isinstance(visit_name, str):
- getter = operator.attrgetter("visit_%s" % visit_name)
- def _compiler_dispatch(self, visitor, **kw):
- return getter(visitor)(self, **kw)
- else:
- def _compiler_dispatch(self, visitor, **kw):
- return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
-
- cls._compiler_dispatch = _compiler_dispatch
+ _generate_dispatch(cls)
super(VisitableType, cls).__init__(clsname, bases, clsdict)
+def _generate_dispatch(cls):
+ # set up an optimized visit dispatch function
+ # for use by the compiler
+ if '__visit_name__' in cls.__dict__:
+ visit_name = cls.__visit_name__
+ if isinstance(visit_name, str):
+ getter = operator.attrgetter("visit_%s" % visit_name)
+ def _compiler_dispatch(self, visitor, **kw):
+ return getter(visitor)(self, **kw)
+ else:
+ def _compiler_dispatch(self, visitor, **kw):
+ return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
+
+ cls._compiler_dispatch = _compiler_dispatch
+
class Visitable(object):
"""Base class for visitable objects, applies the
``VisitableType`` metaclass.
assert val, msg
class AssertsCompiledSQL(object):
- def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, use_default_dialect=False):
+ def assert_compile(self, clause, result, params=None,
+ checkparams=None, dialect=None,
+ use_default_dialect=False):
if use_default_dialect:
dialect = default.DefaultDialect()
for t in types.type_map.values():
t._type_affinity
- @profiling.function_call_count(69, {'2.4': 44,
+ @profiling.function_call_count(73, {'2.4': 44,
'3.0':77, '3.1':77})
def test_insert(self):
t1.insert().compile()
- @profiling.function_call_count(69, {'2.4': 45})
+ @profiling.function_call_count(73, {'2.4': 45})
def test_update(self):
t1.update().compile()
from sqlalchemy import *
from sqlalchemy.types import TypeEngine
from sqlalchemy.sql.expression import ClauseElement, ColumnClause,\
- FunctionElement, Select
+ FunctionElement, Select,\
+ _BindParamClause
from sqlalchemy.schema import DDLElement
from sqlalchemy.ext.compiler import compiles
-from sqlalchemy.sql import table, column
+from sqlalchemy.sql import table, column, visitors
from sqlalchemy.test import *
class UserDefinedTest(TestBase, AssertsCompiledSQL):
if hasattr(Select, '_compiler_dispatcher'):
del Select._compiler_dispatcher
- def test_default_on_existing(self):
- """test that the existing compiler function remains
- as 'default' when overriding the compilation of an
- existing construct."""
-
-
- t1 = table('t1', column('c1'), column('c2'))
-
- dispatch = Select._compiler_dispatch
- try:
-
- @compiles(Select, 'sqlite')
- def compile(element, compiler, **kw):
- return "OVERRIDE"
-
- s1 = select([t1])
- self.assert_compile(
- s1, "SELECT t1.c1, t1.c2 FROM t1",
- )
-
- from sqlalchemy.dialects.sqlite import base as sqlite
- self.assert_compile(
- s1, "OVERRIDE",
- dialect=sqlite.dialect()
- )
- finally:
- Select._compiler_dispatch = dispatch
- if hasattr(Select, '_compiler_dispatcher'):
- del Select._compiler_dispatcher
-
def test_dialect_specific(self):
class AddThingy(DDLElement):
__visit_name__ = 'add_thingy'
'SELECT FOOsub1, sub2, FOOsubsub1',
use_default_dialect=True
)
+
+
+class DefaultOnExistingTest(TestBase, AssertsCompiledSQL):
+ """Test replacement of default compilation on existing constructs."""
+
+ def teardown(self):
+ for cls in (Select, _BindParamClause):
+ if hasattr(cls, '_compiler_dispatcher'):
+ visitors._generate_dispatch(cls)
+ del cls._compiler_dispatcher
+
+ def test_select(self):
+ t1 = table('t1', column('c1'), column('c2'))
+
+ @compiles(Select, 'sqlite')
+ def compile(element, compiler, **kw):
+ return "OVERRIDE"
+
+ s1 = select([t1])
+ self.assert_compile(
+ s1, "SELECT t1.c1, t1.c2 FROM t1",
+ )
+
+ from sqlalchemy.dialects.sqlite import base as sqlite
+ self.assert_compile(
+ s1, "OVERRIDE",
+ dialect=sqlite.dialect()
+ )
+
+ def test_binds_in_select(self):
+ t = table('t',
+ column('a'),
+ column('b'),
+ column('c')
+ )
+
+ @compiles(_BindParamClause)
+ def gen_bind(element, compiler, **kw):
+ return "BIND(%s)" % compiler.visit_bindparam(element, **kw)
+
+ self.assert_compile(
+ t.select().where(t.c.c == 5),
+ "SELECT t.a, t.b, t.c FROM t WHERE t.c = BIND(:c_1)",
+ use_default_dialect=True
+ )
+
+ def test_binds_in_dml(self):
+ t = table('t',
+ column('a'),
+ column('b'),
+ column('c')
+ )
+
+ @compiles(_BindParamClause)
+ def gen_bind(element, compiler, **kw):
+ return "BIND(%s)" % compiler.visit_bindparam(element, **kw)
+
+ self.assert_compile(
+ t.insert(),
+ "INSERT INTO t (a, b) VALUES (BIND(:a), BIND(:b))",
+ {'a':1, 'b':2},
+ use_default_dialect=True
+ )