]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Reimplement .compare() in terms of a visitor
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 17 Apr 2019 17:37:39 +0000 (13:37 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 Apr 2019 15:54:25 +0000 (11:54 -0400)
Reworked the :meth:`.ClauseElement.compare` methods in terms of a new
visitor-based approach, and additionally added test coverage ensuring that
all :class:`.ClauseElement` subclasses can be accurately compared
against each other in terms of structure.   Structural comparison
capability is used to a small degree within the ORM currently, however
it also may form the basis for new caching features.

Fixes: #4336
Change-Id: I581b667d8e1642a6c27165cc9f4aded1c66effc6

doc/build/changelog/unreleased_13/4336.rst [new file with mode: 0644]
lib/sqlalchemy/sql/clause_compare.py [new file with mode: 0644]
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/sql/selectable.py
test/ext/test_compiler.py
test/sql/test_compare.py [new file with mode: 0644]
test/sql/test_operators.py
test/sql/test_utils.py

diff --git a/doc/build/changelog/unreleased_13/4336.rst b/doc/build/changelog/unreleased_13/4336.rst
new file mode 100644 (file)
index 0000000..8a99435
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+   :tags: bug, sql
+   :tickets: 4336
+
+   Reworked the :meth:`.ClauseElement.compare` methods in terms of a new
+   visitor-based approach, and additionally added test coverage ensuring that
+   all :class:`.ClauseElement` subclasses can be accurately compared
+   against each other in terms of structure.   Structural comparison
+   capability is used to a small degree within the ORM currently, however
+   it also may form the basis for new caching features.
diff --git a/lib/sqlalchemy/sql/clause_compare.py b/lib/sqlalchemy/sql/clause_compare.py
new file mode 100644 (file)
index 0000000..87f9fb2
--- /dev/null
@@ -0,0 +1,331 @@
+from collections import deque
+
+from . import operators
+from .. import util
+
+
+SKIP_TRAVERSE = util.symbol("skip_traverse")
+
+
+def compare(obj1, obj2, **kw):
+    if kw.get("use_proxies", False):
+        strategy = ColIdentityComparatorStrategy()
+    else:
+        strategy = StructureComparatorStrategy()
+
+    return strategy.compare(obj1, obj2, **kw)
+
+
+class StructureComparatorStrategy(object):
+    __slots__ = "compare_stack", "cache"
+
+    def __init__(self):
+        self.compare_stack = deque()
+        self.cache = set()
+
+    def compare(self, obj1, obj2, **kw):
+        stack = self.compare_stack
+        cache = self.cache
+
+        stack.append((obj1, obj2))
+
+        while stack:
+            left, right = stack.popleft()
+
+            if left is right:
+                continue
+            elif left is None or right is None:
+                # we know they are different so no match
+                return False
+            elif (left, right) in cache:
+                continue
+            cache.add((left, right))
+
+            visit_name = left.__visit_name__
+
+            # we're not exactly looking for identical types, because
+            # there are things like Column and AnnotatedColumn.  So the
+            # visit_name has to at least match up
+            if visit_name != right.__visit_name__:
+                return False
+
+            meth = getattr(self, "compare_%s" % visit_name, None)
+
+            if meth:
+                comparison = meth(left, right, **kw)
+                if comparison is False:
+                    return False
+                elif comparison is SKIP_TRAVERSE:
+                    continue
+
+            for c1, c2 in util.zip_longest(
+                left.get_children(column_collections=False),
+                right.get_children(column_collections=False),
+                fillvalue=None,
+            ):
+                if c1 is None or c2 is None:
+                    # collections are different sizes, comparison fails
+                    return False
+                stack.append((c1, c2))
+
+        return True
+
+    def compare_inner(self, obj1, obj2, **kw):
+        stack = self.compare_stack
+        try:
+            self.compare_stack = deque()
+            return self.compare(obj1, obj2, **kw)
+        finally:
+            self.compare_stack = stack
+
+    def _compare_unordered_sequences(self, seq1, seq2, **kw):
+        if seq1 is None:
+            return seq2 is None
+
+        completed = set()
+        for clause in seq1:
+            for other_clause in set(seq2).difference(completed):
+                if self.compare_inner(clause, other_clause, **kw):
+                    completed.add(other_clause)
+                    break
+        return len(completed) == len(seq1) == len(seq2)
+
+    def compare_bindparam(self, left, right, **kw):
+        # note the ".key" is often generated from id(self) so can't
+        # be compared, as far as determining structure.
+        return (
+            left.type._compare_type_affinity(right.type)
+            and left.value == right.value
+            and left.callable == right.callable
+            and left._orig_key == right._orig_key
+        )
+
+    def compare_clauselist(self, left, right, **kw):
+        if left.operator is right.operator:
+            if operators.is_associative(left.operator):
+                if self._compare_unordered_sequences(
+                    left.clauses, right.clauses
+                ):
+                    return SKIP_TRAVERSE
+                else:
+                    return False
+            else:
+                # normal ordered traversal
+                return True
+        else:
+            return False
+
+    def compare_unary(self, left, right, **kw):
+        if left.operator:
+            disp = self._get_operator_dispatch(
+                left.operator, "unary", "operator"
+            )
+            if disp is not None:
+                result = disp(left, right, left.operator, **kw)
+                if result is not True:
+                    return result
+        elif left.modifier:
+            disp = self._get_operator_dispatch(
+                left.modifier, "unary", "modifier"
+            )
+            if disp is not None:
+                result = disp(left, right, left.operator, **kw)
+                if result is not True:
+                    return result
+        return (
+            left.operator == right.operator and left.modifier == right.modifier
+        )
+
+    def compare_binary(self, left, right, **kw):
+        disp = self._get_operator_dispatch(left.operator, "binary", None)
+        if disp:
+            result = disp(left, right, left.operator, **kw)
+            if result is not True:
+                return result
+
+        if left.operator == right.operator:
+            if operators.is_commutative(left.operator):
+                if (
+                    compare(left.left, right.left, **kw)
+                    and compare(left.right, right.right, **kw)
+                ) or (
+                    compare(left.left, right.right, **kw)
+                    and compare(left.right, right.left, **kw)
+                ):
+                    return SKIP_TRAVERSE
+                else:
+                    return False
+            else:
+                return True
+        else:
+            return False
+
+    def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
+        # used by compare_binary, compare_unary
+        attrname = "visit_%s_%s%s" % (
+            operator_.__name__,
+            qualifier1,
+            "_" + qualifier2 if qualifier2 else "",
+        )
+        return getattr(self, attrname, None)
+
+    def visit_function_as_comparison_op_binary(
+        self, left, right, operator, **kw
+    ):
+        return (
+            left.left_index == right.left_index
+            and left.right_index == right.right_index
+        )
+
+    def compare_function(self, left, right, **kw):
+        return left.name == right.name
+
+    def compare_column(self, left, right, **kw):
+        if left.table is not None:
+            self.compare_stack.appendleft((left.table, right.table))
+        return (
+            left.key == right.key
+            and left.name == right.name
+            and (
+                left.type._compare_type_affinity(right.type)
+                if left.type is not None
+                else right.type is None
+            )
+            and left.is_literal == right.is_literal
+        )
+
+    def compare_collation(self, left, right, **kw):
+        return left.collation == right.collation
+
+    def compare_type_coerce(self, left, right, **kw):
+        return left.type._compare_type_affinity(right.type)
+
+    @util.dependencies("sqlalchemy.sql.elements")
+    def compare_alias(self, elements, left, right, **kw):
+        return (
+            left.name == right.name
+            if not isinstance(left.name, elements._anonymous_label)
+            else isinstance(right.name, elements._anonymous_label)
+        )
+
+    def compare_extract(self, left, right, **kw):
+        return left.field == right.field
+
+    def compare_textual_label_reference(self, left, right, **kw):
+        return left.element == right.element
+
+    def compare_slice(self, left, right, **kw):
+        return (
+            left.start == right.start
+            and left.stop == right.stop
+            and left.step == right.step
+        )
+
+    def compare_over(self, left, right, **kw):
+        return left.range_ == right.range_ and left.rows == right.rows
+
+    @util.dependencies("sqlalchemy.sql.elements")
+    def compare_label(self, elements, left, right, **kw):
+        return left._type._compare_type_affinity(right._type) and (
+            left.name == right.name
+            if not isinstance(left, elements._anonymous_label)
+            else isinstance(right.name, elements._anonymous_label)
+        )
+
+    def compare_typeclause(self, left, right, **kw):
+        return left.type._compare_type_affinity(right.type)
+
+    def compare_join(self, left, right, **kw):
+        return left.isouter == right.isouter and left.full == right.full
+
+    def compare_table(self, left, right, **kw):
+        if left.name != right.name:
+            return False
+
+        self.compare_stack.extendleft(
+            util.zip_longest(left.columns, right.columns)
+        )
+
+    def compare_compound_select(self, left, right, **kw):
+
+        if not self._compare_unordered_sequences(
+            left.selects, right.selects, **kw
+        ):
+            return False
+
+        if left.keyword != right.keyword:
+            return False
+
+        if left._for_update_arg != right._for_update_arg:
+            return False
+
+        if not self.compare_inner(
+            left._order_by_clause, right._order_by_clause, **kw
+        ):
+            return False
+
+        if not self.compare_inner(
+            left._group_by_clause, right._group_by_clause, **kw
+        ):
+            return False
+
+        return SKIP_TRAVERSE
+
+    def compare_select(self, left, right, **kw):
+        if not self._compare_unordered_sequences(
+            left._correlate, right._correlate
+        ):
+            return False
+        if not self._compare_unordered_sequences(
+            left._correlate_except, right._correlate_except
+        ):
+            return False
+
+        if not self._compare_unordered_sequences(
+            left._from_obj, right._from_obj
+        ):
+            return False
+
+        if left._for_update_arg != right._for_update_arg:
+            return False
+
+        return True
+
+    def compare_text_as_from(self, left, right, **kw):
+        self.compare_stack.extendleft(
+            util.zip_longest(left.column_args, right.column_args)
+        )
+        return left.positional == right.positional
+
+
+class ColIdentityComparatorStrategy(StructureComparatorStrategy):
+    def compare_column_element(
+        self, left, right, use_proxies=True, equivalents=(), **kw
+    ):
+        """Compare ColumnElements using proxies and equivalent collections.
+
+        This is a comparison strategy specific to the ORM.
+        """
+
+        to_compare = (right,)
+        if equivalents and right in equivalents:
+            to_compare = equivalents[right].union(to_compare)
+
+        for oth in to_compare:
+            if use_proxies and left.shares_lineage(oth):
+                return True
+            elif hash(left) == hash(right):
+                return True
+        else:
+            return False
+
+    def compare_column(self, left, right, **kw):
+        return self.compare_column_element(left, right, **kw)
+
+    def compare_label(self, left, right, **kw):
+        return self.compare_column_element(left, right, **kw)
+
+    def compare_table(self, left, right, **kw):
+        # tables compare on identity, since it's not really feasible to
+        # compare them column by column with the above rules
+        return left is right
index 6c9b8ee5bcf4fc27b7e8e6ed00a2d9b28236b109..552f61b4a068d29f7c8cbb3a0745e731e6ba426d 100644 (file)
@@ -482,6 +482,12 @@ class _multiparam_column(elements.ColumnElement):
         self.default = original.default
         self.type = original.type
 
+    def compare(self, other, **kw):
+        raise NotImplementedError()
+
+    def _copy_internals(self, other, **kw):
+        raise NotImplementedError()
+
     def __eq__(self, other):
         return (
             isinstance(other, _multiparam_column)
index b0d0feff5d9b5a29fb56d85b3e07ef7772c42111..38c7cf840e0029025657f33ace1d524caa6d6c95 100644 (file)
@@ -17,6 +17,7 @@ import numbers
 import operator
 import re
 
+from . import clause_compare
 from . import operators
 from . import type_api
 from .annotation import Annotated
@@ -341,7 +342,7 @@ class ClauseElement(Visitable):
         (see :class:`.ColumnElement`)
 
         """
-        return self is other
+        return clause_compare.compare(self, other, **kw)
 
     def _copy_internals(self, clone=_clone, **kw):
         """Reassign internal elements to be clones of themselves.
@@ -810,34 +811,6 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
         selectable._columns[key] = co
         return co
 
-    def compare(self, other, use_proxies=False, equivalents=None, **kw):
-        """Compare this ColumnElement to another.
-
-        Special arguments understood:
-
-        :param use_proxies: when True, consider two columns that
-          share a common base column as equivalent (i.e. shares_lineage())
-
-        :param equivalents: a dictionary of columns as keys mapped to sets
-          of columns. If the given "other" column is present in this
-          dictionary, if any of the columns in the corresponding set() pass
-          the comparison test, the result is True. This is used to expand the
-          comparison to other columns that may be known to be equivalent to
-          this one via foreign key or other criterion.
-
-        """
-        to_compare = (other,)
-        if equivalents and other in equivalents:
-            to_compare = equivalents[other].union(to_compare)
-
-        for oth in to_compare:
-            if use_proxies and self.shares_lineage(oth):
-                return True
-            elif hash(oth) == hash(self):
-                return True
-        else:
-            return False
-
     def cast(self, type_):
         """Produce a type cast, i.e. ``CAST(<expression> AS <type>)``.
 
@@ -1226,17 +1199,6 @@ class BindParameter(ColumnElement):
                 "%%(%d %s)s" % (id(self), self._orig_key or "param")
             )
 
-    def compare(self, other, **kw):
-        """Compare this :class:`BindParameter` to the given
-        clause."""
-
-        return (
-            isinstance(other, BindParameter)
-            and self.type._compare_type_affinity(other.type)
-            and self.value == other.value
-            and self.callable == other.callable
-        )
-
     def __getstate__(self):
         """execute a deferred value for serialization purposes."""
 
@@ -1696,9 +1658,6 @@ class TextClause(Executable, ClauseElement):
     def get_children(self, **kwargs):
         return list(self._bindparams.values())
 
-    def compare(self, other):
-        return isinstance(other, TextClause) and other.text == self.text
-
 
 class Null(ColumnElement):
     """Represent the NULL keyword in a SQL statement.
@@ -1720,9 +1679,6 @@ class Null(ColumnElement):
 
         return Null()
 
-    def compare(self, other):
-        return isinstance(other, Null)
-
 
 class False_(ColumnElement):
     """Represent the ``false`` keyword, or equivalent, in a SQL statement.
@@ -1779,9 +1735,6 @@ class False_(ColumnElement):
 
         return False_()
 
-    def compare(self, other):
-        return isinstance(other, False_)
-
 
 class True_(ColumnElement):
     """Represent the ``true`` keyword, or equivalent, in a SQL statement.
@@ -1845,9 +1798,6 @@ class True_(ColumnElement):
 
         return True_()
 
-    def compare(self, other):
-        return isinstance(other, True_)
-
 
 class ClauseList(ClauseElement):
     """Describe a list of clauses, separated by an operator.
@@ -1908,38 +1858,6 @@ class ClauseList(ClauseElement):
         else:
             return self
 
-    def compare(self, other, **kw):
-        """Compare this :class:`.ClauseList` to the given :class:`.ClauseList`,
-        including a comparison of all the clause items.
-
-        """
-        if not isinstance(other, ClauseList) and len(self.clauses) == 1:
-            return self.clauses[0].compare(other, **kw)
-        elif (
-            isinstance(other, ClauseList)
-            and len(self.clauses) == len(other.clauses)
-            and self.operator is other.operator
-        ):
-
-            if self.operator in (operators.and_, operators.or_):
-                completed = set()
-                for clause in self.clauses:
-                    for other_clause in set(other.clauses).difference(
-                        completed
-                    ):
-                        if clause.compare(other_clause, **kw):
-                            completed.add(other_clause)
-                            break
-                return len(completed) == len(other.clauses)
-            else:
-                for i in range(0, len(self.clauses)):
-                    if not self.clauses[i].compare(other.clauses[i], **kw):
-                        return False
-                else:
-                    return True
-        else:
-            return False
-
 
 class BooleanClauseList(ClauseList, ColumnElement):
     __visit_name__ = "clauselist"
@@ -2606,6 +2524,9 @@ class _label_reference(ColumnElement):
     def _copy_internals(self, clone=_clone, **kw):
         self.element = clone(self.element, **kw)
 
+    def get_children(self, **kwargs):
+        return [self.element]
+
     @property
     def _from_objects(self):
         return ()
@@ -2885,17 +2806,6 @@ class UnaryExpression(ColumnElement):
     def get_children(self, **kwargs):
         return (self.element,)
 
-    def compare(self, other, **kw):
-        """Compare this :class:`UnaryExpression` against the given
-        :class:`.ClauseElement`."""
-
-        return (
-            isinstance(other, UnaryExpression)
-            and self.operator == other.operator
-            and self.modifier == other.modifier
-            and self.element.compare(other.element, **kw)
-        )
-
     def _negate(self):
         if self.negate is not None:
             return UnaryExpression(
@@ -3103,24 +3013,6 @@ class BinaryExpression(ColumnElement):
     def get_children(self, **kwargs):
         return self.left, self.right
 
-    def compare(self, other, **kw):
-        """Compare this :class:`BinaryExpression` against the
-        given :class:`BinaryExpression`."""
-
-        return (
-            isinstance(other, BinaryExpression)
-            and self.operator == other.operator
-            and (
-                self.left.compare(other.left, **kw)
-                and self.right.compare(other.right, **kw)
-                or (
-                    operators.is_commutative(self.operator)
-                    and self.left.compare(other.right, **kw)
-                    and self.right.compare(other.left, **kw)
-                )
-            )
-        )
-
     def self_group(self, against=None):
         if operators.is_precedent(self.operator, against):
             return Grouping(self)
@@ -3213,11 +3105,6 @@ class Grouping(ColumnElement):
         self.element = state["element"]
         self.type = state["type"]
 
-    def compare(self, other, **kw):
-        return isinstance(other, Grouping) and self.element.compare(
-            other.element
-        )
-
 
 RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED")
 RANGE_CURRENT = util.symbol("RANGE_CURRENT")
index fcc843d9195343c02d1c45fe103e0a88132b0650..f48a20ec7f5498428d71db665b88b93d7d6c04fc 100644 (file)
@@ -74,7 +74,9 @@ class FunctionElement(Executable, ColumnElement, FromClause):
     def __init__(self, *clauses, **kwargs):
         """Construct a :class:`.FunctionElement`.
         """
-        args = [_literal_as_binds(c, self.name) for c in clauses]
+        args = [
+            _literal_as_binds(c, getattr(self, "name", None)) for c in clauses
+        ]
         self._has_args = self._has_args or bool(args)
         self.clause_expr = ClauseList(
             operator=operators.comma_op, group_contents=True, *args
@@ -376,12 +378,11 @@ class FunctionAsBinary(BinaryExpression):
         self.left_index = left_index
         self.right_index = right_index
 
-        super(FunctionAsBinary, self).__init__(
-            left,
-            right,
-            operators.function_as_comparison_op,
-            type_=sqltypes.BOOLEANTYPE,
-        )
+        self.operator = operators.function_as_comparison_op
+        self.type = sqltypes.BOOLEANTYPE
+        self.negate = None
+        self._is_implicitly_boolean = True
+        self.modifiers = {}
 
     @property
     def left(self):
@@ -399,10 +400,11 @@ class FunctionAsBinary(BinaryExpression):
     def right(self, value):
         self.sql_function.clauses.clauses[self.right_index - 1] = value
 
-    def _copy_internals(self, **kw):
-        clone = kw.pop("clone")
+    def _copy_internals(self, clone=_clone, **kw):
         self.sql_function = clone(self.sql_function, **kw)
-        super(FunctionAsBinary, self)._copy_internals(**kw)
+
+    def get_children(self, **kw):
+        yield self.sql_function
 
 
 class _FunctionGenerator(object):
@@ -682,6 +684,18 @@ class next_value(GenericFunction):
         self._bind = kw.get("bind", None)
         self.sequence = seq
 
+    def compare(self, other, **kw):
+        return (
+            isinstance(other, next_value)
+            and self.sequence.name == other.sequence.name
+        )
+
+    def get_children(self, **kwargs):
+        return []
+
+    def _copy_internals(self, **kw):
+        pass
+
     @property
     def _from_objects(self):
         return []
index 4206de4603e413b79ad753721b99a98e338a06b8..8479c1d5943f27f77d9caec3f61f186f13823c9d 100644 (file)
@@ -1414,6 +1414,11 @@ def mirror(op):
 
 _associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne])
 
+
+def is_associative(op):
+    return op in _associative
+
+
 _natural_self_precedent = _associative.union(
     [getitem, json_getitem_op, json_path_getitem_op]
 )
index d4528f0c317c85e1911d344f26d7cc7fb2d43e47..796e2b2720658aab3222c6752d62322df68dfd08 100644 (file)
@@ -1994,6 +1994,9 @@ class ForUpdateArg(ClauseElement):
             and other.of is self.of
         )
 
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
     def __hash__(self):
         return id(self)
 
@@ -3941,6 +3944,12 @@ class TextAsFrom(SelectBase):
         self._reset_exported()
         self.element = clone(self.element, **kw)
 
+    def get_children(self, column_collections=True, **kw):
+        if column_collections:
+            for c in self.column_args:
+                yield c
+        yield self.element
+
     def _scalar_type(self):
         return self.column_args[0].type
 
index ac6af0a96daa16fbb08a1ddb88ae44e1ba827d5a..ccd79f8d11088a14eb3fac9994c8fd48424904d4 100644 (file)
@@ -274,6 +274,18 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=postgresql.dialect(),
         )
 
+    def test_functions_args_noname(self):
+        class myfunc(FunctionElement):
+            pass
+
+        @compiles(myfunc)
+        def visit_myfunc(element, compiler, **kw):
+            return "myfunc%s" % (compiler.process(element.clause_expr, **kw),)
+
+        self.assert_compile(myfunc(), "myfunc()")
+
+        self.assert_compile(myfunc(column("x"), column("y")), "myfunc(x, y)")
+
     def test_function_calls_base(self):
         from sqlalchemy.dialects import mssql
 
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
new file mode 100644 (file)
index 0000000..8e62d5d
--- /dev/null
@@ -0,0 +1,504 @@
+import importlib
+import itertools
+
+from sqlalchemy import and_
+from sqlalchemy import Boolean
+from sqlalchemy import case
+from sqlalchemy import cast
+from sqlalchemy import Column
+from sqlalchemy import column
+from sqlalchemy import dialects
+from sqlalchemy import exists
+from sqlalchemy import extract
+from sqlalchemy import Float
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import or_
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import table
+from sqlalchemy import text
+from sqlalchemy import tuple_
+from sqlalchemy import union
+from sqlalchemy import union_all
+from sqlalchemy import util
+from sqlalchemy.schema import Sequence
+from sqlalchemy.sql import bindparam
+from sqlalchemy.sql import ColumnElement
+from sqlalchemy.sql import False_
+from sqlalchemy.sql import func
+from sqlalchemy.sql import operators
+from sqlalchemy.sql import True_
+from sqlalchemy.sql import type_coerce
+from sqlalchemy.sql.elements import _label_reference
+from sqlalchemy.sql.elements import _textual_label_reference
+from sqlalchemy.sql.elements import Annotated
+from sqlalchemy.sql.elements import ClauseElement
+from sqlalchemy.sql.elements import ClauseList
+from sqlalchemy.sql.elements import CollationClause
+from sqlalchemy.sql.elements import Immutable
+from sqlalchemy.sql.elements import Null
+from sqlalchemy.sql.elements import Slice
+from sqlalchemy.sql.elements import UnaryExpression
+from sqlalchemy.sql.functions import FunctionElement
+from sqlalchemy.sql.functions import GenericFunction
+from sqlalchemy.sql.functions import ReturnTypeFromArgs
+from sqlalchemy.sql.selectable import _OffsetLimitParam
+from sqlalchemy.sql.selectable import FromGrouping
+from sqlalchemy.sql.selectable import Selectable
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
+from sqlalchemy.util import class_hierarchy
+
+
+meta = MetaData()
+meta2 = MetaData()
+
+table_a = Table("a", meta, Column("a", Integer), Column("b", String))
+table_a_2 = Table("a", meta2, Column("a", Integer), Column("b", String))
+
+table_b = Table("b", meta, Column("a", Integer), Column("b", Integer))
+
+table_c = Table("c", meta, Column("x", Integer), Column("y", Integer))
+
+table_d = Table("d", meta, Column("y", Integer), Column("z", Integer))
+
+
+class CompareAndCopyTest(fixtures.TestBase):
+
+    # lambdas which return a tuple of ColumnElement objects.
+    # must return at least two objects that should compare differently.
+    # to test more varieties of "difference" additional objects can be added.
+    fixtures = [
+        lambda: (
+            column("q"),
+            column("x"),
+            column("q", Integer),
+            column("q", String),
+        ),
+        lambda: (~column("q", Boolean), ~column("p", Boolean)),
+        lambda: (
+            table_a.c.a.label("foo"),
+            table_a.c.a.label("bar"),
+            table_a.c.b.label("foo"),
+        ),
+        lambda: (
+            _label_reference(table_a.c.a.desc()),
+            _label_reference(table_a.c.a.asc()),
+        ),
+        lambda: (_textual_label_reference("a"), _textual_label_reference("b")),
+        lambda: (
+            text("select a, b from table").columns(a=Integer, b=String),
+            text("select a, b, c from table").columns(
+                a=Integer, b=String, c=Integer
+            ),
+        ),
+        lambda: (
+            column("q") == column("x"),
+            column("q") == column("y"),
+            column("z") == column("x"),
+        ),
+        lambda: (
+            cast(column("q"), Integer),
+            cast(column("q"), Float),
+            cast(column("p"), Integer),
+        ),
+        lambda: (
+            bindparam("x"),
+            bindparam("y"),
+            bindparam("x", type_=Integer),
+            bindparam("x", type_=String),
+            bindparam(None),
+        ),
+        lambda: (_OffsetLimitParam("x"), _OffsetLimitParam("y")),
+        lambda: (func.foo(), func.foo(5), func.bar()),
+        lambda: (func.current_date(), func.current_time()),
+        lambda: (
+            func.next_value(Sequence("q")),
+            func.next_value(Sequence("p")),
+        ),
+        lambda: (True_(), False_()),
+        lambda: (Null(),),
+        lambda: (ReturnTypeFromArgs("foo"), ReturnTypeFromArgs(5)),
+        lambda: (FunctionElement(5), FunctionElement(5, 6)),
+        lambda: (func.count(), func.not_count()),
+        lambda: (func.char_length("abc"), func.char_length("def")),
+        lambda: (GenericFunction("a", "b"), GenericFunction("a")),
+        lambda: (CollationClause("foobar"), CollationClause("batbar")),
+        lambda: (
+            type_coerce(column("q", Integer), String),
+            type_coerce(column("q", Integer), Float),
+            type_coerce(column("z", Integer), Float),
+        ),
+        lambda: (table_a.c.a, table_b.c.a),
+        lambda: (tuple_([1, 2]), tuple_([3, 4])),
+        lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])),
+        lambda: (
+            func.percentile_cont(0.5).within_group(table_a.c.a),
+            func.percentile_cont(0.5).within_group(table_a.c.b),
+            func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b),
+            func.percentile_cont(0.5).within_group(
+                table_a.c.a, table_a.c.b, column("q")
+            ),
+        ),
+        lambda: (
+            func.is_equal("a", "b").as_comparison(1, 2),
+            func.is_equal("a", "c").as_comparison(1, 2),
+            func.is_equal("a", "b").as_comparison(2, 1),
+            func.is_equal("a", "b", "c").as_comparison(1, 2),
+            func.foobar("a", "b").as_comparison(1, 2),
+        ),
+        lambda: (
+            func.row_number().over(order_by=table_a.c.a),
+            func.row_number().over(order_by=table_a.c.a, range_=(0, 10)),
+            func.row_number().over(order_by=table_a.c.a, range_=(None, 10)),
+            func.row_number().over(order_by=table_a.c.a, rows=(None, 20)),
+            func.row_number().over(order_by=table_a.c.b),
+            func.row_number().over(
+                order_by=table_a.c.a, partition_by=table_a.c.b
+            ),
+        ),
+        lambda: (
+            func.count(1).filter(table_a.c.a == 5),
+            func.count(1).filter(table_a.c.a == 10),
+            func.foob(1).filter(table_a.c.a == 10),
+        ),
+        lambda: (
+            and_(table_a.c.a == 5, table_a.c.b == table_b.c.a),
+            and_(table_a.c.a == 5, table_a.c.a == table_b.c.a),
+            or_(table_a.c.a == 5, table_a.c.b == table_b.c.a),
+            ClauseList(table_a.c.a == 5, table_a.c.b == table_b.c.a),
+            ClauseList(table_a.c.a == 5, table_a.c.b == table_a.c.a),
+        ),
+        lambda: (
+            case(whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)]),
+            case(whens=[(table_a.c.a == 18, 10), (table_a.c.a == 10, 20)]),
+            case(whens=[(table_a.c.a == 5, 10), (table_a.c.b == 10, 20)]),
+            case(
+                whens=[
+                    (table_a.c.a == 5, 10),
+                    (table_a.c.b == 10, 20),
+                    (table_a.c.a == 9, 12),
+                ]
+            ),
+            case(
+                whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)],
+                else_=30,
+            ),
+            case({"wendy": "W", "jack": "J"}, value=table_a.c.a, else_="E"),
+            case({"wendy": "W", "jack": "J"}, value=table_a.c.b, else_="E"),
+            case({"wendy_w": "W", "jack": "J"}, value=table_a.c.a, else_="E"),
+        ),
+        lambda: (
+            extract("foo", table_a.c.a),
+            extract("foo", table_a.c.b),
+            extract("bar", table_a.c.a),
+        ),
+        lambda: (
+            Slice(1, 2, 5),
+            Slice(1, 5, 5),
+            Slice(1, 5, 10),
+            Slice(2, 10, 15),
+        ),
+        lambda: (
+            select([table_a.c.a]),
+            select([table_a.c.a, table_a.c.b]),
+            select([table_a.c.b, table_a.c.a]),
+            select([table_a.c.a]).where(table_a.c.b == 5),
+            select([table_a.c.a])
+            .where(table_a.c.b == 5)
+            .where(table_a.c.a == 10),
+            select([table_a.c.a]).where(table_a.c.b == 5).with_for_update(),
+            select([table_a.c.a])
+            .where(table_a.c.b == 5)
+            .with_for_update(nowait=True),
+            select([table_a.c.a]).where(table_a.c.b == 5).correlate(table_b),
+            select([table_a.c.a])
+            .where(table_a.c.b == 5)
+            .correlate_except(table_b),
+        ),
+        lambda: (
+            table_a.join(table_b, table_a.c.a == table_b.c.a),
+            table_a.join(
+                table_b, and_(table_a.c.a == table_b.c.a, table_a.c.b == 1)
+            ),
+            table_a.outerjoin(table_b, table_a.c.a == table_b.c.a),
+        ),
+        lambda: (
+            table_a.alias("a"),
+            table_a.alias("b"),
+            table_a.alias(),
+            table_b.alias("a"),
+            select([table_a.c.a]).alias("a"),
+        ),
+        lambda: (
+            FromGrouping(table_a.alias("a")),
+            FromGrouping(table_a.alias("b")),
+        ),
+        lambda: (
+            select([table_a.c.a]).as_scalar(),
+            select([table_a.c.a]).where(table_a.c.b == 5).as_scalar(),
+        ),
+        lambda: (
+            exists().where(table_a.c.a == 5),
+            exists().where(table_a.c.b == 5),
+        ),
+        lambda: (
+            union(select([table_a.c.a]), select([table_a.c.b])),
+            union(select([table_a.c.a]), select([table_a.c.b])).order_by("a"),
+            union_all(select([table_a.c.a]), select([table_a.c.b])),
+            union(select([table_a.c.a])),
+            union(
+                select([table_a.c.a]),
+                select([table_a.c.b]).where(table_a.c.b > 5),
+            ),
+        ),
+        lambda: (
+            table("a", column("x"), column("y")),
+            table("a", column("y"), column("x")),
+            table("b", column("x"), column("y")),
+            table("a", column("x"), column("y"), column("z")),
+            table("a", column("x"), column("y", Integer)),
+            table("a", column("q"), column("y", Integer)),
+        ),
+        lambda: (
+            Table("a", MetaData(), Column("q", Integer), Column("b", String)),
+            Table("b", MetaData(), Column("q", Integer), Column("b", String)),
+        ),
+    ]
+
+    @classmethod
+    def setup_class(cls):
+        # TODO: we need to get dialects here somehow, perhaps in test_suite?
+        [
+            importlib.import_module("sqlalchemy.dialects.%s" % d)
+            for d in dialects.__all__
+            if not d.startswith("_")
+        ]
+
+    def test_all_present(self):
+        need = set(
+            cls
+            for cls in class_hierarchy(ClauseElement)
+            if issubclass(cls, (ColumnElement, Selectable))
+            and "__init__" in cls.__dict__
+            and not issubclass(cls, (Annotated))
+            and "orm" not in cls.__module__
+            and "crud" not in cls.__module__
+            and "dialects" not in cls.__module__  # TODO: dialects?
+        ).difference({ColumnElement, UnaryExpression})
+        for fixture in self.fixtures:
+            case_a = fixture()
+            for elem in case_a:
+                for mro in type(elem).__mro__:
+                    need.discard(mro)
+
+        is_false(bool(need), "%d Remaining classes: %r" % (len(need), need))
+
+    def test_compare(self):
+        for fixture in self.fixtures:
+            case_a = fixture()
+            case_b = fixture()
+
+            for a, b in itertools.combinations_with_replacement(
+                range(len(case_a)), 2
+            ):
+                if a == b:
+                    is_true(
+                        case_a[a].compare(
+                            case_b[b], arbitrary_expression=True
+                        ),
+                        "%r != %r" % (case_a[a], case_b[b]),
+                    )
+
+                else:
+                    is_false(
+                        case_a[a].compare(
+                            case_b[b], arbitrary_expression=True
+                        ),
+                        "%r == %r" % (case_a[a], case_b[b]),
+                    )
+
+    def test_compare_col_identity(self):
+        stmt1 = (
+            select([table_a.c.a, table_b.c.b])
+            .where(table_a.c.a == table_b.c.b)
+            .alias()
+        )
+        stmt1_c = (
+            select([table_a.c.a, table_b.c.b])
+            .where(table_a.c.a == table_b.c.b)
+            .alias()
+        )
+
+        stmt2 = union(select([table_a]), select([table_b]))
+
+        stmt3 = select([table_b])
+
+        equivalents = {table_a.c.a: [table_b.c.a]}
+
+        is_false(
+            stmt1.compare(stmt2, use_proxies=True, equivalents=equivalents)
+        )
+
+        is_true(
+            stmt1.compare(stmt1_c, use_proxies=True, equivalents=equivalents)
+        )
+        is_true(
+            (table_a.c.a == table_b.c.b).compare(
+                stmt1.c.a == stmt1.c.b,
+                use_proxies=True,
+                equivalents=equivalents,
+            )
+        )
+
+    def test_copy_internals(self):
+        for fixture in self.fixtures:
+            case_a = fixture()
+            case_b = fixture()
+
+            assert case_a[0].compare(case_b[0])
+
+            clone = case_a[0]._clone()
+            clone._copy_internals()
+
+            assert clone.compare(case_b[0])
+
+            stack = [clone]
+            seen = {clone}
+            found_elements = False
+            while stack:
+                obj = stack.pop(0)
+
+                items = [
+                    subelem
+                    for key, elem in clone.__dict__.items()
+                    if key != "_is_clone_of" and elem is not None
+                    for subelem in util.to_list(elem)
+                    if (
+                        isinstance(subelem, (ColumnElement, ClauseList))
+                        and subelem not in seen
+                        and not isinstance(subelem, Immutable)
+                        and subelem is not case_a[0]
+                    )
+                ]
+                stack.extend(items)
+                seen.update(items)
+
+                if obj is not clone:
+                    found_elements = True
+                    # ensure the element will not compare as true
+                    obj.compare = lambda other, **kw: False
+                    obj.__visit_name__ = "dont_match"
+
+            if found_elements:
+                assert not clone.compare(case_b[0])
+            assert case_a[0].compare(case_b[0])
+
+
+class CompareClausesTest(fixtures.TestBase):
+    def test_compare_comparison_associative(self):
+
+        l1 = table_c.c.x == table_d.c.y
+        l2 = table_d.c.y == table_c.c.x
+        l3 = table_c.c.x == table_d.c.z
+
+        is_true(l1.compare(l1))
+        is_true(l1.compare(l2))
+        is_false(l1.compare(l3))
+
+    def test_compare_clauselist_associative(self):
+
+        l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z)
+
+        l2 = and_(table_c.c.y == table_d.c.z, table_c.c.x == table_d.c.y)
+
+        l3 = and_(table_c.c.x == table_d.c.z, table_c.c.y == table_d.c.y)
+
+        is_true(l1.compare(l1))
+        is_true(l1.compare(l2))
+        is_false(l1.compare(l3))
+
+    def test_compare_clauselist_not_associative(self):
+
+        l1 = ClauseList(
+            table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.sub
+        )
+
+        l2 = ClauseList(
+            table_d.c.y, table_c.c.x, table_c.c.y, operator=operators.sub
+        )
+
+        is_true(l1.compare(l1))
+        is_false(l1.compare(l2))
+
+    def test_compare_clauselist_assoc_different_operator(self):
+
+        l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z)
+
+        l2 = or_(table_c.c.y == table_d.c.z, table_c.c.x == table_d.c.y)
+
+        is_false(l1.compare(l2))
+
+    def test_compare_clauselist_not_assoc_different_operator(self):
+
+        l1 = ClauseList(
+            table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.sub
+        )
+
+        l2 = ClauseList(
+            table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.div
+        )
+
+        is_false(l1.compare(l2))
+
+    def test_compare_binds(self):
+        b1 = bindparam("foo", type_=Integer())
+        b2 = bindparam("foo", type_=Integer())
+        b3 = bindparam("bar", type_=Integer())
+        b4 = bindparam("foo", type_=String())
+
+        def c1():
+            return 5
+
+        def c2():
+            return 6
+
+        b5 = bindparam("foo", type_=Integer(), callable_=c1)
+        b6 = bindparam("foo", type_=Integer(), callable_=c2)
+        b7 = bindparam("foo", type_=Integer(), callable_=c1)
+
+        b8 = bindparam("foo", type_=Integer, value=5)
+        b9 = bindparam("foo", type_=Integer, value=6)
+
+        is_false(b1.compare(b5))
+        is_true(b5.compare(b7))
+        is_false(b5.compare(b6))
+        is_true(b1.compare(b2))
+
+        # currently not comparing "key", as we often have to compare
+        # anonymous names.  however we should really check for that
+        # is_true(b1.compare(b3))
+
+        is_false(b1.compare(b4))
+        is_false(b1.compare(b8))
+        is_false(b8.compare(b9))
+        is_true(b8.compare(b8))
+
+    def test_compare_tables(self):
+        is_true(table_a.compare(table_a_2))
+
+        # the "proxy" version compares schema tables on metadata identity
+        is_false(table_a.compare(table_a_2, use_proxies=True))
+
+        # same for lower case tables since it compares lower case columns
+        # using proxies, which makes it very unlikely to have multiple
+        # table() objects with columns that compare equally
+        is_false(
+            table("a", column("x", Integer), column("q", String)).compare(
+                table("a", column("x", Integer), column("q", String)),
+                use_proxies=True,
+            )
+        )
index 82c69003bd346b98d8e7386ced159b5d88da5950..c6eff6ac93c92d4e484ac13c5d6fb53324ff6c1b 100644 (file)
@@ -38,6 +38,7 @@ from sqlalchemy.sql import sqltypes
 from sqlalchemy.sql import table
 from sqlalchemy.sql import true
 from sqlalchemy.sql.elements import _literal_as_text
+from sqlalchemy.sql.elements import BindParameter
 from sqlalchemy.sql.elements import Label
 from sqlalchemy.sql.expression import BinaryExpression
 from sqlalchemy.sql.expression import ClauseList
@@ -193,7 +194,13 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
         assert left.comparator.operate(operators.in_op, [1, 2, 3]).compare(
             BinaryExpression(
                 left,
-                Grouping(ClauseList(literal(1), literal(2), literal(3))),
+                Grouping(
+                    ClauseList(
+                        BindParameter("left", value=1, unique=True),
+                        BindParameter("left", value=2, unique=True),
+                        BindParameter("left", value=3, unique=True),
+                    )
+                ),
                 operators.in_op,
             )
         )
@@ -204,7 +211,13 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
         assert left.comparator.operate(operators.notin_op, [1, 2, 3]).compare(
             BinaryExpression(
                 left,
-                Grouping(ClauseList(literal(1), literal(2), literal(3))),
+                Grouping(
+                    ClauseList(
+                        BindParameter("left", value=1, unique=True),
+                        BindParameter("left", value=2, unique=True),
+                        BindParameter("left", value=3, unique=True),
+                    )
+                ),
                 operators.notin_op,
             )
         )
index 023c483fcf2c89de1cdb76ecf713343769bfe3d9..988d5331eb6ad6a5c3b23ed5773916b0d4af7b9e 100644 (file)
@@ -1,105 +1,7 @@
-from sqlalchemy import and_
-from sqlalchemy import bindparam
-from sqlalchemy import Column
-from sqlalchemy import Integer
-from sqlalchemy import MetaData
-from sqlalchemy import or_
-from sqlalchemy import String
-from sqlalchemy import Table
-from sqlalchemy.sql import operators
 from sqlalchemy.sql import util as sql_util
-from sqlalchemy.sql.elements import ClauseList
 from sqlalchemy.sql.elements import ColumnElement
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
-from sqlalchemy.testing import is_false
-from sqlalchemy.testing import is_true
-
-
-class CompareClausesTest(fixtures.TestBase):
-    def setup(self):
-        m = MetaData()
-        self.a = Table("a", m, Column("x", Integer), Column("y", Integer))
-
-        self.b = Table("b", m, Column("y", Integer), Column("z", Integer))
-
-    def test_compare_clauselist_associative(self):
-
-        l1 = and_(self.a.c.x == self.b.c.y, self.a.c.y == self.b.c.z)
-
-        l2 = and_(self.a.c.y == self.b.c.z, self.a.c.x == self.b.c.y)
-
-        l3 = and_(self.a.c.x == self.b.c.z, self.a.c.y == self.b.c.y)
-
-        is_true(l1.compare(l1))
-        is_true(l1.compare(l2))
-        is_false(l1.compare(l3))
-
-    def test_compare_clauselist_not_associative(self):
-
-        l1 = ClauseList(
-            self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.sub
-        )
-
-        l2 = ClauseList(
-            self.b.c.y, self.a.c.x, self.a.c.y, operator=operators.sub
-        )
-
-        is_true(l1.compare(l1))
-        is_false(l1.compare(l2))
-
-    def test_compare_clauselist_assoc_different_operator(self):
-
-        l1 = and_(self.a.c.x == self.b.c.y, self.a.c.y == self.b.c.z)
-
-        l2 = or_(self.a.c.y == self.b.c.z, self.a.c.x == self.b.c.y)
-
-        is_false(l1.compare(l2))
-
-    def test_compare_clauselist_not_assoc_different_operator(self):
-
-        l1 = ClauseList(
-            self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.sub
-        )
-
-        l2 = ClauseList(
-            self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.div
-        )
-
-        is_false(l1.compare(l2))
-
-    def test_compare_binds(self):
-        b1 = bindparam("foo", type_=Integer())
-        b2 = bindparam("foo", type_=Integer())
-        b3 = bindparam("bar", type_=Integer())
-        b4 = bindparam("foo", type_=String())
-
-        def c1():
-            return 5
-
-        def c2():
-            return 6
-
-        b5 = bindparam("foo", type_=Integer(), callable_=c1)
-        b6 = bindparam("foo", type_=Integer(), callable_=c2)
-        b7 = bindparam("foo", type_=Integer(), callable_=c1)
-
-        b8 = bindparam("foo", type_=Integer, value=5)
-        b9 = bindparam("foo", type_=Integer, value=6)
-
-        is_false(b1.compare(b5))
-        is_true(b5.compare(b7))
-        is_false(b5.compare(b6))
-        is_true(b1.compare(b2))
-
-        # currently not comparing "key", as we often have to compare
-        # anonymous names.  however we should really check for that
-        is_true(b1.compare(b3))
-
-        is_false(b1.compare(b4))
-        is_false(b1.compare(b8))
-        is_false(b8.compare(b9))
-        is_true(b8.compare(b8))
 
 
 class MiscTest(fixtures.TestBase):