From: Mike Bayer Date: Wed, 17 Apr 2019 17:37:39 +0000 (-0400) Subject: Reimplement .compare() in terms of a visitor X-Git-Tag: rel_1_4_0b1~893^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=099522075088a3e1a333a2285c10a8a33b203c19;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Reimplement .compare() in terms of a visitor Reworked the :meth:`.ClauseElement.compare` methods in terms of a new visitor-based approach, and additionally added test coverage ensuring that all :class:`.ClauseElement` subclasses can be accurately compared against each other in terms of structure. Structural comparison capability is used to a small degree within the ORM currently, however it also may form the basis for new caching features. Fixes: #4336 Change-Id: I581b667d8e1642a6c27165cc9f4aded1c66effc6 --- diff --git a/doc/build/changelog/unreleased_13/4336.rst b/doc/build/changelog/unreleased_13/4336.rst new file mode 100644 index 0000000000..8a994357f5 --- /dev/null +++ b/doc/build/changelog/unreleased_13/4336.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, sql + :tickets: 4336 + + Reworked the :meth:`.ClauseElement.compare` methods in terms of a new + visitor-based approach, and additionally added test coverage ensuring that + all :class:`.ClauseElement` subclasses can be accurately compared + against each other in terms of structure. Structural comparison + capability is used to a small degree within the ORM currently, however + it also may form the basis for new caching features. diff --git a/lib/sqlalchemy/sql/clause_compare.py b/lib/sqlalchemy/sql/clause_compare.py new file mode 100644 index 0000000000..87f9fb2df5 --- /dev/null +++ b/lib/sqlalchemy/sql/clause_compare.py @@ -0,0 +1,331 @@ +from collections import deque + +from . import operators +from .. import util + + +SKIP_TRAVERSE = util.symbol("skip_traverse") + + +def compare(obj1, obj2, **kw): + if kw.get("use_proxies", False): + strategy = ColIdentityComparatorStrategy() + else: + strategy = StructureComparatorStrategy() + + return strategy.compare(obj1, obj2, **kw) + + +class StructureComparatorStrategy(object): + __slots__ = "compare_stack", "cache" + + def __init__(self): + self.compare_stack = deque() + self.cache = set() + + def compare(self, obj1, obj2, **kw): + stack = self.compare_stack + cache = self.cache + + stack.append((obj1, obj2)) + + while stack: + left, right = stack.popleft() + + if left is right: + continue + elif left is None or right is None: + # we know they are different so no match + return False + elif (left, right) in cache: + continue + cache.add((left, right)) + + visit_name = left.__visit_name__ + + # we're not exactly looking for identical types, because + # there are things like Column and AnnotatedColumn. So the + # visit_name has to at least match up + if visit_name != right.__visit_name__: + return False + + meth = getattr(self, "compare_%s" % visit_name, None) + + if meth: + comparison = meth(left, right, **kw) + if comparison is False: + return False + elif comparison is SKIP_TRAVERSE: + continue + + for c1, c2 in util.zip_longest( + left.get_children(column_collections=False), + right.get_children(column_collections=False), + fillvalue=None, + ): + if c1 is None or c2 is None: + # collections are different sizes, comparison fails + return False + stack.append((c1, c2)) + + return True + + def compare_inner(self, obj1, obj2, **kw): + stack = self.compare_stack + try: + self.compare_stack = deque() + return self.compare(obj1, obj2, **kw) + finally: + self.compare_stack = stack + + def _compare_unordered_sequences(self, seq1, seq2, **kw): + if seq1 is None: + return seq2 is None + + completed = set() + for clause in seq1: + for other_clause in set(seq2).difference(completed): + if self.compare_inner(clause, other_clause, **kw): + completed.add(other_clause) + break + return len(completed) == len(seq1) == len(seq2) + + def compare_bindparam(self, left, right, **kw): + # note the ".key" is often generated from id(self) so can't + # be compared, as far as determining structure. + return ( + left.type._compare_type_affinity(right.type) + and left.value == right.value + and left.callable == right.callable + and left._orig_key == right._orig_key + ) + + def compare_clauselist(self, left, right, **kw): + if left.operator is right.operator: + if operators.is_associative(left.operator): + if self._compare_unordered_sequences( + left.clauses, right.clauses + ): + return SKIP_TRAVERSE + else: + return False + else: + # normal ordered traversal + return True + else: + return False + + def compare_unary(self, left, right, **kw): + if left.operator: + disp = self._get_operator_dispatch( + left.operator, "unary", "operator" + ) + if disp is not None: + result = disp(left, right, left.operator, **kw) + if result is not True: + return result + elif left.modifier: + disp = self._get_operator_dispatch( + left.modifier, "unary", "modifier" + ) + if disp is not None: + result = disp(left, right, left.operator, **kw) + if result is not True: + return result + return ( + left.operator == right.operator and left.modifier == right.modifier + ) + + def compare_binary(self, left, right, **kw): + disp = self._get_operator_dispatch(left.operator, "binary", None) + if disp: + result = disp(left, right, left.operator, **kw) + if result is not True: + return result + + if left.operator == right.operator: + if operators.is_commutative(left.operator): + if ( + compare(left.left, right.left, **kw) + and compare(left.right, right.right, **kw) + ) or ( + compare(left.left, right.right, **kw) + and compare(left.right, right.left, **kw) + ): + return SKIP_TRAVERSE + else: + return False + else: + return True + else: + return False + + def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): + # used by compare_binary, compare_unary + attrname = "visit_%s_%s%s" % ( + operator_.__name__, + qualifier1, + "_" + qualifier2 if qualifier2 else "", + ) + return getattr(self, attrname, None) + + def visit_function_as_comparison_op_binary( + self, left, right, operator, **kw + ): + return ( + left.left_index == right.left_index + and left.right_index == right.right_index + ) + + def compare_function(self, left, right, **kw): + return left.name == right.name + + def compare_column(self, left, right, **kw): + if left.table is not None: + self.compare_stack.appendleft((left.table, right.table)) + return ( + left.key == right.key + and left.name == right.name + and ( + left.type._compare_type_affinity(right.type) + if left.type is not None + else right.type is None + ) + and left.is_literal == right.is_literal + ) + + def compare_collation(self, left, right, **kw): + return left.collation == right.collation + + def compare_type_coerce(self, left, right, **kw): + return left.type._compare_type_affinity(right.type) + + @util.dependencies("sqlalchemy.sql.elements") + def compare_alias(self, elements, left, right, **kw): + return ( + left.name == right.name + if not isinstance(left.name, elements._anonymous_label) + else isinstance(right.name, elements._anonymous_label) + ) + + def compare_extract(self, left, right, **kw): + return left.field == right.field + + def compare_textual_label_reference(self, left, right, **kw): + return left.element == right.element + + def compare_slice(self, left, right, **kw): + return ( + left.start == right.start + and left.stop == right.stop + and left.step == right.step + ) + + def compare_over(self, left, right, **kw): + return left.range_ == right.range_ and left.rows == right.rows + + @util.dependencies("sqlalchemy.sql.elements") + def compare_label(self, elements, left, right, **kw): + return left._type._compare_type_affinity(right._type) and ( + left.name == right.name + if not isinstance(left, elements._anonymous_label) + else isinstance(right.name, elements._anonymous_label) + ) + + def compare_typeclause(self, left, right, **kw): + return left.type._compare_type_affinity(right.type) + + def compare_join(self, left, right, **kw): + return left.isouter == right.isouter and left.full == right.full + + def compare_table(self, left, right, **kw): + if left.name != right.name: + return False + + self.compare_stack.extendleft( + util.zip_longest(left.columns, right.columns) + ) + + def compare_compound_select(self, left, right, **kw): + + if not self._compare_unordered_sequences( + left.selects, right.selects, **kw + ): + return False + + if left.keyword != right.keyword: + return False + + if left._for_update_arg != right._for_update_arg: + return False + + if not self.compare_inner( + left._order_by_clause, right._order_by_clause, **kw + ): + return False + + if not self.compare_inner( + left._group_by_clause, right._group_by_clause, **kw + ): + return False + + return SKIP_TRAVERSE + + def compare_select(self, left, right, **kw): + if not self._compare_unordered_sequences( + left._correlate, right._correlate + ): + return False + if not self._compare_unordered_sequences( + left._correlate_except, right._correlate_except + ): + return False + + if not self._compare_unordered_sequences( + left._from_obj, right._from_obj + ): + return False + + if left._for_update_arg != right._for_update_arg: + return False + + return True + + def compare_text_as_from(self, left, right, **kw): + self.compare_stack.extendleft( + util.zip_longest(left.column_args, right.column_args) + ) + return left.positional == right.positional + + +class ColIdentityComparatorStrategy(StructureComparatorStrategy): + def compare_column_element( + self, left, right, use_proxies=True, equivalents=(), **kw + ): + """Compare ColumnElements using proxies and equivalent collections. + + This is a comparison strategy specific to the ORM. + """ + + to_compare = (right,) + if equivalents and right in equivalents: + to_compare = equivalents[right].union(to_compare) + + for oth in to_compare: + if use_proxies and left.shares_lineage(oth): + return True + elif hash(left) == hash(right): + return True + else: + return False + + def compare_column(self, left, right, **kw): + return self.compare_column_element(left, right, **kw) + + def compare_label(self, left, right, **kw): + return self.compare_column_element(left, right, **kw) + + def compare_table(self, left, right, **kw): + # tables compare on identity, since it's not really feasible to + # compare them column by column with the above rules + return left is right diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 6c9b8ee5bc..552f61b4a0 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -482,6 +482,12 @@ class _multiparam_column(elements.ColumnElement): self.default = original.default self.type = original.type + def compare(self, other, **kw): + raise NotImplementedError() + + def _copy_internals(self, other, **kw): + raise NotImplementedError() + def __eq__(self, other): return ( isinstance(other, _multiparam_column) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index b0d0feff5d..38c7cf840e 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -17,6 +17,7 @@ import numbers import operator import re +from . import clause_compare from . import operators from . import type_api from .annotation import Annotated @@ -341,7 +342,7 @@ class ClauseElement(Visitable): (see :class:`.ColumnElement`) """ - return self is other + return clause_compare.compare(self, other, **kw) def _copy_internals(self, clone=_clone, **kw): """Reassign internal elements to be clones of themselves. @@ -810,34 +811,6 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): selectable._columns[key] = co return co - def compare(self, other, use_proxies=False, equivalents=None, **kw): - """Compare this ColumnElement to another. - - Special arguments understood: - - :param use_proxies: when True, consider two columns that - share a common base column as equivalent (i.e. shares_lineage()) - - :param equivalents: a dictionary of columns as keys mapped to sets - of columns. If the given "other" column is present in this - dictionary, if any of the columns in the corresponding set() pass - the comparison test, the result is True. This is used to expand the - comparison to other columns that may be known to be equivalent to - this one via foreign key or other criterion. - - """ - to_compare = (other,) - if equivalents and other in equivalents: - to_compare = equivalents[other].union(to_compare) - - for oth in to_compare: - if use_proxies and self.shares_lineage(oth): - return True - elif hash(oth) == hash(self): - return True - else: - return False - def cast(self, type_): """Produce a type cast, i.e. ``CAST( AS )``. @@ -1226,17 +1199,6 @@ class BindParameter(ColumnElement): "%%(%d %s)s" % (id(self), self._orig_key or "param") ) - def compare(self, other, **kw): - """Compare this :class:`BindParameter` to the given - clause.""" - - return ( - isinstance(other, BindParameter) - and self.type._compare_type_affinity(other.type) - and self.value == other.value - and self.callable == other.callable - ) - def __getstate__(self): """execute a deferred value for serialization purposes.""" @@ -1696,9 +1658,6 @@ class TextClause(Executable, ClauseElement): def get_children(self, **kwargs): return list(self._bindparams.values()) - def compare(self, other): - return isinstance(other, TextClause) and other.text == self.text - class Null(ColumnElement): """Represent the NULL keyword in a SQL statement. @@ -1720,9 +1679,6 @@ class Null(ColumnElement): return Null() - def compare(self, other): - return isinstance(other, Null) - class False_(ColumnElement): """Represent the ``false`` keyword, or equivalent, in a SQL statement. @@ -1779,9 +1735,6 @@ class False_(ColumnElement): return False_() - def compare(self, other): - return isinstance(other, False_) - class True_(ColumnElement): """Represent the ``true`` keyword, or equivalent, in a SQL statement. @@ -1845,9 +1798,6 @@ class True_(ColumnElement): return True_() - def compare(self, other): - return isinstance(other, True_) - class ClauseList(ClauseElement): """Describe a list of clauses, separated by an operator. @@ -1908,38 +1858,6 @@ class ClauseList(ClauseElement): else: return self - def compare(self, other, **kw): - """Compare this :class:`.ClauseList` to the given :class:`.ClauseList`, - including a comparison of all the clause items. - - """ - if not isinstance(other, ClauseList) and len(self.clauses) == 1: - return self.clauses[0].compare(other, **kw) - elif ( - isinstance(other, ClauseList) - and len(self.clauses) == len(other.clauses) - and self.operator is other.operator - ): - - if self.operator in (operators.and_, operators.or_): - completed = set() - for clause in self.clauses: - for other_clause in set(other.clauses).difference( - completed - ): - if clause.compare(other_clause, **kw): - completed.add(other_clause) - break - return len(completed) == len(other.clauses) - else: - for i in range(0, len(self.clauses)): - if not self.clauses[i].compare(other.clauses[i], **kw): - return False - else: - return True - else: - return False - class BooleanClauseList(ClauseList, ColumnElement): __visit_name__ = "clauselist" @@ -2606,6 +2524,9 @@ class _label_reference(ColumnElement): def _copy_internals(self, clone=_clone, **kw): self.element = clone(self.element, **kw) + def get_children(self, **kwargs): + return [self.element] + @property def _from_objects(self): return () @@ -2885,17 +2806,6 @@ class UnaryExpression(ColumnElement): def get_children(self, **kwargs): return (self.element,) - def compare(self, other, **kw): - """Compare this :class:`UnaryExpression` against the given - :class:`.ClauseElement`.""" - - return ( - isinstance(other, UnaryExpression) - and self.operator == other.operator - and self.modifier == other.modifier - and self.element.compare(other.element, **kw) - ) - def _negate(self): if self.negate is not None: return UnaryExpression( @@ -3103,24 +3013,6 @@ class BinaryExpression(ColumnElement): def get_children(self, **kwargs): return self.left, self.right - def compare(self, other, **kw): - """Compare this :class:`BinaryExpression` against the - given :class:`BinaryExpression`.""" - - return ( - isinstance(other, BinaryExpression) - and self.operator == other.operator - and ( - self.left.compare(other.left, **kw) - and self.right.compare(other.right, **kw) - or ( - operators.is_commutative(self.operator) - and self.left.compare(other.right, **kw) - and self.right.compare(other.left, **kw) - ) - ) - ) - def self_group(self, against=None): if operators.is_precedent(self.operator, against): return Grouping(self) @@ -3213,11 +3105,6 @@ class Grouping(ColumnElement): self.element = state["element"] self.type = state["type"] - def compare(self, other, **kw): - return isinstance(other, Grouping) and self.element.compare( - other.element - ) - RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED") RANGE_CURRENT = util.symbol("RANGE_CURRENT") diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index fcc843d919..f48a20ec7f 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -74,7 +74,9 @@ class FunctionElement(Executable, ColumnElement, FromClause): def __init__(self, *clauses, **kwargs): """Construct a :class:`.FunctionElement`. """ - args = [_literal_as_binds(c, self.name) for c in clauses] + args = [ + _literal_as_binds(c, getattr(self, "name", None)) for c in clauses + ] self._has_args = self._has_args or bool(args) self.clause_expr = ClauseList( operator=operators.comma_op, group_contents=True, *args @@ -376,12 +378,11 @@ class FunctionAsBinary(BinaryExpression): self.left_index = left_index self.right_index = right_index - super(FunctionAsBinary, self).__init__( - left, - right, - operators.function_as_comparison_op, - type_=sqltypes.BOOLEANTYPE, - ) + self.operator = operators.function_as_comparison_op + self.type = sqltypes.BOOLEANTYPE + self.negate = None + self._is_implicitly_boolean = True + self.modifiers = {} @property def left(self): @@ -399,10 +400,11 @@ class FunctionAsBinary(BinaryExpression): def right(self, value): self.sql_function.clauses.clauses[self.right_index - 1] = value - def _copy_internals(self, **kw): - clone = kw.pop("clone") + def _copy_internals(self, clone=_clone, **kw): self.sql_function = clone(self.sql_function, **kw) - super(FunctionAsBinary, self)._copy_internals(**kw) + + def get_children(self, **kw): + yield self.sql_function class _FunctionGenerator(object): @@ -682,6 +684,18 @@ class next_value(GenericFunction): self._bind = kw.get("bind", None) self.sequence = seq + def compare(self, other, **kw): + return ( + isinstance(other, next_value) + and self.sequence.name == other.sequence.name + ) + + def get_children(self, **kwargs): + return [] + + def _copy_internals(self, **kw): + pass + @property def _from_objects(self): return [] diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 4206de4603..8479c1d594 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -1414,6 +1414,11 @@ def mirror(op): _associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne]) + +def is_associative(op): + return op in _associative + + _natural_self_precedent = _associative.union( [getitem, json_getitem_op, json_path_getitem_op] ) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index d4528f0c31..796e2b2720 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1994,6 +1994,9 @@ class ForUpdateArg(ClauseElement): and other.of is self.of ) + def __ne__(self, other): + return not self.__eq__(other) + def __hash__(self): return id(self) @@ -3941,6 +3944,12 @@ class TextAsFrom(SelectBase): self._reset_exported() self.element = clone(self.element, **kw) + def get_children(self, column_collections=True, **kw): + if column_collections: + for c in self.column_args: + yield c + yield self.element + def _scalar_type(self): return self.column_args[0].type diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index ac6af0a96d..ccd79f8d11 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -274,6 +274,18 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): dialect=postgresql.dialect(), ) + def test_functions_args_noname(self): + class myfunc(FunctionElement): + pass + + @compiles(myfunc) + def visit_myfunc(element, compiler, **kw): + return "myfunc%s" % (compiler.process(element.clause_expr, **kw),) + + self.assert_compile(myfunc(), "myfunc()") + + self.assert_compile(myfunc(column("x"), column("y")), "myfunc(x, y)") + def test_function_calls_base(self): from sqlalchemy.dialects import mssql diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py new file mode 100644 index 0000000000..8e62d5d82a --- /dev/null +++ b/test/sql/test_compare.py @@ -0,0 +1,504 @@ +import importlib +import itertools + +from sqlalchemy import and_ +from sqlalchemy import Boolean +from sqlalchemy import case +from sqlalchemy import cast +from sqlalchemy import Column +from sqlalchemy import column +from sqlalchemy import dialects +from sqlalchemy import exists +from sqlalchemy import extract +from sqlalchemy import Float +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import or_ +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import table +from sqlalchemy import text +from sqlalchemy import tuple_ +from sqlalchemy import union +from sqlalchemy import union_all +from sqlalchemy import util +from sqlalchemy.schema import Sequence +from sqlalchemy.sql import bindparam +from sqlalchemy.sql import ColumnElement +from sqlalchemy.sql import False_ +from sqlalchemy.sql import func +from sqlalchemy.sql import operators +from sqlalchemy.sql import True_ +from sqlalchemy.sql import type_coerce +from sqlalchemy.sql.elements import _label_reference +from sqlalchemy.sql.elements import _textual_label_reference +from sqlalchemy.sql.elements import Annotated +from sqlalchemy.sql.elements import ClauseElement +from sqlalchemy.sql.elements import ClauseList +from sqlalchemy.sql.elements import CollationClause +from sqlalchemy.sql.elements import Immutable +from sqlalchemy.sql.elements import Null +from sqlalchemy.sql.elements import Slice +from sqlalchemy.sql.elements import UnaryExpression +from sqlalchemy.sql.functions import FunctionElement +from sqlalchemy.sql.functions import GenericFunction +from sqlalchemy.sql.functions import ReturnTypeFromArgs +from sqlalchemy.sql.selectable import _OffsetLimitParam +from sqlalchemy.sql.selectable import FromGrouping +from sqlalchemy.sql.selectable import Selectable +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_true +from sqlalchemy.util import class_hierarchy + + +meta = MetaData() +meta2 = MetaData() + +table_a = Table("a", meta, Column("a", Integer), Column("b", String)) +table_a_2 = Table("a", meta2, Column("a", Integer), Column("b", String)) + +table_b = Table("b", meta, Column("a", Integer), Column("b", Integer)) + +table_c = Table("c", meta, Column("x", Integer), Column("y", Integer)) + +table_d = Table("d", meta, Column("y", Integer), Column("z", Integer)) + + +class CompareAndCopyTest(fixtures.TestBase): + + # lambdas which return a tuple of ColumnElement objects. + # must return at least two objects that should compare differently. + # to test more varieties of "difference" additional objects can be added. + fixtures = [ + lambda: ( + column("q"), + column("x"), + column("q", Integer), + column("q", String), + ), + lambda: (~column("q", Boolean), ~column("p", Boolean)), + lambda: ( + table_a.c.a.label("foo"), + table_a.c.a.label("bar"), + table_a.c.b.label("foo"), + ), + lambda: ( + _label_reference(table_a.c.a.desc()), + _label_reference(table_a.c.a.asc()), + ), + lambda: (_textual_label_reference("a"), _textual_label_reference("b")), + lambda: ( + text("select a, b from table").columns(a=Integer, b=String), + text("select a, b, c from table").columns( + a=Integer, b=String, c=Integer + ), + ), + lambda: ( + column("q") == column("x"), + column("q") == column("y"), + column("z") == column("x"), + ), + lambda: ( + cast(column("q"), Integer), + cast(column("q"), Float), + cast(column("p"), Integer), + ), + lambda: ( + bindparam("x"), + bindparam("y"), + bindparam("x", type_=Integer), + bindparam("x", type_=String), + bindparam(None), + ), + lambda: (_OffsetLimitParam("x"), _OffsetLimitParam("y")), + lambda: (func.foo(), func.foo(5), func.bar()), + lambda: (func.current_date(), func.current_time()), + lambda: ( + func.next_value(Sequence("q")), + func.next_value(Sequence("p")), + ), + lambda: (True_(), False_()), + lambda: (Null(),), + lambda: (ReturnTypeFromArgs("foo"), ReturnTypeFromArgs(5)), + lambda: (FunctionElement(5), FunctionElement(5, 6)), + lambda: (func.count(), func.not_count()), + lambda: (func.char_length("abc"), func.char_length("def")), + lambda: (GenericFunction("a", "b"), GenericFunction("a")), + lambda: (CollationClause("foobar"), CollationClause("batbar")), + lambda: ( + type_coerce(column("q", Integer), String), + type_coerce(column("q", Integer), Float), + type_coerce(column("z", Integer), Float), + ), + lambda: (table_a.c.a, table_b.c.a), + lambda: (tuple_([1, 2]), tuple_([3, 4])), + lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])), + lambda: ( + func.percentile_cont(0.5).within_group(table_a.c.a), + func.percentile_cont(0.5).within_group(table_a.c.b), + func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b), + func.percentile_cont(0.5).within_group( + table_a.c.a, table_a.c.b, column("q") + ), + ), + lambda: ( + func.is_equal("a", "b").as_comparison(1, 2), + func.is_equal("a", "c").as_comparison(1, 2), + func.is_equal("a", "b").as_comparison(2, 1), + func.is_equal("a", "b", "c").as_comparison(1, 2), + func.foobar("a", "b").as_comparison(1, 2), + ), + lambda: ( + func.row_number().over(order_by=table_a.c.a), + func.row_number().over(order_by=table_a.c.a, range_=(0, 10)), + func.row_number().over(order_by=table_a.c.a, range_=(None, 10)), + func.row_number().over(order_by=table_a.c.a, rows=(None, 20)), + func.row_number().over(order_by=table_a.c.b), + func.row_number().over( + order_by=table_a.c.a, partition_by=table_a.c.b + ), + ), + lambda: ( + func.count(1).filter(table_a.c.a == 5), + func.count(1).filter(table_a.c.a == 10), + func.foob(1).filter(table_a.c.a == 10), + ), + lambda: ( + and_(table_a.c.a == 5, table_a.c.b == table_b.c.a), + and_(table_a.c.a == 5, table_a.c.a == table_b.c.a), + or_(table_a.c.a == 5, table_a.c.b == table_b.c.a), + ClauseList(table_a.c.a == 5, table_a.c.b == table_b.c.a), + ClauseList(table_a.c.a == 5, table_a.c.b == table_a.c.a), + ), + lambda: ( + case(whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)]), + case(whens=[(table_a.c.a == 18, 10), (table_a.c.a == 10, 20)]), + case(whens=[(table_a.c.a == 5, 10), (table_a.c.b == 10, 20)]), + case( + whens=[ + (table_a.c.a == 5, 10), + (table_a.c.b == 10, 20), + (table_a.c.a == 9, 12), + ] + ), + case( + whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)], + else_=30, + ), + case({"wendy": "W", "jack": "J"}, value=table_a.c.a, else_="E"), + case({"wendy": "W", "jack": "J"}, value=table_a.c.b, else_="E"), + case({"wendy_w": "W", "jack": "J"}, value=table_a.c.a, else_="E"), + ), + lambda: ( + extract("foo", table_a.c.a), + extract("foo", table_a.c.b), + extract("bar", table_a.c.a), + ), + lambda: ( + Slice(1, 2, 5), + Slice(1, 5, 5), + Slice(1, 5, 10), + Slice(2, 10, 15), + ), + lambda: ( + select([table_a.c.a]), + select([table_a.c.a, table_a.c.b]), + select([table_a.c.b, table_a.c.a]), + select([table_a.c.a]).where(table_a.c.b == 5), + select([table_a.c.a]) + .where(table_a.c.b == 5) + .where(table_a.c.a == 10), + select([table_a.c.a]).where(table_a.c.b == 5).with_for_update(), + select([table_a.c.a]) + .where(table_a.c.b == 5) + .with_for_update(nowait=True), + select([table_a.c.a]).where(table_a.c.b == 5).correlate(table_b), + select([table_a.c.a]) + .where(table_a.c.b == 5) + .correlate_except(table_b), + ), + lambda: ( + table_a.join(table_b, table_a.c.a == table_b.c.a), + table_a.join( + table_b, and_(table_a.c.a == table_b.c.a, table_a.c.b == 1) + ), + table_a.outerjoin(table_b, table_a.c.a == table_b.c.a), + ), + lambda: ( + table_a.alias("a"), + table_a.alias("b"), + table_a.alias(), + table_b.alias("a"), + select([table_a.c.a]).alias("a"), + ), + lambda: ( + FromGrouping(table_a.alias("a")), + FromGrouping(table_a.alias("b")), + ), + lambda: ( + select([table_a.c.a]).as_scalar(), + select([table_a.c.a]).where(table_a.c.b == 5).as_scalar(), + ), + lambda: ( + exists().where(table_a.c.a == 5), + exists().where(table_a.c.b == 5), + ), + lambda: ( + union(select([table_a.c.a]), select([table_a.c.b])), + union(select([table_a.c.a]), select([table_a.c.b])).order_by("a"), + union_all(select([table_a.c.a]), select([table_a.c.b])), + union(select([table_a.c.a])), + union( + select([table_a.c.a]), + select([table_a.c.b]).where(table_a.c.b > 5), + ), + ), + lambda: ( + table("a", column("x"), column("y")), + table("a", column("y"), column("x")), + table("b", column("x"), column("y")), + table("a", column("x"), column("y"), column("z")), + table("a", column("x"), column("y", Integer)), + table("a", column("q"), column("y", Integer)), + ), + lambda: ( + Table("a", MetaData(), Column("q", Integer), Column("b", String)), + Table("b", MetaData(), Column("q", Integer), Column("b", String)), + ), + ] + + @classmethod + def setup_class(cls): + # TODO: we need to get dialects here somehow, perhaps in test_suite? + [ + importlib.import_module("sqlalchemy.dialects.%s" % d) + for d in dialects.__all__ + if not d.startswith("_") + ] + + def test_all_present(self): + need = set( + cls + for cls in class_hierarchy(ClauseElement) + if issubclass(cls, (ColumnElement, Selectable)) + and "__init__" in cls.__dict__ + and not issubclass(cls, (Annotated)) + and "orm" not in cls.__module__ + and "crud" not in cls.__module__ + and "dialects" not in cls.__module__ # TODO: dialects? + ).difference({ColumnElement, UnaryExpression}) + for fixture in self.fixtures: + case_a = fixture() + for elem in case_a: + for mro in type(elem).__mro__: + need.discard(mro) + + is_false(bool(need), "%d Remaining classes: %r" % (len(need), need)) + + def test_compare(self): + for fixture in self.fixtures: + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + if a == b: + is_true( + case_a[a].compare( + case_b[b], arbitrary_expression=True + ), + "%r != %r" % (case_a[a], case_b[b]), + ) + + else: + is_false( + case_a[a].compare( + case_b[b], arbitrary_expression=True + ), + "%r == %r" % (case_a[a], case_b[b]), + ) + + def test_compare_col_identity(self): + stmt1 = ( + select([table_a.c.a, table_b.c.b]) + .where(table_a.c.a == table_b.c.b) + .alias() + ) + stmt1_c = ( + select([table_a.c.a, table_b.c.b]) + .where(table_a.c.a == table_b.c.b) + .alias() + ) + + stmt2 = union(select([table_a]), select([table_b])) + + stmt3 = select([table_b]) + + equivalents = {table_a.c.a: [table_b.c.a]} + + is_false( + stmt1.compare(stmt2, use_proxies=True, equivalents=equivalents) + ) + + is_true( + stmt1.compare(stmt1_c, use_proxies=True, equivalents=equivalents) + ) + is_true( + (table_a.c.a == table_b.c.b).compare( + stmt1.c.a == stmt1.c.b, + use_proxies=True, + equivalents=equivalents, + ) + ) + + def test_copy_internals(self): + for fixture in self.fixtures: + case_a = fixture() + case_b = fixture() + + assert case_a[0].compare(case_b[0]) + + clone = case_a[0]._clone() + clone._copy_internals() + + assert clone.compare(case_b[0]) + + stack = [clone] + seen = {clone} + found_elements = False + while stack: + obj = stack.pop(0) + + items = [ + subelem + for key, elem in clone.__dict__.items() + if key != "_is_clone_of" and elem is not None + for subelem in util.to_list(elem) + if ( + isinstance(subelem, (ColumnElement, ClauseList)) + and subelem not in seen + and not isinstance(subelem, Immutable) + and subelem is not case_a[0] + ) + ] + stack.extend(items) + seen.update(items) + + if obj is not clone: + found_elements = True + # ensure the element will not compare as true + obj.compare = lambda other, **kw: False + obj.__visit_name__ = "dont_match" + + if found_elements: + assert not clone.compare(case_b[0]) + assert case_a[0].compare(case_b[0]) + + +class CompareClausesTest(fixtures.TestBase): + def test_compare_comparison_associative(self): + + l1 = table_c.c.x == table_d.c.y + l2 = table_d.c.y == table_c.c.x + l3 = table_c.c.x == table_d.c.z + + is_true(l1.compare(l1)) + is_true(l1.compare(l2)) + is_false(l1.compare(l3)) + + def test_compare_clauselist_associative(self): + + l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z) + + l2 = and_(table_c.c.y == table_d.c.z, table_c.c.x == table_d.c.y) + + l3 = and_(table_c.c.x == table_d.c.z, table_c.c.y == table_d.c.y) + + is_true(l1.compare(l1)) + is_true(l1.compare(l2)) + is_false(l1.compare(l3)) + + def test_compare_clauselist_not_associative(self): + + l1 = ClauseList( + table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.sub + ) + + l2 = ClauseList( + table_d.c.y, table_c.c.x, table_c.c.y, operator=operators.sub + ) + + is_true(l1.compare(l1)) + is_false(l1.compare(l2)) + + def test_compare_clauselist_assoc_different_operator(self): + + l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z) + + l2 = or_(table_c.c.y == table_d.c.z, table_c.c.x == table_d.c.y) + + is_false(l1.compare(l2)) + + def test_compare_clauselist_not_assoc_different_operator(self): + + l1 = ClauseList( + table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.sub + ) + + l2 = ClauseList( + table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.div + ) + + is_false(l1.compare(l2)) + + def test_compare_binds(self): + b1 = bindparam("foo", type_=Integer()) + b2 = bindparam("foo", type_=Integer()) + b3 = bindparam("bar", type_=Integer()) + b4 = bindparam("foo", type_=String()) + + def c1(): + return 5 + + def c2(): + return 6 + + b5 = bindparam("foo", type_=Integer(), callable_=c1) + b6 = bindparam("foo", type_=Integer(), callable_=c2) + b7 = bindparam("foo", type_=Integer(), callable_=c1) + + b8 = bindparam("foo", type_=Integer, value=5) + b9 = bindparam("foo", type_=Integer, value=6) + + is_false(b1.compare(b5)) + is_true(b5.compare(b7)) + is_false(b5.compare(b6)) + is_true(b1.compare(b2)) + + # currently not comparing "key", as we often have to compare + # anonymous names. however we should really check for that + # is_true(b1.compare(b3)) + + is_false(b1.compare(b4)) + is_false(b1.compare(b8)) + is_false(b8.compare(b9)) + is_true(b8.compare(b8)) + + def test_compare_tables(self): + is_true(table_a.compare(table_a_2)) + + # the "proxy" version compares schema tables on metadata identity + is_false(table_a.compare(table_a_2, use_proxies=True)) + + # same for lower case tables since it compares lower case columns + # using proxies, which makes it very unlikely to have multiple + # table() objects with columns that compare equally + is_false( + table("a", column("x", Integer), column("q", String)).compare( + table("a", column("x", Integer), column("q", String)), + use_proxies=True, + ) + ) diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 82c69003bd..c6eff6ac93 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -38,6 +38,7 @@ from sqlalchemy.sql import sqltypes from sqlalchemy.sql import table from sqlalchemy.sql import true from sqlalchemy.sql.elements import _literal_as_text +from sqlalchemy.sql.elements import BindParameter from sqlalchemy.sql.elements import Label from sqlalchemy.sql.expression import BinaryExpression from sqlalchemy.sql.expression import ClauseList @@ -193,7 +194,13 @@ class DefaultColumnComparatorTest(fixtures.TestBase): assert left.comparator.operate(operators.in_op, [1, 2, 3]).compare( BinaryExpression( left, - Grouping(ClauseList(literal(1), literal(2), literal(3))), + Grouping( + ClauseList( + BindParameter("left", value=1, unique=True), + BindParameter("left", value=2, unique=True), + BindParameter("left", value=3, unique=True), + ) + ), operators.in_op, ) ) @@ -204,7 +211,13 @@ class DefaultColumnComparatorTest(fixtures.TestBase): assert left.comparator.operate(operators.notin_op, [1, 2, 3]).compare( BinaryExpression( left, - Grouping(ClauseList(literal(1), literal(2), literal(3))), + Grouping( + ClauseList( + BindParameter("left", value=1, unique=True), + BindParameter("left", value=2, unique=True), + BindParameter("left", value=3, unique=True), + ) + ), operators.notin_op, ) ) diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py index 023c483fcf..988d5331eb 100644 --- a/test/sql/test_utils.py +++ b/test/sql/test_utils.py @@ -1,105 +1,7 @@ -from sqlalchemy import and_ -from sqlalchemy import bindparam -from sqlalchemy import Column -from sqlalchemy import Integer -from sqlalchemy import MetaData -from sqlalchemy import or_ -from sqlalchemy import String -from sqlalchemy import Table -from sqlalchemy.sql import operators from sqlalchemy.sql import util as sql_util -from sqlalchemy.sql.elements import ClauseList from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing import is_false -from sqlalchemy.testing import is_true - - -class CompareClausesTest(fixtures.TestBase): - def setup(self): - m = MetaData() - self.a = Table("a", m, Column("x", Integer), Column("y", Integer)) - - self.b = Table("b", m, Column("y", Integer), Column("z", Integer)) - - def test_compare_clauselist_associative(self): - - l1 = and_(self.a.c.x == self.b.c.y, self.a.c.y == self.b.c.z) - - l2 = and_(self.a.c.y == self.b.c.z, self.a.c.x == self.b.c.y) - - l3 = and_(self.a.c.x == self.b.c.z, self.a.c.y == self.b.c.y) - - is_true(l1.compare(l1)) - is_true(l1.compare(l2)) - is_false(l1.compare(l3)) - - def test_compare_clauselist_not_associative(self): - - l1 = ClauseList( - self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.sub - ) - - l2 = ClauseList( - self.b.c.y, self.a.c.x, self.a.c.y, operator=operators.sub - ) - - is_true(l1.compare(l1)) - is_false(l1.compare(l2)) - - def test_compare_clauselist_assoc_different_operator(self): - - l1 = and_(self.a.c.x == self.b.c.y, self.a.c.y == self.b.c.z) - - l2 = or_(self.a.c.y == self.b.c.z, self.a.c.x == self.b.c.y) - - is_false(l1.compare(l2)) - - def test_compare_clauselist_not_assoc_different_operator(self): - - l1 = ClauseList( - self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.sub - ) - - l2 = ClauseList( - self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.div - ) - - is_false(l1.compare(l2)) - - def test_compare_binds(self): - b1 = bindparam("foo", type_=Integer()) - b2 = bindparam("foo", type_=Integer()) - b3 = bindparam("bar", type_=Integer()) - b4 = bindparam("foo", type_=String()) - - def c1(): - return 5 - - def c2(): - return 6 - - b5 = bindparam("foo", type_=Integer(), callable_=c1) - b6 = bindparam("foo", type_=Integer(), callable_=c2) - b7 = bindparam("foo", type_=Integer(), callable_=c1) - - b8 = bindparam("foo", type_=Integer, value=5) - b9 = bindparam("foo", type_=Integer, value=6) - - is_false(b1.compare(b5)) - is_true(b5.compare(b7)) - is_false(b5.compare(b6)) - is_true(b1.compare(b2)) - - # currently not comparing "key", as we often have to compare - # anonymous names. however we should really check for that - is_true(b1.compare(b3)) - - is_false(b1.compare(b4)) - is_false(b1.compare(b8)) - is_false(b8.compare(b9)) - is_true(b8.compare(b8)) class MiscTest(fixtures.TestBase):