From: Federico Caselli Date: Sun, 25 Sep 2022 14:37:15 +0000 (+0200) Subject: `aggregate_order_by` now supports cache generation. X-Git-Tag: rel_1_4_42~18^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e3a71aadd7637824e5a6937118668f304460d003;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git `aggregate_order_by` now supports cache generation. also adjusted CacheKeyFixture to be a general purpose fixture so that sub-components / dialects can run their own cache key tests. Fixes: #8574 Change-Id: I6c66107856aee11e548d357cea77bceee3e316a0 (cherry picked from commit 7980b677085fc759a0406f6778b9729955f3c7f6) --- diff --git a/doc/build/changelog/unreleased_14/8574.rst b/doc/build/changelog/unreleased_14/8574.rst new file mode 100644 index 0000000000..ffc1761c30 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8574.rst @@ -0,0 +1,5 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 8574 + + :class:`_postgresql.aggregate_order_by` now supports cache generation. diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 9e52ee1ee9..e6b992e88a 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -14,6 +14,7 @@ from ...sql import functions from ...sql import roles from ...sql import schema from ...sql.schema import ColumnCollectionConstraint +from ...sql.visitors import InternalTraversal class aggregate_order_by(expression.ColumnElement): @@ -54,7 +55,11 @@ class aggregate_order_by(expression.ColumnElement): __visit_name__ = "aggregate_order_by" stringify_dialect = "postgresql" - inherit_cache = False + _traverse_internals = [ + ("target", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ("order_by", InternalTraversal.dp_clauseelement), + ] def __init__(self, target, *order_by): self.target = coercions.expect(roles.ExpressionElementRole, target) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 0a2d63b548..999647b5b1 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -6,6 +6,7 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php import contextlib +import itertools import re import sys @@ -13,6 +14,8 @@ import sqlalchemy as sa from . import assertions from . import config from . import schema +from .assertions import eq_ +from .assertions import ne_ from .entities import BasicEntity from .entities import ComparableEntity from .entities import ComparableMixin # noqa @@ -24,6 +27,8 @@ from ..orm import declarative_base from ..orm import registry from ..orm.decl_api import DeclarativeMeta from ..schema import sort_tables_and_constraints +from ..sql import visitors +from ..sql.elements import ClauseElement @config.mark_base_test_class() @@ -868,3 +873,106 @@ class ComputedReflectionFixtureTest(TablesTest): Computed("normal * 42", persisted=True), ) ) + + +class CacheKeyFixture(object): + def _compare_equal(self, a, b, compare_values): + a_key = a._generate_cache_key() + b_key = b._generate_cache_key() + + if a_key is None: + assert a._annotations.get("nocache") + + assert b_key is None + else: + + eq_(a_key.key, b_key.key) + eq_(hash(a_key.key), hash(b_key.key)) + + for a_param, b_param in zip(a_key.bindparams, b_key.bindparams): + assert a_param.compare(b_param, compare_values=compare_values) + return a_key, b_key + + def _run_cache_key_fixture(self, fixture, compare_values): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + if a == b: + a_key, b_key = self._compare_equal( + case_a[a], case_b[b], compare_values + ) + if a_key is None: + continue + else: + a_key = case_a[a]._generate_cache_key() + b_key = case_b[b]._generate_cache_key() + + if a_key is None or b_key is None: + if a_key is None: + assert case_a[a]._annotations.get("nocache") + if b_key is None: + assert case_b[b]._annotations.get("nocache") + continue + + if a_key.key == b_key.key: + for a_param, b_param in zip( + a_key.bindparams, b_key.bindparams + ): + if not a_param.compare( + b_param, compare_values=compare_values + ): + break + else: + # this fails unconditionally since we could not + # find bound parameter values that differed. + # Usually we intended to get two distinct keys here + # so the failure will be more descriptive using the + # ne_() assertion. + ne_(a_key.key, b_key.key) + else: + ne_(a_key.key, b_key.key) + + # ClauseElement-specific test to ensure the cache key + # collected all the bound parameters that aren't marked + # as "literal execute" + if isinstance(case_a[a], ClauseElement) and isinstance( + case_b[b], ClauseElement + ): + assert_a_params = [] + assert_b_params = [] + + for elem in visitors.iterate(case_a[a]): + if elem.__visit_name__ == "bindparam": + assert_a_params.append(elem) + + for elem in visitors.iterate(case_b[b]): + if elem.__visit_name__ == "bindparam": + assert_b_params.append(elem) + + # note we're asserting the order of the params as well as + # if there are dupes or not. ordering has to be + # deterministic and matches what a traversal would provide. + eq_( + sorted(a_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_a_params), key=lambda b: b.key + ), + ) + eq_( + sorted(b_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_b_params), key=lambda b: b.key + ), + ) + + def _run_cache_key_equal_fixture(self, fixture, compare_values): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + self._compare_equal(case_a[a], case_b[b], compare_values) diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index d85ae9152f..897909b158 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -3347,3 +3347,36 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): "SELECT 1 " + exp, checkparams=params, ) + + +class CacheKeyTest(fixtures.CacheKeyFixture, fixtures.TestBase): + def test_aggregate_order_by(self): + """test #8574""" + + self._run_cache_key_fixture( + lambda: ( + aggregate_order_by(column("a"), column("a")), + aggregate_order_by(column("a"), column("b")), + aggregate_order_by(column("a"), column("a").desc()), + aggregate_order_by(column("a"), column("a").nulls_first()), + aggregate_order_by( + column("a"), column("a").desc().nulls_first() + ), + aggregate_order_by(column("a", Integer), column("b")), + aggregate_order_by(column("a"), column("b"), column("c")), + aggregate_order_by(column("a"), column("c"), column("b")), + aggregate_order_by( + column("a"), column("b").desc(), column("c") + ), + aggregate_order_by( + column("a"), column("b").nulls_first(), column("c") + ), + aggregate_order_by( + column("a"), column("b").desc().nulls_first(), column("c") + ), + aggregate_order_by( + column("a", Integer), column("a"), column("b") + ), + ), + compare_values=False, + ) diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py index daf963952c..23fec61d2a 100644 --- a/test/orm/test_cache_key.py +++ b/test/orm/test_cache_key.py @@ -42,7 +42,6 @@ from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures from .inheritance import _poly_fixtures from .test_query import QueryTest -from ..sql.test_compare import CacheKeyFixture def stmt_20(*elements): @@ -52,7 +51,7 @@ def stmt_20(*elements): ) -class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): +class CacheKeyTest(fixtures.CacheKeyFixture, _fixtures.FixtureTest): run_setup_mappers = "once" run_inserts = None run_deletes = None @@ -586,7 +585,7 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): ) -class PolyCacheKeyTest(CacheKeyFixture, _poly_fixtures._Polymorphic): +class PolyCacheKeyTest(fixtures.CacheKeyFixture, _poly_fixtures._Polymorphic): run_setup_mappers = "once" run_inserts = None run_deletes = None diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index 8febf3b3fc..bbfcd0cfd3 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -77,6 +77,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import CacheKeyFixture from sqlalchemy.testing.fixtures import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.mock import call @@ -105,7 +106,6 @@ from .test_options import PathTest from .test_options import QueryTest as OptionsQueryTest from .test_query import QueryTest from .test_transaction import _LocalFixture -from ..sql.test_compare import CacheKeyFixture join_aliased_dep = ( diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index f73e9864d3..6cee271c9c 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -1053,110 +1053,7 @@ class CoreFixtures(object): ] -class CacheKeyFixture(object): - def _compare_equal(self, a, b, compare_values): - a_key = a._generate_cache_key() - b_key = b._generate_cache_key() - - if a_key is None: - assert a._annotations.get("nocache") - - assert b_key is None - else: - - eq_(a_key.key, b_key.key) - eq_(hash(a_key.key), hash(b_key.key)) - - for a_param, b_param in zip(a_key.bindparams, b_key.bindparams): - assert a_param.compare(b_param, compare_values=compare_values) - return a_key, b_key - - def _run_cache_key_fixture(self, fixture, compare_values): - case_a = fixture() - case_b = fixture() - - for a, b in itertools.combinations_with_replacement( - range(len(case_a)), 2 - ): - if a == b: - a_key, b_key = self._compare_equal( - case_a[a], case_b[b], compare_values - ) - if a_key is None: - continue - else: - a_key = case_a[a]._generate_cache_key() - b_key = case_b[b]._generate_cache_key() - - if a_key is None or b_key is None: - if a_key is None: - assert case_a[a]._annotations.get("nocache") - if b_key is None: - assert case_b[b]._annotations.get("nocache") - continue - - if a_key.key == b_key.key: - for a_param, b_param in zip( - a_key.bindparams, b_key.bindparams - ): - if not a_param.compare( - b_param, compare_values=compare_values - ): - break - else: - # this fails unconditionally since we could not - # find bound parameter values that differed. - # Usually we intended to get two distinct keys here - # so the failure will be more descriptive using the - # ne_() assertion. - ne_(a_key.key, b_key.key) - else: - ne_(a_key.key, b_key.key) - - # ClauseElement-specific test to ensure the cache key - # collected all the bound parameters that aren't marked - # as "literal execute" - if isinstance(case_a[a], ClauseElement) and isinstance( - case_b[b], ClauseElement - ): - assert_a_params = [] - assert_b_params = [] - - for elem in visitors.iterate(case_a[a]): - if elem.__visit_name__ == "bindparam": - assert_a_params.append(elem) - - for elem in visitors.iterate(case_b[b]): - if elem.__visit_name__ == "bindparam": - assert_b_params.append(elem) - - # note we're asserting the order of the params as well as - # if there are dupes or not. ordering has to be - # deterministic and matches what a traversal would provide. - eq_( - sorted(a_key.bindparams, key=lambda b: b.key), - sorted( - util.unique_list(assert_a_params), key=lambda b: b.key - ), - ) - eq_( - sorted(b_key.bindparams, key=lambda b: b.key), - sorted( - util.unique_list(assert_b_params), key=lambda b: b.key - ), - ) - - def _run_cache_key_equal_fixture(self, fixture, compare_values): - case_a = fixture() - case_b = fixture() - - for a, b in itertools.combinations_with_replacement( - range(len(case_a)), 2 - ): - self._compare_equal(case_a[a], case_b[b], compare_values) - - -class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase): +class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase): # we are slightly breaking the policy of not having external dialect # stuff in here, but use pg/mysql as test cases to ensure that these # objects don't report an inaccurate cache key, which is dependent