From: Federico Caselli Date: Fri, 14 Jun 2024 22:06:46 +0000 (+0200) Subject: add CTE cache elements for CompoundSelect, more verify tests X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=881be0a21633b3fee101cb34cc611904b8cba618;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add CTE cache elements for CompoundSelect, more verify tests Follow up of :ticket:`11471` to fix caching issue where using the :meth:`.CompoundSelectState.add_cte` method of the :class:`.CompoundSelectState` construct would not set a correct cache key which distinguished between different CTE expressions. Also added tests that would detect issues similar to the one fixed in :ticket:`11544`. Fixes: #11471 Change-Id: Iae6a91077c987d83cd70ea826daff42855491330 --- diff --git a/doc/build/changelog/unreleased_20/11471.rst b/doc/build/changelog/unreleased_20/11471.rst new file mode 100644 index 0000000000..4170de0298 --- /dev/null +++ b/doc/build/changelog/unreleased_20/11471.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, sql + :tickets: 11471 + + Follow up of :ticket:`11471` to fix caching issue where using the + :meth:`.CompoundSelectState.add_cte` method of the + :class:`.CompoundSelectState` construct would not set a correct cache key + which distinguished between different CTE expressions. Also added tests + that would detect issues similar to the one fixed in :ticket:`11544`. diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 8a1ffba64c..1ecb680e44 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -80,7 +80,6 @@ from .elements import TextClause from .selectable import TableClause from .type_api import to_instance from .visitors import ExternallyTraversible -from .visitors import InternalTraversal from .. import event from .. import exc from .. import inspection @@ -102,7 +101,6 @@ if typing.TYPE_CHECKING: from .elements import BindParameter from .functions import Function from .type_api import TypeEngine - from .visitors import _TraverseInternalsType from .visitors import anon_map from ..engine import Connection from ..engine import Engine @@ -395,11 +393,6 @@ class Table( """ - _traverse_internals: _TraverseInternalsType = ( - TableClause._traverse_internals - + [("schema", InternalTraversal.dp_string)] - ) - if TYPE_CHECKING: @util.ro_non_memoized_property diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 6fa29fd767..3c9ca808a3 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3686,7 +3686,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase, Generic[_SB]): __visit_name__ = "select_statement_grouping" _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement) - ] + ] + SupportsCloneAnnotations._clone_annotations_traverse_internals _is_select_container = True @@ -3766,6 +3766,10 @@ class SelectStatementGrouping(GroupedElement, SelectBase, Generic[_SB]): def _from_objects(self) -> List[FromClause]: return self.element._from_objects + def add_cte(self, *ctes: CTE, nest_here: bool = False) -> Self: + # SelectStatementGrouping not generative: has no attribute '_generate' + raise NotImplementedError + class GenerativeSelect(SelectBase, Generative): """Base class for SELECT statements where additional elements can be @@ -4313,17 +4317,21 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): __visit_name__ = "compound_select" - _traverse_internals: _TraverseInternalsType = [ - ("selects", InternalTraversal.dp_clauseelement_list), - ("_limit_clause", InternalTraversal.dp_clauseelement), - ("_offset_clause", InternalTraversal.dp_clauseelement), - ("_fetch_clause", InternalTraversal.dp_clauseelement), - ("_fetch_clause_options", InternalTraversal.dp_plain_dict), - ("_order_by_clauses", InternalTraversal.dp_clauseelement_list), - ("_group_by_clauses", InternalTraversal.dp_clauseelement_list), - ("_for_update_arg", InternalTraversal.dp_clauseelement), - ("keyword", InternalTraversal.dp_string), - ] + SupportsCloneAnnotations._clone_annotations_traverse_internals + _traverse_internals: _TraverseInternalsType = ( + [ + ("selects", InternalTraversal.dp_clauseelement_list), + ("_limit_clause", InternalTraversal.dp_clauseelement), + ("_offset_clause", InternalTraversal.dp_clauseelement), + ("_fetch_clause", InternalTraversal.dp_clauseelement), + ("_fetch_clause_options", InternalTraversal.dp_plain_dict), + ("_order_by_clauses", InternalTraversal.dp_clauseelement_list), + ("_group_by_clauses", InternalTraversal.dp_clauseelement_list), + ("_for_update_arg", InternalTraversal.dp_clauseelement), + ("keyword", InternalTraversal.dp_string), + ] + + SupportsCloneAnnotations._clone_annotations_traverse_internals + + HasCTE._has_ctes_traverse_internals + ) selects: List[SelectBase] diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index a43ea70e10..f9c435f839 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -1,4 +1,5 @@ import importlib +from inspect import signature import itertools import random @@ -35,7 +36,6 @@ from sqlalchemy.schema import Sequence from sqlalchemy.sql import bindparam from sqlalchemy.sql import ColumnElement from sqlalchemy.sql import dml -from sqlalchemy.sql import elements from sqlalchemy.sql import False_ from sqlalchemy.sql import func from sqlalchemy.sql import operators @@ -43,10 +43,11 @@ from sqlalchemy.sql import roles from sqlalchemy.sql import True_ from sqlalchemy.sql import type_coerce from sqlalchemy.sql import visitors +from sqlalchemy.sql.annotation import Annotated from sqlalchemy.sql.base import HasCacheKey +from sqlalchemy.sql.base import SingletonConstant from sqlalchemy.sql.elements import _label_reference from sqlalchemy.sql.elements import _textual_label_reference -from sqlalchemy.sql.elements import Annotated from sqlalchemy.sql.elements import BindParameter from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.elements import ClauseList @@ -62,10 +63,10 @@ from sqlalchemy.sql.lambdas import lambda_stmt from sqlalchemy.sql.lambdas import LambdaElement from sqlalchemy.sql.lambdas import LambdaOptions from sqlalchemy.sql.selectable import _OffsetLimitParam -from sqlalchemy.sql.selectable import AliasedReturnsRows from sqlalchemy.sql.selectable import FromGrouping from sqlalchemy.sql.selectable import LABEL_STYLE_NONE from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from sqlalchemy.sql.selectable import NoInit from sqlalchemy.sql.selectable import Select from sqlalchemy.sql.selectable import Selectable from sqlalchemy.sql.selectable import SelectStatementGrouping @@ -214,6 +215,34 @@ class CoreFixtures: .columns(a=Integer()) .add_cte(table_b.select().where(table_b.c.a > 5).cte()), ), + lambda: ( + union( + select(table_a).where(table_a.c.a > 1), + select(table_a).where(table_a.c.a < 1), + ).add_cte(select(table_b).where(table_b.c.a > 1).cte("ttt")), + union( + select(table_a).where(table_a.c.a > 1), + select(table_a).where(table_a.c.a < 1), + ).add_cte(select(table_b).where(table_b.c.a < 1).cte("ttt")), + union( + select(table_a).where(table_a.c.a > 1), + select(table_a).where(table_a.c.a < 1), + ) + .add_cte(select(table_b).where(table_b.c.a > 1).cte("ttt")) + ._annotate({"foo": "bar"}), + ), + lambda: ( + union( + select(table_a).where(table_a.c.a > 1), + select(table_a).where(table_a.c.a < 1), + ).self_group(), + union( + select(table_a).where(table_a.c.a > 1), + select(table_a).where(table_a.c.a < 1), + ) + .self_group() + ._annotate({"foo": "bar"}), + ), lambda: ( literal(1).op("+")(literal(1)), literal(1).op("-")(literal(1)), @@ -1396,6 +1425,246 @@ class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase): is_not(ck3, None) +def all_hascachekey_subclasses(ignore_subclasses=()): + def find_subclasses(cls: type): + for s in class_hierarchy(cls): + if ( + # class_hierarchy may return values that + # aren't subclasses of cls + not issubclass(s, cls) + or "_traverse_internals" not in s.__dict__ + or any(issubclass(s, ignore) for ignore in ignore_subclasses) + ): + continue + yield s + + return dict.fromkeys(find_subclasses(HasCacheKey)) + + +class HasCacheKeySubclass(fixtures.TestBase): + custom_traverse = { + "AnnotatedFunctionAsBinary": { + "sql_function", + "left_index", + "right_index", + "modifiers", + "_annotations", + }, + "Annotatednext_value": {"sequence", "_annotations"}, + "FunctionAsBinary": { + "sql_function", + "left_index", + "right_index", + "modifiers", + }, + "next_value": {"sequence"}, + } + + ignore_keys = { + "AnnotatedColumn": {"dialect_options"}, + "SelectStatementGrouping": { + "_independent_ctes", + "_independent_ctes_opts", + }, + } + + @testing.combinations(*all_hascachekey_subclasses()) + def test_traverse_internals(self, cls: type): + super_traverse = {} + # ignore_super = self.ignore_super.get(cls.__name__, set()) + for s in cls.mro()[1:]: + # if s.__name__ in ignore_super: + # continue + if s.__name__ == "Executable": + continue + for attr in s.__dict__: + if not attr.endswith("_traverse_internals"): + continue + for k, v in s.__dict__[attr]: + if k not in super_traverse: + super_traverse[k] = v + traverse_dict = dict(cls.__dict__["_traverse_internals"]) + eq_(len(cls.__dict__["_traverse_internals"]), len(traverse_dict)) + if cls.__name__ in self.custom_traverse: + eq_(traverse_dict.keys(), self.custom_traverse[cls.__name__]) + else: + ignore = self.ignore_keys.get(cls.__name__, set()) + + left_keys = traverse_dict.keys() | ignore + is_true( + left_keys >= super_traverse.keys(), + f"{left_keys} >= {super_traverse.keys()} - missing: " + f"{super_traverse.keys() - left_keys} - ignored {ignore}", + ) + + subset = { + k: v for k, v in traverse_dict.items() if k in super_traverse + } + eq_( + subset, + {k: v for k, v in super_traverse.items() if k not in ignore}, + ) + + # name -> (traverse names, init args) + custom_init = { + "BinaryExpression": ( + {"right", "operator", "type", "negate", "modifiers", "left"}, + {"right", "operator", "type_", "negate", "modifiers", "left"}, + ), + "BindParameter": ( + {"literal_execute", "type", "callable", "value", "key"}, + {"required", "isoutparam", "literal_execute", "type_", "callable_"} + | {"unique", "expanding", "quote", "value", "key"}, + ), + "Cast": ({"type", "clause"}, {"type_", "expression"}), + "ClauseList": ( + {"clauses", "operator"}, + {"group_contents", "group", "operator", "clauses"}, + ), + "ColumnClause": ( + {"is_literal", "type", "table", "name"}, + {"type_", "is_literal", "text"}, + ), + "ExpressionClauseList": ( + {"clauses", "operator"}, + {"type_", "operator", "clauses"}, + ), + "FromStatement": ( + {"_raw_columns", "_with_options", "element"} + | {"_propagate_attrs", "_with_context_options"}, + {"element", "entities"}, + ), + "FunctionAsBinary": ( + {"modifiers", "sql_function", "right_index", "left_index"}, + {"right_index", "left_index", "fn"}, + ), + "FunctionElement": ( + {"clause_expr", "_table_value_type", "_with_ordinality"}, + {"clauses"}, + ), + "Function": ( + {"_table_value_type", "clause_expr", "_with_ordinality"} + | {"packagenames", "type", "name"}, + {"type_", "packagenames", "name", "clauses"}, + ), + "Label": ({"_element", "type", "name"}, {"type_", "element", "name"}), + "LambdaElement": ( + {"_resolved"}, + {"role", "opts", "apply_propagate_attrs", "fn"}, + ), + "Load": ( + {"propagate_to_loaders", "additional_source_entities"} + | {"path", "context"}, + {"entity"}, + ), + "LoaderCriteriaOption": ( + {"where_criteria", "entity", "propagate_to_loaders"} + | {"root_entity", "include_aliases"}, + {"where_criteria", "include_aliases", "propagate_to_loaders"} + | {"entity_or_base", "loader_only", "track_closure_variables"}, + ), + "NullLambdaStatement": ({"_resolved"}, {"statement"}), + "ScalarFunctionColumn": ( + {"type", "fn", "name"}, + {"type_", "name", "fn"}, + ), + "ScalarValues": ( + {"_data", "_column_args", "literal_binds"}, + {"columns", "data", "literal_binds"}, + ), + "Select": ( + { + "_having_criteria", + "_distinct", + "_group_by_clauses", + "_fetch_clause", + "_limit_clause", + "_label_style", + "_order_by_clauses", + "_raw_columns", + "_correlate_except", + "_statement_hints", + "_hints", + "_independent_ctes", + "_distinct_on", + "_with_context_options", + "_setup_joins", + "_suffixes", + "_memoized_select_entities", + "_for_update_arg", + "_prefixes", + "_propagate_attrs", + "_with_options", + "_independent_ctes_opts", + "_offset_clause", + "_correlate", + "_where_criteria", + "_annotations", + "_fetch_clause_options", + "_from_obj", + }, + {"entities"}, + ), + "TableValuedColumn": ( + {"scalar_alias", "type", "name"}, + {"type_", "scalar_alias"}, + ), + "TableValueType": ({"_elements"}, {"elements"}), + "TextualSelect": ( + {"column_args", "_annotations", "_independent_ctes"} + | {"element", "_independent_ctes_opts"}, + {"positional", "columns", "text"}, + ), + "Tuple": ({"clauses", "operator"}, {"clauses", "types"}), + "TypeClause": ({"type"}, {"type_"}), + "TypeCoerce": ({"type", "clause"}, {"type_", "expression"}), + "UnaryExpression": ( + {"modifier", "element", "operator"}, + {"operator", "wraps_column_expression"} + | {"type_", "modifier", "element"}, + ), + "Values": ( + {"_column_args", "literal_binds", "name", "_data"}, + {"columns", "name", "literal_binds"}, + ), + "_FrameClause": ( + {"upper_integer_bind", "upper_type"} + | {"lower_type", "lower_integer_bind"}, + {"range_"}, + ), + "_MemoizedSelectEntities": ( + {"_with_options", "_raw_columns", "_setup_joins"}, + {"args"}, + ), + "next_value": ({"sequence"}, {"seq"}), + } + + @testing.combinations( + *all_hascachekey_subclasses( + ignore_subclasses=[Annotated, NoInit, SingletonConstant] + ) + ) + def test_init_args_in_traversal(self, cls: type): + sig = signature(cls.__init__) + init_args = set() + for p in sig.parameters.values(): + if ( + p.name == "self" + or p.name.startswith("_") + or p.kind in (p.VAR_KEYWORD,) + ): + continue + init_args.add(p.name) + + names = {n for n, _ in cls.__dict__["_traverse_internals"]} + if cls.__name__ in self.custom_init: + traverse, inits = self.custom_init[cls.__name__] + eq_(names, traverse) + eq_(init_args, inits) + else: + is_true(names.issuperset(init_args), f"{names} : {init_args}") + + class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): @classmethod def setup_test_class(cls): @@ -1411,21 +1680,16 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): also included in the fixtures above. """ - need = { + need = set( cls - for cls in class_hierarchy(ClauseElement) - if issubclass(cls, (ColumnElement, Selectable, LambdaElement)) - and ( - "__init__" in cls.__dict__ - or issubclass(cls, AliasedReturnsRows) + for cls in all_hascachekey_subclasses( + ignore_subclasses=[Annotated, NoInit, SingletonConstant] ) - and not issubclass(cls, (Annotated, elements._OverrideBinds)) - and cls.__module__.startswith("sqlalchemy.") - and "orm" not in cls.__module__ + if "orm" not in cls.__module__ and "compiler" not in cls.__module__ - and "crud" not in cls.__module__ - and "dialects" not in cls.__module__ # TODO: dialects? - }.difference({ColumnElement, UnaryExpression}) + and "dialects" not in cls.__module__ + and issubclass(cls, (ColumnElement, Selectable, LambdaElement)) + ) for fixture in self.fixtures + self.dont_compare_values_fixtures: case_a = fixture()