From 49b6c50016c8a038a6df7104560bb3945debe064 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 1 Apr 2020 18:31:16 -0400 Subject: [PATCH] Repair caching / traversals for values The test suite wasn't running the copy_internals most fixtures, enable that and try to get all cases working. Set up selectable.values to do tuple conversion within compilation step. at the same time, disable caching for selectable.values for the moment and make it equivalent to dml_multi_values. fix cache / compare / copy cases for dml_values and dml_multi_values which weren't fully tested or covered. Change-Id: I484ca6e9cb2b66c2e6a321698f2abc0838db1460 --- lib/sqlalchemy/sql/compiler.py | 9 +- lib/sqlalchemy/sql/selectable.py | 24 ++-- lib/sqlalchemy/sql/traversals.py | 77 ++++++------- test/sql/test_compare.py | 189 ++++++++++++++++++------------- 4 files changed, 165 insertions(+), 134 deletions(-) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 799fca2f58..b93ed88905 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2324,9 +2324,14 @@ class SQLCompiler(Compiled): return text def visit_values(self, element, asfrom=False, from_linter=None, **kw): + v = "VALUES %s" % ", ".join( - self.process(elem, literal_binds=element.literal_binds) - for elem in element._data + self.process( + elements.Tuple(*elem).self_group(), + literal_binds=element.literal_binds, + ) + for chunk in element._data + for elem in chunk ) if isinstance(element.name, elements._truncated_label): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index e39d61fdbb..a0df45b527 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -47,7 +47,6 @@ from .elements import ColumnClause from .elements import GroupedElement from .elements import Grouping from .elements import literal_column -from .elements import Tuple from .elements import UnaryExpression from .visitors import InternalTraversal from .. import exc @@ -1264,14 +1263,16 @@ class AliasedReturnsRows(NoInit, FromClause): self.element._generate_fromclause_column_proxies(self) def _copy_internals(self, clone=_clone, **kw): - element = clone(self.element, **kw) + existing_element = self.element + + super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw) # the element clone is usually against a Table that returns the # same object. don't reset exported .c. collections and other - # memoized details if nothing changed - if element is not self.element: + # memoized details if it was not changed. this saves a lot on + # performance. + if existing_element is not self.element: self._reset_column_collection() - self.element = element @property def _from_objects(self): @@ -1528,15 +1529,6 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows): self._suffixes = _suffixes super(CTE, self)._init(selectable, name=name) - def _copy_internals(self, clone=_clone, **kw): - super(CTE, self)._copy_internals(clone, **kw) - # TODO: I don't like that we can't use the traversal data here - if self._cte_alias is not None: - self._cte_alias = clone(self._cte_alias, **kw) - self._restates = frozenset( - [clone(elem, **kw) for elem in self._restates] - ) - def alias(self, name=None, flat=False): """Return an :class:`.Alias` of this :class:`.CTE`. @@ -2064,7 +2056,7 @@ class Values(Generative, FromClause): _traverse_internals = [ ("_column_args", InternalTraversal.dp_clauseelement_list,), - ("_data", InternalTraversal.dp_clauseelement_list), + ("_data", InternalTraversal.dp_dml_multi_values), ("name", InternalTraversal.dp_string), ("literal_binds", InternalTraversal.dp_boolean), ] @@ -2155,7 +2147,7 @@ class Values(Generative, FromClause): """ - self._data += tuple(Tuple(*row).self_group() for row in values) + self._data += (values,) def _populate_column_collection(self): for c in self._column_args: diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 9ac6cda978..032488826d 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -7,6 +7,7 @@ from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal from .. import util from ..inspection import inspect +from ..util import collections_abc from ..util import HasMemoized SKIP_TRAVERSE = util.symbol("skip_traverse") @@ -533,18 +534,12 @@ class _CopyInternals(InternalTraversal): ] def visit_dml_values(self, parent, element, clone=_clone, **kw): - # sequence of dictionaries - return [ - { - ( - clone(key, **kw) - if hasattr(key, "__clause_element__") - else key - ): clone(value, **kw) - for key, value in sub_element.items() - } - for sub_element in element - ] + return { + ( + clone(key, **kw) if hasattr(key, "__clause_element__") else key + ): clone(value, **kw) + for key, value in element.items() + } def visit_dml_multi_values(self, parent, element, clone=_clone, **kw): # sequence of sequences, each sequence contains a list/dict/tuple @@ -552,15 +547,10 @@ class _CopyInternals(InternalTraversal): def copy(elem): if isinstance(elem, (list, tuple)): return [ - ( - clone(key, **kw) - if hasattr(key, "__clause_element__") - else key, - clone(value, **kw) - if hasattr(value, "__clause_element__") - else value, - ) - for key, value in elem + clone(value, **kw) + if hasattr(value, "__clause_element__") + else value + for value in elem ] elif isinstance(elem, dict): return { @@ -573,7 +563,7 @@ class _CopyInternals(InternalTraversal): if hasattr(value, "__clause_element__") else value ) - for key, value in elem + for key, value in elem.items() } else: # TODO: use abc classes @@ -939,30 +929,41 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): for (lk, lv), (rk, rv) in util.zip_longest( left, right, fillvalue=(None, None) ): - lkce = hasattr(lk, "__clause_element__") - rkce = hasattr(rk, "__clause_element__") - if lkce != rkce: - return COMPARE_FAILED - elif lkce and not self.compare_inner(lk, rk, **kw): - return COMPARE_FAILED - elif not lkce and lk != rk: - return COMPARE_FAILED - elif not self.compare_inner(lv, rv, **kw): + if not self._compare_dml_values_or_ce(lk, rk, **kw): return COMPARE_FAILED + def _compare_dml_values_or_ce(self, lv, rv, **kw): + lvce = hasattr(lv, "__clause_element__") + rvce = hasattr(rv, "__clause_element__") + if lvce != rvce: + return False + elif lvce and not self.compare_inner(lv, rv, **kw): + return False + elif not lvce and lv != rv: + return False + elif not self.compare_inner(lv, rv, **kw): + return False + + return True + def visit_dml_values(self, left_parent, left, right_parent, right, **kw): if left is None or right is None or len(left) != len(right): return COMPARE_FAILED - for lk in left: - lv = left[lk] + if isinstance(left, collections_abc.Sequence): + for lv, rv in zip(left, right): + if not self._compare_dml_values_or_ce(lv, rv, **kw): + return COMPARE_FAILED + else: + for lk in left: + lv = left[lk] - if lk not in right: - return COMPARE_FAILED - rv = right[lk] + if lk not in right: + return COMPARE_FAILED + rv = right[lk] - if not self.compare_inner(lv, rv, **kw): - return COMPARE_FAILED + if not self._compare_dml_values_or_ce(lv, rv, **kw): + return COMPARE_FAILED def visit_dml_multi_values( self, left_parent, left, right_parent, right, **kw diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 2800f8248c..3a6feac018 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -55,11 +55,13 @@ from sqlalchemy.sql.functions import ReturnTypeFromArgs from sqlalchemy.sql.selectable import _OffsetLimitParam from sqlalchemy.sql.selectable import AliasedReturnsRows from sqlalchemy.sql.selectable import FromGrouping +from sqlalchemy.sql.selectable import Select from sqlalchemy.sql.selectable import Selectable from sqlalchemy.sql.selectable import SelectStatementGrouping from sqlalchemy.sql.visitors import InternalTraversal from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not_ from sqlalchemy.testing import is_true @@ -372,6 +374,15 @@ class CoreFixtures(object): table_b.insert() .values([{"a": 5, "b": 10}, {"a": 8, "b": 12}]) ._annotate({"nocache": True}), + table_b.insert() + .values([{"a": 9, "b": 10}, {"a": 8, "b": 7}]) + ._annotate({"nocache": True}), + table_b.insert() + .values([(5, 10), (8, 12)]) + ._annotate({"nocache": True}), + table_b.insert() + .values([(5, 9), (5, 12)]) + ._annotate({"nocache": True}), ), lambda: ( table_b.update(), @@ -404,6 +415,51 @@ class CoreFixtures(object): table_b.delete().where(table_b.c.a == 5), table_b.delete().where(table_b.c.b == 5), ), + lambda: ( + values( + column("mykey", Integer), + column("mytext", String), + column("myint", Integer), + name="myvalues", + ) + .data([(1, "textA", 99), (2, "textB", 88)]) + ._annotate({"nocache": True}), + values( + column("mykey", Integer), + column("mytext", String), + column("myint", Integer), + name="myothervalues", + ) + .data([(1, "textA", 99), (2, "textB", 88)]) + ._annotate({"nocache": True}), + values( + column("mykey", Integer), + column("mytext", String), + column("myint", Integer), + name="myvalues", + ) + .data([(1, "textA", 89), (2, "textG", 88)]) + ._annotate({"nocache": True}), + values( + column("mykey", Integer), + column("mynottext", String), + column("myint", Integer), + name="myvalues", + ) + .data([(1, "textA", 99), (2, "textB", 88)]) + ._annotate({"nocache": True}), + # TODO: difference in type + # values( + # [ + # column("mykey", Integer), + # column("mytext", Text), + # column("myint", Integer), + # ], + # (1, "textA", 99), + # (2, "textB", 88), + # alias_name="myvalues", + # ), + ), lambda: ( select([table_a.c.a]), select([table_a.c.a]).prefix_with("foo"), @@ -482,43 +538,6 @@ class CoreFixtures(object): table("a", column("q"), column("y", Integer)), ), lambda: (table_a, table_b), - lambda: ( - values( - column("mykey", Integer), - column("mytext", String), - column("myint", Integer), - name="myvalues", - ).data([(1, "textA", 99), (2, "textB", 88)]), - values( - column("mykey", Integer), - column("mytext", String), - column("myint", Integer), - name="myothervalues", - ).data([(1, "textA", 99), (2, "textB", 88)]), - values( - column("mykey", Integer), - column("mytext", String), - column("myint", Integer), - name="myvalues", - ).data([(1, "textA", 89), (2, "textG", 88)]), - values( - column("mykey", Integer), - column("mynottext", String), - column("myint", Integer), - name="myvalues", - ).data([(1, "textA", 99), (2, "textB", 88)]), - # TODO: difference in type - # values( - # [ - # column("mykey", Integer), - # column("mytext", Text), - # column("myint", Integer), - # ], - # (1, "textA", 99), - # (2, "textB", 88), - # alias_name="myvalues", - # ), - ), ] dont_compare_values_fixtures = [ @@ -697,10 +716,36 @@ class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase): index_elements=[table_a.c.a], set_={"name": "foo"} ), mysql.insert(table_a).on_duplicate_key_update(updated_once=None), + table_a.insert().values( # multivalues doesn't cache + [ + {"name": "some name"}, + {"name": "some other name"}, + {"name": "yet another name"}, + ] + ), ) def test_dml_not_cached_yet(self, dml_stmt): eq_(dml_stmt._generate_cache_key(), None) + def test_values_doesnt_caches_right_now(self): + v1 = values( + column("mykey", Integer), + column("mytext", String), + column("myint", Integer), + name="myvalues", + ).data([(1, "textA", 99), (2, "textB", 88)]) + + is_(v1._generate_cache_key(), None) + + large_v1 = values( + column("mykey", Integer), + column("mytext", String), + column("myint", Integer), + name="myvalues", + ).data([(i, "data %s" % i, i * 5) for i in range(500)]) + + is_(large_v1._generate_cache_key(), None) + def test_cache_key(self): for fixtures_, compare_values in [ (self.fixtures, True), @@ -912,50 +957,38 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): case_a = fixture() case_b = fixture() - assert case_a[0].compare( - case_b[0], compare_values=compare_values - ) + for idx in range(len(case_a)): + assert case_a[idx].compare( + case_b[idx], compare_values=compare_values + ) - clone = visitors.replacement_traverse( - case_a[0], {}, lambda elem: None - ) + clone = visitors.replacement_traverse( + case_a[idx], {}, lambda elem: None + ) - assert clone.compare(case_b[0], compare_values=compare_values) - - stack = [clone] - seen = {clone} - found_elements = False - while stack: - obj = stack.pop(0) - - items = [ - subelem - for key, elem in clone.__dict__.items() - if key != "_is_clone_of" and elem is not None - for subelem in util.to_list(elem) - if ( - isinstance(subelem, (ColumnElement, ClauseList)) - and subelem not in seen - and not isinstance(subelem, Immutable) - and subelem is not case_a[0] - ) - ] - stack.extend(items) - seen.update(items) - - if obj is not clone: - found_elements = True - # ensure the element will not compare as true - obj.compare = lambda other, **kw: False - obj.__visit_name__ = "dont_match" - - if found_elements: - assert not clone.compare( - case_b[0], compare_values=compare_values + assert clone.compare( + case_b[idx], compare_values=compare_values ) - assert case_a[0].compare( - case_b[0], compare_values=compare_values - ) + + assert case_a[idx].compare( + case_b[idx], compare_values=compare_values + ) + + # copy internals of Select is very different than other + # elements and additionally this is extremely well tested + # in test_selectable and test_external_traversal, so + # skip these + if isinstance(case_a[idx], Select): + continue + + for elema, elemb in zip( + visitors.iterate(case_a[idx], {}), + visitors.iterate(clone, {}), + ): + if isinstance(elema, ClauseElement) and not isinstance( + elema, Immutable + ): + assert elema is not elemb class CompareClausesTest(fixtures.TestBase): -- 2.47.3