]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Repair caching / traversals for values
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Apr 2020 22:31:16 +0000 (18:31 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Apr 2020 23:25:23 +0000 (19:25 -0400)
The test suite wasn't running the copy_internals most fixtures,
enable that and try to get all cases working.

Set up selectable.values to do tuple conversion within compilation
step.  at the same time, disable caching for selectable.values
for the moment and make it equivalent to dml_multi_values.

fix cache / compare / copy cases for dml_values and dml_multi_values
which weren't fully tested or covered.

Change-Id: I484ca6e9cb2b66c2e6a321698f2abc0838db1460

lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/traversals.py
test/sql/test_compare.py

index 799fca2f583441b924b44ed1260acd934ced7ce9..b93ed88905586ebcfd3ed7df321cacc5d8373d0f 100644 (file)
@@ -2324,9 +2324,14 @@ class SQLCompiler(Compiled):
         return text
 
     def visit_values(self, element, asfrom=False, from_linter=None, **kw):
+
         v = "VALUES %s" % ", ".join(
-            self.process(elem, literal_binds=element.literal_binds)
-            for elem in element._data
+            self.process(
+                elements.Tuple(*elem).self_group(),
+                literal_binds=element.literal_binds,
+            )
+            for chunk in element._data
+            for elem in chunk
         )
 
         if isinstance(element.name, elements._truncated_label):
index e39d61fdbb4602afdf07a9511833959edc280e3a..a0df45b52746d87f1d3dca8c95f383b118aa1389 100644 (file)
@@ -47,7 +47,6 @@ from .elements import ColumnClause
 from .elements import GroupedElement
 from .elements import Grouping
 from .elements import literal_column
-from .elements import Tuple
 from .elements import UnaryExpression
 from .visitors import InternalTraversal
 from .. import exc
@@ -1264,14 +1263,16 @@ class AliasedReturnsRows(NoInit, FromClause):
         self.element._generate_fromclause_column_proxies(self)
 
     def _copy_internals(self, clone=_clone, **kw):
-        element = clone(self.element, **kw)
+        existing_element = self.element
+
+        super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw)
 
         # the element clone is usually against a Table that returns the
         # same object.  don't reset exported .c. collections and other
-        # memoized details if nothing changed
-        if element is not self.element:
+        # memoized details if it was not changed.  this saves a lot on
+        # performance.
+        if existing_element is not self.element:
             self._reset_column_collection()
-            self.element = element
 
     @property
     def _from_objects(self):
@@ -1528,15 +1529,6 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
             self._suffixes = _suffixes
         super(CTE, self)._init(selectable, name=name)
 
-    def _copy_internals(self, clone=_clone, **kw):
-        super(CTE, self)._copy_internals(clone, **kw)
-        # TODO: I don't like that we can't use the traversal data here
-        if self._cte_alias is not None:
-            self._cte_alias = clone(self._cte_alias, **kw)
-        self._restates = frozenset(
-            [clone(elem, **kw) for elem in self._restates]
-        )
-
     def alias(self, name=None, flat=False):
         """Return an :class:`.Alias` of this :class:`.CTE`.
 
@@ -2064,7 +2056,7 @@ class Values(Generative, FromClause):
 
     _traverse_internals = [
         ("_column_args", InternalTraversal.dp_clauseelement_list,),
-        ("_data", InternalTraversal.dp_clauseelement_list),
+        ("_data", InternalTraversal.dp_dml_multi_values),
         ("name", InternalTraversal.dp_string),
         ("literal_binds", InternalTraversal.dp_boolean),
     ]
@@ -2155,7 +2147,7 @@ class Values(Generative, FromClause):
 
         """
 
-        self._data += tuple(Tuple(*row).self_group() for row in values)
+        self._data += (values,)
 
     def _populate_column_collection(self):
         for c in self._column_args:
index 9ac6cda978025e689358753e5c09f84956fb3b95..032488826dcbc5f0d3ea38d5fdea588db81a0ace 100644 (file)
@@ -7,6 +7,7 @@ from .visitors import ExtendedInternalTraversal
 from .visitors import InternalTraversal
 from .. import util
 from ..inspection import inspect
+from ..util import collections_abc
 from ..util import HasMemoized
 
 SKIP_TRAVERSE = util.symbol("skip_traverse")
@@ -533,18 +534,12 @@ class _CopyInternals(InternalTraversal):
         ]
 
     def visit_dml_values(self, parent, element, clone=_clone, **kw):
-        # sequence of dictionaries
-        return [
-            {
-                (
-                    clone(key, **kw)
-                    if hasattr(key, "__clause_element__")
-                    else key
-                ): clone(value, **kw)
-                for key, value in sub_element.items()
-            }
-            for sub_element in element
-        ]
+        return {
+            (
+                clone(key, **kw) if hasattr(key, "__clause_element__") else key
+            ): clone(value, **kw)
+            for key, value in element.items()
+        }
 
     def visit_dml_multi_values(self, parent, element, clone=_clone, **kw):
         # sequence of sequences, each sequence contains a list/dict/tuple
@@ -552,15 +547,10 @@ class _CopyInternals(InternalTraversal):
         def copy(elem):
             if isinstance(elem, (list, tuple)):
                 return [
-                    (
-                        clone(key, **kw)
-                        if hasattr(key, "__clause_element__")
-                        else key,
-                        clone(value, **kw)
-                        if hasattr(value, "__clause_element__")
-                        else value,
-                    )
-                    for key, value in elem
+                    clone(value, **kw)
+                    if hasattr(value, "__clause_element__")
+                    else value
+                    for value in elem
                 ]
             elif isinstance(elem, dict):
                 return {
@@ -573,7 +563,7 @@ class _CopyInternals(InternalTraversal):
                         if hasattr(value, "__clause_element__")
                         else value
                     )
-                    for key, value in elem
+                    for key, value in elem.items()
                 }
             else:
                 # TODO: use abc classes
@@ -939,30 +929,41 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
         for (lk, lv), (rk, rv) in util.zip_longest(
             left, right, fillvalue=(None, None)
         ):
-            lkce = hasattr(lk, "__clause_element__")
-            rkce = hasattr(rk, "__clause_element__")
-            if lkce != rkce:
-                return COMPARE_FAILED
-            elif lkce and not self.compare_inner(lk, rk, **kw):
-                return COMPARE_FAILED
-            elif not lkce and lk != rk:
-                return COMPARE_FAILED
-            elif not self.compare_inner(lv, rv, **kw):
+            if not self._compare_dml_values_or_ce(lk, rk, **kw):
                 return COMPARE_FAILED
 
+    def _compare_dml_values_or_ce(self, lv, rv, **kw):
+        lvce = hasattr(lv, "__clause_element__")
+        rvce = hasattr(rv, "__clause_element__")
+        if lvce != rvce:
+            return False
+        elif lvce and not self.compare_inner(lv, rv, **kw):
+            return False
+        elif not lvce and lv != rv:
+            return False
+        elif not self.compare_inner(lv, rv, **kw):
+            return False
+
+        return True
+
     def visit_dml_values(self, left_parent, left, right_parent, right, **kw):
         if left is None or right is None or len(left) != len(right):
             return COMPARE_FAILED
 
-        for lk in left:
-            lv = left[lk]
+        if isinstance(left, collections_abc.Sequence):
+            for lv, rv in zip(left, right):
+                if not self._compare_dml_values_or_ce(lv, rv, **kw):
+                    return COMPARE_FAILED
+        else:
+            for lk in left:
+                lv = left[lk]
 
-            if lk not in right:
-                return COMPARE_FAILED
-            rv = right[lk]
+                if lk not in right:
+                    return COMPARE_FAILED
+                rv = right[lk]
 
-            if not self.compare_inner(lv, rv, **kw):
-                return COMPARE_FAILED
+                if not self._compare_dml_values_or_ce(lv, rv, **kw):
+                    return COMPARE_FAILED
 
     def visit_dml_multi_values(
         self, left_parent, left, right_parent, right, **kw
index 2800f8248c823ca653226b62cd4b8cb62b6517fc..3a6feac018e48baf5e642c6ae108ddf4da706852 100644 (file)
@@ -55,11 +55,13 @@ from sqlalchemy.sql.functions import ReturnTypeFromArgs
 from sqlalchemy.sql.selectable import _OffsetLimitParam
 from sqlalchemy.sql.selectable import AliasedReturnsRows
 from sqlalchemy.sql.selectable import FromGrouping
+from sqlalchemy.sql.selectable import Select
 from sqlalchemy.sql.selectable import Selectable
 from sqlalchemy.sql.selectable import SelectStatementGrouping
 from sqlalchemy.sql.visitors import InternalTraversal
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_not_
 from sqlalchemy.testing import is_true
@@ -372,6 +374,15 @@ class CoreFixtures(object):
             table_b.insert()
             .values([{"a": 5, "b": 10}, {"a": 8, "b": 12}])
             ._annotate({"nocache": True}),
+            table_b.insert()
+            .values([{"a": 9, "b": 10}, {"a": 8, "b": 7}])
+            ._annotate({"nocache": True}),
+            table_b.insert()
+            .values([(5, 10), (8, 12)])
+            ._annotate({"nocache": True}),
+            table_b.insert()
+            .values([(5, 9), (5, 12)])
+            ._annotate({"nocache": True}),
         ),
         lambda: (
             table_b.update(),
@@ -404,6 +415,51 @@ class CoreFixtures(object):
             table_b.delete().where(table_b.c.a == 5),
             table_b.delete().where(table_b.c.b == 5),
         ),
+        lambda: (
+            values(
+                column("mykey", Integer),
+                column("mytext", String),
+                column("myint", Integer),
+                name="myvalues",
+            )
+            .data([(1, "textA", 99), (2, "textB", 88)])
+            ._annotate({"nocache": True}),
+            values(
+                column("mykey", Integer),
+                column("mytext", String),
+                column("myint", Integer),
+                name="myothervalues",
+            )
+            .data([(1, "textA", 99), (2, "textB", 88)])
+            ._annotate({"nocache": True}),
+            values(
+                column("mykey", Integer),
+                column("mytext", String),
+                column("myint", Integer),
+                name="myvalues",
+            )
+            .data([(1, "textA", 89), (2, "textG", 88)])
+            ._annotate({"nocache": True}),
+            values(
+                column("mykey", Integer),
+                column("mynottext", String),
+                column("myint", Integer),
+                name="myvalues",
+            )
+            .data([(1, "textA", 99), (2, "textB", 88)])
+            ._annotate({"nocache": True}),
+            # TODO: difference in type
+            # values(
+            #    [
+            #        column("mykey", Integer),
+            #        column("mytext", Text),
+            #        column("myint", Integer),
+            #    ],
+            #    (1, "textA", 99),
+            #    (2, "textB", 88),
+            #    alias_name="myvalues",
+            # ),
+        ),
         lambda: (
             select([table_a.c.a]),
             select([table_a.c.a]).prefix_with("foo"),
@@ -482,43 +538,6 @@ class CoreFixtures(object):
             table("a", column("q"), column("y", Integer)),
         ),
         lambda: (table_a, table_b),
-        lambda: (
-            values(
-                column("mykey", Integer),
-                column("mytext", String),
-                column("myint", Integer),
-                name="myvalues",
-            ).data([(1, "textA", 99), (2, "textB", 88)]),
-            values(
-                column("mykey", Integer),
-                column("mytext", String),
-                column("myint", Integer),
-                name="myothervalues",
-            ).data([(1, "textA", 99), (2, "textB", 88)]),
-            values(
-                column("mykey", Integer),
-                column("mytext", String),
-                column("myint", Integer),
-                name="myvalues",
-            ).data([(1, "textA", 89), (2, "textG", 88)]),
-            values(
-                column("mykey", Integer),
-                column("mynottext", String),
-                column("myint", Integer),
-                name="myvalues",
-            ).data([(1, "textA", 99), (2, "textB", 88)]),
-            # TODO: difference in type
-            # values(
-            #    [
-            #        column("mykey", Integer),
-            #        column("mytext", Text),
-            #        column("myint", Integer),
-            #    ],
-            #    (1, "textA", 99),
-            #    (2, "textB", 88),
-            #    alias_name="myvalues",
-            # ),
-        ),
     ]
 
     dont_compare_values_fixtures = [
@@ -697,10 +716,36 @@ class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
             index_elements=[table_a.c.a], set_={"name": "foo"}
         ),
         mysql.insert(table_a).on_duplicate_key_update(updated_once=None),
+        table_a.insert().values(  # multivalues doesn't cache
+            [
+                {"name": "some name"},
+                {"name": "some other name"},
+                {"name": "yet another name"},
+            ]
+        ),
     )
     def test_dml_not_cached_yet(self, dml_stmt):
         eq_(dml_stmt._generate_cache_key(), None)
 
+    def test_values_doesnt_caches_right_now(self):
+        v1 = values(
+            column("mykey", Integer),
+            column("mytext", String),
+            column("myint", Integer),
+            name="myvalues",
+        ).data([(1, "textA", 99), (2, "textB", 88)])
+
+        is_(v1._generate_cache_key(), None)
+
+        large_v1 = values(
+            column("mykey", Integer),
+            column("mytext", String),
+            column("myint", Integer),
+            name="myvalues",
+        ).data([(i, "data %s" % i, i * 5) for i in range(500)])
+
+        is_(large_v1._generate_cache_key(), None)
+
     def test_cache_key(self):
         for fixtures_, compare_values in [
             (self.fixtures, True),
@@ -912,50 +957,38 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
                 case_a = fixture()
                 case_b = fixture()
 
-                assert case_a[0].compare(
-                    case_b[0], compare_values=compare_values
-                )
+                for idx in range(len(case_a)):
+                    assert case_a[idx].compare(
+                        case_b[idx], compare_values=compare_values
+                    )
 
-                clone = visitors.replacement_traverse(
-                    case_a[0], {}, lambda elem: None
-                )
+                    clone = visitors.replacement_traverse(
+                        case_a[idx], {}, lambda elem: None
+                    )
 
-                assert clone.compare(case_b[0], compare_values=compare_values)
-
-                stack = [clone]
-                seen = {clone}
-                found_elements = False
-                while stack:
-                    obj = stack.pop(0)
-
-                    items = [
-                        subelem
-                        for key, elem in clone.__dict__.items()
-                        if key != "_is_clone_of" and elem is not None
-                        for subelem in util.to_list(elem)
-                        if (
-                            isinstance(subelem, (ColumnElement, ClauseList))
-                            and subelem not in seen
-                            and not isinstance(subelem, Immutable)
-                            and subelem is not case_a[0]
-                        )
-                    ]
-                    stack.extend(items)
-                    seen.update(items)
-
-                    if obj is not clone:
-                        found_elements = True
-                        # ensure the element will not compare as true
-                        obj.compare = lambda other, **kw: False
-                        obj.__visit_name__ = "dont_match"
-
-                if found_elements:
-                    assert not clone.compare(
-                        case_b[0], compare_values=compare_values
+                    assert clone.compare(
+                        case_b[idx], compare_values=compare_values
                     )
-                assert case_a[0].compare(
-                    case_b[0], compare_values=compare_values
-                )
+
+                    assert case_a[idx].compare(
+                        case_b[idx], compare_values=compare_values
+                    )
+
+                    # copy internals of Select is very different than other
+                    # elements and additionally this is extremely well tested
+                    # in test_selectable and test_external_traversal, so
+                    # skip these
+                    if isinstance(case_a[idx], Select):
+                        continue
+
+                    for elema, elemb in zip(
+                        visitors.iterate(case_a[idx], {}),
+                        visitors.iterate(clone, {}),
+                    ):
+                        if isinstance(elema, ClauseElement) and not isinstance(
+                            elema, Immutable
+                        ):
+                            assert elema is not elemb
 
 
 class CompareClausesTest(fixtures.TestBase):