]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure comparison includes "don't compare values" feature
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 20 Dec 2019 15:17:17 +0000 (10:17 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 20 Dec 2019 15:17:17 +0000 (10:17 -0500)
upcoming changes for "expanding IN in all cases" and
"lambda elements" both rely upon comparisons that work
across changing bound values, so commit the testing fixture
ahead of time.   Additionally, repair the feature itself
within traversals.

Change-Id: Ie65a512dc64745614180da77435f9f745ce78c71

lib/sqlalchemy/sql/traversals.py
test/orm/test_cache_key.py
test/sql/test_compare.py

index b5701dbdf385d6810745e43278fbe7baaf79125e..84a5623d36bb8daa7a46c34b5388daa2ae8a80d8 100644 (file)
@@ -743,6 +743,14 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
         else:
             return COMPARE_FAILED
 
+    def compare_bindparam(self, left, right, **kw):
+        compare_values = kw.pop("compare_values", True)
+        if compare_values:
+            return []
+        else:
+            # this means, "skip these, we already compared"
+            return ["callable", "value"]
+
 
 class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
     def compare_column_element(
index 79a94848ea9a33d8d552cfdae339673b21c6baa0..72a1f4c8ee86fb60da53843d3f4a6b9b593e6014 100644 (file)
@@ -23,7 +23,8 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
         User, Address, Keyword = self.classes("User", "Address", "Keyword")
 
         self._run_cache_key_fixture(
-            lambda: (inspect(User), inspect(Address), inspect(aliased(User)))
+            lambda: (inspect(User), inspect(Address), inspect(aliased(User))),
+            compare_values=True,
         )
 
     def test_attributes(self):
@@ -40,7 +41,8 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
                 User.addresses,
                 Address.email_address,
                 aliased(User).addresses,
-            )
+            ),
+            compare_values=True,
         )
 
     def test_unbound_options(self):
@@ -68,7 +70,8 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
                 .defer(Item.description),
                 defaultload(User.orders).defaultload(Order.items),
                 defaultload(User.orders),
-            )
+            ),
+            compare_values=True,
         )
 
     def test_bound_options(self):
@@ -94,7 +97,8 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
                 .defer(Item.description),
                 Load(User).defaultload(User.orders).defaultload(Order.items),
                 Load(User).defaultload(User.orders),
-            )
+            ),
+            compare_values=True,
         )
 
     def test_bound_options_equiv_on_strname(self):
index 520133272f1612280dba1136af5e3f9c88685d44..f8fc43ba54dafbf09f6d46e4d3fd0a0363d322a5 100644 (file)
@@ -1,5 +1,6 @@
 import importlib
 import itertools
+import random
 
 from sqlalchemy import and_
 from sqlalchemy import Boolean
@@ -55,6 +56,7 @@ from sqlalchemy.sql.visitors import InternalTraversal
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_not_
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import ne_
 from sqlalchemy.util import class_hierarchy
@@ -121,13 +123,13 @@ class CoreFixtures(object):
                 a=Integer, b=String, c=Integer
             ),
             text("select a, b, c from table where foo=:bar").bindparams(
-                bindparam("bar", Integer)
+                bindparam("bar", type_=Integer)
             ),
             text("select a, b, c from table where foo=:foo").bindparams(
-                bindparam("foo", Integer)
+                bindparam("foo", type_=Integer)
             ),
             text("select a, b, c from table where foo=:bar").bindparams(
-                bindparam("bar", String)
+                bindparam("bar", type_=String)
             ),
         ),
         lambda: (
@@ -138,6 +140,8 @@ class CoreFixtures(object):
             column("z") - column("x"),
             column("x") - column("z"),
             column("z") > column("x"),
+            column("x").in_([5, 7]),
+            column("x").in_([10, 7, 8]),
             # note these two are mathematically equivalent but for now they
             # are considered to be different
             column("z") >= column("x"),
@@ -195,7 +199,7 @@ class CoreFixtures(object):
             type_coerce(column("z", Integer), Float),
         ),
         lambda: (table_a.c.a, table_b.c.a),
-        lambda: (tuple_([1, 2]), tuple_([3, 4])),
+        lambda: (tuple_(1, 2), tuple_(3, 4)),
         lambda: (func.array_agg([1, 2]), func.array_agg([3, 4])),
         lambda: (
             func.percentile_cont(0.5).within_group(table_a.c.a),
@@ -384,6 +388,24 @@ class CoreFixtures(object):
         lambda: (table_a, table_b),
     ]
 
+    dont_compare_values_fixtures = [
+        lambda: (
+            # same number of params each time, so compare for IN
+            # with legacy behavior of bind for each value works
+            column("x").in_(random.choices(range(10), k=3)),
+            # expanding IN places the whole list into a single parameter
+            # so it can be of arbitrary length as well
+            column("x").in_(
+                bindparam(
+                    "q",
+                    random.choices(range(10), k=random.randint(0, 7)),
+                    expanding=True,
+                )
+            ),
+            column("x") == random.randint(1, 10),
+        )
+    ]
+
     def _complex_fixtures():
         def one():
             a1 = table_a.alias()
@@ -439,7 +461,7 @@ class CoreFixtures(object):
 
 
 class CacheKeyFixture(object):
-    def _run_cache_key_fixture(self, fixture):
+    def _run_cache_key_fixture(self, fixture, compare_values):
         case_a = fixture()
         case_b = fixture()
 
@@ -449,13 +471,18 @@ class CacheKeyFixture(object):
             if a == b:
                 a_key = case_a[a]._generate_cache_key()
                 b_key = case_b[b]._generate_cache_key()
+                is_not_(a_key, None)
+                is_not_(b_key, None)
+
                 eq_(a_key.key, b_key.key)
                 eq_(hash(a_key), hash(b_key))
 
                 for a_param, b_param in zip(
                     a_key.bindparams, b_key.bindparams
                 ):
-                    assert a_param.compare(b_param, compare_values=False)
+                    assert a_param.compare(
+                        b_param, compare_values=compare_values
+                    )
             else:
                 a_key = case_a[a]._generate_cache_key()
                 b_key = case_b[b]._generate_cache_key()
@@ -464,7 +491,9 @@ class CacheKeyFixture(object):
                     for a_param, b_param in zip(
                         a_key.bindparams, b_key.bindparams
                     ):
-                        if not a_param.compare(b_param, compare_values=True):
+                        if not a_param.compare(
+                            b_param, compare_values=compare_values
+                        ):
                             break
                     else:
                         # this fails unconditionally since we could not
@@ -491,9 +520,10 @@ class CacheKeyFixture(object):
                 )
 
                 # note we're asserting the order of the params as well as
-                # if there are dupes or not.  ordering has to be deterministic
-                # and matches what a traversal would provide.
-                # regular traverse_depthfirst does produce dupes in cases like
+                # if there are dupes or not.  ordering has to be
+                # deterministic and matches what a traversal would provide.
+                # regular traverse_depthfirst does produce dupes in cases
+                # like
                 # select([some_alias]).
                 #    select_from(join(some_alias, other_table))
                 # where a bound parameter is inside of some_alias.  the
@@ -514,8 +544,12 @@ class CacheKeyFixture(object):
 
 class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
     def test_cache_key(self):
-        for fixture in self.fixtures:
-            self._run_cache_key_fixture(fixture)
+        for fixtures_, compare_values in [
+            (self.fixtures, True),
+            (self.dont_compare_values_fixtures, False),
+        ]:
+            for fixture in fixtures_:
+                self._run_cache_key_fixture(fixture, compare_values)
 
     def test_cache_key_unknown_traverse(self):
         class Foobar1(ClauseElement):
@@ -602,7 +636,8 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
             and "crud" not in cls.__module__
             and "dialects" not in cls.__module__  # TODO: dialects?
         ).difference({ColumnElement, UnaryExpression})
-        for fixture in self.fixtures:
+
+        for fixture in self.fixtures + self.dont_compare_values_fixtures:
             case_a = fixture()
             for elem in case_a:
                 for mro in type(elem).__mro__:
@@ -610,25 +645,37 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
 
         is_false(bool(need), "%d Remaining classes: %r" % (len(need), need))
 
-    def test_compare(self):
-        for fixture in self.fixtures:
-            case_a = fixture()
-            case_b = fixture()
-
-            for a, b in itertools.combinations_with_replacement(
-                range(len(case_a)), 2
-            ):
-                if a == b:
-                    is_true(
-                        case_a[a].compare(case_b[b], compare_annotations=True),
-                        "%r != %r" % (case_a[a], case_b[b]),
-                    )
+    def test_compare_labels(self):
+        for fixtures_, compare_values in [
+            (self.fixtures, True),
+            (self.dont_compare_values_fixtures, False),
+        ]:
+            for fixture in fixtures_:
+                case_a = fixture()
+                case_b = fixture()
+
+                for a, b in itertools.combinations_with_replacement(
+                    range(len(case_a)), 2
+                ):
+                    if a == b:
+                        is_true(
+                            case_a[a].compare(
+                                case_b[b],
+                                compare_annotations=True,
+                                compare_values=compare_values,
+                            ),
+                            "%r != %r" % (case_a[a], case_b[b]),
+                        )
 
-                else:
-                    is_false(
-                        case_a[a].compare(case_b[b], compare_annotations=True),
-                        "%r == %r" % (case_a[a], case_b[b]),
-                    )
+                    else:
+                        is_false(
+                            case_a[a].compare(
+                                case_b[b],
+                                compare_annotations=True,
+                                compare_values=compare_values,
+                            ),
+                            "%r == %r" % (case_a[a], case_b[b]),
+                        )
 
     def test_compare_col_identity(self):
         stmt1 = (
@@ -662,48 +709,58 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
         )
 
     def test_copy_internals(self):
-        for fixture in self.fixtures:
-            case_a = fixture()
-            case_b = fixture()
-
-            assert case_a[0].compare(case_b[0])
+        for fixtures_, compare_values in [
+            (self.fixtures, True),
+            (self.dont_compare_values_fixtures, False),
+        ]:
+            for fixture in fixtures_:
+                case_a = fixture()
+                case_b = fixture()
+
+                assert case_a[0].compare(
+                    case_b[0], compare_values=compare_values
+                )
 
-            clone = visitors.replacement_traverse(
-                case_a[0], {}, lambda elem: None
-            )
+                clone = visitors.replacement_traverse(
+                    case_a[0], {}, lambda elem: None
+                )
 
-            assert clone.compare(case_b[0])
-
-            stack = [clone]
-            seen = {clone}
-            found_elements = False
-            while stack:
-                obj = stack.pop(0)
-
-                items = [
-                    subelem
-                    for key, elem in clone.__dict__.items()
-                    if key != "_is_clone_of" and elem is not None
-                    for subelem in util.to_list(elem)
-                    if (
-                        isinstance(subelem, (ColumnElement, ClauseList))
-                        and subelem not in seen
-                        and not isinstance(subelem, Immutable)
-                        and subelem is not case_a[0]
+                assert clone.compare(case_b[0], compare_values=compare_values)
+
+                stack = [clone]
+                seen = {clone}
+                found_elements = False
+                while stack:
+                    obj = stack.pop(0)
+
+                    items = [
+                        subelem
+                        for key, elem in clone.__dict__.items()
+                        if key != "_is_clone_of" and elem is not None
+                        for subelem in util.to_list(elem)
+                        if (
+                            isinstance(subelem, (ColumnElement, ClauseList))
+                            and subelem not in seen
+                            and not isinstance(subelem, Immutable)
+                            and subelem is not case_a[0]
+                        )
+                    ]
+                    stack.extend(items)
+                    seen.update(items)
+
+                    if obj is not clone:
+                        found_elements = True
+                        # ensure the element will not compare as true
+                        obj.compare = lambda other, **kw: False
+                        obj.__visit_name__ = "dont_match"
+
+                if found_elements:
+                    assert not clone.compare(
+                        case_b[0], compare_values=compare_values
                     )
-                ]
-                stack.extend(items)
-                seen.update(items)
-
-                if obj is not clone:
-                    found_elements = True
-                    # ensure the element will not compare as true
-                    obj.compare = lambda other, **kw: False
-                    obj.__visit_name__ = "dont_match"
-
-            if found_elements:
-                assert not clone.compare(case_b[0])
-            assert case_a[0].compare(case_b[0])
+                assert case_a[0].compare(
+                    case_b[0], compare_values=compare_values
+                )
 
 
 class CompareClausesTest(fixtures.TestBase):