]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- functions and operators generated by the compiler now use (almost) regular
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 20 Jun 2009 17:20:09 +0000 (17:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 20 Jun 2009 17:20:09 +0000 (17:20 +0000)
dispatch functions of the form "visit_<opname>" and "visit_<funcname>_fn"
to provide customed processing.  This replaces the need to copy the "functions"
and "operators" dictionaries in compiler subclasses with straightforward
visitor methods, and also allows compiler subclasses complete control over
rendering, as the full _Function or _BinaryExpression object is passed in.
- move the pool assertion to be module-level, zoomark tests keep the connection
open across tests.

17 files changed:
06CHANGES
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/maxdb/base.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgres/base.py
lib/sqlalchemy/dialects/postgres/pg8000.py
lib/sqlalchemy/dialects/postgres/psycopg2.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/dialects/sybase/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/test/config.py
lib/sqlalchemy/test/noseplugin.py
test/aaa_profiling/test_pool.py
test/sql/test_functions.py

index c71eb1fdba3a46eea777a3fa80f938e2c2b03d72..6e717f13b7f9d53103f8def7e00b584e3b81b2d5 100644 (file)
--- a/06CHANGES
+++ b/06CHANGES
     - Dialect.get_rowcount() has been renamed to a descriptor "rowcount", and calls 
       cursor.rowcount directly.  Dialects which need to hardwire a rowcount in for 
       certain calls should override the method to provide different behavior.
-    
+    - functions and operators generated by the compiler now use (almost) regular
+      dispatch functions of the form "visit_<opname>" and "visit_<funcname>_fn" 
+      to provide customed processing.  This replaces the need to copy the "functions" 
+      and "operators" dictionaries in compiler subclasses with straightforward
+      visitor methods, and also allows compiler subclasses complete control over 
+      rendering, as the full _Function or _BinaryExpression object is passed in.
     
 - mysql
     - all the _detect_XXX() functions now run once underneath dialect.initialize()
index 6122c5a0cf394c55850bc7f8ea7b13af658ceb7d..a77801fcfc2eede8bab2c7e740c2dbee8027c47e 100644 (file)
@@ -621,25 +621,14 @@ class FBDialect(default.DefaultDialect):
         connection.commit(True)
 
 
-def _substring(s, start, length=None):
-    "Helper function to handle Firebird 2 SUBSTRING builtin"
-
-    if length is None:
-        return "SUBSTRING(%s FROM %s)" % (s, start)
-    else:
-        return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
-
-
 class FBCompiler(sql.compiler.SQLCompiler):
     """Firebird specific idiosincrasies"""
 
-    # Firebird lacks a builtin modulo operator, but there is
-    # an equivalent function in the ib_udf library.
-    operators = sql.compiler.SQLCompiler.operators.copy()
-    operators.update({
-        sql.operators.mod : lambda x, y:"mod(%s, %s)" % (x, y)
-        })
-
+    def visit_mod(self, binary, **kw):
+        # Firebird lacks a builtin modulo operator, but there is
+        # an equivalent function in the ib_udf library.
+        return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+        
     def visit_alias(self, alias, asfrom=False, **kwargs):
         # Override to not use the AS keyword which FB 1.5 does not like
         if asfrom:
@@ -647,9 +636,24 @@ class FBCompiler(sql.compiler.SQLCompiler):
         else:
             return self.process(alias.original, **kwargs)
 
-    functions = sql.compiler.SQLCompiler.functions.copy()
-    functions['substring'] = _substring
+    def visit_substring_func(self, func, **kw):
+        s = self.process(func.clauses.clauses[0])
+        start = self.process(func.clauses.clauses[1])
+        if len(func.clauses.clauses) > 2:
+            length = self.process(func.clauses.clauses[2])
+            return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
+        else:
+            return "SUBSTRING(%s FROM %s)" % (s, start)
 
+    # TODO: auto-detect this or something
+    LENGTH_FUNCTION_NAME = 'char_length'
+        
+    def visit_length_func(self, function, **kw):
+        return self.LENGTH_FUNCTION_NAME + self.function_argspec(function)
+
+    def visit_char_length_func(self, function, **kw):
+        return self.LENGTH_FUNCTION_NAME + self.function_argspec(function)
+        
     def function_argspec(self, func):
         if func.clauses:
             return self.process(func.clause_expr)
@@ -682,17 +686,6 @@ class FBCompiler(sql.compiler.SQLCompiler):
 
         return ""
 
-    LENGTH_FUNCTION_NAME = 'char_length'
-    def function_string(self, func):
-        """Substitute the ``length`` function.
-
-        On newer FB there is a ``char_length`` function, while older
-        ones need the ``strlen`` UDF.
-        """
-
-        if func.name == 'length':
-            return self.LENGTH_FUNCTION_NAME + '%(expr)s'
-        return super(FBCompiler, self).function_string(func)
 
     def _append_returning(self, text, stmt):
         returning_cols = stmt.kwargs[RETURNING_KW_NAME]
index 0d70c3df4e94fe6b8e2df8037b6ef194b8983903..e81ecccb9b4e08a48bd42a2f117708c9b46c9b57 100644 (file)
@@ -453,8 +453,6 @@ class MaxDBResultProxy(engine_base.ResultProxy):
     _process_row = MaxDBCachedColumnRow
 
 class MaxDBCompiler(compiler.SQLCompiler):
-    operators = compiler.SQLCompiler.operators.copy()
-    operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y)
 
     function_conversion = {
         'CURRENT_DATE': 'DATE',
@@ -469,6 +467,9 @@ class MaxDBCompiler(compiler.SQLCompiler):
         'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP',
         'UTCDATE', 'UTCDIFF'])
 
+    def visit_mod(self, binary, **kw):
+        return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+        
     def default_from(self):
         return ' FROM DUAL'
 
@@ -491,11 +492,13 @@ class MaxDBCompiler(compiler.SQLCompiler):
         else:
             return " WITH LOCK EXCLUSIVE"
 
-    def apply_function_parens(self, func):
-        if func.name.upper() in self.bare_functions:
-            return len(func.clauses) > 0
+    def function_argspec(self, fn, **kw):
+        if fn.name.upper() in self.bare_functions:
+            return ""
+        elif len(fn.clauses) > 0:
+            return compiler.SQLCompiler.function_argspec(self, fn, **kw)
         else:
-            return True
+            return ""
 
     def visit_function(self, fn, **kw):
         transform = self.function_conversion.get(fn.name.upper(), None)
index 11ef0725de63cbcf41a210775d8241677b5a8b8d..ed4355f41a571f5ec6ce107d59c2cbd88ccc9857 100644 (file)
@@ -872,21 +872,6 @@ ischema_names = {
 }
 
 class MSSQLCompiler(compiler.SQLCompiler):
-    operators = compiler.OPERATORS.copy()
-    operators.update({
-        sql_operators.concat_op: '+',
-        sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
-    })
-
-    functions = compiler.SQLCompiler.functions.copy()
-    functions.update (
-        {
-            sql_functions.now: 'CURRENT_TIMESTAMP',
-            sql_functions.current_date: 'GETDATE()',
-            'length': lambda x: "LEN(%s)" % x,
-            sql_functions.char_length: lambda x: "LEN(%s)" % x
-        }
-    )
 
     extract_map = compiler.SQLCompiler.extract_map.copy()
     extract_map.update ({
@@ -900,6 +885,24 @@ class MSSQLCompiler(compiler.SQLCompiler):
         super(MSSQLCompiler, self).__init__(*args, **kwargs)
         self.tablealiases = {}
 
+    def visit_now_func(self, fn, **kw):
+        return "CURRENT_TIMESTAMP"
+        
+    def visit_current_date_func(self, fn, **kw):
+        return "GETDATE()"
+        
+    def visit_length_func(self, fn, **kw):
+        return "LEN%s" % self.function_argspec(fn, **kw)
+        
+    def visit_char_length_func(self, fn, **kw):
+        return "LEN%s" % self.function_argspec(fn, **kw)
+        
+    def visit_concat_op(self, binary):
+        return "%s + %s" % (self.process(binary.left), self.process(binary.right))
+        
+    def visit_match_op(self, binary):
+        return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right))
+        
     def get_select_precolumns(self, select):
         """ MS-SQL puts TOP, it's version of LIMIT here """
         if select._distinct or select._limit:
index 47ce09a86627ee9e1ec0d6f61a057cdca39be452..a4cb70eafd5870e1653ee537b52f1c48f56af667 100644 (file)
@@ -1285,26 +1285,24 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
         return AUTOCOMMIT_RE.match(statement)
 
 class MySQLCompiler(compiler.SQLCompiler):
-    operators = util.update_copy(
-        compiler.SQLCompiler.operators,
-        {
-            sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y),
-            sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y)
-        }
-    )
-
-    functions = util.update_copy(
-        compiler.SQLCompiler.functions,
-        {
-            sql_functions.random: 'rand%(expr)s',
-            "utc_timestamp":"UTC_TIMESTAMP"
-        })
 
     extract_map = compiler.SQLCompiler.extract_map.copy()
     extract_map.update ({
         'milliseconds': 'millisecond',
     })
-
+    
+    def visit_random_func(self, fn, **kw):
+        return "rand%s" % self.function_argspec(fn)
+    
+    def visit_utc_timestamp_func(self, fn, **kw):
+        return "UTC_TIMESTAMP"
+        
+    def visit_concat_op(self, binary, **kw):
+        return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+        
+    def visit_match_op(self, binary, **kw):
+        return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (self.process(binary.left), self.process(binary.right))
+        
     def visit_typeclause(self, typeclause):
         type_ = typeclause.type.dialect_impl(self.dialect)
         if isinstance(type_, MSInteger):
index 11c331915c98ffa89b0fd8f388760647ddb59fec..adaea792ac83f22f75dcf7d88dcc97e1eaacf77a 100644 (file)
@@ -40,12 +40,8 @@ class MySQL_mysqldbExecutionContext(MySQLExecutionContext):
         
         
 class MySQL_mysqldbCompiler(MySQLCompiler):
-    operators = util.update_copy(
-        MySQLCompiler.operators,
-        {
-            sql_operators.mod: '%%',
-        }
-    )
+    def visit_mod(self, binary, **kw):
+        return self.process(binary.left) + " %% " + self.process(binary.right)
     
     def post_process_text(self, text):
         return text.replace('%', '%%')
index 5de31d510b7ca45ff26f2af1a40e8a4a8afeb14b..4bc664b5064cfea04731965ca49a54a5cb8da9a5 100644 (file)
@@ -248,26 +248,26 @@ class OracleCompiler(compiler.SQLCompiler):
     the use_ansi flag is False.
     """
 
-    operators = util.update_copy(
-        compiler.SQLCompiler.operators,
-        {
-            sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y),
-            sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
-        }
-    )
-
-    functions = util.update_copy(
-        compiler.SQLCompiler.functions,
-        {
-            sql_functions.now : 'CURRENT_TIMESTAMP'
-        }
-    )
-
     def __init__(self, *args, **kwargs):
         super(OracleCompiler, self).__init__(*args, **kwargs)
         self.__wheres = {}
         self._quoted_bind_names = {}
 
+    def visit_mod(self, binary, **kw):
+        return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+    
+    def visit_now_func(self, fn, **kw):
+        return "CURRENT_TIMESTAMP"
+        
+    def visit_match_op(self, binary, **kw):
+        return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right))
+    
+    def function_argspec(self, fn, **kw):
+        if len(fn.clauses) > 0:
+            return compiler.SQLCompiler.function_argspec(self, fn, **kw)
+        else:
+            return ""
+        
     def bindparam_string(self, name):
         if self.preparer._bindparam_requires_quotes(name):
             quoted_name = '"%s"' % name
@@ -284,9 +284,6 @@ class OracleCompiler(compiler.SQLCompiler):
 
         return " FROM DUAL"
 
-    def apply_function_parens(self, func):
-        return len(func.clauses) > 0
-
     def visit_join(self, join, **kwargs):
         if self.dialect.use_ansi:
             return compiler.SQLCompiler.visit_join(self, join, **kwargs)
index 30a1877807c2dc97ea2c3eda2c9c2b9a2f7121c8..a10baf5d43f92f8da1f5dbeee43cb2271d0ff195 100644 (file)
@@ -192,16 +192,18 @@ ischema_names = {
 
 class PGCompiler(compiler.SQLCompiler):
     
-    operators = util.update_copy(
-        compiler.SQLCompiler.operators,
-        {
-            sql_operators.mod : '%%',
-        
-            sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-            sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-            sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
-        }
-    )
+    def visit_match_op(self, binary, **kw):
+        return "%s @@ to_tsquery(%s)" % (self.process(binary.left), self.process(binary.right))
+
+    def visit_ilike_op(self, binary, **kw):
+        escape = binary.modifiers.get("escape", None)
+        return '%s ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+            + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+    def visit_notilike_op(self, binary, **kw):
+        escape = binary.modifiers.get("escape", None)
+        return '%s NOT ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+            + (escape and ' ESCAPE \'%s\'' % escape or '')
 
     def post_process_text(self, text):
         if '%%' in text:
index 47ccab3f8bffb112c71c9e2ea411b3e797b4040b..d42fd937c929c1c81eb0acc860d845d7e3408e2a 100644 (file)
@@ -22,7 +22,7 @@ from sqlalchemy.engine import default
 import decimal
 from sqlalchemy import util
 from sqlalchemy import types as sqltypes
-from sqlalchemy.dialects.postgres.base import PGDialect
+from sqlalchemy.dialects.postgres.base import PGDialect, PGCompiler
 
 class PGNumeric(sqltypes.Numeric):
     def bind_processor(self, dialect):
@@ -42,6 +42,11 @@ class PGNumeric(sqltypes.Numeric):
 class Postgres_pg8000ExecutionContext(default.DefaultExecutionContext):
     pass
 
+class Postgres_pg8000Compiler(PGCompiler):
+    def visit_mod(self, binary, **kw):
+        return self.process(binary.left) + " %% " + self.process(binary.right)
+    
+    
 class Postgres_pg8000(PGDialect):
     driver = 'pg8000'
 
@@ -52,6 +57,8 @@ class Postgres_pg8000(PGDialect):
     default_paramstyle = 'format'
     supports_sane_multi_rowcount = False
     execution_ctx_cls = Postgres_pg8000ExecutionContext
+    statement_compiler = Postgres_pg8000Compiler
+    
     colspecs = util.update_copy(
         PGDialect.colspecs,
         {
index c865c9decfbe1b8e7f1556d1b648302f238773a2..a3ffdd84bf68ccdde4187430cce1575a2b4944aa 100644 (file)
@@ -94,12 +94,8 @@ class Postgres_psycopg2ExecutionContext(default.DefaultExecutionContext):
             return base.ResultProxy(self)
 
 class Postgres_psycopg2Compiler(PGCompiler):
-    operators = util.update_copy(
-        PGCompiler.operators, 
-        {
-            sql_operators.mod : '%%',
-        }
-    )
+    def visit_mod(self, binary, **kw):
+        return self.process(binary.left) + " %% " + self.process(binary.right)
     
     def post_process_text(self, text):
         return text.replace('%', '%%')
index 83c405f69f49445058794604f6c4974a95f7890e..8a414535bfddbea1c2775a41649f0f224474231c 100644 (file)
@@ -169,14 +169,6 @@ ischema_names = {
 
 
 class SQLiteCompiler(compiler.SQLCompiler):
-    functions = compiler.SQLCompiler.functions.copy()
-    functions.update (
-        {
-            sql_functions.now: 'CURRENT_TIMESTAMP',
-            sql_functions.char_length: 'length%(expr)s'
-        }
-    )
-
     extract_map = compiler.SQLCompiler.extract_map.copy()
     extract_map.update({
         'month': '%m',
@@ -191,6 +183,12 @@ class SQLiteCompiler(compiler.SQLCompiler):
         'week': '%W'
     })
 
+    def visit_now_func(self, fn, **kw):
+        return "CURRENT_TIMESTAMP"
+    
+    def visit_char_length_func(self, fn, **kw):
+        return "length%s" % self.funtion_argspec(fn)
+        
     def visit_cast(self, cast, **kwargs):
         if self.dialect.supports_cast:
             return super(SQLiteCompiler, self).visit_cast(cast)
index 4204530ae27629cf87f4db652353b7d45a39869b..775d2b3b09faebbba851de4bab3c9e17c6f65853 100644 (file)
@@ -212,10 +212,6 @@ class SybaseExecutionContext(default.DefaultExecutionContext):
 
 
 class SybaseSQLCompiler(compiler.SQLCompiler):
-    operators = compiler.SQLCompiler.operators.copy()
-    operators.update({
-        sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y),
-    })
 
     extract_map = compiler.SQLCompiler.extract_map.copy()
     extract_map.update ({
@@ -224,6 +220,9 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
         'milliseconds': 'millisecond'
     })
 
+    def visit_mod(self, binary, **kw):
+        return "MOD(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
     def bindparam_string(self, name):
         res = super(SybaseSQLCompiler, self).bindparam_string(name)
         if name.lower().startswith('literal'):
index 71e97262c628e515fa7db032efcbd4c006c46d90..1863d38cf3fa77a851d8e87a1eb641bea1abe495 100644 (file)
@@ -58,42 +58,43 @@ BIND_TEMPLATES = {
 
 
 OPERATORS =  {
-    operators.and_ : 'AND',
-    operators.or_ : 'OR',
-    operators.inv : 'NOT',
-    operators.add : '+',
-    operators.mul : '*',
-    operators.sub : '-',
+    # binary
+    operators.and_ : ' AND ',
+    operators.or_ : ' OR ',
+    operators.add : ' + ',
+    operators.mul : ' * ',
+    operators.sub : ' - ',
     # Py2K
-    operators.div : '/',
+    operators.div : ' / ',
     # end Py2K
-    operators.mod : '%',
-    operators.truediv : '/',
-    operators.lt : '<',
-    operators.le : '<=',
-    operators.ne : '!=',
-    operators.gt : '>',
-    operators.ge : '>=',
-    operators.eq : '=',
-    operators.distinct_op : 'DISTINCT',
-    operators.concat_op : '||',
-    operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-    operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-    operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-    operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
-    operators.between_op : 'BETWEEN',
-    operators.match_op : 'MATCH',
-    operators.in_op : 'IN',
-    operators.notin_op : 'NOT IN',
+    operators.mod : ' % ',
+    operators.truediv : ' / ',
+    operators.lt : ' < ',
+    operators.le : ' <= ',
+    operators.ne : ' != ',
+    operators.gt : ' > ',
+    operators.ge : ' >= ',
+    operators.eq : ' = ',
+    operators.concat_op : ' || ',
+    operators.between_op : ' BETWEEN ',
+    operators.match_op : ' MATCH ',
+    operators.in_op : ' IN ',
+    operators.notin_op : ' NOT IN ',
     operators.comma_op : ', ',
-    operators.desc_op : 'DESC',
-    operators.asc_op : 'ASC',
-    operators.from_ : 'FROM',
-    operators.as_ : 'AS',
-    operators.exists : 'EXISTS',
-    operators.is_ : 'IS',
-    operators.isnot : 'IS NOT',
-    operators.collate : 'COLLATE',
+    operators.from_ : ' FROM ',
+    operators.as_ : ' AS ',
+    operators.is_ : ' IS ',
+    operators.isnot : ' IS NOT ',
+    operators.collate : ' COLLATE ',
+
+    # unary
+    operators.exists : 'EXISTS ',
+    operators.distinct_op : 'DISTINCT ',
+    operators.inv : 'NOT ',
+
+    # modifiers
+    operators.desc_op : ' DESC',
+    operators.asc_op : ' ASC',
 }
 
 FUNCTIONS = {
@@ -150,8 +151,6 @@ class SQLCompiler(engine.Compiled):
 
     """
 
-    operators = OPERATORS
-    functions = FUNCTIONS
     extract_map = EXTRACT_MAP
 
     # class-level defaults which can be set at the instance
@@ -270,9 +269,7 @@ class SQLCompiler(engine.Compiled):
             if result_map is not None:
                 result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type)
 
-            return self.process(label.element) + " " + \
-                        self.operator_string(operators.as_) + " " + \
-                        self.preparer.format_label(label, labelname)
+            return self.process(label.element) + OPERATORS[operators.as_] + self.preparer.format_label(label, labelname)
         else:
             return self.process(label.element)
             
@@ -344,10 +341,8 @@ class SQLCompiler(engine.Compiled):
         sep = clauselist.operator
         if sep is None:
             sep = " "
-        elif sep is operators.comma_op:
-            sep = ', '
         else:
-            sep = " " + self.operator_string(clauselist.operator) + " "
+            sep = OPERATORS[clauselist.operator]
         return sep.join(s for s in (self.process(c) for c in clauselist.clauses)
                         if s is not None)
 
@@ -373,19 +368,16 @@ class SQLCompiler(engine.Compiled):
         if result_map is not None:
             result_map[func.name.lower()] = (func.name, None, func.type)
 
-        name = self.function_string(func)
-
-        if util.callable(name):
-            return name(*[self.process(x) for x in func.clauses])
+        disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+        if disp:
+            return disp(func, **kwargs)
         else:
-            return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)}
+            name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
+            return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func, **kwargs)}
 
     def function_argspec(self, func, **kwargs):
         return self.process(func.clause_expr, **kwargs)
 
-    def function_string(self, func):
-        return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s"))
-
     def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
         entry = self.stack and self.stack[-1] or {}
         self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
@@ -408,21 +400,49 @@ class SQLCompiler(engine.Compiled):
     def visit_unary(self, unary, **kwargs):
         s = self.process(unary.element)
         if unary.operator:
-            s = self.operator_string(unary.operator) + " " + s
+            s = OPERATORS[unary.operator] + s
         if unary.modifier:
-            s = s + " " + self.operator_string(unary.modifier)
+            s = s + OPERATORS[unary.modifier]
         return s
 
     def visit_binary(self, binary, **kwargs):
-        op = self.operator_string(binary.operator)
-        if util.callable(op):
-            return op(self.process(binary.left), self.process(binary.right), **binary.modifiers)
-        else:
-            return self.process(binary.left) + " " + op + " " + self.process(binary.right)
+        
+        return self._operator_dispatch(binary.operator,
+                    binary,
+                    lambda opstr: self.process(binary.left) + opstr + self.process(binary.right),
+                    **kwargs
+        )
 
-    def operator_string(self, operator):
-        return self.operators.get(operator, str(operator))
+    def visit_like_op(self, binary, **kw):
+        escape = binary.modifiers.get("escape", None)
+        return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+            + (escape and ' ESCAPE \'%s\'' % escape or '')
 
+    def visit_notlike_op(self, binary, **kw):
+        escape = binary.modifiers.get("escape", None)
+        return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+            + (escape and ' ESCAPE \'%s\'' % escape or '')
+        
+    def visit_ilike_op(self, binary, **kw):
+        escape = binary.modifiers.get("escape", None)
+        return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+            + (escape and ' ESCAPE \'%s\'' % escape or '')
+    
+    def visit_notilike_op(self, binary, **kw):
+        escape = binary.modifiers.get("escape", None)
+        return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+            + (escape and ' ESCAPE \'%s\'' % escape or '')
+        
+    def _operator_dispatch(self, operator, element, fn, **kw):
+        if util.callable(operator):
+            disp = getattr(self, "visit_%s" % operator.__name__, None)
+            if disp:
+                return disp(element, **kw)
+            else:
+                return fn(OPERATORS[operator])
+        else:
+            return fn(" " + operator + " ")
+        
     def visit_bindparam(self, bindparam, **kwargs):
         name = self._truncate_bindparam(bindparam)
         if name in self.binds:
index b4c2e9def46581567e881d1d0f559f5ff6f3acd2..6d60642a59e71c5787f50a36899bf529a6f5a865 100644 (file)
@@ -18,6 +18,9 @@ base_config = """
 sqlite=sqlite:///:memory:
 sqlite_file=sqlite:///querytest.db
 postgres=postgres://scott:tiger@127.0.0.1:5432/test
+pg8000=postgres+pg8000://scott:tiger@127.0.0.1:5432/test
+postgres_jython=postgres+zxjdbc://scott:tiger@127.0.0.1:5432/test
+mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test
 mysql=mysql://scott:tiger@127.0.0.1:3306/test
 oracle=oracle://scott:tiger@127.0.0.1:1521
 oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
index dbad80409c5124aed82ee5ce687bb7d0b76b697f..02a1b25bcc53e97564600d8082a85f7270757192 100644 (file)
@@ -149,6 +149,8 @@ class NoseSQLAlchemy(Plugin):
 
     def afterTest(self, test):
         testing.resetwarnings()
+        
+    def afterContext(self):
         testing.global_cleanup_assertions()
         
     #def handleError(self, test, err):
index 1d11dd766d82cbe648b18e030ae7c8998bba3f41..6ae3edc989ddb33e9b93dae2acfe150f186f51ac 100644 (file)
@@ -5,6 +5,9 @@ from sqlalchemy.pool import QueuePool
 
 class QueuePoolTest(TestBase, AssertsExecutionResults):
     class Connection(object):
+        def rollback(self):
+            pass
+            
         def close(self):
             pass
 
@@ -23,7 +26,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults):
         conn = pool.connect()
         conn.close()
 
-        @profiling.function_call_count(31, {'2.4': 21})
+        @profiling.function_call_count(29, {'2.4': 21})
         def go():
             conn2 = pool.connect()
             return conn2
index 68b367b3eaf242aa1dc2e486a031260ecd63ca74..0659a2fa707588da43e80203e504ad94076e48ea 100644 (file)
@@ -24,7 +24,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
             self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect)
             self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect)
-            if isinstance(dialect, firebird.dialect):
+            if isinstance(dialect, (firebird.dialect, maxdb.dialect, oracle.dialect)):
                 self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect)
             else:
                 self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect)
@@ -64,7 +64,7 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             ('random()', sqlite.dialect()),
             ('random()', postgres.dialect()),
             ('rand()', mysql.dialect()),
-            ('random()', oracle.dialect())
+            ('random', oracle.dialect())
         ]:
             self.assert_compile(func.random(), ret, dialect=dialect)