]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
`aggregate_order_by` now supports cache generation.
authorFederico Caselli <cfederico87@gmail.com>
Sun, 25 Sep 2022 14:37:15 +0000 (16:37 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Sep 2022 01:16:41 +0000 (21:16 -0400)
also adjusted CacheKeyFixture to be a general purpose
fixture so that sub-components / dialects can run
their own cache key tests.

Fixes: #8574
Change-Id: I6c66107856aee11e548d357cea77bceee3e316a0
(cherry picked from commit 7980b677085fc759a0406f6778b9729955f3c7f6)

doc/build/changelog/unreleased_14/8574.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/ext.py
lib/sqlalchemy/testing/fixtures.py
test/dialect/postgresql/test_compiler.py
test/orm/test_cache_key.py
test/orm/test_deprecations.py
test/sql/test_compare.py

diff --git a/doc/build/changelog/unreleased_14/8574.rst b/doc/build/changelog/unreleased_14/8574.rst
new file mode 100644 (file)
index 0000000..ffc1761
--- /dev/null
@@ -0,0 +1,5 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 8574
+
+    :class:`_postgresql.aggregate_order_by` now supports cache generation.
index 9e52ee1ee9f241a18edd83afbe9980551a1e3d1c..e6b992e88a91ed434464eeb5753ec9527df33f9c 100644 (file)
@@ -14,6 +14,7 @@ from ...sql import functions
 from ...sql import roles
 from ...sql import schema
 from ...sql.schema import ColumnCollectionConstraint
+from ...sql.visitors import InternalTraversal
 
 
 class aggregate_order_by(expression.ColumnElement):
@@ -54,7 +55,11 @@ class aggregate_order_by(expression.ColumnElement):
     __visit_name__ = "aggregate_order_by"
 
     stringify_dialect = "postgresql"
-    inherit_cache = False
+    _traverse_internals = [
+        ("target", InternalTraversal.dp_clauseelement),
+        ("type", InternalTraversal.dp_type),
+        ("order_by", InternalTraversal.dp_clauseelement),
+    ]
 
     def __init__(self, target, *order_by):
         self.target = coercions.expect(roles.ExpressionElementRole, target)
index 0a2d63b5480d8707c40f82819bf0a2f46fa327f9..999647b5b193d03073aadd4ea44a015342c2fcfb 100644 (file)
@@ -6,6 +6,7 @@
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
 import contextlib
+import itertools
 import re
 import sys
 
@@ -13,6 +14,8 @@ import sqlalchemy as sa
 from . import assertions
 from . import config
 from . import schema
+from .assertions import eq_
+from .assertions import ne_
 from .entities import BasicEntity
 from .entities import ComparableEntity
 from .entities import ComparableMixin  # noqa
@@ -24,6 +27,8 @@ from ..orm import declarative_base
 from ..orm import registry
 from ..orm.decl_api import DeclarativeMeta
 from ..schema import sort_tables_and_constraints
+from ..sql import visitors
+from ..sql.elements import ClauseElement
 
 
 @config.mark_base_test_class()
@@ -868,3 +873,106 @@ class ComputedReflectionFixtureTest(TablesTest):
                         Computed("normal * 42", persisted=True),
                     )
                 )
+
+
+class CacheKeyFixture(object):
+    def _compare_equal(self, a, b, compare_values):
+        a_key = a._generate_cache_key()
+        b_key = b._generate_cache_key()
+
+        if a_key is None:
+            assert a._annotations.get("nocache")
+
+            assert b_key is None
+        else:
+
+            eq_(a_key.key, b_key.key)
+            eq_(hash(a_key.key), hash(b_key.key))
+
+            for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
+                assert a_param.compare(b_param, compare_values=compare_values)
+        return a_key, b_key
+
+    def _run_cache_key_fixture(self, fixture, compare_values):
+        case_a = fixture()
+        case_b = fixture()
+
+        for a, b in itertools.combinations_with_replacement(
+            range(len(case_a)), 2
+        ):
+            if a == b:
+                a_key, b_key = self._compare_equal(
+                    case_a[a], case_b[b], compare_values
+                )
+                if a_key is None:
+                    continue
+            else:
+                a_key = case_a[a]._generate_cache_key()
+                b_key = case_b[b]._generate_cache_key()
+
+                if a_key is None or b_key is None:
+                    if a_key is None:
+                        assert case_a[a]._annotations.get("nocache")
+                    if b_key is None:
+                        assert case_b[b]._annotations.get("nocache")
+                    continue
+
+                if a_key.key == b_key.key:
+                    for a_param, b_param in zip(
+                        a_key.bindparams, b_key.bindparams
+                    ):
+                        if not a_param.compare(
+                            b_param, compare_values=compare_values
+                        ):
+                            break
+                    else:
+                        # this fails unconditionally since we could not
+                        # find bound parameter values that differed.
+                        # Usually we intended to get two distinct keys here
+                        # so the failure will be more descriptive using the
+                        # ne_() assertion.
+                        ne_(a_key.key, b_key.key)
+                else:
+                    ne_(a_key.key, b_key.key)
+
+            # ClauseElement-specific test to ensure the cache key
+            # collected all the bound parameters that aren't marked
+            # as "literal execute"
+            if isinstance(case_a[a], ClauseElement) and isinstance(
+                case_b[b], ClauseElement
+            ):
+                assert_a_params = []
+                assert_b_params = []
+
+                for elem in visitors.iterate(case_a[a]):
+                    if elem.__visit_name__ == "bindparam":
+                        assert_a_params.append(elem)
+
+                for elem in visitors.iterate(case_b[b]):
+                    if elem.__visit_name__ == "bindparam":
+                        assert_b_params.append(elem)
+
+                # note we're asserting the order of the params as well as
+                # if there are dupes or not.  ordering has to be
+                # deterministic and matches what a traversal would provide.
+                eq_(
+                    sorted(a_key.bindparams, key=lambda b: b.key),
+                    sorted(
+                        util.unique_list(assert_a_params), key=lambda b: b.key
+                    ),
+                )
+                eq_(
+                    sorted(b_key.bindparams, key=lambda b: b.key),
+                    sorted(
+                        util.unique_list(assert_b_params), key=lambda b: b.key
+                    ),
+                )
+
+    def _run_cache_key_equal_fixture(self, fixture, compare_values):
+        case_a = fixture()
+        case_b = fixture()
+
+        for a, b in itertools.combinations_with_replacement(
+            range(len(case_a)), 2
+        ):
+            self._compare_equal(case_a[a], case_b[b], compare_values)
index d85ae9152fd4ee2be7fb3ae045fa2451f0892409..897909b158bcc1f7143e2480a8bd1c6d44fcf7d7 100644 (file)
@@ -3347,3 +3347,36 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             "SELECT 1 " + exp,
             checkparams=params,
         )
+
+
+class CacheKeyTest(fixtures.CacheKeyFixture, fixtures.TestBase):
+    def test_aggregate_order_by(self):
+        """test #8574"""
+
+        self._run_cache_key_fixture(
+            lambda: (
+                aggregate_order_by(column("a"), column("a")),
+                aggregate_order_by(column("a"), column("b")),
+                aggregate_order_by(column("a"), column("a").desc()),
+                aggregate_order_by(column("a"), column("a").nulls_first()),
+                aggregate_order_by(
+                    column("a"), column("a").desc().nulls_first()
+                ),
+                aggregate_order_by(column("a", Integer), column("b")),
+                aggregate_order_by(column("a"), column("b"), column("c")),
+                aggregate_order_by(column("a"), column("c"), column("b")),
+                aggregate_order_by(
+                    column("a"), column("b").desc(), column("c")
+                ),
+                aggregate_order_by(
+                    column("a"), column("b").nulls_first(), column("c")
+                ),
+                aggregate_order_by(
+                    column("a"), column("b").desc().nulls_first(), column("c")
+                ),
+                aggregate_order_by(
+                    column("a", Integer), column("a"), column("b")
+                ),
+            ),
+            compare_values=False,
+        )
index daf963952c817e020bf8ae1a2a33b8707033ca96..23fec61d2a09da9c3b5d37ab7e97f6c3bc411d15 100644 (file)
@@ -42,7 +42,6 @@ from sqlalchemy.testing.fixtures import fixture_session
 from test.orm import _fixtures
 from .inheritance import _poly_fixtures
 from .test_query import QueryTest
-from ..sql.test_compare import CacheKeyFixture
 
 
 def stmt_20(*elements):
@@ -52,7 +51,7 @@ def stmt_20(*elements):
     )
 
 
-class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
+class CacheKeyTest(fixtures.CacheKeyFixture, _fixtures.FixtureTest):
     run_setup_mappers = "once"
     run_inserts = None
     run_deletes = None
@@ -586,7 +585,7 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
         )
 
 
-class PolyCacheKeyTest(CacheKeyFixture, _poly_fixtures._Polymorphic):
+class PolyCacheKeyTest(fixtures.CacheKeyFixture, _poly_fixtures._Polymorphic):
     run_setup_mappers = "once"
     run_inserts = None
     run_deletes = None
index 8febf3b3fcf83862df9ea2cf630bd6f9a266f2c5..bbfcd0cfd39c10d55c2d533ac0a1ba0e86102a69 100644 (file)
@@ -77,6 +77,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.fixtures import CacheKeyFixture
 from sqlalchemy.testing.fixtures import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.mock import call
@@ -105,7 +106,6 @@ from .test_options import PathTest
 from .test_options import QueryTest as OptionsQueryTest
 from .test_query import QueryTest
 from .test_transaction import _LocalFixture
-from ..sql.test_compare import CacheKeyFixture
 
 
 join_aliased_dep = (
index f73e9864d37691e64c969d0c0fc8d9f96a65acf4..6cee271c9c12b3e752fa9508157bf44b14d6f57c 100644 (file)
@@ -1053,110 +1053,7 @@ class CoreFixtures(object):
     ]
 
 
-class CacheKeyFixture(object):
-    def _compare_equal(self, a, b, compare_values):
-        a_key = a._generate_cache_key()
-        b_key = b._generate_cache_key()
-
-        if a_key is None:
-            assert a._annotations.get("nocache")
-
-            assert b_key is None
-        else:
-
-            eq_(a_key.key, b_key.key)
-            eq_(hash(a_key.key), hash(b_key.key))
-
-            for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
-                assert a_param.compare(b_param, compare_values=compare_values)
-        return a_key, b_key
-
-    def _run_cache_key_fixture(self, fixture, compare_values):
-        case_a = fixture()
-        case_b = fixture()
-
-        for a, b in itertools.combinations_with_replacement(
-            range(len(case_a)), 2
-        ):
-            if a == b:
-                a_key, b_key = self._compare_equal(
-                    case_a[a], case_b[b], compare_values
-                )
-                if a_key is None:
-                    continue
-            else:
-                a_key = case_a[a]._generate_cache_key()
-                b_key = case_b[b]._generate_cache_key()
-
-                if a_key is None or b_key is None:
-                    if a_key is None:
-                        assert case_a[a]._annotations.get("nocache")
-                    if b_key is None:
-                        assert case_b[b]._annotations.get("nocache")
-                    continue
-
-                if a_key.key == b_key.key:
-                    for a_param, b_param in zip(
-                        a_key.bindparams, b_key.bindparams
-                    ):
-                        if not a_param.compare(
-                            b_param, compare_values=compare_values
-                        ):
-                            break
-                    else:
-                        # this fails unconditionally since we could not
-                        # find bound parameter values that differed.
-                        # Usually we intended to get two distinct keys here
-                        # so the failure will be more descriptive using the
-                        # ne_() assertion.
-                        ne_(a_key.key, b_key.key)
-                else:
-                    ne_(a_key.key, b_key.key)
-
-            # ClauseElement-specific test to ensure the cache key
-            # collected all the bound parameters that aren't marked
-            # as "literal execute"
-            if isinstance(case_a[a], ClauseElement) and isinstance(
-                case_b[b], ClauseElement
-            ):
-                assert_a_params = []
-                assert_b_params = []
-
-                for elem in visitors.iterate(case_a[a]):
-                    if elem.__visit_name__ == "bindparam":
-                        assert_a_params.append(elem)
-
-                for elem in visitors.iterate(case_b[b]):
-                    if elem.__visit_name__ == "bindparam":
-                        assert_b_params.append(elem)
-
-                # note we're asserting the order of the params as well as
-                # if there are dupes or not.  ordering has to be
-                # deterministic and matches what a traversal would provide.
-                eq_(
-                    sorted(a_key.bindparams, key=lambda b: b.key),
-                    sorted(
-                        util.unique_list(assert_a_params), key=lambda b: b.key
-                    ),
-                )
-                eq_(
-                    sorted(b_key.bindparams, key=lambda b: b.key),
-                    sorted(
-                        util.unique_list(assert_b_params), key=lambda b: b.key
-                    ),
-                )
-
-    def _run_cache_key_equal_fixture(self, fixture, compare_values):
-        case_a = fixture()
-        case_b = fixture()
-
-        for a, b in itertools.combinations_with_replacement(
-            range(len(case_a)), 2
-        ):
-            self._compare_equal(case_a[a], case_b[b], compare_values)
-
-
-class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
+class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase):
     # we are slightly breaking the policy of not having external dialect
     # stuff in here, but use pg/mysql as test cases to ensure that these
     # objects don't report an inaccurate cache key, which is dependent