- Dialect.get_rowcount() has been renamed to a descriptor "rowcount", and calls
cursor.rowcount directly. Dialects which need to hardwire a rowcount in for
certain calls should override the method to provide different behavior.
-
+ - functions and operators generated by the compiler now use (almost) regular
+ dispatch functions of the form "visit_<opname>" and "visit_<funcname>_fn"
+ to provide customed processing. This replaces the need to copy the "functions"
+ and "operators" dictionaries in compiler subclasses with straightforward
+ visitor methods, and also allows compiler subclasses complete control over
+ rendering, as the full _Function or _BinaryExpression object is passed in.
- mysql
- all the _detect_XXX() functions now run once underneath dialect.initialize()
connection.commit(True)
-def _substring(s, start, length=None):
- "Helper function to handle Firebird 2 SUBSTRING builtin"
-
- if length is None:
- return "SUBSTRING(%s FROM %s)" % (s, start)
- else:
- return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
-
-
class FBCompiler(sql.compiler.SQLCompiler):
"""Firebird specific idiosincrasies"""
- # Firebird lacks a builtin modulo operator, but there is
- # an equivalent function in the ib_udf library.
- operators = sql.compiler.SQLCompiler.operators.copy()
- operators.update({
- sql.operators.mod : lambda x, y:"mod(%s, %s)" % (x, y)
- })
-
+ def visit_mod(self, binary, **kw):
+ # Firebird lacks a builtin modulo operator, but there is
+ # an equivalent function in the ib_udf library.
+ return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
def visit_alias(self, alias, asfrom=False, **kwargs):
# Override to not use the AS keyword which FB 1.5 does not like
if asfrom:
else:
return self.process(alias.original, **kwargs)
- functions = sql.compiler.SQLCompiler.functions.copy()
- functions['substring'] = _substring
+ def visit_substring_func(self, func, **kw):
+ s = self.process(func.clauses.clauses[0])
+ start = self.process(func.clauses.clauses[1])
+ if len(func.clauses.clauses) > 2:
+ length = self.process(func.clauses.clauses[2])
+ return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
+ else:
+ return "SUBSTRING(%s FROM %s)" % (s, start)
+ # TODO: auto-detect this or something
+ LENGTH_FUNCTION_NAME = 'char_length'
+
+ def visit_length_func(self, function, **kw):
+ return self.LENGTH_FUNCTION_NAME + self.function_argspec(function)
+
+ def visit_char_length_func(self, function, **kw):
+ return self.LENGTH_FUNCTION_NAME + self.function_argspec(function)
+
def function_argspec(self, func):
if func.clauses:
return self.process(func.clause_expr)
return ""
- LENGTH_FUNCTION_NAME = 'char_length'
- def function_string(self, func):
- """Substitute the ``length`` function.
-
- On newer FB there is a ``char_length`` function, while older
- ones need the ``strlen`` UDF.
- """
-
- if func.name == 'length':
- return self.LENGTH_FUNCTION_NAME + '%(expr)s'
- return super(FBCompiler, self).function_string(func)
def _append_returning(self, text, stmt):
returning_cols = stmt.kwargs[RETURNING_KW_NAME]
_process_row = MaxDBCachedColumnRow
class MaxDBCompiler(compiler.SQLCompiler):
- operators = compiler.SQLCompiler.operators.copy()
- operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y)
function_conversion = {
'CURRENT_DATE': 'DATE',
'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP',
'UTCDATE', 'UTCDIFF'])
+ def visit_mod(self, binary, **kw):
+ return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
def default_from(self):
return ' FROM DUAL'
else:
return " WITH LOCK EXCLUSIVE"
- def apply_function_parens(self, func):
- if func.name.upper() in self.bare_functions:
- return len(func.clauses) > 0
+ def function_argspec(self, fn, **kw):
+ if fn.name.upper() in self.bare_functions:
+ return ""
+ elif len(fn.clauses) > 0:
+ return compiler.SQLCompiler.function_argspec(self, fn, **kw)
else:
- return True
+ return ""
def visit_function(self, fn, **kw):
transform = self.function_conversion.get(fn.name.upper(), None)
}
class MSSQLCompiler(compiler.SQLCompiler):
- operators = compiler.OPERATORS.copy()
- operators.update({
- sql_operators.concat_op: '+',
- sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
- })
-
- functions = compiler.SQLCompiler.functions.copy()
- functions.update (
- {
- sql_functions.now: 'CURRENT_TIMESTAMP',
- sql_functions.current_date: 'GETDATE()',
- 'length': lambda x: "LEN(%s)" % x,
- sql_functions.char_length: lambda x: "LEN(%s)" % x
- }
- )
extract_map = compiler.SQLCompiler.extract_map.copy()
extract_map.update ({
super(MSSQLCompiler, self).__init__(*args, **kwargs)
self.tablealiases = {}
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_current_date_func(self, fn, **kw):
+ return "GETDATE()"
+
+ def visit_length_func(self, fn, **kw):
+ return "LEN%s" % self.function_argspec(fn, **kw)
+
+ def visit_char_length_func(self, fn, **kw):
+ return "LEN%s" % self.function_argspec(fn, **kw)
+
+ def visit_concat_op(self, binary):
+ return "%s + %s" % (self.process(binary.left), self.process(binary.right))
+
+ def visit_match_op(self, binary):
+ return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
def get_select_precolumns(self, select):
""" MS-SQL puts TOP, it's version of LIMIT here """
if select._distinct or select._limit:
return AUTOCOMMIT_RE.match(statement)
class MySQLCompiler(compiler.SQLCompiler):
- operators = util.update_copy(
- compiler.SQLCompiler.operators,
- {
- sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y),
- sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y)
- }
- )
-
- functions = util.update_copy(
- compiler.SQLCompiler.functions,
- {
- sql_functions.random: 'rand%(expr)s',
- "utc_timestamp":"UTC_TIMESTAMP"
- })
extract_map = compiler.SQLCompiler.extract_map.copy()
extract_map.update ({
'milliseconds': 'millisecond',
})
-
+
+ def visit_random_func(self, fn, **kw):
+ return "rand%s" % self.function_argspec(fn)
+
+ def visit_utc_timestamp_func(self, fn, **kw):
+ return "UTC_TIMESTAMP"
+
+ def visit_concat_op(self, binary, **kw):
+ return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
+ def visit_match_op(self, binary, **kw):
+ return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (self.process(binary.left), self.process(binary.right))
+
def visit_typeclause(self, typeclause):
type_ = typeclause.type.dialect_impl(self.dialect)
if isinstance(type_, MSInteger):
class MySQL_mysqldbCompiler(MySQLCompiler):
- operators = util.update_copy(
- MySQLCompiler.operators,
- {
- sql_operators.mod: '%%',
- }
- )
+ def visit_mod(self, binary, **kw):
+ return self.process(binary.left) + " %% " + self.process(binary.right)
def post_process_text(self, text):
return text.replace('%', '%%')
the use_ansi flag is False.
"""
- operators = util.update_copy(
- compiler.SQLCompiler.operators,
- {
- sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y),
- sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
- }
- )
-
- functions = util.update_copy(
- compiler.SQLCompiler.functions,
- {
- sql_functions.now : 'CURRENT_TIMESTAMP'
- }
- )
-
def __init__(self, *args, **kwargs):
super(OracleCompiler, self).__init__(*args, **kwargs)
self.__wheres = {}
self._quoted_bind_names = {}
+ def visit_mod(self, binary, **kw):
+ return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_match_op(self, binary, **kw):
+ return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
+ def function_argspec(self, fn, **kw):
+ if len(fn.clauses) > 0:
+ return compiler.SQLCompiler.function_argspec(self, fn, **kw)
+ else:
+ return ""
+
def bindparam_string(self, name):
if self.preparer._bindparam_requires_quotes(name):
quoted_name = '"%s"' % name
return " FROM DUAL"
- def apply_function_parens(self, func):
- return len(func.clauses) > 0
-
def visit_join(self, join, **kwargs):
if self.dialect.use_ansi:
return compiler.SQLCompiler.visit_join(self, join, **kwargs)
class PGCompiler(compiler.SQLCompiler):
- operators = util.update_copy(
- compiler.SQLCompiler.operators,
- {
- sql_operators.mod : '%%',
-
- sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
- }
- )
+ def visit_match_op(self, binary, **kw):
+ return "%s @@ to_tsquery(%s)" % (self.process(binary.left), self.process(binary.right))
+
+ def visit_ilike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return '%s ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def visit_notilike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return '%s NOT ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
def post_process_text(self, text):
if '%%' in text:
import decimal
from sqlalchemy import util
from sqlalchemy import types as sqltypes
-from sqlalchemy.dialects.postgres.base import PGDialect
+from sqlalchemy.dialects.postgres.base import PGDialect, PGCompiler
class PGNumeric(sqltypes.Numeric):
def bind_processor(self, dialect):
class Postgres_pg8000ExecutionContext(default.DefaultExecutionContext):
pass
+class Postgres_pg8000Compiler(PGCompiler):
+ def visit_mod(self, binary, **kw):
+ return self.process(binary.left) + " %% " + self.process(binary.right)
+
+
class Postgres_pg8000(PGDialect):
driver = 'pg8000'
default_paramstyle = 'format'
supports_sane_multi_rowcount = False
execution_ctx_cls = Postgres_pg8000ExecutionContext
+ statement_compiler = Postgres_pg8000Compiler
+
colspecs = util.update_copy(
PGDialect.colspecs,
{
return base.ResultProxy(self)
class Postgres_psycopg2Compiler(PGCompiler):
- operators = util.update_copy(
- PGCompiler.operators,
- {
- sql_operators.mod : '%%',
- }
- )
+ def visit_mod(self, binary, **kw):
+ return self.process(binary.left) + " %% " + self.process(binary.right)
def post_process_text(self, text):
return text.replace('%', '%%')
class SQLiteCompiler(compiler.SQLCompiler):
- functions = compiler.SQLCompiler.functions.copy()
- functions.update (
- {
- sql_functions.now: 'CURRENT_TIMESTAMP',
- sql_functions.char_length: 'length%(expr)s'
- }
- )
-
extract_map = compiler.SQLCompiler.extract_map.copy()
extract_map.update({
'month': '%m',
'week': '%W'
})
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_char_length_func(self, fn, **kw):
+ return "length%s" % self.funtion_argspec(fn)
+
def visit_cast(self, cast, **kwargs):
if self.dialect.supports_cast:
return super(SQLiteCompiler, self).visit_cast(cast)
class SybaseSQLCompiler(compiler.SQLCompiler):
- operators = compiler.SQLCompiler.operators.copy()
- operators.update({
- sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y),
- })
extract_map = compiler.SQLCompiler.extract_map.copy()
extract_map.update ({
'milliseconds': 'millisecond'
})
+ def visit_mod(self, binary, **kw):
+ return "MOD(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
def bindparam_string(self, name):
res = super(SybaseSQLCompiler, self).bindparam_string(name)
if name.lower().startswith('literal'):
OPERATORS = {
- operators.and_ : 'AND',
- operators.or_ : 'OR',
- operators.inv : 'NOT',
- operators.add : '+',
- operators.mul : '*',
- operators.sub : '-',
+ # binary
+ operators.and_ : ' AND ',
+ operators.or_ : ' OR ',
+ operators.add : ' + ',
+ operators.mul : ' * ',
+ operators.sub : ' - ',
# Py2K
- operators.div : '/',
+ operators.div : ' / ',
# end Py2K
- operators.mod : '%',
- operators.truediv : '/',
- operators.lt : '<',
- operators.le : '<=',
- operators.ne : '!=',
- operators.gt : '>',
- operators.ge : '>=',
- operators.eq : '=',
- operators.distinct_op : 'DISTINCT',
- operators.concat_op : '||',
- operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- operators.between_op : 'BETWEEN',
- operators.match_op : 'MATCH',
- operators.in_op : 'IN',
- operators.notin_op : 'NOT IN',
+ operators.mod : ' % ',
+ operators.truediv : ' / ',
+ operators.lt : ' < ',
+ operators.le : ' <= ',
+ operators.ne : ' != ',
+ operators.gt : ' > ',
+ operators.ge : ' >= ',
+ operators.eq : ' = ',
+ operators.concat_op : ' || ',
+ operators.between_op : ' BETWEEN ',
+ operators.match_op : ' MATCH ',
+ operators.in_op : ' IN ',
+ operators.notin_op : ' NOT IN ',
operators.comma_op : ', ',
- operators.desc_op : 'DESC',
- operators.asc_op : 'ASC',
- operators.from_ : 'FROM',
- operators.as_ : 'AS',
- operators.exists : 'EXISTS',
- operators.is_ : 'IS',
- operators.isnot : 'IS NOT',
- operators.collate : 'COLLATE',
+ operators.from_ : ' FROM ',
+ operators.as_ : ' AS ',
+ operators.is_ : ' IS ',
+ operators.isnot : ' IS NOT ',
+ operators.collate : ' COLLATE ',
+
+ # unary
+ operators.exists : 'EXISTS ',
+ operators.distinct_op : 'DISTINCT ',
+ operators.inv : 'NOT ',
+
+ # modifiers
+ operators.desc_op : ' DESC',
+ operators.asc_op : ' ASC',
}
FUNCTIONS = {
"""
- operators = OPERATORS
- functions = FUNCTIONS
extract_map = EXTRACT_MAP
# class-level defaults which can be set at the instance
if result_map is not None:
result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type)
- return self.process(label.element) + " " + \
- self.operator_string(operators.as_) + " " + \
- self.preparer.format_label(label, labelname)
+ return self.process(label.element) + OPERATORS[operators.as_] + self.preparer.format_label(label, labelname)
else:
return self.process(label.element)
sep = clauselist.operator
if sep is None:
sep = " "
- elif sep is operators.comma_op:
- sep = ', '
else:
- sep = " " + self.operator_string(clauselist.operator) + " "
+ sep = OPERATORS[clauselist.operator]
return sep.join(s for s in (self.process(c) for c in clauselist.clauses)
if s is not None)
if result_map is not None:
result_map[func.name.lower()] = (func.name, None, func.type)
- name = self.function_string(func)
-
- if util.callable(name):
- return name(*[self.process(x) for x in func.clauses])
+ disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+ if disp:
+ return disp(func, **kwargs)
else:
- return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)}
+ name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
+ return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func, **kwargs)}
def function_argspec(self, func, **kwargs):
return self.process(func.clause_expr, **kwargs)
- def function_string(self, func):
- return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s"))
-
def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
entry = self.stack and self.stack[-1] or {}
self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
def visit_unary(self, unary, **kwargs):
s = self.process(unary.element)
if unary.operator:
- s = self.operator_string(unary.operator) + " " + s
+ s = OPERATORS[unary.operator] + s
if unary.modifier:
- s = s + " " + self.operator_string(unary.modifier)
+ s = s + OPERATORS[unary.modifier]
return s
def visit_binary(self, binary, **kwargs):
- op = self.operator_string(binary.operator)
- if util.callable(op):
- return op(self.process(binary.left), self.process(binary.right), **binary.modifiers)
- else:
- return self.process(binary.left) + " " + op + " " + self.process(binary.right)
+
+ return self._operator_dispatch(binary.operator,
+ binary,
+ lambda opstr: self.process(binary.left) + opstr + self.process(binary.right),
+ **kwargs
+ )
- def operator_string(self, operator):
- return self.operators.get(operator, str(operator))
+ def visit_like_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+ def visit_notlike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def visit_ilike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def visit_notilike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def _operator_dispatch(self, operator, element, fn, **kw):
+ if util.callable(operator):
+ disp = getattr(self, "visit_%s" % operator.__name__, None)
+ if disp:
+ return disp(element, **kw)
+ else:
+ return fn(OPERATORS[operator])
+ else:
+ return fn(" " + operator + " ")
+
def visit_bindparam(self, bindparam, **kwargs):
name = self._truncate_bindparam(bindparam)
if name in self.binds:
sqlite=sqlite:///:memory:
sqlite_file=sqlite:///querytest.db
postgres=postgres://scott:tiger@127.0.0.1:5432/test
+pg8000=postgres+pg8000://scott:tiger@127.0.0.1:5432/test
+postgres_jython=postgres+zxjdbc://scott:tiger@127.0.0.1:5432/test
+mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test
mysql=mysql://scott:tiger@127.0.0.1:3306/test
oracle=oracle://scott:tiger@127.0.0.1:1521
oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
def afterTest(self, test):
testing.resetwarnings()
+
+ def afterContext(self):
testing.global_cleanup_assertions()
#def handleError(self, test, err):
class QueuePoolTest(TestBase, AssertsExecutionResults):
class Connection(object):
+ def rollback(self):
+ pass
+
def close(self):
pass
conn = pool.connect()
conn.close()
- @profiling.function_call_count(31, {'2.4': 21})
+ @profiling.function_call_count(29, {'2.4': 21})
def go():
conn2 = pool.connect()
return conn2
bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect)
self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect)
- if isinstance(dialect, firebird.dialect):
+ if isinstance(dialect, (firebird.dialect, maxdb.dialect, oracle.dialect)):
self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect)
else:
self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect)
('random()', sqlite.dialect()),
('random()', postgres.dialect()),
('rand()', mysql.dialect()),
- ('random()', oracle.dialect())
+ ('random', oracle.dialect())
]:
self.assert_compile(func.random(), ret, dialect=dialect)