__visit_name__ = "select_statement_grouping"
_traverse_internals: _TraverseInternalsType = [
("element", InternalTraversal.dp_clauseelement)
- ]
+ ] + SupportsCloneAnnotations._clone_annotations_traverse_internals
_is_select_container = True
def _from_objects(self) -> List[FromClause]:
return self.element._from_objects
+ def add_cte(self, *ctes: CTE, nest_here: bool = False) -> Self:
+ # SelectStatementGrouping not generative: has no attribute '_generate'
+ raise NotImplementedError
+
class GenerativeSelect(SelectBase, Generative):
"""Base class for SELECT statements where additional elements can be
__visit_name__ = "compound_select"
- _traverse_internals: _TraverseInternalsType = [
- ("selects", InternalTraversal.dp_clauseelement_list),
- ("_limit_clause", InternalTraversal.dp_clauseelement),
- ("_offset_clause", InternalTraversal.dp_clauseelement),
- ("_fetch_clause", InternalTraversal.dp_clauseelement),
- ("_fetch_clause_options", InternalTraversal.dp_plain_dict),
- ("_order_by_clauses", InternalTraversal.dp_clauseelement_list),
- ("_group_by_clauses", InternalTraversal.dp_clauseelement_list),
- ("_for_update_arg", InternalTraversal.dp_clauseelement),
- ("keyword", InternalTraversal.dp_string),
- ] + SupportsCloneAnnotations._clone_annotations_traverse_internals
+ _traverse_internals: _TraverseInternalsType = (
+ [
+ ("selects", InternalTraversal.dp_clauseelement_list),
+ ("_limit_clause", InternalTraversal.dp_clauseelement),
+ ("_offset_clause", InternalTraversal.dp_clauseelement),
+ ("_fetch_clause", InternalTraversal.dp_clauseelement),
+ ("_fetch_clause_options", InternalTraversal.dp_plain_dict),
+ ("_order_by_clauses", InternalTraversal.dp_clauseelement_list),
+ ("_group_by_clauses", InternalTraversal.dp_clauseelement_list),
+ ("_for_update_arg", InternalTraversal.dp_clauseelement),
+ ("keyword", InternalTraversal.dp_string),
+ ]
+ + SupportsCloneAnnotations._clone_annotations_traverse_internals
+ + HasCTE._has_ctes_traverse_internals
+ )
selects: List[SelectBase]
import importlib
+from inspect import signature
import itertools
import random
from sqlalchemy.sql import bindparam
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql import dml
-from sqlalchemy.sql import elements
from sqlalchemy.sql import False_
from sqlalchemy.sql import func
from sqlalchemy.sql import operators
from sqlalchemy.sql import True_
from sqlalchemy.sql import type_coerce
from sqlalchemy.sql import visitors
+from sqlalchemy.sql.annotation import Annotated
from sqlalchemy.sql.base import HasCacheKey
+from sqlalchemy.sql.base import SingletonConstant
from sqlalchemy.sql.elements import _label_reference
from sqlalchemy.sql.elements import _textual_label_reference
-from sqlalchemy.sql.elements import Annotated
from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import ClauseList
from sqlalchemy.sql.lambdas import LambdaElement
from sqlalchemy.sql.lambdas import LambdaOptions
from sqlalchemy.sql.selectable import _OffsetLimitParam
-from sqlalchemy.sql.selectable import AliasedReturnsRows
from sqlalchemy.sql.selectable import FromGrouping
from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from sqlalchemy.sql.selectable import NoInit
from sqlalchemy.sql.selectable import Select
from sqlalchemy.sql.selectable import Selectable
from sqlalchemy.sql.selectable import SelectStatementGrouping
.columns(a=Integer())
.add_cte(table_b.select().where(table_b.c.a > 5).cte()),
),
+ lambda: (
+ union(
+ select(table_a).where(table_a.c.a > 1),
+ select(table_a).where(table_a.c.a < 1),
+ ).add_cte(select(table_b).where(table_b.c.a > 1).cte("ttt")),
+ union(
+ select(table_a).where(table_a.c.a > 1),
+ select(table_a).where(table_a.c.a < 1),
+ ).add_cte(select(table_b).where(table_b.c.a < 1).cte("ttt")),
+ union(
+ select(table_a).where(table_a.c.a > 1),
+ select(table_a).where(table_a.c.a < 1),
+ )
+ .add_cte(select(table_b).where(table_b.c.a > 1).cte("ttt"))
+ ._annotate({"foo": "bar"}),
+ ),
+ lambda: (
+ union(
+ select(table_a).where(table_a.c.a > 1),
+ select(table_a).where(table_a.c.a < 1),
+ ).self_group(),
+ union(
+ select(table_a).where(table_a.c.a > 1),
+ select(table_a).where(table_a.c.a < 1),
+ )
+ .self_group()
+ ._annotate({"foo": "bar"}),
+ ),
lambda: (
literal(1).op("+")(literal(1)),
literal(1).op("-")(literal(1)),
is_not(ck3, None)
+def all_hascachekey_subclasses(ignore_subclasses=()):
+ def find_subclasses(cls: type):
+ for s in class_hierarchy(cls):
+ if (
+ # class_hierarchy may return values that
+ # aren't subclasses of cls
+ not issubclass(s, cls)
+ or "_traverse_internals" not in s.__dict__
+ or any(issubclass(s, ignore) for ignore in ignore_subclasses)
+ ):
+ continue
+ yield s
+
+ return dict.fromkeys(find_subclasses(HasCacheKey))
+
+
+class HasCacheKeySubclass(fixtures.TestBase):
+ custom_traverse = {
+ "AnnotatedFunctionAsBinary": {
+ "sql_function",
+ "left_index",
+ "right_index",
+ "modifiers",
+ "_annotations",
+ },
+ "Annotatednext_value": {"sequence", "_annotations"},
+ "FunctionAsBinary": {
+ "sql_function",
+ "left_index",
+ "right_index",
+ "modifiers",
+ },
+ "next_value": {"sequence"},
+ }
+
+ ignore_keys = {
+ "AnnotatedColumn": {"dialect_options"},
+ "SelectStatementGrouping": {
+ "_independent_ctes",
+ "_independent_ctes_opts",
+ },
+ }
+
+ @testing.combinations(*all_hascachekey_subclasses())
+ def test_traverse_internals(self, cls: type):
+ super_traverse = {}
+ # ignore_super = self.ignore_super.get(cls.__name__, set())
+ for s in cls.mro()[1:]:
+ # if s.__name__ in ignore_super:
+ # continue
+ if s.__name__ == "Executable":
+ continue
+ for attr in s.__dict__:
+ if not attr.endswith("_traverse_internals"):
+ continue
+ for k, v in s.__dict__[attr]:
+ if k not in super_traverse:
+ super_traverse[k] = v
+ traverse_dict = dict(cls.__dict__["_traverse_internals"])
+ eq_(len(cls.__dict__["_traverse_internals"]), len(traverse_dict))
+ if cls.__name__ in self.custom_traverse:
+ eq_(traverse_dict.keys(), self.custom_traverse[cls.__name__])
+ else:
+ ignore = self.ignore_keys.get(cls.__name__, set())
+
+ left_keys = traverse_dict.keys() | ignore
+ is_true(
+ left_keys >= super_traverse.keys(),
+ f"{left_keys} >= {super_traverse.keys()} - missing: "
+ f"{super_traverse.keys() - left_keys} - ignored {ignore}",
+ )
+
+ subset = {
+ k: v for k, v in traverse_dict.items() if k in super_traverse
+ }
+ eq_(
+ subset,
+ {k: v for k, v in super_traverse.items() if k not in ignore},
+ )
+
+ # name -> (traverse names, init args)
+ custom_init = {
+ "BinaryExpression": (
+ {"right", "operator", "type", "negate", "modifiers", "left"},
+ {"right", "operator", "type_", "negate", "modifiers", "left"},
+ ),
+ "BindParameter": (
+ {"literal_execute", "type", "callable", "value", "key"},
+ {"required", "isoutparam", "literal_execute", "type_", "callable_"}
+ | {"unique", "expanding", "quote", "value", "key"},
+ ),
+ "Cast": ({"type", "clause"}, {"type_", "expression"}),
+ "ClauseList": (
+ {"clauses", "operator"},
+ {"group_contents", "group", "operator", "clauses"},
+ ),
+ "ColumnClause": (
+ {"is_literal", "type", "table", "name"},
+ {"type_", "is_literal", "text"},
+ ),
+ "ExpressionClauseList": (
+ {"clauses", "operator"},
+ {"type_", "operator", "clauses"},
+ ),
+ "FromStatement": (
+ {"_raw_columns", "_with_options", "element"}
+ | {"_propagate_attrs", "_with_context_options"},
+ {"element", "entities"},
+ ),
+ "FunctionAsBinary": (
+ {"modifiers", "sql_function", "right_index", "left_index"},
+ {"right_index", "left_index", "fn"},
+ ),
+ "FunctionElement": (
+ {"clause_expr", "_table_value_type", "_with_ordinality"},
+ {"clauses"},
+ ),
+ "Function": (
+ {"_table_value_type", "clause_expr", "_with_ordinality"}
+ | {"packagenames", "type", "name"},
+ {"type_", "packagenames", "name", "clauses"},
+ ),
+ "Label": ({"_element", "type", "name"}, {"type_", "element", "name"}),
+ "LambdaElement": (
+ {"_resolved"},
+ {"role", "opts", "apply_propagate_attrs", "fn"},
+ ),
+ "Load": (
+ {"propagate_to_loaders", "additional_source_entities"}
+ | {"path", "context"},
+ {"entity"},
+ ),
+ "LoaderCriteriaOption": (
+ {"where_criteria", "entity", "propagate_to_loaders"}
+ | {"root_entity", "include_aliases"},
+ {"where_criteria", "include_aliases", "propagate_to_loaders"}
+ | {"entity_or_base", "loader_only", "track_closure_variables"},
+ ),
+ "NullLambdaStatement": ({"_resolved"}, {"statement"}),
+ "ScalarFunctionColumn": (
+ {"type", "fn", "name"},
+ {"type_", "name", "fn"},
+ ),
+ "ScalarValues": (
+ {"_data", "_column_args", "literal_binds"},
+ {"columns", "data", "literal_binds"},
+ ),
+ "Select": (
+ {
+ "_having_criteria",
+ "_distinct",
+ "_group_by_clauses",
+ "_fetch_clause",
+ "_limit_clause",
+ "_label_style",
+ "_order_by_clauses",
+ "_raw_columns",
+ "_correlate_except",
+ "_statement_hints",
+ "_hints",
+ "_independent_ctes",
+ "_distinct_on",
+ "_with_context_options",
+ "_setup_joins",
+ "_suffixes",
+ "_memoized_select_entities",
+ "_for_update_arg",
+ "_prefixes",
+ "_propagate_attrs",
+ "_with_options",
+ "_independent_ctes_opts",
+ "_offset_clause",
+ "_correlate",
+ "_where_criteria",
+ "_annotations",
+ "_fetch_clause_options",
+ "_from_obj",
+ },
+ {"entities"},
+ ),
+ "TableValuedColumn": (
+ {"scalar_alias", "type", "name"},
+ {"type_", "scalar_alias"},
+ ),
+ "TableValueType": ({"_elements"}, {"elements"}),
+ "TextualSelect": (
+ {"column_args", "_annotations", "_independent_ctes"}
+ | {"element", "_independent_ctes_opts"},
+ {"positional", "columns", "text"},
+ ),
+ "Tuple": ({"clauses", "operator"}, {"clauses", "types"}),
+ "TypeClause": ({"type"}, {"type_"}),
+ "TypeCoerce": ({"type", "clause"}, {"type_", "expression"}),
+ "UnaryExpression": (
+ {"modifier", "element", "operator"},
+ {"operator", "wraps_column_expression"}
+ | {"type_", "modifier", "element"},
+ ),
+ "Values": (
+ {"_column_args", "literal_binds", "name", "_data"},
+ {"columns", "name", "literal_binds"},
+ ),
+ "_FrameClause": (
+ {"upper_integer_bind", "upper_type"}
+ | {"lower_type", "lower_integer_bind"},
+ {"range_"},
+ ),
+ "_MemoizedSelectEntities": (
+ {"_with_options", "_raw_columns", "_setup_joins"},
+ {"args"},
+ ),
+ "next_value": ({"sequence"}, {"seq"}),
+ }
+
+ @testing.combinations(
+ *all_hascachekey_subclasses(
+ ignore_subclasses=[Annotated, NoInit, SingletonConstant]
+ )
+ )
+ def test_init_args_in_traversal(self, cls: type):
+ sig = signature(cls.__init__)
+ init_args = set()
+ for p in sig.parameters.values():
+ if (
+ p.name == "self"
+ or p.name.startswith("_")
+ or p.kind in (p.VAR_KEYWORD,)
+ ):
+ continue
+ init_args.add(p.name)
+
+ names = {n for n, _ in cls.__dict__["_traverse_internals"]}
+ if cls.__name__ in self.custom_init:
+ traverse, inits = self.custom_init[cls.__name__]
+ eq_(names, traverse)
+ eq_(init_args, inits)
+ else:
+ is_true(names.issuperset(init_args), f"{names} : {init_args}")
+
+
class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
@classmethod
def setup_test_class(cls):
also included in the fixtures above.
"""
- need = {
+ need = set(
cls
- for cls in class_hierarchy(ClauseElement)
- if issubclass(cls, (ColumnElement, Selectable, LambdaElement))
- and (
- "__init__" in cls.__dict__
- or issubclass(cls, AliasedReturnsRows)
+ for cls in all_hascachekey_subclasses(
+ ignore_subclasses=[Annotated, NoInit, SingletonConstant]
)
- and not issubclass(cls, (Annotated, elements._OverrideBinds))
- and cls.__module__.startswith("sqlalchemy.")
- and "orm" not in cls.__module__
+ if "orm" not in cls.__module__
and "compiler" not in cls.__module__
- and "crud" not in cls.__module__
- and "dialects" not in cls.__module__ # TODO: dialects?
- }.difference({ColumnElement, UnaryExpression})
+ and "dialects" not in cls.__module__
+ and issubclass(cls, (ColumnElement, Selectable, LambdaElement))
+ )
for fixture in self.fixtures + self.dont_compare_values_fixtures:
case_a = fixture()