]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add CTE cache elements for CompoundSelect, more verify tests
authorFederico Caselli <cfederico87@gmail.com>
Fri, 14 Jun 2024 22:06:46 +0000 (00:06 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Jul 2024 15:09:43 +0000 (11:09 -0400)
Follow up of :ticket:`11471` to fix caching issue where using the
:meth:`.CompoundSelectState.add_cte` method of the
:class:`.CompoundSelectState` construct would not set a correct cache key
which distinguished between different CTE expressions. Also added tests
that would detect issues similar to the one fixed in :ticket:`11544`.

Fixes: #11471
Change-Id: Iae6a91077c987d83cd70ea826daff42855491330

doc/build/changelog/unreleased_20/11471.rst [new file with mode: 0644]
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
test/sql/test_compare.py

diff --git a/doc/build/changelog/unreleased_20/11471.rst b/doc/build/changelog/unreleased_20/11471.rst
new file mode 100644 (file)
index 0000000..4170de0
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 11471
+
+    Follow up of :ticket:`11471` to fix caching issue where using the
+    :meth:`.CompoundSelectState.add_cte` method of the
+    :class:`.CompoundSelectState` construct would not set a correct cache key
+    which distinguished between different CTE expressions. Also added tests
+    that would detect issues similar to the one fixed in :ticket:`11544`.
index 8a1ffba64c3d93cab4ceb5f0c6220ff68c9d84e3..1ecb680e446766fbd328b22c27fafb6ec202eb8f 100644 (file)
@@ -80,7 +80,6 @@ from .elements import TextClause
 from .selectable import TableClause
 from .type_api import to_instance
 from .visitors import ExternallyTraversible
-from .visitors import InternalTraversal
 from .. import event
 from .. import exc
 from .. import inspection
@@ -102,7 +101,6 @@ if typing.TYPE_CHECKING:
     from .elements import BindParameter
     from .functions import Function
     from .type_api import TypeEngine
-    from .visitors import _TraverseInternalsType
     from .visitors import anon_map
     from ..engine import Connection
     from ..engine import Engine
@@ -395,11 +393,6 @@ class Table(
 
     """
 
-    _traverse_internals: _TraverseInternalsType = (
-        TableClause._traverse_internals
-        + [("schema", InternalTraversal.dp_string)]
-    )
-
     if TYPE_CHECKING:
 
         @util.ro_non_memoized_property
index 6fa29fd767fb6a18a3a462d08a12a0ed5420a8af..3c9ca808a3e1f74c10a8af04493443ac22b7f49b 100644 (file)
@@ -3686,7 +3686,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase, Generic[_SB]):
     __visit_name__ = "select_statement_grouping"
     _traverse_internals: _TraverseInternalsType = [
         ("element", InternalTraversal.dp_clauseelement)
-    ]
+    ] + SupportsCloneAnnotations._clone_annotations_traverse_internals
 
     _is_select_container = True
 
@@ -3766,6 +3766,10 @@ class SelectStatementGrouping(GroupedElement, SelectBase, Generic[_SB]):
     def _from_objects(self) -> List[FromClause]:
         return self.element._from_objects
 
+    def add_cte(self, *ctes: CTE, nest_here: bool = False) -> Self:
+        # SelectStatementGrouping not generative: has no attribute '_generate'
+        raise NotImplementedError
+
 
 class GenerativeSelect(SelectBase, Generative):
     """Base class for SELECT statements where additional elements can be
@@ -4313,17 +4317,21 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
 
     __visit_name__ = "compound_select"
 
-    _traverse_internals: _TraverseInternalsType = [
-        ("selects", InternalTraversal.dp_clauseelement_list),
-        ("_limit_clause", InternalTraversal.dp_clauseelement),
-        ("_offset_clause", InternalTraversal.dp_clauseelement),
-        ("_fetch_clause", InternalTraversal.dp_clauseelement),
-        ("_fetch_clause_options", InternalTraversal.dp_plain_dict),
-        ("_order_by_clauses", InternalTraversal.dp_clauseelement_list),
-        ("_group_by_clauses", InternalTraversal.dp_clauseelement_list),
-        ("_for_update_arg", InternalTraversal.dp_clauseelement),
-        ("keyword", InternalTraversal.dp_string),
-    ] + SupportsCloneAnnotations._clone_annotations_traverse_internals
+    _traverse_internals: _TraverseInternalsType = (
+        [
+            ("selects", InternalTraversal.dp_clauseelement_list),
+            ("_limit_clause", InternalTraversal.dp_clauseelement),
+            ("_offset_clause", InternalTraversal.dp_clauseelement),
+            ("_fetch_clause", InternalTraversal.dp_clauseelement),
+            ("_fetch_clause_options", InternalTraversal.dp_plain_dict),
+            ("_order_by_clauses", InternalTraversal.dp_clauseelement_list),
+            ("_group_by_clauses", InternalTraversal.dp_clauseelement_list),
+            ("_for_update_arg", InternalTraversal.dp_clauseelement),
+            ("keyword", InternalTraversal.dp_string),
+        ]
+        + SupportsCloneAnnotations._clone_annotations_traverse_internals
+        + HasCTE._has_ctes_traverse_internals
+    )
 
     selects: List[SelectBase]
 
index a43ea70e1091351378e0d753749c165951a2f253..f9c435f839b49b14d43cf95061fdef5361beb6cd 100644 (file)
@@ -1,4 +1,5 @@
 import importlib
+from inspect import signature
 import itertools
 import random
 
@@ -35,7 +36,6 @@ from sqlalchemy.schema import Sequence
 from sqlalchemy.sql import bindparam
 from sqlalchemy.sql import ColumnElement
 from sqlalchemy.sql import dml
-from sqlalchemy.sql import elements
 from sqlalchemy.sql import False_
 from sqlalchemy.sql import func
 from sqlalchemy.sql import operators
@@ -43,10 +43,11 @@ from sqlalchemy.sql import roles
 from sqlalchemy.sql import True_
 from sqlalchemy.sql import type_coerce
 from sqlalchemy.sql import visitors
+from sqlalchemy.sql.annotation import Annotated
 from sqlalchemy.sql.base import HasCacheKey
+from sqlalchemy.sql.base import SingletonConstant
 from sqlalchemy.sql.elements import _label_reference
 from sqlalchemy.sql.elements import _textual_label_reference
-from sqlalchemy.sql.elements import Annotated
 from sqlalchemy.sql.elements import BindParameter
 from sqlalchemy.sql.elements import ClauseElement
 from sqlalchemy.sql.elements import ClauseList
@@ -62,10 +63,10 @@ from sqlalchemy.sql.lambdas import lambda_stmt
 from sqlalchemy.sql.lambdas import LambdaElement
 from sqlalchemy.sql.lambdas import LambdaOptions
 from sqlalchemy.sql.selectable import _OffsetLimitParam
-from sqlalchemy.sql.selectable import AliasedReturnsRows
 from sqlalchemy.sql.selectable import FromGrouping
 from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
 from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from sqlalchemy.sql.selectable import NoInit
 from sqlalchemy.sql.selectable import Select
 from sqlalchemy.sql.selectable import Selectable
 from sqlalchemy.sql.selectable import SelectStatementGrouping
@@ -214,6 +215,34 @@ class CoreFixtures:
             .columns(a=Integer())
             .add_cte(table_b.select().where(table_b.c.a > 5).cte()),
         ),
+        lambda: (
+            union(
+                select(table_a).where(table_a.c.a > 1),
+                select(table_a).where(table_a.c.a < 1),
+            ).add_cte(select(table_b).where(table_b.c.a > 1).cte("ttt")),
+            union(
+                select(table_a).where(table_a.c.a > 1),
+                select(table_a).where(table_a.c.a < 1),
+            ).add_cte(select(table_b).where(table_b.c.a < 1).cte("ttt")),
+            union(
+                select(table_a).where(table_a.c.a > 1),
+                select(table_a).where(table_a.c.a < 1),
+            )
+            .add_cte(select(table_b).where(table_b.c.a > 1).cte("ttt"))
+            ._annotate({"foo": "bar"}),
+        ),
+        lambda: (
+            union(
+                select(table_a).where(table_a.c.a > 1),
+                select(table_a).where(table_a.c.a < 1),
+            ).self_group(),
+            union(
+                select(table_a).where(table_a.c.a > 1),
+                select(table_a).where(table_a.c.a < 1),
+            )
+            .self_group()
+            ._annotate({"foo": "bar"}),
+        ),
         lambda: (
             literal(1).op("+")(literal(1)),
             literal(1).op("-")(literal(1)),
@@ -1396,6 +1425,246 @@ class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase):
         is_not(ck3, None)
 
 
+def all_hascachekey_subclasses(ignore_subclasses=()):
+    def find_subclasses(cls: type):
+        for s in class_hierarchy(cls):
+            if (
+                # class_hierarchy may return values that
+                # aren't subclasses of cls
+                not issubclass(s, cls)
+                or "_traverse_internals" not in s.__dict__
+                or any(issubclass(s, ignore) for ignore in ignore_subclasses)
+            ):
+                continue
+            yield s
+
+    return dict.fromkeys(find_subclasses(HasCacheKey))
+
+
+class HasCacheKeySubclass(fixtures.TestBase):
+    custom_traverse = {
+        "AnnotatedFunctionAsBinary": {
+            "sql_function",
+            "left_index",
+            "right_index",
+            "modifiers",
+            "_annotations",
+        },
+        "Annotatednext_value": {"sequence", "_annotations"},
+        "FunctionAsBinary": {
+            "sql_function",
+            "left_index",
+            "right_index",
+            "modifiers",
+        },
+        "next_value": {"sequence"},
+    }
+
+    ignore_keys = {
+        "AnnotatedColumn": {"dialect_options"},
+        "SelectStatementGrouping": {
+            "_independent_ctes",
+            "_independent_ctes_opts",
+        },
+    }
+
+    @testing.combinations(*all_hascachekey_subclasses())
+    def test_traverse_internals(self, cls: type):
+        super_traverse = {}
+        # ignore_super = self.ignore_super.get(cls.__name__, set())
+        for s in cls.mro()[1:]:
+            # if s.__name__ in ignore_super:
+            #     continue
+            if s.__name__ == "Executable":
+                continue
+            for attr in s.__dict__:
+                if not attr.endswith("_traverse_internals"):
+                    continue
+                for k, v in s.__dict__[attr]:
+                    if k not in super_traverse:
+                        super_traverse[k] = v
+        traverse_dict = dict(cls.__dict__["_traverse_internals"])
+        eq_(len(cls.__dict__["_traverse_internals"]), len(traverse_dict))
+        if cls.__name__ in self.custom_traverse:
+            eq_(traverse_dict.keys(), self.custom_traverse[cls.__name__])
+        else:
+            ignore = self.ignore_keys.get(cls.__name__, set())
+
+            left_keys = traverse_dict.keys() | ignore
+            is_true(
+                left_keys >= super_traverse.keys(),
+                f"{left_keys} >= {super_traverse.keys()} - missing: "
+                f"{super_traverse.keys() - left_keys} - ignored {ignore}",
+            )
+
+            subset = {
+                k: v for k, v in traverse_dict.items() if k in super_traverse
+            }
+            eq_(
+                subset,
+                {k: v for k, v in super_traverse.items() if k not in ignore},
+            )
+
+    # name -> (traverse names, init args)
+    custom_init = {
+        "BinaryExpression": (
+            {"right", "operator", "type", "negate", "modifiers", "left"},
+            {"right", "operator", "type_", "negate", "modifiers", "left"},
+        ),
+        "BindParameter": (
+            {"literal_execute", "type", "callable", "value", "key"},
+            {"required", "isoutparam", "literal_execute", "type_", "callable_"}
+            | {"unique", "expanding", "quote", "value", "key"},
+        ),
+        "Cast": ({"type", "clause"}, {"type_", "expression"}),
+        "ClauseList": (
+            {"clauses", "operator"},
+            {"group_contents", "group", "operator", "clauses"},
+        ),
+        "ColumnClause": (
+            {"is_literal", "type", "table", "name"},
+            {"type_", "is_literal", "text"},
+        ),
+        "ExpressionClauseList": (
+            {"clauses", "operator"},
+            {"type_", "operator", "clauses"},
+        ),
+        "FromStatement": (
+            {"_raw_columns", "_with_options", "element"}
+            | {"_propagate_attrs", "_with_context_options"},
+            {"element", "entities"},
+        ),
+        "FunctionAsBinary": (
+            {"modifiers", "sql_function", "right_index", "left_index"},
+            {"right_index", "left_index", "fn"},
+        ),
+        "FunctionElement": (
+            {"clause_expr", "_table_value_type", "_with_ordinality"},
+            {"clauses"},
+        ),
+        "Function": (
+            {"_table_value_type", "clause_expr", "_with_ordinality"}
+            | {"packagenames", "type", "name"},
+            {"type_", "packagenames", "name", "clauses"},
+        ),
+        "Label": ({"_element", "type", "name"}, {"type_", "element", "name"}),
+        "LambdaElement": (
+            {"_resolved"},
+            {"role", "opts", "apply_propagate_attrs", "fn"},
+        ),
+        "Load": (
+            {"propagate_to_loaders", "additional_source_entities"}
+            | {"path", "context"},
+            {"entity"},
+        ),
+        "LoaderCriteriaOption": (
+            {"where_criteria", "entity", "propagate_to_loaders"}
+            | {"root_entity", "include_aliases"},
+            {"where_criteria", "include_aliases", "propagate_to_loaders"}
+            | {"entity_or_base", "loader_only", "track_closure_variables"},
+        ),
+        "NullLambdaStatement": ({"_resolved"}, {"statement"}),
+        "ScalarFunctionColumn": (
+            {"type", "fn", "name"},
+            {"type_", "name", "fn"},
+        ),
+        "ScalarValues": (
+            {"_data", "_column_args", "literal_binds"},
+            {"columns", "data", "literal_binds"},
+        ),
+        "Select": (
+            {
+                "_having_criteria",
+                "_distinct",
+                "_group_by_clauses",
+                "_fetch_clause",
+                "_limit_clause",
+                "_label_style",
+                "_order_by_clauses",
+                "_raw_columns",
+                "_correlate_except",
+                "_statement_hints",
+                "_hints",
+                "_independent_ctes",
+                "_distinct_on",
+                "_with_context_options",
+                "_setup_joins",
+                "_suffixes",
+                "_memoized_select_entities",
+                "_for_update_arg",
+                "_prefixes",
+                "_propagate_attrs",
+                "_with_options",
+                "_independent_ctes_opts",
+                "_offset_clause",
+                "_correlate",
+                "_where_criteria",
+                "_annotations",
+                "_fetch_clause_options",
+                "_from_obj",
+            },
+            {"entities"},
+        ),
+        "TableValuedColumn": (
+            {"scalar_alias", "type", "name"},
+            {"type_", "scalar_alias"},
+        ),
+        "TableValueType": ({"_elements"}, {"elements"}),
+        "TextualSelect": (
+            {"column_args", "_annotations", "_independent_ctes"}
+            | {"element", "_independent_ctes_opts"},
+            {"positional", "columns", "text"},
+        ),
+        "Tuple": ({"clauses", "operator"}, {"clauses", "types"}),
+        "TypeClause": ({"type"}, {"type_"}),
+        "TypeCoerce": ({"type", "clause"}, {"type_", "expression"}),
+        "UnaryExpression": (
+            {"modifier", "element", "operator"},
+            {"operator", "wraps_column_expression"}
+            | {"type_", "modifier", "element"},
+        ),
+        "Values": (
+            {"_column_args", "literal_binds", "name", "_data"},
+            {"columns", "name", "literal_binds"},
+        ),
+        "_FrameClause": (
+            {"upper_integer_bind", "upper_type"}
+            | {"lower_type", "lower_integer_bind"},
+            {"range_"},
+        ),
+        "_MemoizedSelectEntities": (
+            {"_with_options", "_raw_columns", "_setup_joins"},
+            {"args"},
+        ),
+        "next_value": ({"sequence"}, {"seq"}),
+    }
+
+    @testing.combinations(
+        *all_hascachekey_subclasses(
+            ignore_subclasses=[Annotated, NoInit, SingletonConstant]
+        )
+    )
+    def test_init_args_in_traversal(self, cls: type):
+        sig = signature(cls.__init__)
+        init_args = set()
+        for p in sig.parameters.values():
+            if (
+                p.name == "self"
+                or p.name.startswith("_")
+                or p.kind in (p.VAR_KEYWORD,)
+            ):
+                continue
+            init_args.add(p.name)
+
+        names = {n for n, _ in cls.__dict__["_traverse_internals"]}
+        if cls.__name__ in self.custom_init:
+            traverse, inits = self.custom_init[cls.__name__]
+            eq_(names, traverse)
+            eq_(init_args, inits)
+        else:
+            is_true(names.issuperset(init_args), f"{names} : {init_args}")
+
+
 class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
     @classmethod
     def setup_test_class(cls):
@@ -1411,21 +1680,16 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
         also included in the fixtures above.
 
         """
-        need = {
+        need = set(
             cls
-            for cls in class_hierarchy(ClauseElement)
-            if issubclass(cls, (ColumnElement, Selectable, LambdaElement))
-            and (
-                "__init__" in cls.__dict__
-                or issubclass(cls, AliasedReturnsRows)
+            for cls in all_hascachekey_subclasses(
+                ignore_subclasses=[Annotated, NoInit, SingletonConstant]
             )
-            and not issubclass(cls, (Annotated, elements._OverrideBinds))
-            and cls.__module__.startswith("sqlalchemy.")
-            and "orm" not in cls.__module__
+            if "orm" not in cls.__module__
             and "compiler" not in cls.__module__
-            and "crud" not in cls.__module__
-            and "dialects" not in cls.__module__  # TODO: dialects?
-        }.difference({ColumnElement, UnaryExpression})
+            and "dialects" not in cls.__module__
+            and issubclass(cls, (ColumnElement, Selectable, LambdaElement))
+        )
 
         for fixture in self.fixtures + self.dont_compare_values_fixtures:
             case_a = fixture()