From: Mike Bayer Date: Fri, 20 Dec 2019 15:17:17 +0000 (-0500) Subject: Ensure comparison includes "don't compare values" feature X-Git-Tag: rel_1_4_0b1~588 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e6afc0a8cf7a8fb18855cab9da488a0d48c42386;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Ensure comparison includes "don't compare values" feature upcoming changes for "expanding IN in all cases" and "lambda elements" both rely upon comparisons that work across changing bound values, so commit the testing fixture ahead of time. Additionally, repair the feature itself within traversals. Change-Id: Ie65a512dc64745614180da77435f9f745ce78c71 --- diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index b5701dbdf3..84a5623d36 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -743,6 +743,14 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): else: return COMPARE_FAILED + def compare_bindparam(self, left, right, **kw): + compare_values = kw.pop("compare_values", True) + if compare_values: + return [] + else: + # this means, "skip these, we already compared" + return ["callable", "value"] + class ColIdentityComparatorStrategy(TraversalComparatorStrategy): def compare_column_element( diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py index 79a94848ea..72a1f4c8ee 100644 --- a/test/orm/test_cache_key.py +++ b/test/orm/test_cache_key.py @@ -23,7 +23,8 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): User, Address, Keyword = self.classes("User", "Address", "Keyword") self._run_cache_key_fixture( - lambda: (inspect(User), inspect(Address), inspect(aliased(User))) + lambda: (inspect(User), inspect(Address), inspect(aliased(User))), + compare_values=True, ) def test_attributes(self): @@ -40,7 +41,8 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): User.addresses, Address.email_address, aliased(User).addresses, - ) + ), + compare_values=True, ) def test_unbound_options(self): @@ -68,7 +70,8 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): .defer(Item.description), defaultload(User.orders).defaultload(Order.items), defaultload(User.orders), - ) + ), + compare_values=True, ) def test_bound_options(self): @@ -94,7 +97,8 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): .defer(Item.description), Load(User).defaultload(User.orders).defaultload(Order.items), Load(User).defaultload(User.orders), - ) + ), + compare_values=True, ) def test_bound_options_equiv_on_strname(self): diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 520133272f..f8fc43ba54 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -1,5 +1,6 @@ import importlib import itertools +import random from sqlalchemy import and_ from sqlalchemy import Boolean @@ -55,6 +56,7 @@ from sqlalchemy.sql.visitors import InternalTraversal from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_not_ from sqlalchemy.testing import is_true from sqlalchemy.testing import ne_ from sqlalchemy.util import class_hierarchy @@ -121,13 +123,13 @@ class CoreFixtures(object): a=Integer, b=String, c=Integer ), text("select a, b, c from table where foo=:bar").bindparams( - bindparam("bar", Integer) + bindparam("bar", type_=Integer) ), text("select a, b, c from table where foo=:foo").bindparams( - bindparam("foo", Integer) + bindparam("foo", type_=Integer) ), text("select a, b, c from table where foo=:bar").bindparams( - bindparam("bar", String) + bindparam("bar", type_=String) ), ), lambda: ( @@ -138,6 +140,8 @@ class CoreFixtures(object): column("z") - column("x"), column("x") - column("z"), column("z") > column("x"), + column("x").in_([5, 7]), + column("x").in_([10, 7, 8]), # note these two are mathematically equivalent but for now they # are considered to be different column("z") >= column("x"), @@ -195,7 +199,7 @@ class CoreFixtures(object): type_coerce(column("z", Integer), Float), ), lambda: (table_a.c.a, table_b.c.a), - lambda: (tuple_([1, 2]), tuple_([3, 4])), + lambda: (tuple_(1, 2), tuple_(3, 4)), lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])), lambda: ( func.percentile_cont(0.5).within_group(table_a.c.a), @@ -384,6 +388,24 @@ class CoreFixtures(object): lambda: (table_a, table_b), ] + dont_compare_values_fixtures = [ + lambda: ( + # same number of params each time, so compare for IN + # with legacy behavior of bind for each value works + column("x").in_(random.choices(range(10), k=3)), + # expanding IN places the whole list into a single parameter + # so it can be of arbitrary length as well + column("x").in_( + bindparam( + "q", + random.choices(range(10), k=random.randint(0, 7)), + expanding=True, + ) + ), + column("x") == random.randint(1, 10), + ) + ] + def _complex_fixtures(): def one(): a1 = table_a.alias() @@ -439,7 +461,7 @@ class CoreFixtures(object): class CacheKeyFixture(object): - def _run_cache_key_fixture(self, fixture): + def _run_cache_key_fixture(self, fixture, compare_values): case_a = fixture() case_b = fixture() @@ -449,13 +471,18 @@ class CacheKeyFixture(object): if a == b: a_key = case_a[a]._generate_cache_key() b_key = case_b[b]._generate_cache_key() + is_not_(a_key, None) + is_not_(b_key, None) + eq_(a_key.key, b_key.key) eq_(hash(a_key), hash(b_key)) for a_param, b_param in zip( a_key.bindparams, b_key.bindparams ): - assert a_param.compare(b_param, compare_values=False) + assert a_param.compare( + b_param, compare_values=compare_values + ) else: a_key = case_a[a]._generate_cache_key() b_key = case_b[b]._generate_cache_key() @@ -464,7 +491,9 @@ class CacheKeyFixture(object): for a_param, b_param in zip( a_key.bindparams, b_key.bindparams ): - if not a_param.compare(b_param, compare_values=True): + if not a_param.compare( + b_param, compare_values=compare_values + ): break else: # this fails unconditionally since we could not @@ -491,9 +520,10 @@ class CacheKeyFixture(object): ) # note we're asserting the order of the params as well as - # if there are dupes or not. ordering has to be deterministic - # and matches what a traversal would provide. - # regular traverse_depthfirst does produce dupes in cases like + # if there are dupes or not. ordering has to be + # deterministic and matches what a traversal would provide. + # regular traverse_depthfirst does produce dupes in cases + # like # select([some_alias]). # select_from(join(some_alias, other_table)) # where a bound parameter is inside of some_alias. the @@ -514,8 +544,12 @@ class CacheKeyFixture(object): class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase): def test_cache_key(self): - for fixture in self.fixtures: - self._run_cache_key_fixture(fixture) + for fixtures_, compare_values in [ + (self.fixtures, True), + (self.dont_compare_values_fixtures, False), + ]: + for fixture in fixtures_: + self._run_cache_key_fixture(fixture, compare_values) def test_cache_key_unknown_traverse(self): class Foobar1(ClauseElement): @@ -602,7 +636,8 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): and "crud" not in cls.__module__ and "dialects" not in cls.__module__ # TODO: dialects? ).difference({ColumnElement, UnaryExpression}) - for fixture in self.fixtures: + + for fixture in self.fixtures + self.dont_compare_values_fixtures: case_a = fixture() for elem in case_a: for mro in type(elem).__mro__: @@ -610,25 +645,37 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): is_false(bool(need), "%d Remaining classes: %r" % (len(need), need)) - def test_compare(self): - for fixture in self.fixtures: - case_a = fixture() - case_b = fixture() - - for a, b in itertools.combinations_with_replacement( - range(len(case_a)), 2 - ): - if a == b: - is_true( - case_a[a].compare(case_b[b], compare_annotations=True), - "%r != %r" % (case_a[a], case_b[b]), - ) + def test_compare_labels(self): + for fixtures_, compare_values in [ + (self.fixtures, True), + (self.dont_compare_values_fixtures, False), + ]: + for fixture in fixtures_: + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + if a == b: + is_true( + case_a[a].compare( + case_b[b], + compare_annotations=True, + compare_values=compare_values, + ), + "%r != %r" % (case_a[a], case_b[b]), + ) - else: - is_false( - case_a[a].compare(case_b[b], compare_annotations=True), - "%r == %r" % (case_a[a], case_b[b]), - ) + else: + is_false( + case_a[a].compare( + case_b[b], + compare_annotations=True, + compare_values=compare_values, + ), + "%r == %r" % (case_a[a], case_b[b]), + ) def test_compare_col_identity(self): stmt1 = ( @@ -662,48 +709,58 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): ) def test_copy_internals(self): - for fixture in self.fixtures: - case_a = fixture() - case_b = fixture() - - assert case_a[0].compare(case_b[0]) + for fixtures_, compare_values in [ + (self.fixtures, True), + (self.dont_compare_values_fixtures, False), + ]: + for fixture in fixtures_: + case_a = fixture() + case_b = fixture() + + assert case_a[0].compare( + case_b[0], compare_values=compare_values + ) - clone = visitors.replacement_traverse( - case_a[0], {}, lambda elem: None - ) + clone = visitors.replacement_traverse( + case_a[0], {}, lambda elem: None + ) - assert clone.compare(case_b[0]) - - stack = [clone] - seen = {clone} - found_elements = False - while stack: - obj = stack.pop(0) - - items = [ - subelem - for key, elem in clone.__dict__.items() - if key != "_is_clone_of" and elem is not None - for subelem in util.to_list(elem) - if ( - isinstance(subelem, (ColumnElement, ClauseList)) - and subelem not in seen - and not isinstance(subelem, Immutable) - and subelem is not case_a[0] + assert clone.compare(case_b[0], compare_values=compare_values) + + stack = [clone] + seen = {clone} + found_elements = False + while stack: + obj = stack.pop(0) + + items = [ + subelem + for key, elem in clone.__dict__.items() + if key != "_is_clone_of" and elem is not None + for subelem in util.to_list(elem) + if ( + isinstance(subelem, (ColumnElement, ClauseList)) + and subelem not in seen + and not isinstance(subelem, Immutable) + and subelem is not case_a[0] + ) + ] + stack.extend(items) + seen.update(items) + + if obj is not clone: + found_elements = True + # ensure the element will not compare as true + obj.compare = lambda other, **kw: False + obj.__visit_name__ = "dont_match" + + if found_elements: + assert not clone.compare( + case_b[0], compare_values=compare_values ) - ] - stack.extend(items) - seen.update(items) - - if obj is not clone: - found_elements = True - # ensure the element will not compare as true - obj.compare = lambda other, **kw: False - obj.__visit_name__ = "dont_match" - - if found_elements: - assert not clone.compare(case_b[0]) - assert case_a[0].compare(case_b[0]) + assert case_a[0].compare( + case_b[0], compare_values=compare_values + ) class CompareClausesTest(fixtures.TestBase):