]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fix concat() operator, tests
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Aug 2012 17:47:58 +0000 (13:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Aug 2012 17:47:58 +0000 (13:47 -0400)
- [feature] Custom unary operators can now be
  used by combining operators.custom_op() with
  UnaryExpression().
- clean up the operator dispatch system and make it more consistent.
This does change the compiler contract for custom ops.

CHANGES
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/sql/test_operators.py

diff --git a/CHANGES b/CHANGES
index c898769289c7228293c01a1e2f85c88c00865712..f0c80ee8f9c310a6f4a87895d43ba0a74f3f77d0 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -313,6 +313,10 @@ underneath "0.7.xx".
     also customizable via the "precedence" argument
     on the ``op()`` method.  [ticket:2537]
 
+  - [feature] Custom unary operators can now be
+    used by combining operators.custom_op() with
+    UnaryExpression().
+
   - [changed] Most classes in expression.sql
     are no longer preceded with an underscore,
     i.e. Label, SelectBase, Generative, CompareMixin.
index cc50fbe6b78aabd9406742af211a982b78e7e428..83f6346a787cf9b6c75c4c0a67ab253f6532572f 100644 (file)
@@ -748,12 +748,12 @@ class MSSQLCompiler(compiler.SQLCompiler):
     def visit_char_length_func(self, fn, **kw):
         return "LEN%s" % self.function_argspec(fn, **kw)
 
-    def visit_concat_op(self, binary, **kw):
+    def visit_concat_op_binary(self, binary, operator, **kw):
         return "%s + %s" % \
                 (self.process(binary.left, **kw),
                 self.process(binary.right, **kw))
 
-    def visit_match_op(self, binary, **kw):
+    def visit_match_op_binary(self, binary, operator, **kw):
         return "CONTAINS (%s, %s)" % (
                                         self.process(binary.left, **kw),
                                         self.process(binary.right, **kw))
@@ -969,14 +969,14 @@ class MSSQLStrictCompiler(MSSQLCompiler):
     """
     ansi_bind_rules = True
 
-    def visit_in_op(self, binary, **kw):
+    def visit_in_op_binary(self, binary, operator, **kw):
         kw['literal_binds'] = True
         return "%s IN %s" % (
                                 self.process(binary.left, **kw),
                                 self.process(binary.right, **kw)
             )
 
-    def visit_notin_op(self, binary, **kw):
+    def visit_notin_op_binary(self, binary, operator, **kw):
         kw['literal_binds'] = True
         return "%s NOT IN %s" % (
                                 self.process(binary.left, **kw),
index d8bef2d4b38e8ae53114b101c82fab1639c2acd4..2c440605eefba18535e9ae76ce05b22069a148a4 100644 (file)
@@ -1254,11 +1254,13 @@ class MySQLCompiler(compiler.SQLCompiler):
     def visit_sysdate_func(self, fn, **kw):
         return "SYSDATE()"
 
-    def visit_concat_op(self, binary, **kw):
-        return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+    def visit_concat_op_binary(self, binary, operator, **kw):
+        return "concat(%s, %s)" % (self.process(binary.left),
+                                                self.process(binary.right))
 
-    def visit_match_op(self, binary, **kw):
-        return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (self.process(binary.left), self.process(binary.right))
+    def visit_match_op_binary(self, binary, operator, **kw):
+        return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % \
+                    (self.process(binary.left), self.process(binary.right))
 
     def get_from_hint_text(self, table, text):
         return text
index 3c6ec55afc2391454b24b5907391746ac0e34327..7279f254bb8feb5a13086a79f84621b4349f53b9 100644 (file)
@@ -398,7 +398,7 @@ class OracleCompiler(compiler.SQLCompiler):
     compound_keywords = util.update_copy(
         compiler.SQLCompiler.compound_keywords,
         {
-        expression.CompoundSelect.EXCEPT : 'MINUS'
+        expression.CompoundSelect.EXCEPT: 'MINUS'
         }
     )
 
@@ -416,8 +416,9 @@ class OracleCompiler(compiler.SQLCompiler):
     def visit_char_length_func(self, fn, **kw):
         return "LENGTH" + self.function_argspec(fn, **kw)
 
-    def visit_match_op(self, binary, **kw):
-        return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right))
+    def visit_match_op_binary(self, binary, operator, **kw):
+        return "CONTAINS (%s, %s)" % (self.process(binary.left),
+                                        self.process(binary.right))
 
     def get_select_hint_text(self, byfroms):
         return " ".join(
index d779935ce4d030a29f74963772bba588cbea29c7..a9ff988e81802d6a53aecf8cd456b3cd7975f0a2 100644 (file)
@@ -619,12 +619,12 @@ ischema_names = {
 
 class PGCompiler(compiler.SQLCompiler):
 
-    def visit_match_op(self, binary, **kw):
+    def visit_match_op_binary(self, binary, operator, **kw):
         return "%s @@ to_tsquery(%s)" % (
                         self.process(binary.left),
                         self.process(binary.right))
 
-    def visit_ilike_op(self, binary, **kw):
+    def visit_ilike_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
         return '%s ILIKE %s' % \
                 (self.process(binary.left), self.process(binary.right)) \
@@ -632,7 +632,7 @@ class PGCompiler(compiler.SQLCompiler):
                         (' ESCAPE ' + self.render_literal_value(escape, None))
                         or '')
 
-    def visit_notilike_op(self, binary, **kw):
+    def visit_notilike_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
         return '%s NOT ILIKE %s' % \
                 (self.process(binary.left), self.process(binary.right)) \
index 2588de6b4ab9eb3f99df02469a314b8a9ee2830e..300cdb6b4d7aef05a0b642c6ce2c018f7fef25c7 100644 (file)
@@ -65,7 +65,7 @@ BIND_TEMPLATES = {
 }
 
 
-OPERATORS =  {
+OPERATORS = {
     # binary
     operators.and_ : ' AND ',
     operators.or_ : ' OR ',
@@ -594,7 +594,7 @@ class SQLCompiler(engine.Compiled):
                         self.limit_clause(cs) or ""
 
         if self.ctes and \
-            compound_index==1 and not entry:
+            compound_index == 1 and not entry:
             text = self._render_cte_clause() + text
 
         self.stack.pop(-1)
@@ -604,12 +604,29 @@ class SQLCompiler(engine.Compiled):
             return text
 
     def visit_unary(self, unary, **kw):
-        s = unary.element._compiler_dispatch(self, **kw)
         if unary.operator:
-            s = OPERATORS[unary.operator] + s
-        if unary.modifier:
-            s = s + OPERATORS[unary.modifier]
-        return s
+            if unary.modifier:
+                raise exc.CompileError(
+                        "Unary expression does not support operator "
+                        "and modifier simultaneously")
+            disp = getattr(self, "visit_%s_unary_operator" %
+                                    unary.operator.__name__, None)
+            if disp:
+                return disp(unary, unary.operator, **kw)
+            else:
+                return self._generate_generic_unary_operator(unary,
+                                    OPERATORS[unary.operator], **kw)
+        elif unary.modifier:
+            disp = getattr(self, "visit_%s_unary_modifier" %
+                                    unary.modifier.__name__, None)
+            if disp:
+                return disp(unary, unary.modifier, **kw)
+            else:
+                return self._generate_generic_unary_modifier(unary,
+                                    OPERATORS[unary.modifier], **kw)
+        else:
+            raise exc.CompileError(
+                            "Unary expression has no operator or modifier")
 
     def visit_binary(self, binary, **kw):
         # don't allow "? = ?" to render
@@ -618,16 +635,38 @@ class SQLCompiler(engine.Compiled):
             isinstance(binary.right, sql.BindParameter):
             kw['literal_binds'] = True
 
-        return self._operator_dispatch(binary.operator,
-                    binary,
-                    lambda opstr: binary.left._compiler_dispatch(self, **kw) +
-                                        opstr +
-                                    binary.right._compiler_dispatch(
-                                            self, **kw),
-                    **kw
-        )
+        operator = binary.operator
+        disp = getattr(self, "visit_%s_binary" % operator.__name__, None)
+        if disp:
+            return disp(binary, operator, **kw)
+        else:
+            return self._generate_generic_binary(binary,
+                                OPERATORS[operator], **kw)
+
+    def visit_custom_op_binary(self, element, operator, **kw):
+        return self._generate_generic_binary(element,
+                            " " + operator.opstring + " ", **kw)
 
-    def visit_like_op(self, binary, **kw):
+    def visit_custom_op_unary_operator(self, element, operator, **kw):
+        return self._generate_generic_unary_operator(element,
+                            operator.opstring + " ", **kw)
+
+    def visit_custom_op_unary_modifier(self, element, operator, **kw):
+        return self._generate_generic_unary_modifier(element,
+                            " " + operator.opstring, **kw)
+
+    def _generate_generic_binary(self, binary, opstring, **kw):
+        return binary.left._compiler_dispatch(self, **kw) + \
+                                        opstring + \
+                            binary.right._compiler_dispatch(self, **kw)
+
+    def _generate_generic_unary_operator(self, unary, opstring, **kw):
+        return opstring + unary.element._compiler_dispatch(self, **kw)
+
+    def _generate_generic_unary_modifier(self, unary, opstring, **kw):
+        return unary.element._compiler_dispatch(self, **kw) + opstring
+
+    def visit_like_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
         return '%s LIKE %s' % (
                             binary.left._compiler_dispatch(self, **kw),
@@ -636,7 +675,7 @@ class SQLCompiler(engine.Compiled):
                     (' ESCAPE ' + self.render_literal_value(escape, None))
                     or '')
 
-    def visit_notlike_op(self, binary, **kw):
+    def visit_notlike_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
         return '%s NOT LIKE %s' % (
                             binary.left._compiler_dispatch(self, **kw),
@@ -645,7 +684,7 @@ class SQLCompiler(engine.Compiled):
                     (' ESCAPE ' + self.render_literal_value(escape, None))
                     or '')
 
-    def visit_ilike_op(self, binary, **kw):
+    def visit_ilike_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
         return 'lower(%s) LIKE lower(%s)' % (
                             binary.left._compiler_dispatch(self, **kw),
@@ -654,7 +693,7 @@ class SQLCompiler(engine.Compiled):
                     (' ESCAPE ' + self.render_literal_value(escape, None))
                     or '')
 
-    def visit_notilike_op(self, binary, **kw):
+    def visit_notilike_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
         return 'lower(%s) NOT LIKE lower(%s)' % (
                             binary.left._compiler_dispatch(self, **kw),
@@ -663,16 +702,6 @@ class SQLCompiler(engine.Compiled):
                     (' ESCAPE ' + self.render_literal_value(escape, None))
                     or '')
 
-    def visit_custom_op(self, element, dispatch_operator, dispatch_fn, **kw):
-        return dispatch_fn(" " + dispatch_operator.opstring + " ")
-
-    def _operator_dispatch(self, operator, element, fn, **kw):
-        disp = getattr(self, "visit_%s" % operator.__name__, None)
-        if disp:
-            kw.update(dispatch_operator=operator, dispatch_fn=fn)
-            return disp(element, **kw)
-        else:
-            return fn(OPERATORS[operator])
 
     def visit_bindparam(self, bindparam, within_columns_clause=False,
                                             literal_binds=False, **kwargs):
index 6021b40b1b41e2cb06c8a14518cacf87f1186973..31e24e564c88bde8d7668b329f9b3c9c2c0a1fbb 100644 (file)
@@ -2057,6 +2057,7 @@ class _DefaultColumnComparator(ColumnOperators):
         "mod": (__operate,),
         "truediv": (__operate,),
         "custom_op": (__operate,),
+        "concat_op": (__operate,),
         "lt": (__compare, operators.ge),
         "le": (__compare, operators.gt),
         "ne": (__compare, operators.eq),
@@ -3475,7 +3476,14 @@ class Extract(ColumnElement):
 
 
 class UnaryExpression(ColumnElement):
+    """Define a 'unary' expression.
 
+    A unary expression has a single column expression
+    and an operator.  The operator can be placed on the left
+    (where it is called the 'operator') or right (where it is called the
+    'modifier') of the column expression.
+
+    """
     __visit_name__ = 'unary'
 
     def __init__(self, element, operator=None, modifier=None,
index 02acda0f11fe83c1a21a88d9dc7f4643072eb79d..b369dfc1144692d67aa221cc9fe50553ce16bb6c 100644 (file)
@@ -1,8 +1,11 @@
-from test.lib import fixtures
+from test.lib import fixtures, testing
+from test.lib.testing import assert_raises_message
 from sqlalchemy.sql import column, desc, asc, literal, collate
 from sqlalchemy.sql.expression import BinaryExpression, \
-                ClauseList, Grouping, _DefaultColumnComparator
+                ClauseList, Grouping, _DefaultColumnComparator,\
+                UnaryExpression
 from sqlalchemy.sql import operators
+from sqlalchemy import exc
 from sqlalchemy.schema import Column, Table, MetaData
 from sqlalchemy.types import Integer, TypeEngine, TypeDecorator
 
@@ -54,6 +57,65 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
             collate(left, right)
         )
 
+    def test_concat(self):
+        self._do_operate_test(operators.concat_op)
+
+class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
+    __dialect__ = 'default'
+
+    def _factorial_fixture(self):
+        class MyInteger(Integer):
+            class comparator_factory(Integer.Comparator):
+                def factorial(self):
+                    return UnaryExpression(self.expr,
+                                modifier=operators.custom_op("!"),
+                                type_=MyInteger)
+
+                def factorial_prefix(self):
+                    return UnaryExpression(self.expr,
+                                operator=operators.custom_op("!!"),
+                                type_=MyInteger)
+
+        return MyInteger
+
+    def test_factorial(self):
+        col = column('somecol', self._factorial_fixture())
+        self.assert_compile(
+            col.factorial(),
+            "somecol !"
+        )
+
+    def test_double_factorial(self):
+        col = column('somecol', self._factorial_fixture())
+        self.assert_compile(
+            col.factorial().factorial(),
+            "somecol ! !"
+        )
+
+    def test_factorial_prefix(self):
+        col = column('somecol', self._factorial_fixture())
+        self.assert_compile(
+            col.factorial_prefix(),
+            "!! somecol"
+        )
+
+    def test_unary_no_ops(self):
+        assert_raises_message(
+            exc.CompileError,
+            "Unary expression has no operator or modifier",
+            UnaryExpression(literal("x")).compile
+        )
+
+    def test_unary_both_ops(self):
+        assert_raises_message(
+            exc.CompileError,
+            "Unary expression does not support operator and "
+                "modifier simultaneously",
+            UnaryExpression(literal("x"),
+                    operator=operators.custom_op("x"),
+                    modifier=operators.custom_op("y")).compile
+        )
+
 class _CustomComparatorTests(object):
     def test_override_builtin(self):
         c1 = Column('foo', self._add_override_factory())