]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
`aggregate_order_by` now supports cache generation.
authorFederico Caselli <cfederico87@gmail.com>
Sun, 25 Sep 2022 14:37:15 +0000 (16:37 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Sep 2022 01:14:48 +0000 (21:14 -0400)
also adjusted CacheKeyFixture to be a general purpose
fixture so that sub-components / dialects can run
their own cache key tests.

Fixes: #8574
Change-Id: I6c66107856aee11e548d357cea77bceee3e316a0

doc/build/changelog/unreleased_14/8574.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/ext.py
lib/sqlalchemy/testing/fixtures.py
test/dialect/postgresql/test_compiler.py
test/orm/test_cache_key.py
test/orm/test_deprecations.py
test/sql/test_compare.py

diff --git a/doc/build/changelog/unreleased_14/8574.rst b/doc/build/changelog/unreleased_14/8574.rst
new file mode 100644 (file)
index 0000000..ffc1761
--- /dev/null
@@ -0,0 +1,5 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 8574
+
+    :class:`_postgresql.aggregate_order_by` now supports cache generation.
index 0192cf58157852d8c7bc56f2b38297c61d169123..ebaad273426419c27092c6e6142f8181832efb37 100644 (file)
@@ -5,8 +5,10 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 # mypy: ignore-errors
+from __future__ import annotations
 
 from itertools import zip_longest
+from typing import TYPE_CHECKING
 
 from .array import ARRAY
 from ...sql import coercions
@@ -16,6 +18,10 @@ from ...sql import functions
 from ...sql import roles
 from ...sql import schema
 from ...sql.schema import ColumnCollectionConstraint
+from ...sql.visitors import InternalTraversal
+
+if TYPE_CHECKING:
+    from ...sql.visitors import _TraverseInternalsType
 
 
 class aggregate_order_by(expression.ColumnElement):
@@ -56,7 +62,11 @@ class aggregate_order_by(expression.ColumnElement):
     __visit_name__ = "aggregate_order_by"
 
     stringify_dialect = "postgresql"
-    inherit_cache = False
+    _traverse_internals: _TraverseInternalsType = [
+        ("target", InternalTraversal.dp_clauseelement),
+        ("type", InternalTraversal.dp_type),
+        ("order_by", InternalTraversal.dp_clauseelement),
+    ]
 
     def __init__(self, target, *order_by):
         self.target = coercions.expect(roles.ExpressionElementRole, target)
index 20dee5273bb96150517a7b0e0b3e09acc59f4de1..2a5f97dbb52f5e409053fbd007bd9dcd79725f66 100644 (file)
@@ -9,6 +9,7 @@
 
 from __future__ import annotations
 
+import itertools
 import re
 import sys
 
@@ -16,6 +17,8 @@ import sqlalchemy as sa
 from . import assertions
 from . import config
 from . import schema
+from .assertions import eq_
+from .assertions import ne_
 from .entities import BasicEntity
 from .entities import ComparableEntity
 from .entities import ComparableMixin  # noqa
@@ -28,6 +31,8 @@ from ..orm import DeclarativeBase
 from ..orm import MappedAsDataclass
 from ..orm import registry
 from ..schema import sort_tables_and_constraints
+from ..sql import visitors
+from ..sql.elements import ClauseElement
 
 
 @config.mark_base_test_class()
@@ -881,3 +886,106 @@ class ComputedReflectionFixtureTest(TablesTest):
                         Computed("normal * 42", persisted=True),
                     )
                 )
+
+
+class CacheKeyFixture:
+    def _compare_equal(self, a, b, compare_values):
+        a_key = a._generate_cache_key()
+        b_key = b._generate_cache_key()
+
+        if a_key is None:
+            assert a._annotations.get("nocache")
+
+            assert b_key is None
+        else:
+
+            eq_(a_key.key, b_key.key)
+            eq_(hash(a_key.key), hash(b_key.key))
+
+            for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
+                assert a_param.compare(b_param, compare_values=compare_values)
+        return a_key, b_key
+
+    def _run_cache_key_fixture(self, fixture, compare_values):
+        case_a = fixture()
+        case_b = fixture()
+
+        for a, b in itertools.combinations_with_replacement(
+            range(len(case_a)), 2
+        ):
+            if a == b:
+                a_key, b_key = self._compare_equal(
+                    case_a[a], case_b[b], compare_values
+                )
+                if a_key is None:
+                    continue
+            else:
+                a_key = case_a[a]._generate_cache_key()
+                b_key = case_b[b]._generate_cache_key()
+
+                if a_key is None or b_key is None:
+                    if a_key is None:
+                        assert case_a[a]._annotations.get("nocache")
+                    if b_key is None:
+                        assert case_b[b]._annotations.get("nocache")
+                    continue
+
+                if a_key.key == b_key.key:
+                    for a_param, b_param in zip(
+                        a_key.bindparams, b_key.bindparams
+                    ):
+                        if not a_param.compare(
+                            b_param, compare_values=compare_values
+                        ):
+                            break
+                    else:
+                        # this fails unconditionally since we could not
+                        # find bound parameter values that differed.
+                        # Usually we intended to get two distinct keys here
+                        # so the failure will be more descriptive using the
+                        # ne_() assertion.
+                        ne_(a_key.key, b_key.key)
+                else:
+                    ne_(a_key.key, b_key.key)
+
+            # ClauseElement-specific test to ensure the cache key
+            # collected all the bound parameters that aren't marked
+            # as "literal execute"
+            if isinstance(case_a[a], ClauseElement) and isinstance(
+                case_b[b], ClauseElement
+            ):
+                assert_a_params = []
+                assert_b_params = []
+
+                for elem in visitors.iterate(case_a[a]):
+                    if elem.__visit_name__ == "bindparam":
+                        assert_a_params.append(elem)
+
+                for elem in visitors.iterate(case_b[b]):
+                    if elem.__visit_name__ == "bindparam":
+                        assert_b_params.append(elem)
+
+                # note we're asserting the order of the params as well as
+                # if there are dupes or not.  ordering has to be
+                # deterministic and matches what a traversal would provide.
+                eq_(
+                    sorted(a_key.bindparams, key=lambda b: b.key),
+                    sorted(
+                        util.unique_list(assert_a_params), key=lambda b: b.key
+                    ),
+                )
+                eq_(
+                    sorted(b_key.bindparams, key=lambda b: b.key),
+                    sorted(
+                        util.unique_list(assert_b_params), key=lambda b: b.key
+                    ),
+                )
+
+    def _run_cache_key_equal_fixture(self, fixture, compare_values):
+        case_a = fixture()
+        case_b = fixture()
+
+        for a, b in itertools.combinations_with_replacement(
+            range(len(case_a)), 2
+        ):
+            self._compare_equal(case_a[a], case_b[b], compare_values)
index 67e54e4f5176d795ed2509fc368cccd7deadc8fa..c763dbeacc60f2c09e0f40c70740d5161d5cb9f3 100644 (file)
@@ -3465,3 +3465,36 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             "SELECT 1 " + exp,
             checkparams=params,
         )
+
+
+class CacheKeyTest(fixtures.CacheKeyFixture, fixtures.TestBase):
+    def test_aggregate_order_by(self):
+        """test #8574"""
+
+        self._run_cache_key_fixture(
+            lambda: (
+                aggregate_order_by(column("a"), column("a")),
+                aggregate_order_by(column("a"), column("b")),
+                aggregate_order_by(column("a"), column("a").desc()),
+                aggregate_order_by(column("a"), column("a").nulls_first()),
+                aggregate_order_by(
+                    column("a"), column("a").desc().nulls_first()
+                ),
+                aggregate_order_by(column("a", Integer), column("b")),
+                aggregate_order_by(column("a"), column("b"), column("c")),
+                aggregate_order_by(column("a"), column("c"), column("b")),
+                aggregate_order_by(
+                    column("a"), column("b").desc(), column("c")
+                ),
+                aggregate_order_by(
+                    column("a"), column("b").nulls_first(), column("c")
+                ),
+                aggregate_order_by(
+                    column("a"), column("b").desc().nulls_first(), column("c")
+                ),
+                aggregate_order_by(
+                    column("a", Integer), column("a"), column("b")
+                ),
+            ),
+            compare_values=False,
+        )
index 08fd22dc8960d39f31ea2c0fa29ada5032dea0ac..70770089c3faf918db018bf2d00ebbd11a7b0296 100644 (file)
@@ -42,7 +42,6 @@ from sqlalchemy.testing.fixtures import fixture_session
 from test.orm import _fixtures
 from .inheritance import _poly_fixtures
 from .test_query import QueryTest
-from ..sql.test_compare import CacheKeyFixture
 
 
 def stmt_20(*elements):
@@ -52,7 +51,7 @@ def stmt_20(*elements):
     )
 
 
-class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
+class CacheKeyTest(fixtures.CacheKeyFixture, _fixtures.FixtureTest):
     run_setup_mappers = "once"
     run_inserts = None
     run_deletes = None
@@ -591,7 +590,7 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
         )
 
 
-class PolyCacheKeyTest(CacheKeyFixture, _poly_fixtures._Polymorphic):
+class PolyCacheKeyTest(fixtures.CacheKeyFixture, _poly_fixtures._Polymorphic):
     run_setup_mappers = "once"
     run_inserts = None
     run_deletes = None
index b012009c88feb43bea02398e2a73790d608ab358..71c03aee7810b82ad12febbaa3468ca4b92c0437 100644 (file)
@@ -52,6 +52,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
+from sqlalchemy.testing.fixtures import CacheKeyFixture
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -65,7 +66,6 @@ from .test_options import PathTest as OptionsPathTest
 from .test_options import PathTest
 from .test_options import QueryTest as OptionsQueryTest
 from .test_query import QueryTest
-from ..sql.test_compare import CacheKeyFixture
 
 if True:
     # hack - zimports won't stop reformatting this to be too-long for now
index 18f26887a18262afdb0c57c1e872954884cd424d..30ca5c56995b76eaef9b2a0170b2b80aa12664ff 100644 (file)
@@ -27,7 +27,6 @@ from sqlalchemy import tuple_
 from sqlalchemy import TypeDecorator
 from sqlalchemy import union
 from sqlalchemy import union_all
-from sqlalchemy import util
 from sqlalchemy import values
 from sqlalchemy.dialects import mysql
 from sqlalchemy.dialects import postgresql
@@ -1054,110 +1053,7 @@ class CoreFixtures:
     ]
 
 
-class CacheKeyFixture:
-    def _compare_equal(self, a, b, compare_values):
-        a_key = a._generate_cache_key()
-        b_key = b._generate_cache_key()
-
-        if a_key is None:
-            assert a._annotations.get("nocache")
-
-            assert b_key is None
-        else:
-
-            eq_(a_key.key, b_key.key)
-            eq_(hash(a_key.key), hash(b_key.key))
-
-            for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
-                assert a_param.compare(b_param, compare_values=compare_values)
-        return a_key, b_key
-
-    def _run_cache_key_fixture(self, fixture, compare_values):
-        case_a = fixture()
-        case_b = fixture()
-
-        for a, b in itertools.combinations_with_replacement(
-            range(len(case_a)), 2
-        ):
-            if a == b:
-                a_key, b_key = self._compare_equal(
-                    case_a[a], case_b[b], compare_values
-                )
-                if a_key is None:
-                    continue
-            else:
-                a_key = case_a[a]._generate_cache_key()
-                b_key = case_b[b]._generate_cache_key()
-
-                if a_key is None or b_key is None:
-                    if a_key is None:
-                        assert case_a[a]._annotations.get("nocache")
-                    if b_key is None:
-                        assert case_b[b]._annotations.get("nocache")
-                    continue
-
-                if a_key.key == b_key.key:
-                    for a_param, b_param in zip(
-                        a_key.bindparams, b_key.bindparams
-                    ):
-                        if not a_param.compare(
-                            b_param, compare_values=compare_values
-                        ):
-                            break
-                    else:
-                        # this fails unconditionally since we could not
-                        # find bound parameter values that differed.
-                        # Usually we intended to get two distinct keys here
-                        # so the failure will be more descriptive using the
-                        # ne_() assertion.
-                        ne_(a_key.key, b_key.key)
-                else:
-                    ne_(a_key.key, b_key.key)
-
-            # ClauseElement-specific test to ensure the cache key
-            # collected all the bound parameters that aren't marked
-            # as "literal execute"
-            if isinstance(case_a[a], ClauseElement) and isinstance(
-                case_b[b], ClauseElement
-            ):
-                assert_a_params = []
-                assert_b_params = []
-
-                for elem in visitors.iterate(case_a[a]):
-                    if elem.__visit_name__ == "bindparam":
-                        assert_a_params.append(elem)
-
-                for elem in visitors.iterate(case_b[b]):
-                    if elem.__visit_name__ == "bindparam":
-                        assert_b_params.append(elem)
-
-                # note we're asserting the order of the params as well as
-                # if there are dupes or not.  ordering has to be
-                # deterministic and matches what a traversal would provide.
-                eq_(
-                    sorted(a_key.bindparams, key=lambda b: b.key),
-                    sorted(
-                        util.unique_list(assert_a_params), key=lambda b: b.key
-                    ),
-                )
-                eq_(
-                    sorted(b_key.bindparams, key=lambda b: b.key),
-                    sorted(
-                        util.unique_list(assert_b_params), key=lambda b: b.key
-                    ),
-                )
-
-    def _run_cache_key_equal_fixture(self, fixture, compare_values):
-        case_a = fixture()
-        case_b = fixture()
-
-        for a, b in itertools.combinations_with_replacement(
-            range(len(case_a)), 2
-        ):
-            self._compare_equal(case_a[a], case_b[b], compare_values)
-
-
-class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
+class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase):
     # we are slightly breaking the policy of not having external dialect
     # stuff in here, but use pg/mysql as test cases to ensure that these
     # objects don't report an inaccurate cache key, which is dependent