From: Mike Bayer Date: Sat, 20 Jun 2009 17:20:09 +0000 (+0000) Subject: - functions and operators generated by the compiler now use (almost) regular X-Git-Tag: rel_0_6_6~167 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=008c924f6e65680e3bbf16693b365cb25dbb404e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - functions and operators generated by the compiler now use (almost) regular dispatch functions of the form "visit_" and "visit__fn" to provide customed processing. This replaces the need to copy the "functions" and "operators" dictionaries in compiler subclasses with straightforward visitor methods, and also allows compiler subclasses complete control over rendering, as the full _Function or _BinaryExpression object is passed in. - move the pool assertion to be module-level, zoomark tests keep the connection open across tests. --- diff --git a/06CHANGES b/06CHANGES index c71eb1fdba..6e717f13b7 100644 --- a/06CHANGES +++ b/06CHANGES @@ -73,7 +73,12 @@ - Dialect.get_rowcount() has been renamed to a descriptor "rowcount", and calls cursor.rowcount directly. Dialects which need to hardwire a rowcount in for certain calls should override the method to provide different behavior. - + - functions and operators generated by the compiler now use (almost) regular + dispatch functions of the form "visit_" and "visit__fn" + to provide customed processing. This replaces the need to copy the "functions" + and "operators" dictionaries in compiler subclasses with straightforward + visitor methods, and also allows compiler subclasses complete control over + rendering, as the full _Function or _BinaryExpression object is passed in. - mysql - all the _detect_XXX() functions now run once underneath dialect.initialize() diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 6122c5a0cf..a77801fcfc 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -621,25 +621,14 @@ class FBDialect(default.DefaultDialect): connection.commit(True) -def _substring(s, start, length=None): - "Helper function to handle Firebird 2 SUBSTRING builtin" - - if length is None: - return "SUBSTRING(%s FROM %s)" % (s, start) - else: - return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) - - class FBCompiler(sql.compiler.SQLCompiler): """Firebird specific idiosincrasies""" - # Firebird lacks a builtin modulo operator, but there is - # an equivalent function in the ib_udf library. - operators = sql.compiler.SQLCompiler.operators.copy() - operators.update({ - sql.operators.mod : lambda x, y:"mod(%s, %s)" % (x, y) - }) - + def visit_mod(self, binary, **kw): + # Firebird lacks a builtin modulo operator, but there is + # an equivalent function in the ib_udf library. + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + def visit_alias(self, alias, asfrom=False, **kwargs): # Override to not use the AS keyword which FB 1.5 does not like if asfrom: @@ -647,9 +636,24 @@ class FBCompiler(sql.compiler.SQLCompiler): else: return self.process(alias.original, **kwargs) - functions = sql.compiler.SQLCompiler.functions.copy() - functions['substring'] = _substring + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0]) + start = self.process(func.clauses.clauses[1]) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2]) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) + else: + return "SUBSTRING(%s FROM %s)" % (s, start) + # TODO: auto-detect this or something + LENGTH_FUNCTION_NAME = 'char_length' + + def visit_length_func(self, function, **kw): + return self.LENGTH_FUNCTION_NAME + self.function_argspec(function) + + def visit_char_length_func(self, function, **kw): + return self.LENGTH_FUNCTION_NAME + self.function_argspec(function) + def function_argspec(self, func): if func.clauses: return self.process(func.clause_expr) @@ -682,17 +686,6 @@ class FBCompiler(sql.compiler.SQLCompiler): return "" - LENGTH_FUNCTION_NAME = 'char_length' - def function_string(self, func): - """Substitute the ``length`` function. - - On newer FB there is a ``char_length`` function, while older - ones need the ``strlen`` UDF. - """ - - if func.name == 'length': - return self.LENGTH_FUNCTION_NAME + '%(expr)s' - return super(FBCompiler, self).function_string(func) def _append_returning(self, text, stmt): returning_cols = stmt.kwargs[RETURNING_KW_NAME] diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py index 0d70c3df4e..e81ecccb9b 100644 --- a/lib/sqlalchemy/dialects/maxdb/base.py +++ b/lib/sqlalchemy/dialects/maxdb/base.py @@ -453,8 +453,6 @@ class MaxDBResultProxy(engine_base.ResultProxy): _process_row = MaxDBCachedColumnRow class MaxDBCompiler(compiler.SQLCompiler): - operators = compiler.SQLCompiler.operators.copy() - operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y) function_conversion = { 'CURRENT_DATE': 'DATE', @@ -469,6 +467,9 @@ class MaxDBCompiler(compiler.SQLCompiler): 'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP', 'UTCDATE', 'UTCDIFF']) + def visit_mod(self, binary, **kw): + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + def default_from(self): return ' FROM DUAL' @@ -491,11 +492,13 @@ class MaxDBCompiler(compiler.SQLCompiler): else: return " WITH LOCK EXCLUSIVE" - def apply_function_parens(self, func): - if func.name.upper() in self.bare_functions: - return len(func.clauses) > 0 + def function_argspec(self, fn, **kw): + if fn.name.upper() in self.bare_functions: + return "" + elif len(fn.clauses) > 0: + return compiler.SQLCompiler.function_argspec(self, fn, **kw) else: - return True + return "" def visit_function(self, fn, **kw): transform = self.function_conversion.get(fn.name.upper(), None) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 11ef0725de..ed4355f41a 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -872,21 +872,6 @@ ischema_names = { } class MSSQLCompiler(compiler.SQLCompiler): - operators = compiler.OPERATORS.copy() - operators.update({ - sql_operators.concat_op: '+', - sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y) - }) - - functions = compiler.SQLCompiler.functions.copy() - functions.update ( - { - sql_functions.now: 'CURRENT_TIMESTAMP', - sql_functions.current_date: 'GETDATE()', - 'length': lambda x: "LEN(%s)" % x, - sql_functions.char_length: lambda x: "LEN(%s)" % x - } - ) extract_map = compiler.SQLCompiler.extract_map.copy() extract_map.update ({ @@ -900,6 +885,24 @@ class MSSQLCompiler(compiler.SQLCompiler): super(MSSQLCompiler, self).__init__(*args, **kwargs) self.tablealiases = {} + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_current_date_func(self, fn, **kw): + return "GETDATE()" + + def visit_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_char_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_concat_op(self, binary): + return "%s + %s" % (self.process(binary.left), self.process(binary.right)) + + def visit_match_op(self, binary): + return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right)) + def get_select_precolumns(self, select): """ MS-SQL puts TOP, it's version of LIMIT here """ if select._distinct or select._limit: diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 47ce09a866..a4cb70eafd 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1285,26 +1285,24 @@ class MySQLExecutionContext(default.DefaultExecutionContext): return AUTOCOMMIT_RE.match(statement) class MySQLCompiler(compiler.SQLCompiler): - operators = util.update_copy( - compiler.SQLCompiler.operators, - { - sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y), - sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y) - } - ) - - functions = util.update_copy( - compiler.SQLCompiler.functions, - { - sql_functions.random: 'rand%(expr)s', - "utc_timestamp":"UTC_TIMESTAMP" - }) extract_map = compiler.SQLCompiler.extract_map.copy() extract_map.update ({ 'milliseconds': 'millisecond', }) - + + def visit_random_func(self, fn, **kw): + return "rand%s" % self.function_argspec(fn) + + def visit_utc_timestamp_func(self, fn, **kw): + return "UTC_TIMESTAMP" + + def visit_concat_op(self, binary, **kw): + return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_match_op(self, binary, **kw): + return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (self.process(binary.left), self.process(binary.right)) + def visit_typeclause(self, typeclause): type_ = typeclause.type.dialect_impl(self.dialect) if isinstance(type_, MSInteger): diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 11c331915c..adaea792ac 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -40,12 +40,8 @@ class MySQL_mysqldbExecutionContext(MySQLExecutionContext): class MySQL_mysqldbCompiler(MySQLCompiler): - operators = util.update_copy( - MySQLCompiler.operators, - { - sql_operators.mod: '%%', - } - ) + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) def post_process_text(self, text): return text.replace('%', '%%') diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 5de31d510b..4bc664b506 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -248,26 +248,26 @@ class OracleCompiler(compiler.SQLCompiler): the use_ansi flag is False. """ - operators = util.update_copy( - compiler.SQLCompiler.operators, - { - sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y), - sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y) - } - ) - - functions = util.update_copy( - compiler.SQLCompiler.functions, - { - sql_functions.now : 'CURRENT_TIMESTAMP' - } - ) - def __init__(self, *args, **kwargs): super(OracleCompiler, self).__init__(*args, **kwargs) self.__wheres = {} self._quoted_bind_names = {} + def visit_mod(self, binary, **kw): + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_match_op(self, binary, **kw): + return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def function_argspec(self, fn, **kw): + if len(fn.clauses) > 0: + return compiler.SQLCompiler.function_argspec(self, fn, **kw) + else: + return "" + def bindparam_string(self, name): if self.preparer._bindparam_requires_quotes(name): quoted_name = '"%s"' % name @@ -284,9 +284,6 @@ class OracleCompiler(compiler.SQLCompiler): return " FROM DUAL" - def apply_function_parens(self, func): - return len(func.clauses) > 0 - def visit_join(self, join, **kwargs): if self.dialect.use_ansi: return compiler.SQLCompiler.visit_join(self, join, **kwargs) diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index 30a1877807..a10baf5d43 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -192,16 +192,18 @@ ischema_names = { class PGCompiler(compiler.SQLCompiler): - operators = util.update_copy( - compiler.SQLCompiler.operators, - { - sql_operators.mod : '%%', - - sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y), - } - ) + def visit_match_op(self, binary, **kw): + return "%s @@ to_tsquery(%s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_ilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_notilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s NOT ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') def post_process_text(self, text): if '%%' in text: diff --git a/lib/sqlalchemy/dialects/postgres/pg8000.py b/lib/sqlalchemy/dialects/postgres/pg8000.py index 47ccab3f8b..d42fd937c9 100644 --- a/lib/sqlalchemy/dialects/postgres/pg8000.py +++ b/lib/sqlalchemy/dialects/postgres/pg8000.py @@ -22,7 +22,7 @@ from sqlalchemy.engine import default import decimal from sqlalchemy import util from sqlalchemy import types as sqltypes -from sqlalchemy.dialects.postgres.base import PGDialect +from sqlalchemy.dialects.postgres.base import PGDialect, PGCompiler class PGNumeric(sqltypes.Numeric): def bind_processor(self, dialect): @@ -42,6 +42,11 @@ class PGNumeric(sqltypes.Numeric): class Postgres_pg8000ExecutionContext(default.DefaultExecutionContext): pass +class Postgres_pg8000Compiler(PGCompiler): + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) + + class Postgres_pg8000(PGDialect): driver = 'pg8000' @@ -52,6 +57,8 @@ class Postgres_pg8000(PGDialect): default_paramstyle = 'format' supports_sane_multi_rowcount = False execution_ctx_cls = Postgres_pg8000ExecutionContext + statement_compiler = Postgres_pg8000Compiler + colspecs = util.update_copy( PGDialect.colspecs, { diff --git a/lib/sqlalchemy/dialects/postgres/psycopg2.py b/lib/sqlalchemy/dialects/postgres/psycopg2.py index c865c9decf..a3ffdd84bf 100644 --- a/lib/sqlalchemy/dialects/postgres/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgres/psycopg2.py @@ -94,12 +94,8 @@ class Postgres_psycopg2ExecutionContext(default.DefaultExecutionContext): return base.ResultProxy(self) class Postgres_psycopg2Compiler(PGCompiler): - operators = util.update_copy( - PGCompiler.operators, - { - sql_operators.mod : '%%', - } - ) + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) def post_process_text(self, text): return text.replace('%', '%%') diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 83c405f69f..8a414535bf 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -169,14 +169,6 @@ ischema_names = { class SQLiteCompiler(compiler.SQLCompiler): - functions = compiler.SQLCompiler.functions.copy() - functions.update ( - { - sql_functions.now: 'CURRENT_TIMESTAMP', - sql_functions.char_length: 'length%(expr)s' - } - ) - extract_map = compiler.SQLCompiler.extract_map.copy() extract_map.update({ 'month': '%m', @@ -191,6 +183,12 @@ class SQLiteCompiler(compiler.SQLCompiler): 'week': '%W' }) + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_char_length_func(self, fn, **kw): + return "length%s" % self.funtion_argspec(fn) + def visit_cast(self, cast, **kwargs): if self.dialect.supports_cast: return super(SQLiteCompiler, self).visit_cast(cast) diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index 4204530ae2..775d2b3b09 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -212,10 +212,6 @@ class SybaseExecutionContext(default.DefaultExecutionContext): class SybaseSQLCompiler(compiler.SQLCompiler): - operators = compiler.SQLCompiler.operators.copy() - operators.update({ - sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y), - }) extract_map = compiler.SQLCompiler.extract_map.copy() extract_map.update ({ @@ -224,6 +220,9 @@ class SybaseSQLCompiler(compiler.SQLCompiler): 'milliseconds': 'millisecond' }) + def visit_mod(self, binary, **kw): + return "MOD(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + def bindparam_string(self, name): res = super(SybaseSQLCompiler, self).bindparam_string(name) if name.lower().startswith('literal'): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 71e97262c6..1863d38cf3 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -58,42 +58,43 @@ BIND_TEMPLATES = { OPERATORS = { - operators.and_ : 'AND', - operators.or_ : 'OR', - operators.inv : 'NOT', - operators.add : '+', - operators.mul : '*', - operators.sub : '-', + # binary + operators.and_ : ' AND ', + operators.or_ : ' OR ', + operators.add : ' + ', + operators.mul : ' * ', + operators.sub : ' - ', # Py2K - operators.div : '/', + operators.div : ' / ', # end Py2K - operators.mod : '%', - operators.truediv : '/', - operators.lt : '<', - operators.le : '<=', - operators.ne : '!=', - operators.gt : '>', - operators.ge : '>=', - operators.eq : '=', - operators.distinct_op : 'DISTINCT', - operators.concat_op : '||', - operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.between_op : 'BETWEEN', - operators.match_op : 'MATCH', - operators.in_op : 'IN', - operators.notin_op : 'NOT IN', + operators.mod : ' % ', + operators.truediv : ' / ', + operators.lt : ' < ', + operators.le : ' <= ', + operators.ne : ' != ', + operators.gt : ' > ', + operators.ge : ' >= ', + operators.eq : ' = ', + operators.concat_op : ' || ', + operators.between_op : ' BETWEEN ', + operators.match_op : ' MATCH ', + operators.in_op : ' IN ', + operators.notin_op : ' NOT IN ', operators.comma_op : ', ', - operators.desc_op : 'DESC', - operators.asc_op : 'ASC', - operators.from_ : 'FROM', - operators.as_ : 'AS', - operators.exists : 'EXISTS', - operators.is_ : 'IS', - operators.isnot : 'IS NOT', - operators.collate : 'COLLATE', + operators.from_ : ' FROM ', + operators.as_ : ' AS ', + operators.is_ : ' IS ', + operators.isnot : ' IS NOT ', + operators.collate : ' COLLATE ', + + # unary + operators.exists : 'EXISTS ', + operators.distinct_op : 'DISTINCT ', + operators.inv : 'NOT ', + + # modifiers + operators.desc_op : ' DESC', + operators.asc_op : ' ASC', } FUNCTIONS = { @@ -150,8 +151,6 @@ class SQLCompiler(engine.Compiled): """ - operators = OPERATORS - functions = FUNCTIONS extract_map = EXTRACT_MAP # class-level defaults which can be set at the instance @@ -270,9 +269,7 @@ class SQLCompiler(engine.Compiled): if result_map is not None: result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type) - return self.process(label.element) + " " + \ - self.operator_string(operators.as_) + " " + \ - self.preparer.format_label(label, labelname) + return self.process(label.element) + OPERATORS[operators.as_] + self.preparer.format_label(label, labelname) else: return self.process(label.element) @@ -344,10 +341,8 @@ class SQLCompiler(engine.Compiled): sep = clauselist.operator if sep is None: sep = " " - elif sep is operators.comma_op: - sep = ', ' else: - sep = " " + self.operator_string(clauselist.operator) + " " + sep = OPERATORS[clauselist.operator] return sep.join(s for s in (self.process(c) for c in clauselist.clauses) if s is not None) @@ -373,19 +368,16 @@ class SQLCompiler(engine.Compiled): if result_map is not None: result_map[func.name.lower()] = (func.name, None, func.type) - name = self.function_string(func) - - if util.callable(name): - return name(*[self.process(x) for x in func.clauses]) + disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + if disp: + return disp(func, **kwargs) else: - return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} + name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s") + return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func, **kwargs)} def function_argspec(self, func, **kwargs): return self.process(func.clause_expr, **kwargs) - def function_string(self, func): - return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s")) - def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): entry = self.stack and self.stack[-1] or {} self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) @@ -408,21 +400,49 @@ class SQLCompiler(engine.Compiled): def visit_unary(self, unary, **kwargs): s = self.process(unary.element) if unary.operator: - s = self.operator_string(unary.operator) + " " + s + s = OPERATORS[unary.operator] + s if unary.modifier: - s = s + " " + self.operator_string(unary.modifier) + s = s + OPERATORS[unary.modifier] return s def visit_binary(self, binary, **kwargs): - op = self.operator_string(binary.operator) - if util.callable(op): - return op(self.process(binary.left), self.process(binary.right), **binary.modifiers) - else: - return self.process(binary.left) + " " + op + " " + self.process(binary.right) + + return self._operator_dispatch(binary.operator, + binary, + lambda opstr: self.process(binary.left) + opstr + self.process(binary.right), + **kwargs + ) - def operator_string(self, operator): - return self.operators.get(operator, str(operator)) + def visit_like_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + def visit_notlike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_ilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_notilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def _operator_dispatch(self, operator, element, fn, **kw): + if util.callable(operator): + disp = getattr(self, "visit_%s" % operator.__name__, None) + if disp: + return disp(element, **kw) + else: + return fn(OPERATORS[operator]) + else: + return fn(" " + operator + " ") + def visit_bindparam(self, bindparam, **kwargs): name = self._truncate_bindparam(bindparam) if name in self.binds: diff --git a/lib/sqlalchemy/test/config.py b/lib/sqlalchemy/test/config.py index b4c2e9def4..6d60642a59 100644 --- a/lib/sqlalchemy/test/config.py +++ b/lib/sqlalchemy/test/config.py @@ -18,6 +18,9 @@ base_config = """ sqlite=sqlite:///:memory: sqlite_file=sqlite:///querytest.db postgres=postgres://scott:tiger@127.0.0.1:5432/test +pg8000=postgres+pg8000://scott:tiger@127.0.0.1:5432/test +postgres_jython=postgres+zxjdbc://scott:tiger@127.0.0.1:5432/test +mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test mysql=mysql://scott:tiger@127.0.0.1:3306/test oracle=oracle://scott:tiger@127.0.0.1:1521 oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 diff --git a/lib/sqlalchemy/test/noseplugin.py b/lib/sqlalchemy/test/noseplugin.py index dbad80409c..02a1b25bcc 100644 --- a/lib/sqlalchemy/test/noseplugin.py +++ b/lib/sqlalchemy/test/noseplugin.py @@ -149,6 +149,8 @@ class NoseSQLAlchemy(Plugin): def afterTest(self, test): testing.resetwarnings() + + def afterContext(self): testing.global_cleanup_assertions() #def handleError(self, test, err): diff --git a/test/aaa_profiling/test_pool.py b/test/aaa_profiling/test_pool.py index 1d11dd766d..6ae3edc989 100644 --- a/test/aaa_profiling/test_pool.py +++ b/test/aaa_profiling/test_pool.py @@ -5,6 +5,9 @@ from sqlalchemy.pool import QueuePool class QueuePoolTest(TestBase, AssertsExecutionResults): class Connection(object): + def rollback(self): + pass + def close(self): pass @@ -23,7 +26,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults): conn = pool.connect() conn.close() - @profiling.function_call_count(31, {'2.4': 21}) + @profiling.function_call_count(29, {'2.4': 21}) def go(): conn2 = pool.connect() return conn2 diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 68b367b3ea..0659a2fa70 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -24,7 +24,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): bindtemplate = BIND_TEMPLATES[dialect.paramstyle] self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect) self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect) - if isinstance(dialect, firebird.dialect): + if isinstance(dialect, (firebird.dialect, maxdb.dialect, oracle.dialect)): self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect) else: self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect) @@ -64,7 +64,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): ('random()', sqlite.dialect()), ('random()', postgres.dialect()), ('rand()', mysql.dialect()), - ('random()', oracle.dialect()) + ('random', oracle.dialect()) ]: self.assert_compile(func.random(), ret, dialect=dialect)