]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug where the :meth:`.Operators.__and__`,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 May 2014 18:35:28 +0000 (14:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 May 2014 18:42:21 +0000 (14:42 -0400)
:meth:`.Operators.__or__` and :meth:`.Operators.__invert__`
operator overload methods could not be overridden within a custom
:class:`.TypeEngine.Comparator` implementation.
fixes #3012

doc/build/changelog/changelog_09.rst
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/elements.py
test/sql/test_operators.py

index bb9da76af24a0c71735b7f1183df46d6ab90db60..33ba65d7107809359919f34f8ea245348a1248f5 100644 (file)
 .. changelog::
     :version: 0.9.5
 
+    .. change::
+        :tags: bug, sql
+        :tickets: 3012
+        :versions: 1.0.0
+
+        Fixed bug where the :meth:`.Operators.__and__`,
+        :meth:`.Operators.__or__` and :meth:`.Operators.__invert__`
+        operator overload methods could not be overridden within a custom
+        :class:`.TypeEngine.Comparator` implementation.
+
     .. change::
         :tags: feature, postgresql
         :tickets: 2785
index 6bc7fb5807ba348975320ae5a92945619b98075d..20f13f70fa084d3e0f8b1c20c03ca735bb188ef4 100644 (file)
@@ -13,7 +13,7 @@ from . import type_api
 from .elements import BindParameter, True_, False_, BinaryExpression, \
         Null, _const_expr, _clause_element_as_expr, \
         ClauseList, ColumnElement, TextClause, UnaryExpression, \
-        collate, _is_literal, _literal_as_text, ClauseElement
+        collate, _is_literal, _literal_as_text, ClauseElement, and_, or_
 from .selectable import SelectBase, Alias, Selectable, ScalarSelect
 
 class _DefaultColumnComparator(operators.ColumnOperators):
@@ -124,6 +124,14 @@ class _DefaultColumnComparator(operators.ColumnOperators):
 
         return BinaryExpression(left, right, op, type_=result_type)
 
+    def _conjunction_operate(self, expr, op, other, **kw):
+        if op is operators.and_:
+            return and_(expr, other)
+        elif op is operators.or_:
+            return or_(expr, other)
+        else:
+            raise NotImplementedError()
+
     def _scalar(self, expr, op, fn, **kw):
         return fn(expr)
 
@@ -190,6 +198,13 @@ class _DefaultColumnComparator(operators.ColumnOperators):
         raise NotImplementedError("Operator '%s' is not supported on "
                             "this expression" % op.__name__)
 
+    def _inv_impl(self, expr, op, **kw):
+        """See :meth:`.ColumnOperators.__inv__`."""
+        if hasattr(expr, 'negation_clause'):
+            return expr.negation_clause
+        else:
+            return expr._negate()
+
     def _neg_impl(self, expr, op, **kw):
         """See :meth:`.ColumnOperators.__neg__`."""
         return UnaryExpression(expr, operator=operators.neg)
@@ -226,6 +241,9 @@ class _DefaultColumnComparator(operators.ColumnOperators):
     # a mapping of operators with the method they use, along with
     # their negated operator for comparison operators
     operators = {
+        "and_": (_conjunction_operate,),
+        "or_": (_conjunction_operate,),
+        "inv": (_inv_impl,),
         "add": (_binary_operate,),
         "mul": (_binary_operate,),
         "sub": (_binary_operate,),
index 13cf9aa918eaac4f9d845f863c8c3b82c18c30c3..ce6056418fb1dd4ca2bfc674e574a386eda60c7e 100644 (file)
@@ -505,9 +505,21 @@ class ClauseElement(Visitable):
             return unicode(self.compile()).encode('ascii', 'backslashreplace')
 
     def __and__(self, other):
+        """'and' at the ClauseElement level.
+
+        .. deprecated:: 0.9.5 - conjunctions are intended to be
+           at the :class:`.ColumnElement`. level
+
+        """
         return and_(self, other)
 
     def __or__(self, other):
+        """'or' at the ClauseElement level.
+
+        .. deprecated:: 0.9.5 - conjunctions are intended to be
+           at the :class:`.ColumnElement`. level
+
+        """
         return or_(self, other)
 
     def __invert__(self):
@@ -516,17 +528,18 @@ class ClauseElement(Visitable):
         else:
             return self._negate()
 
-    def __bool__(self):
-        raise TypeError("Boolean value of this clause is not defined")
-
-    __nonzero__ = __bool__
-
     def _negate(self):
         return UnaryExpression(
                     self.self_group(against=operators.inv),
                     operator=operators.inv,
                     negate=None)
 
+    def __bool__(self):
+        raise TypeError("Boolean value of this clause is not defined")
+
+    __nonzero__ = __bool__
+
+
     def __repr__(self):
         friendly = getattr(self, 'description', None)
         if friendly is None:
@@ -536,8 +549,7 @@ class ClauseElement(Visitable):
                 self.__module__, self.__class__.__name__, id(self), friendly)
 
 
-
-class ColumnElement(ClauseElement, operators.ColumnOperators):
+class ColumnElement(operators.ColumnOperators, ClauseElement):
     """Represent a column-oriented SQL expression suitable for usage in the
     "columns" clause, WHERE clause etc. of a statement.
 
@@ -1503,6 +1515,8 @@ class TextClause(Executable, ClauseElement):
     def get_children(self, **kwargs):
         return list(self._bindparams.values())
 
+    def compare(self, other):
+        return isinstance(other, TextClause) and other.text == self.text
 
 class Null(ColumnElement):
     """Represent the NULL keyword in a SQL statement.
index 63475fcefd028bbeb98bbf7c31ca9b446f599053..1006dc33ab5dfa49d01e2c8c26f43df89855f7bb 100644 (file)
@@ -212,6 +212,11 @@ class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                                 operator=operators.custom_op("!!"),
                                 type_=MyInteger)
 
+                def __invert__(self):
+                    return UnaryExpression(self.expr,
+                                operator=operators.custom_op("!!!"),
+                                type_=MyInteger)
+
         return MyInteger
 
     def test_factorial(self):
@@ -235,6 +240,20 @@ class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             "!! somecol"
         )
 
+    def test_factorial_invert(self):
+        col = column('somecol', self._factorial_fixture())
+        self.assert_compile(
+            ~col,
+            "!!! somecol"
+        )
+
+    def test_double_factorial_invert(self):
+        col = column('somecol', self._factorial_fixture())
+        self.assert_compile(
+            ~(~col),
+            "!!! (!!! somecol)"
+        )
+
     def test_unary_no_ops(self):
         assert_raises_message(
             exc.CompileError,
@@ -263,6 +282,7 @@ class _CustomComparatorTests(object):
             )
         proxied = t.select().c.foo
         self._assert_add_override(proxied)
+        self._assert_and_override(proxied)
 
     def test_alias_proxy(self):
         t = Table('t', MetaData(),
@@ -270,22 +290,32 @@ class _CustomComparatorTests(object):
             )
         proxied = t.alias().c.foo
         self._assert_add_override(proxied)
+        self._assert_and_override(proxied)
 
     def test_binary_propagate(self):
         c1 = Column('foo', self._add_override_factory())
         self._assert_add_override(c1 - 6)
+        self._assert_and_override(c1 - 6)
 
     def test_reverse_binary_propagate(self):
         c1 = Column('foo', self._add_override_factory())
         self._assert_add_override(6 - c1)
+        self._assert_and_override(6 - c1)
 
     def test_binary_multi_propagate(self):
         c1 = Column('foo', self._add_override_factory())
         self._assert_add_override((c1 - 6) + 5)
+        self._assert_and_override((c1 - 6) + 5)
 
     def test_no_boolean_propagate(self):
         c1 = Column('foo', self._add_override_factory())
         self._assert_not_add_override(c1 == 56)
+        self._assert_not_and_override(c1 == 56)
+
+    def _assert_and_override(self, expr):
+        assert (expr & text("5")).compare(
+            expr.op("goofy_and")(text("5"))
+        )
 
     def _assert_add_override(self, expr):
         assert (expr + 5).compare(
@@ -297,6 +327,11 @@ class _CustomComparatorTests(object):
             expr.op("goofy")(5)
         )
 
+    def _assert_not_and_override(self, expr):
+        assert not (expr & text("5")).compare(
+            expr.op("goofy_and")(text("5"))
+        )
+
 class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase):
     def _add_override_factory(self):
 
@@ -308,6 +343,8 @@ class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase):
                 def __add__(self, other):
                     return self.expr.op("goofy")(other)
 
+                def __and__(self, other):
+                    return self.expr.op("goofy_and")(other)
 
         return MyInteger
 
@@ -325,6 +362,9 @@ class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase):
                 def __add__(self, other):
                     return self.expr.op("goofy")(other)
 
+                def __and__(self, other):
+                    return self.expr.op("goofy_and")(other)
+
 
         return MyInteger
 
@@ -339,6 +379,9 @@ class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBas
                 def __add__(self, other):
                     return self.expr.op("goofy")(other)
 
+                def __and__(self, other):
+                    return self.expr.op("goofy_and")(other)
+
 
         class MyDecInteger(TypeDecorator):
             impl = MyInteger
@@ -364,6 +407,12 @@ class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase):
     def _assert_not_add_override(self, expr):
         assert not hasattr(expr, "foob")
 
+    def _assert_and_override(self, expr):
+        pass
+
+    def _assert_not_and_override(self, expr):
+        pass
+
 class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = 'default'