--- /dev/null
+.. change::
+ :tags: usecase, postgresql
+ :tickets: 8574
+
+ :class:`_postgresql.aggregate_order_by` now supports cache generation.
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
+from __future__ import annotations
from itertools import zip_longest
+from typing import TYPE_CHECKING
from .array import ARRAY
from ...sql import coercions
from ...sql import roles
from ...sql import schema
from ...sql.schema import ColumnCollectionConstraint
+from ...sql.visitors import InternalTraversal
+
+if TYPE_CHECKING:
+ from ...sql.visitors import _TraverseInternalsType
class aggregate_order_by(expression.ColumnElement):
__visit_name__ = "aggregate_order_by"
stringify_dialect = "postgresql"
- inherit_cache = False
+ _traverse_internals: _TraverseInternalsType = [
+ ("target", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ]
def __init__(self, target, *order_by):
self.target = coercions.expect(roles.ExpressionElementRole, target)
from __future__ import annotations
+import itertools
import re
import sys
from . import assertions
from . import config
from . import schema
+from .assertions import eq_
+from .assertions import ne_
from .entities import BasicEntity
from .entities import ComparableEntity
from .entities import ComparableMixin # noqa
from ..orm import MappedAsDataclass
from ..orm import registry
from ..schema import sort_tables_and_constraints
+from ..sql import visitors
+from ..sql.elements import ClauseElement
@config.mark_base_test_class()
Computed("normal * 42", persisted=True),
)
)
+
+
+class CacheKeyFixture:
+ def _compare_equal(self, a, b, compare_values):
+ a_key = a._generate_cache_key()
+ b_key = b._generate_cache_key()
+
+ if a_key is None:
+ assert a._annotations.get("nocache")
+
+ assert b_key is None
+ else:
+
+ eq_(a_key.key, b_key.key)
+ eq_(hash(a_key.key), hash(b_key.key))
+
+ for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
+ assert a_param.compare(b_param, compare_values=compare_values)
+ return a_key, b_key
+
+ def _run_cache_key_fixture(self, fixture, compare_values):
+ case_a = fixture()
+ case_b = fixture()
+
+ for a, b in itertools.combinations_with_replacement(
+ range(len(case_a)), 2
+ ):
+ if a == b:
+ a_key, b_key = self._compare_equal(
+ case_a[a], case_b[b], compare_values
+ )
+ if a_key is None:
+ continue
+ else:
+ a_key = case_a[a]._generate_cache_key()
+ b_key = case_b[b]._generate_cache_key()
+
+ if a_key is None or b_key is None:
+ if a_key is None:
+ assert case_a[a]._annotations.get("nocache")
+ if b_key is None:
+ assert case_b[b]._annotations.get("nocache")
+ continue
+
+ if a_key.key == b_key.key:
+ for a_param, b_param in zip(
+ a_key.bindparams, b_key.bindparams
+ ):
+ if not a_param.compare(
+ b_param, compare_values=compare_values
+ ):
+ break
+ else:
+ # this fails unconditionally since we could not
+ # find bound parameter values that differed.
+ # Usually we intended to get two distinct keys here
+ # so the failure will be more descriptive using the
+ # ne_() assertion.
+ ne_(a_key.key, b_key.key)
+ else:
+ ne_(a_key.key, b_key.key)
+
+ # ClauseElement-specific test to ensure the cache key
+ # collected all the bound parameters that aren't marked
+ # as "literal execute"
+ if isinstance(case_a[a], ClauseElement) and isinstance(
+ case_b[b], ClauseElement
+ ):
+ assert_a_params = []
+ assert_b_params = []
+
+ for elem in visitors.iterate(case_a[a]):
+ if elem.__visit_name__ == "bindparam":
+ assert_a_params.append(elem)
+
+ for elem in visitors.iterate(case_b[b]):
+ if elem.__visit_name__ == "bindparam":
+ assert_b_params.append(elem)
+
+ # note we're asserting the order of the params as well as
+ # if there are dupes or not. ordering has to be
+ # deterministic and matches what a traversal would provide.
+ eq_(
+ sorted(a_key.bindparams, key=lambda b: b.key),
+ sorted(
+ util.unique_list(assert_a_params), key=lambda b: b.key
+ ),
+ )
+ eq_(
+ sorted(b_key.bindparams, key=lambda b: b.key),
+ sorted(
+ util.unique_list(assert_b_params), key=lambda b: b.key
+ ),
+ )
+
+ def _run_cache_key_equal_fixture(self, fixture, compare_values):
+ case_a = fixture()
+ case_b = fixture()
+
+ for a, b in itertools.combinations_with_replacement(
+ range(len(case_a)), 2
+ ):
+ self._compare_equal(case_a[a], case_b[b], compare_values)
"SELECT 1 " + exp,
checkparams=params,
)
+
+
+class CacheKeyTest(fixtures.CacheKeyFixture, fixtures.TestBase):
+ def test_aggregate_order_by(self):
+ """test #8574"""
+
+ self._run_cache_key_fixture(
+ lambda: (
+ aggregate_order_by(column("a"), column("a")),
+ aggregate_order_by(column("a"), column("b")),
+ aggregate_order_by(column("a"), column("a").desc()),
+ aggregate_order_by(column("a"), column("a").nulls_first()),
+ aggregate_order_by(
+ column("a"), column("a").desc().nulls_first()
+ ),
+ aggregate_order_by(column("a", Integer), column("b")),
+ aggregate_order_by(column("a"), column("b"), column("c")),
+ aggregate_order_by(column("a"), column("c"), column("b")),
+ aggregate_order_by(
+ column("a"), column("b").desc(), column("c")
+ ),
+ aggregate_order_by(
+ column("a"), column("b").nulls_first(), column("c")
+ ),
+ aggregate_order_by(
+ column("a"), column("b").desc().nulls_first(), column("c")
+ ),
+ aggregate_order_by(
+ column("a", Integer), column("a"), column("b")
+ ),
+ ),
+ compare_values=False,
+ )
from test.orm import _fixtures
from .inheritance import _poly_fixtures
from .test_query import QueryTest
-from ..sql.test_compare import CacheKeyFixture
def stmt_20(*elements):
)
-class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
+class CacheKeyTest(fixtures.CacheKeyFixture, _fixtures.FixtureTest):
run_setup_mappers = "once"
run_inserts = None
run_deletes = None
)
-class PolyCacheKeyTest(CacheKeyFixture, _poly_fixtures._Polymorphic):
+class PolyCacheKeyTest(fixtures.CacheKeyFixture, _poly_fixtures._Polymorphic):
run_setup_mappers = "once"
run_inserts = None
run_deletes = None
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
+from sqlalchemy.testing.fixtures import CacheKeyFixture
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from .test_options import PathTest
from .test_options import QueryTest as OptionsQueryTest
from .test_query import QueryTest
-from ..sql.test_compare import CacheKeyFixture
if True:
# hack - zimports won't stop reformatting this to be too-long for now
from sqlalchemy import TypeDecorator
from sqlalchemy import union
from sqlalchemy import union_all
-from sqlalchemy import util
from sqlalchemy import values
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import postgresql
]
-class CacheKeyFixture:
- def _compare_equal(self, a, b, compare_values):
- a_key = a._generate_cache_key()
- b_key = b._generate_cache_key()
-
- if a_key is None:
- assert a._annotations.get("nocache")
-
- assert b_key is None
- else:
-
- eq_(a_key.key, b_key.key)
- eq_(hash(a_key.key), hash(b_key.key))
-
- for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
- assert a_param.compare(b_param, compare_values=compare_values)
- return a_key, b_key
-
- def _run_cache_key_fixture(self, fixture, compare_values):
- case_a = fixture()
- case_b = fixture()
-
- for a, b in itertools.combinations_with_replacement(
- range(len(case_a)), 2
- ):
- if a == b:
- a_key, b_key = self._compare_equal(
- case_a[a], case_b[b], compare_values
- )
- if a_key is None:
- continue
- else:
- a_key = case_a[a]._generate_cache_key()
- b_key = case_b[b]._generate_cache_key()
-
- if a_key is None or b_key is None:
- if a_key is None:
- assert case_a[a]._annotations.get("nocache")
- if b_key is None:
- assert case_b[b]._annotations.get("nocache")
- continue
-
- if a_key.key == b_key.key:
- for a_param, b_param in zip(
- a_key.bindparams, b_key.bindparams
- ):
- if not a_param.compare(
- b_param, compare_values=compare_values
- ):
- break
- else:
- # this fails unconditionally since we could not
- # find bound parameter values that differed.
- # Usually we intended to get two distinct keys here
- # so the failure will be more descriptive using the
- # ne_() assertion.
- ne_(a_key.key, b_key.key)
- else:
- ne_(a_key.key, b_key.key)
-
- # ClauseElement-specific test to ensure the cache key
- # collected all the bound parameters that aren't marked
- # as "literal execute"
- if isinstance(case_a[a], ClauseElement) and isinstance(
- case_b[b], ClauseElement
- ):
- assert_a_params = []
- assert_b_params = []
-
- for elem in visitors.iterate(case_a[a]):
- if elem.__visit_name__ == "bindparam":
- assert_a_params.append(elem)
-
- for elem in visitors.iterate(case_b[b]):
- if elem.__visit_name__ == "bindparam":
- assert_b_params.append(elem)
-
- # note we're asserting the order of the params as well as
- # if there are dupes or not. ordering has to be
- # deterministic and matches what a traversal would provide.
- eq_(
- sorted(a_key.bindparams, key=lambda b: b.key),
- sorted(
- util.unique_list(assert_a_params), key=lambda b: b.key
- ),
- )
- eq_(
- sorted(b_key.bindparams, key=lambda b: b.key),
- sorted(
- util.unique_list(assert_b_params), key=lambda b: b.key
- ),
- )
-
- def _run_cache_key_equal_fixture(self, fixture, compare_values):
- case_a = fixture()
- case_b = fixture()
-
- for a, b in itertools.combinations_with_replacement(
- range(len(case_a)), 2
- ):
- self._compare_equal(case_a[a], case_b[b], compare_values)
-
-
-class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
+class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase):
# we are slightly breaking the policy of not having external dialect
# stuff in here, but use pg/mysql as test cases to ensure that these
# objects don't report an inaccurate cache key, which is dependent