--- /dev/null
+.. change::
+ :tags: bug, sql
+ :tickets: 4336
+
+ Reworked the :meth:`.ClauseElement.compare` methods in terms of a new
+ visitor-based approach, and additionally added test coverage ensuring that
+ all :class:`.ClauseElement` subclasses can be accurately compared
+ against each other in terms of structure. Structural comparison
+ capability is used to a small degree within the ORM currently, however
+ it also may form the basis for new caching features.
--- /dev/null
+from collections import deque
+
+from . import operators
+from .. import util
+
+
+SKIP_TRAVERSE = util.symbol("skip_traverse")
+
+
+def compare(obj1, obj2, **kw):
+ if kw.get("use_proxies", False):
+ strategy = ColIdentityComparatorStrategy()
+ else:
+ strategy = StructureComparatorStrategy()
+
+ return strategy.compare(obj1, obj2, **kw)
+
+
+class StructureComparatorStrategy(object):
+ __slots__ = "compare_stack", "cache"
+
+ def __init__(self):
+ self.compare_stack = deque()
+ self.cache = set()
+
+ def compare(self, obj1, obj2, **kw):
+ stack = self.compare_stack
+ cache = self.cache
+
+ stack.append((obj1, obj2))
+
+ while stack:
+ left, right = stack.popleft()
+
+ if left is right:
+ continue
+ elif left is None or right is None:
+ # we know they are different so no match
+ return False
+ elif (left, right) in cache:
+ continue
+ cache.add((left, right))
+
+ visit_name = left.__visit_name__
+
+ # we're not exactly looking for identical types, because
+ # there are things like Column and AnnotatedColumn. So the
+ # visit_name has to at least match up
+ if visit_name != right.__visit_name__:
+ return False
+
+ meth = getattr(self, "compare_%s" % visit_name, None)
+
+ if meth:
+ comparison = meth(left, right, **kw)
+ if comparison is False:
+ return False
+ elif comparison is SKIP_TRAVERSE:
+ continue
+
+ for c1, c2 in util.zip_longest(
+ left.get_children(column_collections=False),
+ right.get_children(column_collections=False),
+ fillvalue=None,
+ ):
+ if c1 is None or c2 is None:
+ # collections are different sizes, comparison fails
+ return False
+ stack.append((c1, c2))
+
+ return True
+
+ def compare_inner(self, obj1, obj2, **kw):
+ stack = self.compare_stack
+ try:
+ self.compare_stack = deque()
+ return self.compare(obj1, obj2, **kw)
+ finally:
+ self.compare_stack = stack
+
+ def _compare_unordered_sequences(self, seq1, seq2, **kw):
+ if seq1 is None:
+ return seq2 is None
+
+ completed = set()
+ for clause in seq1:
+ for other_clause in set(seq2).difference(completed):
+ if self.compare_inner(clause, other_clause, **kw):
+ completed.add(other_clause)
+ break
+ return len(completed) == len(seq1) == len(seq2)
+
+ def compare_bindparam(self, left, right, **kw):
+ # note the ".key" is often generated from id(self) so can't
+ # be compared, as far as determining structure.
+ return (
+ left.type._compare_type_affinity(right.type)
+ and left.value == right.value
+ and left.callable == right.callable
+ and left._orig_key == right._orig_key
+ )
+
+ def compare_clauselist(self, left, right, **kw):
+ if left.operator is right.operator:
+ if operators.is_associative(left.operator):
+ if self._compare_unordered_sequences(
+ left.clauses, right.clauses
+ ):
+ return SKIP_TRAVERSE
+ else:
+ return False
+ else:
+ # normal ordered traversal
+ return True
+ else:
+ return False
+
+ def compare_unary(self, left, right, **kw):
+ if left.operator:
+ disp = self._get_operator_dispatch(
+ left.operator, "unary", "operator"
+ )
+ if disp is not None:
+ result = disp(left, right, left.operator, **kw)
+ if result is not True:
+ return result
+ elif left.modifier:
+ disp = self._get_operator_dispatch(
+ left.modifier, "unary", "modifier"
+ )
+ if disp is not None:
+ result = disp(left, right, left.operator, **kw)
+ if result is not True:
+ return result
+ return (
+ left.operator == right.operator and left.modifier == right.modifier
+ )
+
+ def compare_binary(self, left, right, **kw):
+ disp = self._get_operator_dispatch(left.operator, "binary", None)
+ if disp:
+ result = disp(left, right, left.operator, **kw)
+ if result is not True:
+ return result
+
+ if left.operator == right.operator:
+ if operators.is_commutative(left.operator):
+ if (
+ compare(left.left, right.left, **kw)
+ and compare(left.right, right.right, **kw)
+ ) or (
+ compare(left.left, right.right, **kw)
+ and compare(left.right, right.left, **kw)
+ ):
+ return SKIP_TRAVERSE
+ else:
+ return False
+ else:
+ return True
+ else:
+ return False
+
+ def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
+ # used by compare_binary, compare_unary
+ attrname = "visit_%s_%s%s" % (
+ operator_.__name__,
+ qualifier1,
+ "_" + qualifier2 if qualifier2 else "",
+ )
+ return getattr(self, attrname, None)
+
+ def visit_function_as_comparison_op_binary(
+ self, left, right, operator, **kw
+ ):
+ return (
+ left.left_index == right.left_index
+ and left.right_index == right.right_index
+ )
+
+ def compare_function(self, left, right, **kw):
+ return left.name == right.name
+
+ def compare_column(self, left, right, **kw):
+ if left.table is not None:
+ self.compare_stack.appendleft((left.table, right.table))
+ return (
+ left.key == right.key
+ and left.name == right.name
+ and (
+ left.type._compare_type_affinity(right.type)
+ if left.type is not None
+ else right.type is None
+ )
+ and left.is_literal == right.is_literal
+ )
+
+ def compare_collation(self, left, right, **kw):
+ return left.collation == right.collation
+
+ def compare_type_coerce(self, left, right, **kw):
+ return left.type._compare_type_affinity(right.type)
+
+ @util.dependencies("sqlalchemy.sql.elements")
+ def compare_alias(self, elements, left, right, **kw):
+ return (
+ left.name == right.name
+ if not isinstance(left.name, elements._anonymous_label)
+ else isinstance(right.name, elements._anonymous_label)
+ )
+
+ def compare_extract(self, left, right, **kw):
+ return left.field == right.field
+
+ def compare_textual_label_reference(self, left, right, **kw):
+ return left.element == right.element
+
+ def compare_slice(self, left, right, **kw):
+ return (
+ left.start == right.start
+ and left.stop == right.stop
+ and left.step == right.step
+ )
+
+ def compare_over(self, left, right, **kw):
+ return left.range_ == right.range_ and left.rows == right.rows
+
+ @util.dependencies("sqlalchemy.sql.elements")
+ def compare_label(self, elements, left, right, **kw):
+ return left._type._compare_type_affinity(right._type) and (
+ left.name == right.name
+ if not isinstance(left, elements._anonymous_label)
+ else isinstance(right.name, elements._anonymous_label)
+ )
+
+ def compare_typeclause(self, left, right, **kw):
+ return left.type._compare_type_affinity(right.type)
+
+ def compare_join(self, left, right, **kw):
+ return left.isouter == right.isouter and left.full == right.full
+
+ def compare_table(self, left, right, **kw):
+ if left.name != right.name:
+ return False
+
+ self.compare_stack.extendleft(
+ util.zip_longest(left.columns, right.columns)
+ )
+
+ def compare_compound_select(self, left, right, **kw):
+
+ if not self._compare_unordered_sequences(
+ left.selects, right.selects, **kw
+ ):
+ return False
+
+ if left.keyword != right.keyword:
+ return False
+
+ if left._for_update_arg != right._for_update_arg:
+ return False
+
+ if not self.compare_inner(
+ left._order_by_clause, right._order_by_clause, **kw
+ ):
+ return False
+
+ if not self.compare_inner(
+ left._group_by_clause, right._group_by_clause, **kw
+ ):
+ return False
+
+ return SKIP_TRAVERSE
+
+ def compare_select(self, left, right, **kw):
+ if not self._compare_unordered_sequences(
+ left._correlate, right._correlate
+ ):
+ return False
+ if not self._compare_unordered_sequences(
+ left._correlate_except, right._correlate_except
+ ):
+ return False
+
+ if not self._compare_unordered_sequences(
+ left._from_obj, right._from_obj
+ ):
+ return False
+
+ if left._for_update_arg != right._for_update_arg:
+ return False
+
+ return True
+
+ def compare_text_as_from(self, left, right, **kw):
+ self.compare_stack.extendleft(
+ util.zip_longest(left.column_args, right.column_args)
+ )
+ return left.positional == right.positional
+
+
+class ColIdentityComparatorStrategy(StructureComparatorStrategy):
+ def compare_column_element(
+ self, left, right, use_proxies=True, equivalents=(), **kw
+ ):
+ """Compare ColumnElements using proxies and equivalent collections.
+
+ This is a comparison strategy specific to the ORM.
+ """
+
+ to_compare = (right,)
+ if equivalents and right in equivalents:
+ to_compare = equivalents[right].union(to_compare)
+
+ for oth in to_compare:
+ if use_proxies and left.shares_lineage(oth):
+ return True
+ elif hash(left) == hash(right):
+ return True
+ else:
+ return False
+
+ def compare_column(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_label(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_table(self, left, right, **kw):
+ # tables compare on identity, since it's not really feasible to
+ # compare them column by column with the above rules
+ return left is right
self.default = original.default
self.type = original.type
+ def compare(self, other, **kw):
+ raise NotImplementedError()
+
+ def _copy_internals(self, other, **kw):
+ raise NotImplementedError()
+
def __eq__(self, other):
return (
isinstance(other, _multiparam_column)
import operator
import re
+from . import clause_compare
from . import operators
from . import type_api
from .annotation import Annotated
(see :class:`.ColumnElement`)
"""
- return self is other
+ return clause_compare.compare(self, other, **kw)
def _copy_internals(self, clone=_clone, **kw):
"""Reassign internal elements to be clones of themselves.
selectable._columns[key] = co
return co
- def compare(self, other, use_proxies=False, equivalents=None, **kw):
- """Compare this ColumnElement to another.
-
- Special arguments understood:
-
- :param use_proxies: when True, consider two columns that
- share a common base column as equivalent (i.e. shares_lineage())
-
- :param equivalents: a dictionary of columns as keys mapped to sets
- of columns. If the given "other" column is present in this
- dictionary, if any of the columns in the corresponding set() pass
- the comparison test, the result is True. This is used to expand the
- comparison to other columns that may be known to be equivalent to
- this one via foreign key or other criterion.
-
- """
- to_compare = (other,)
- if equivalents and other in equivalents:
- to_compare = equivalents[other].union(to_compare)
-
- for oth in to_compare:
- if use_proxies and self.shares_lineage(oth):
- return True
- elif hash(oth) == hash(self):
- return True
- else:
- return False
-
def cast(self, type_):
"""Produce a type cast, i.e. ``CAST(<expression> AS <type>)``.
"%%(%d %s)s" % (id(self), self._orig_key or "param")
)
- def compare(self, other, **kw):
- """Compare this :class:`BindParameter` to the given
- clause."""
-
- return (
- isinstance(other, BindParameter)
- and self.type._compare_type_affinity(other.type)
- and self.value == other.value
- and self.callable == other.callable
- )
-
def __getstate__(self):
"""execute a deferred value for serialization purposes."""
def get_children(self, **kwargs):
return list(self._bindparams.values())
- def compare(self, other):
- return isinstance(other, TextClause) and other.text == self.text
-
class Null(ColumnElement):
"""Represent the NULL keyword in a SQL statement.
return Null()
- def compare(self, other):
- return isinstance(other, Null)
-
class False_(ColumnElement):
"""Represent the ``false`` keyword, or equivalent, in a SQL statement.
return False_()
- def compare(self, other):
- return isinstance(other, False_)
-
class True_(ColumnElement):
"""Represent the ``true`` keyword, or equivalent, in a SQL statement.
return True_()
- def compare(self, other):
- return isinstance(other, True_)
-
class ClauseList(ClauseElement):
"""Describe a list of clauses, separated by an operator.
else:
return self
- def compare(self, other, **kw):
- """Compare this :class:`.ClauseList` to the given :class:`.ClauseList`,
- including a comparison of all the clause items.
-
- """
- if not isinstance(other, ClauseList) and len(self.clauses) == 1:
- return self.clauses[0].compare(other, **kw)
- elif (
- isinstance(other, ClauseList)
- and len(self.clauses) == len(other.clauses)
- and self.operator is other.operator
- ):
-
- if self.operator in (operators.and_, operators.or_):
- completed = set()
- for clause in self.clauses:
- for other_clause in set(other.clauses).difference(
- completed
- ):
- if clause.compare(other_clause, **kw):
- completed.add(other_clause)
- break
- return len(completed) == len(other.clauses)
- else:
- for i in range(0, len(self.clauses)):
- if not self.clauses[i].compare(other.clauses[i], **kw):
- return False
- else:
- return True
- else:
- return False
-
class BooleanClauseList(ClauseList, ColumnElement):
__visit_name__ = "clauselist"
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
+ def get_children(self, **kwargs):
+ return [self.element]
+
@property
def _from_objects(self):
return ()
def get_children(self, **kwargs):
return (self.element,)
- def compare(self, other, **kw):
- """Compare this :class:`UnaryExpression` against the given
- :class:`.ClauseElement`."""
-
- return (
- isinstance(other, UnaryExpression)
- and self.operator == other.operator
- and self.modifier == other.modifier
- and self.element.compare(other.element, **kw)
- )
-
def _negate(self):
if self.negate is not None:
return UnaryExpression(
def get_children(self, **kwargs):
return self.left, self.right
- def compare(self, other, **kw):
- """Compare this :class:`BinaryExpression` against the
- given :class:`BinaryExpression`."""
-
- return (
- isinstance(other, BinaryExpression)
- and self.operator == other.operator
- and (
- self.left.compare(other.left, **kw)
- and self.right.compare(other.right, **kw)
- or (
- operators.is_commutative(self.operator)
- and self.left.compare(other.right, **kw)
- and self.right.compare(other.left, **kw)
- )
- )
- )
-
def self_group(self, against=None):
if operators.is_precedent(self.operator, against):
return Grouping(self)
self.element = state["element"]
self.type = state["type"]
- def compare(self, other, **kw):
- return isinstance(other, Grouping) and self.element.compare(
- other.element
- )
-
RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED")
RANGE_CURRENT = util.symbol("RANGE_CURRENT")
def __init__(self, *clauses, **kwargs):
"""Construct a :class:`.FunctionElement`.
"""
- args = [_literal_as_binds(c, self.name) for c in clauses]
+ args = [
+ _literal_as_binds(c, getattr(self, "name", None)) for c in clauses
+ ]
self._has_args = self._has_args or bool(args)
self.clause_expr = ClauseList(
operator=operators.comma_op, group_contents=True, *args
self.left_index = left_index
self.right_index = right_index
- super(FunctionAsBinary, self).__init__(
- left,
- right,
- operators.function_as_comparison_op,
- type_=sqltypes.BOOLEANTYPE,
- )
+ self.operator = operators.function_as_comparison_op
+ self.type = sqltypes.BOOLEANTYPE
+ self.negate = None
+ self._is_implicitly_boolean = True
+ self.modifiers = {}
@property
def left(self):
def right(self, value):
self.sql_function.clauses.clauses[self.right_index - 1] = value
- def _copy_internals(self, **kw):
- clone = kw.pop("clone")
+ def _copy_internals(self, clone=_clone, **kw):
self.sql_function = clone(self.sql_function, **kw)
- super(FunctionAsBinary, self)._copy_internals(**kw)
+
+ def get_children(self, **kw):
+ yield self.sql_function
class _FunctionGenerator(object):
self._bind = kw.get("bind", None)
self.sequence = seq
+ def compare(self, other, **kw):
+ return (
+ isinstance(other, next_value)
+ and self.sequence.name == other.sequence.name
+ )
+
+ def get_children(self, **kwargs):
+ return []
+
+ def _copy_internals(self, **kw):
+ pass
+
@property
def _from_objects(self):
return []
_associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne])
+
+def is_associative(op):
+ return op in _associative
+
+
_natural_self_precedent = _associative.union(
[getitem, json_getitem_op, json_path_getitem_op]
)
and other.of is self.of
)
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
def __hash__(self):
return id(self)
self._reset_exported()
self.element = clone(self.element, **kw)
+ def get_children(self, column_collections=True, **kw):
+ if column_collections:
+ for c in self.column_args:
+ yield c
+ yield self.element
+
def _scalar_type(self):
return self.column_args[0].type
dialect=postgresql.dialect(),
)
+ def test_functions_args_noname(self):
+ class myfunc(FunctionElement):
+ pass
+
+ @compiles(myfunc)
+ def visit_myfunc(element, compiler, **kw):
+ return "myfunc%s" % (compiler.process(element.clause_expr, **kw),)
+
+ self.assert_compile(myfunc(), "myfunc()")
+
+ self.assert_compile(myfunc(column("x"), column("y")), "myfunc(x, y)")
+
def test_function_calls_base(self):
from sqlalchemy.dialects import mssql
--- /dev/null
+import importlib
+import itertools
+
+from sqlalchemy import and_
+from sqlalchemy import Boolean
+from sqlalchemy import case
+from sqlalchemy import cast
+from sqlalchemy import Column
+from sqlalchemy import column
+from sqlalchemy import dialects
+from sqlalchemy import exists
+from sqlalchemy import extract
+from sqlalchemy import Float
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import or_
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import table
+from sqlalchemy import text
+from sqlalchemy import tuple_
+from sqlalchemy import union
+from sqlalchemy import union_all
+from sqlalchemy import util
+from sqlalchemy.schema import Sequence
+from sqlalchemy.sql import bindparam
+from sqlalchemy.sql import ColumnElement
+from sqlalchemy.sql import False_
+from sqlalchemy.sql import func
+from sqlalchemy.sql import operators
+from sqlalchemy.sql import True_
+from sqlalchemy.sql import type_coerce
+from sqlalchemy.sql.elements import _label_reference
+from sqlalchemy.sql.elements import _textual_label_reference
+from sqlalchemy.sql.elements import Annotated
+from sqlalchemy.sql.elements import ClauseElement
+from sqlalchemy.sql.elements import ClauseList
+from sqlalchemy.sql.elements import CollationClause
+from sqlalchemy.sql.elements import Immutable
+from sqlalchemy.sql.elements import Null
+from sqlalchemy.sql.elements import Slice
+from sqlalchemy.sql.elements import UnaryExpression
+from sqlalchemy.sql.functions import FunctionElement
+from sqlalchemy.sql.functions import GenericFunction
+from sqlalchemy.sql.functions import ReturnTypeFromArgs
+from sqlalchemy.sql.selectable import _OffsetLimitParam
+from sqlalchemy.sql.selectable import FromGrouping
+from sqlalchemy.sql.selectable import Selectable
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
+from sqlalchemy.util import class_hierarchy
+
+
+meta = MetaData()
+meta2 = MetaData()
+
+table_a = Table("a", meta, Column("a", Integer), Column("b", String))
+table_a_2 = Table("a", meta2, Column("a", Integer), Column("b", String))
+
+table_b = Table("b", meta, Column("a", Integer), Column("b", Integer))
+
+table_c = Table("c", meta, Column("x", Integer), Column("y", Integer))
+
+table_d = Table("d", meta, Column("y", Integer), Column("z", Integer))
+
+
+class CompareAndCopyTest(fixtures.TestBase):
+
+ # lambdas which return a tuple of ColumnElement objects.
+ # must return at least two objects that should compare differently.
+ # to test more varieties of "difference" additional objects can be added.
+ fixtures = [
+ lambda: (
+ column("q"),
+ column("x"),
+ column("q", Integer),
+ column("q", String),
+ ),
+ lambda: (~column("q", Boolean), ~column("p", Boolean)),
+ lambda: (
+ table_a.c.a.label("foo"),
+ table_a.c.a.label("bar"),
+ table_a.c.b.label("foo"),
+ ),
+ lambda: (
+ _label_reference(table_a.c.a.desc()),
+ _label_reference(table_a.c.a.asc()),
+ ),
+ lambda: (_textual_label_reference("a"), _textual_label_reference("b")),
+ lambda: (
+ text("select a, b from table").columns(a=Integer, b=String),
+ text("select a, b, c from table").columns(
+ a=Integer, b=String, c=Integer
+ ),
+ ),
+ lambda: (
+ column("q") == column("x"),
+ column("q") == column("y"),
+ column("z") == column("x"),
+ ),
+ lambda: (
+ cast(column("q"), Integer),
+ cast(column("q"), Float),
+ cast(column("p"), Integer),
+ ),
+ lambda: (
+ bindparam("x"),
+ bindparam("y"),
+ bindparam("x", type_=Integer),
+ bindparam("x", type_=String),
+ bindparam(None),
+ ),
+ lambda: (_OffsetLimitParam("x"), _OffsetLimitParam("y")),
+ lambda: (func.foo(), func.foo(5), func.bar()),
+ lambda: (func.current_date(), func.current_time()),
+ lambda: (
+ func.next_value(Sequence("q")),
+ func.next_value(Sequence("p")),
+ ),
+ lambda: (True_(), False_()),
+ lambda: (Null(),),
+ lambda: (ReturnTypeFromArgs("foo"), ReturnTypeFromArgs(5)),
+ lambda: (FunctionElement(5), FunctionElement(5, 6)),
+ lambda: (func.count(), func.not_count()),
+ lambda: (func.char_length("abc"), func.char_length("def")),
+ lambda: (GenericFunction("a", "b"), GenericFunction("a")),
+ lambda: (CollationClause("foobar"), CollationClause("batbar")),
+ lambda: (
+ type_coerce(column("q", Integer), String),
+ type_coerce(column("q", Integer), Float),
+ type_coerce(column("z", Integer), Float),
+ ),
+ lambda: (table_a.c.a, table_b.c.a),
+ lambda: (tuple_([1, 2]), tuple_([3, 4])),
+ lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])),
+ lambda: (
+ func.percentile_cont(0.5).within_group(table_a.c.a),
+ func.percentile_cont(0.5).within_group(table_a.c.b),
+ func.percentile_cont(0.5).within_group(table_a.c.a, table_a.c.b),
+ func.percentile_cont(0.5).within_group(
+ table_a.c.a, table_a.c.b, column("q")
+ ),
+ ),
+ lambda: (
+ func.is_equal("a", "b").as_comparison(1, 2),
+ func.is_equal("a", "c").as_comparison(1, 2),
+ func.is_equal("a", "b").as_comparison(2, 1),
+ func.is_equal("a", "b", "c").as_comparison(1, 2),
+ func.foobar("a", "b").as_comparison(1, 2),
+ ),
+ lambda: (
+ func.row_number().over(order_by=table_a.c.a),
+ func.row_number().over(order_by=table_a.c.a, range_=(0, 10)),
+ func.row_number().over(order_by=table_a.c.a, range_=(None, 10)),
+ func.row_number().over(order_by=table_a.c.a, rows=(None, 20)),
+ func.row_number().over(order_by=table_a.c.b),
+ func.row_number().over(
+ order_by=table_a.c.a, partition_by=table_a.c.b
+ ),
+ ),
+ lambda: (
+ func.count(1).filter(table_a.c.a == 5),
+ func.count(1).filter(table_a.c.a == 10),
+ func.foob(1).filter(table_a.c.a == 10),
+ ),
+ lambda: (
+ and_(table_a.c.a == 5, table_a.c.b == table_b.c.a),
+ and_(table_a.c.a == 5, table_a.c.a == table_b.c.a),
+ or_(table_a.c.a == 5, table_a.c.b == table_b.c.a),
+ ClauseList(table_a.c.a == 5, table_a.c.b == table_b.c.a),
+ ClauseList(table_a.c.a == 5, table_a.c.b == table_a.c.a),
+ ),
+ lambda: (
+ case(whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)]),
+ case(whens=[(table_a.c.a == 18, 10), (table_a.c.a == 10, 20)]),
+ case(whens=[(table_a.c.a == 5, 10), (table_a.c.b == 10, 20)]),
+ case(
+ whens=[
+ (table_a.c.a == 5, 10),
+ (table_a.c.b == 10, 20),
+ (table_a.c.a == 9, 12),
+ ]
+ ),
+ case(
+ whens=[(table_a.c.a == 5, 10), (table_a.c.a == 10, 20)],
+ else_=30,
+ ),
+ case({"wendy": "W", "jack": "J"}, value=table_a.c.a, else_="E"),
+ case({"wendy": "W", "jack": "J"}, value=table_a.c.b, else_="E"),
+ case({"wendy_w": "W", "jack": "J"}, value=table_a.c.a, else_="E"),
+ ),
+ lambda: (
+ extract("foo", table_a.c.a),
+ extract("foo", table_a.c.b),
+ extract("bar", table_a.c.a),
+ ),
+ lambda: (
+ Slice(1, 2, 5),
+ Slice(1, 5, 5),
+ Slice(1, 5, 10),
+ Slice(2, 10, 15),
+ ),
+ lambda: (
+ select([table_a.c.a]),
+ select([table_a.c.a, table_a.c.b]),
+ select([table_a.c.b, table_a.c.a]),
+ select([table_a.c.a]).where(table_a.c.b == 5),
+ select([table_a.c.a])
+ .where(table_a.c.b == 5)
+ .where(table_a.c.a == 10),
+ select([table_a.c.a]).where(table_a.c.b == 5).with_for_update(),
+ select([table_a.c.a])
+ .where(table_a.c.b == 5)
+ .with_for_update(nowait=True),
+ select([table_a.c.a]).where(table_a.c.b == 5).correlate(table_b),
+ select([table_a.c.a])
+ .where(table_a.c.b == 5)
+ .correlate_except(table_b),
+ ),
+ lambda: (
+ table_a.join(table_b, table_a.c.a == table_b.c.a),
+ table_a.join(
+ table_b, and_(table_a.c.a == table_b.c.a, table_a.c.b == 1)
+ ),
+ table_a.outerjoin(table_b, table_a.c.a == table_b.c.a),
+ ),
+ lambda: (
+ table_a.alias("a"),
+ table_a.alias("b"),
+ table_a.alias(),
+ table_b.alias("a"),
+ select([table_a.c.a]).alias("a"),
+ ),
+ lambda: (
+ FromGrouping(table_a.alias("a")),
+ FromGrouping(table_a.alias("b")),
+ ),
+ lambda: (
+ select([table_a.c.a]).as_scalar(),
+ select([table_a.c.a]).where(table_a.c.b == 5).as_scalar(),
+ ),
+ lambda: (
+ exists().where(table_a.c.a == 5),
+ exists().where(table_a.c.b == 5),
+ ),
+ lambda: (
+ union(select([table_a.c.a]), select([table_a.c.b])),
+ union(select([table_a.c.a]), select([table_a.c.b])).order_by("a"),
+ union_all(select([table_a.c.a]), select([table_a.c.b])),
+ union(select([table_a.c.a])),
+ union(
+ select([table_a.c.a]),
+ select([table_a.c.b]).where(table_a.c.b > 5),
+ ),
+ ),
+ lambda: (
+ table("a", column("x"), column("y")),
+ table("a", column("y"), column("x")),
+ table("b", column("x"), column("y")),
+ table("a", column("x"), column("y"), column("z")),
+ table("a", column("x"), column("y", Integer)),
+ table("a", column("q"), column("y", Integer)),
+ ),
+ lambda: (
+ Table("a", MetaData(), Column("q", Integer), Column("b", String)),
+ Table("b", MetaData(), Column("q", Integer), Column("b", String)),
+ ),
+ ]
+
+ @classmethod
+ def setup_class(cls):
+ # TODO: we need to get dialects here somehow, perhaps in test_suite?
+ [
+ importlib.import_module("sqlalchemy.dialects.%s" % d)
+ for d in dialects.__all__
+ if not d.startswith("_")
+ ]
+
+ def test_all_present(self):
+ need = set(
+ cls
+ for cls in class_hierarchy(ClauseElement)
+ if issubclass(cls, (ColumnElement, Selectable))
+ and "__init__" in cls.__dict__
+ and not issubclass(cls, (Annotated))
+ and "orm" not in cls.__module__
+ and "crud" not in cls.__module__
+ and "dialects" not in cls.__module__ # TODO: dialects?
+ ).difference({ColumnElement, UnaryExpression})
+ for fixture in self.fixtures:
+ case_a = fixture()
+ for elem in case_a:
+ for mro in type(elem).__mro__:
+ need.discard(mro)
+
+ is_false(bool(need), "%d Remaining classes: %r" % (len(need), need))
+
+ def test_compare(self):
+ for fixture in self.fixtures:
+ case_a = fixture()
+ case_b = fixture()
+
+ for a, b in itertools.combinations_with_replacement(
+ range(len(case_a)), 2
+ ):
+ if a == b:
+ is_true(
+ case_a[a].compare(
+ case_b[b], arbitrary_expression=True
+ ),
+ "%r != %r" % (case_a[a], case_b[b]),
+ )
+
+ else:
+ is_false(
+ case_a[a].compare(
+ case_b[b], arbitrary_expression=True
+ ),
+ "%r == %r" % (case_a[a], case_b[b]),
+ )
+
+ def test_compare_col_identity(self):
+ stmt1 = (
+ select([table_a.c.a, table_b.c.b])
+ .where(table_a.c.a == table_b.c.b)
+ .alias()
+ )
+ stmt1_c = (
+ select([table_a.c.a, table_b.c.b])
+ .where(table_a.c.a == table_b.c.b)
+ .alias()
+ )
+
+ stmt2 = union(select([table_a]), select([table_b]))
+
+ stmt3 = select([table_b])
+
+ equivalents = {table_a.c.a: [table_b.c.a]}
+
+ is_false(
+ stmt1.compare(stmt2, use_proxies=True, equivalents=equivalents)
+ )
+
+ is_true(
+ stmt1.compare(stmt1_c, use_proxies=True, equivalents=equivalents)
+ )
+ is_true(
+ (table_a.c.a == table_b.c.b).compare(
+ stmt1.c.a == stmt1.c.b,
+ use_proxies=True,
+ equivalents=equivalents,
+ )
+ )
+
+ def test_copy_internals(self):
+ for fixture in self.fixtures:
+ case_a = fixture()
+ case_b = fixture()
+
+ assert case_a[0].compare(case_b[0])
+
+ clone = case_a[0]._clone()
+ clone._copy_internals()
+
+ assert clone.compare(case_b[0])
+
+ stack = [clone]
+ seen = {clone}
+ found_elements = False
+ while stack:
+ obj = stack.pop(0)
+
+ items = [
+ subelem
+ for key, elem in clone.__dict__.items()
+ if key != "_is_clone_of" and elem is not None
+ for subelem in util.to_list(elem)
+ if (
+ isinstance(subelem, (ColumnElement, ClauseList))
+ and subelem not in seen
+ and not isinstance(subelem, Immutable)
+ and subelem is not case_a[0]
+ )
+ ]
+ stack.extend(items)
+ seen.update(items)
+
+ if obj is not clone:
+ found_elements = True
+ # ensure the element will not compare as true
+ obj.compare = lambda other, **kw: False
+ obj.__visit_name__ = "dont_match"
+
+ if found_elements:
+ assert not clone.compare(case_b[0])
+ assert case_a[0].compare(case_b[0])
+
+
+class CompareClausesTest(fixtures.TestBase):
+ def test_compare_comparison_associative(self):
+
+ l1 = table_c.c.x == table_d.c.y
+ l2 = table_d.c.y == table_c.c.x
+ l3 = table_c.c.x == table_d.c.z
+
+ is_true(l1.compare(l1))
+ is_true(l1.compare(l2))
+ is_false(l1.compare(l3))
+
+ def test_compare_clauselist_associative(self):
+
+ l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z)
+
+ l2 = and_(table_c.c.y == table_d.c.z, table_c.c.x == table_d.c.y)
+
+ l3 = and_(table_c.c.x == table_d.c.z, table_c.c.y == table_d.c.y)
+
+ is_true(l1.compare(l1))
+ is_true(l1.compare(l2))
+ is_false(l1.compare(l3))
+
+ def test_compare_clauselist_not_associative(self):
+
+ l1 = ClauseList(
+ table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.sub
+ )
+
+ l2 = ClauseList(
+ table_d.c.y, table_c.c.x, table_c.c.y, operator=operators.sub
+ )
+
+ is_true(l1.compare(l1))
+ is_false(l1.compare(l2))
+
+ def test_compare_clauselist_assoc_different_operator(self):
+
+ l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z)
+
+ l2 = or_(table_c.c.y == table_d.c.z, table_c.c.x == table_d.c.y)
+
+ is_false(l1.compare(l2))
+
+ def test_compare_clauselist_not_assoc_different_operator(self):
+
+ l1 = ClauseList(
+ table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.sub
+ )
+
+ l2 = ClauseList(
+ table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.div
+ )
+
+ is_false(l1.compare(l2))
+
+ def test_compare_binds(self):
+ b1 = bindparam("foo", type_=Integer())
+ b2 = bindparam("foo", type_=Integer())
+ b3 = bindparam("bar", type_=Integer())
+ b4 = bindparam("foo", type_=String())
+
+ def c1():
+ return 5
+
+ def c2():
+ return 6
+
+ b5 = bindparam("foo", type_=Integer(), callable_=c1)
+ b6 = bindparam("foo", type_=Integer(), callable_=c2)
+ b7 = bindparam("foo", type_=Integer(), callable_=c1)
+
+ b8 = bindparam("foo", type_=Integer, value=5)
+ b9 = bindparam("foo", type_=Integer, value=6)
+
+ is_false(b1.compare(b5))
+ is_true(b5.compare(b7))
+ is_false(b5.compare(b6))
+ is_true(b1.compare(b2))
+
+ # currently not comparing "key", as we often have to compare
+ # anonymous names. however we should really check for that
+ # is_true(b1.compare(b3))
+
+ is_false(b1.compare(b4))
+ is_false(b1.compare(b8))
+ is_false(b8.compare(b9))
+ is_true(b8.compare(b8))
+
+ def test_compare_tables(self):
+ is_true(table_a.compare(table_a_2))
+
+ # the "proxy" version compares schema tables on metadata identity
+ is_false(table_a.compare(table_a_2, use_proxies=True))
+
+ # same for lower case tables since it compares lower case columns
+ # using proxies, which makes it very unlikely to have multiple
+ # table() objects with columns that compare equally
+ is_false(
+ table("a", column("x", Integer), column("q", String)).compare(
+ table("a", column("x", Integer), column("q", String)),
+ use_proxies=True,
+ )
+ )
from sqlalchemy.sql import table
from sqlalchemy.sql import true
from sqlalchemy.sql.elements import _literal_as_text
+from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.elements import Label
from sqlalchemy.sql.expression import BinaryExpression
from sqlalchemy.sql.expression import ClauseList
assert left.comparator.operate(operators.in_op, [1, 2, 3]).compare(
BinaryExpression(
left,
- Grouping(ClauseList(literal(1), literal(2), literal(3))),
+ Grouping(
+ ClauseList(
+ BindParameter("left", value=1, unique=True),
+ BindParameter("left", value=2, unique=True),
+ BindParameter("left", value=3, unique=True),
+ )
+ ),
operators.in_op,
)
)
assert left.comparator.operate(operators.notin_op, [1, 2, 3]).compare(
BinaryExpression(
left,
- Grouping(ClauseList(literal(1), literal(2), literal(3))),
+ Grouping(
+ ClauseList(
+ BindParameter("left", value=1, unique=True),
+ BindParameter("left", value=2, unique=True),
+ BindParameter("left", value=3, unique=True),
+ )
+ ),
operators.notin_op,
)
)
-from sqlalchemy import and_
-from sqlalchemy import bindparam
-from sqlalchemy import Column
-from sqlalchemy import Integer
-from sqlalchemy import MetaData
-from sqlalchemy import or_
-from sqlalchemy import String
-from sqlalchemy import Table
-from sqlalchemy.sql import operators
from sqlalchemy.sql import util as sql_util
-from sqlalchemy.sql.elements import ClauseList
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
-from sqlalchemy.testing import is_false
-from sqlalchemy.testing import is_true
-
-
-class CompareClausesTest(fixtures.TestBase):
- def setup(self):
- m = MetaData()
- self.a = Table("a", m, Column("x", Integer), Column("y", Integer))
-
- self.b = Table("b", m, Column("y", Integer), Column("z", Integer))
-
- def test_compare_clauselist_associative(self):
-
- l1 = and_(self.a.c.x == self.b.c.y, self.a.c.y == self.b.c.z)
-
- l2 = and_(self.a.c.y == self.b.c.z, self.a.c.x == self.b.c.y)
-
- l3 = and_(self.a.c.x == self.b.c.z, self.a.c.y == self.b.c.y)
-
- is_true(l1.compare(l1))
- is_true(l1.compare(l2))
- is_false(l1.compare(l3))
-
- def test_compare_clauselist_not_associative(self):
-
- l1 = ClauseList(
- self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.sub
- )
-
- l2 = ClauseList(
- self.b.c.y, self.a.c.x, self.a.c.y, operator=operators.sub
- )
-
- is_true(l1.compare(l1))
- is_false(l1.compare(l2))
-
- def test_compare_clauselist_assoc_different_operator(self):
-
- l1 = and_(self.a.c.x == self.b.c.y, self.a.c.y == self.b.c.z)
-
- l2 = or_(self.a.c.y == self.b.c.z, self.a.c.x == self.b.c.y)
-
- is_false(l1.compare(l2))
-
- def test_compare_clauselist_not_assoc_different_operator(self):
-
- l1 = ClauseList(
- self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.sub
- )
-
- l2 = ClauseList(
- self.a.c.x, self.a.c.y, self.b.c.y, operator=operators.div
- )
-
- is_false(l1.compare(l2))
-
- def test_compare_binds(self):
- b1 = bindparam("foo", type_=Integer())
- b2 = bindparam("foo", type_=Integer())
- b3 = bindparam("bar", type_=Integer())
- b4 = bindparam("foo", type_=String())
-
- def c1():
- return 5
-
- def c2():
- return 6
-
- b5 = bindparam("foo", type_=Integer(), callable_=c1)
- b6 = bindparam("foo", type_=Integer(), callable_=c2)
- b7 = bindparam("foo", type_=Integer(), callable_=c1)
-
- b8 = bindparam("foo", type_=Integer, value=5)
- b9 = bindparam("foo", type_=Integer, value=6)
-
- is_false(b1.compare(b5))
- is_true(b5.compare(b7))
- is_false(b5.compare(b6))
- is_true(b1.compare(b2))
-
- # currently not comparing "key", as we often have to compare
- # anonymous names. however we should really check for that
- is_true(b1.compare(b3))
-
- is_false(b1.compare(b4))
- is_false(b1.compare(b8))
- is_false(b8.compare(b9))
- is_true(b8.compare(b8))
class MiscTest(fixtures.TestBase):