From: Mike Bayer Date: Fri, 25 Mar 2022 21:08:48 +0000 (-0400) Subject: pep-484: the pep-484ening, SQL part three X-Git-Tag: rel_2_0_0b1~390^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4e754a8914a1c2c16c97bdf363d2e24bfa823730;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep-484: the pep-484ening, SQL part three hitting DML which is causing us to open up the ColumnCollection structure a bit, as we do put anonymous column expressions with None here. However, we still want Table /TableClause to have named column collections that don't return None, so parametrize the "key" in this collection also. * rename some "immutable" elements to "readonly". we change the contents of immutablecolumncollection underneath, so it's not "immutable" Change-Id: I2593995a4e5c6eae874bed5bf76117198be8ae97 --- diff --git a/lib/sqlalchemy/cyextension/immutabledict.pyx b/lib/sqlalchemy/cyextension/immutabledict.pyx index 861e7574da..6ab2553112 100644 --- a/lib/sqlalchemy/cyextension/immutabledict.pyx +++ b/lib/sqlalchemy/cyextension/immutabledict.pyx @@ -1,18 +1,24 @@ from cpython.dict cimport PyDict_New, PyDict_Update, PyDict_Size +def _readonly_fn(obj): + raise TypeError( + "%s object is immutable and/or readonly" % obj.__class__.__name__) + + def _immutable_fn(obj): - raise TypeError("%s object is immutable" % obj.__class__.__name__) + raise TypeError( + "%s object is immutable" % obj.__class__.__name__) -class ImmutableContainer: +class ReadOnlyContainer: __slots__ = () - def _immutable(self, *a,**kw): - _immutable_fn(self) + def _readonly(self, *a,**kw): + _readonly_fn(self) - __delitem__ = __setitem__ = __setattr__ = _immutable + __delitem__ = __setitem__ = __setattr__ = _readonly class ImmutableDictBase(dict): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 65cb57e10a..85ce91deb1 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -54,7 +54,6 @@ from ..sql import expression from ..sql._typing import is_tuple_type from ..sql.compiler import DDLCompiler from ..sql.compiler import SQLCompiler -from ..sql.elements import ColumnClause from ..sql.elements import quoted_name from ..sql.schema import default_is_scalar @@ -88,6 +87,7 @@ if typing.TYPE_CHECKING: from ..sql.dml import DMLState from ..sql.dml import UpdateBase from ..sql.elements import BindParameter + from ..sql.roles import ColumnsClauseRole from ..sql.schema import Column from ..sql.schema import ColumnDefault from ..sql.type_api import _BindProcessorType @@ -1166,7 +1166,7 @@ class DefaultExecutionContext(ExecutionContext): return () @util.memoized_property - def returning_cols(self) -> Optional[Sequence[ColumnClause[Any]]]: + def returning_cols(self) -> Optional[Sequence[ColumnsClauseRole]]: if TYPE_CHECKING: assert isinstance(self.compiled, SQLCompiler) return self.compiled.returning diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 92b3ce54f7..5ca8b03dd6 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -813,6 +813,7 @@ from typing import Generic from typing import List from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING @@ -824,15 +825,20 @@ from ..orm import attributes from ..orm import InspectionAttrExtensionType from ..orm import interfaces from ..orm import ORMDescriptor -from ..sql._typing import is_has_column_element_clause_element +from ..sql._typing import is_has_clause_element from ..sql.elements import ColumnElement from ..sql.elements import SQLCoreOperations from ..util.typing import Literal from ..util.typing import Protocol + if TYPE_CHECKING: from ..orm.util import AliasedInsp + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _HasClauseElement from ..sql.operators import OperatorType + from ..sql.roles import ColumnsClauseRole _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) @@ -878,10 +884,12 @@ class _HybridSetterType(Protocol[_T_con]): ... -class _HybridUpdaterType(Protocol[_T]): +class _HybridUpdaterType(Protocol[_T_con]): def __call__( - self, cls: Type[Any], value: Union[_T, SQLCoreOperations[_T]] - ) -> List[Tuple[SQLCoreOperations[_T], Any]]: + self, + cls: Type[Any], + value: Union[_T_con, _ColumnExpressionArgument[_T_con]], + ) -> List[Tuple[_DMLColumnArgument, Any]]: ... @@ -890,8 +898,10 @@ class _HybridDeleterType(Protocol[_T_co]): ... -class _HybridExprCallableType(Protocol[_T]): - def __call__(self, cls: Any) -> SQLCoreOperations[_T]: +class _HybridExprCallableType(Protocol[_T_co]): + def __call__( + self, cls: Any + ) -> Union[_HasClauseElement, ColumnElement[_T_co]]: ... @@ -1273,17 +1283,21 @@ class Comparator(interfaces.PropComparator[_T]): :class:`~.orm.interfaces.PropComparator` classes for usage with hybrids.""" - def __init__(self, expression: SQLCoreOperations[_T]): + def __init__( + self, expression: Union[_HasClauseElement, ColumnElement[_T]] + ): self.expression = expression - def __clause_element__(self) -> ColumnElement[_T]: + def __clause_element__(self) -> ColumnsClauseRole: expr = self.expression - if is_has_column_element_clause_element(expr): - expr = expr.__clause_element__() + if is_has_clause_element(expr): + ret_expr = expr.__clause_element__() + else: + if TYPE_CHECKING: + assert isinstance(expr, ColumnElement) + ret_expr = expr - elif TYPE_CHECKING: - assert isinstance(expr, ColumnElement) - return expr + return ret_expr @util.non_memoized_property def property(self) -> Any: @@ -1298,7 +1312,7 @@ class ExprComparator(Comparator[_T]): def __init__( self, cls: Type[Any], - expression: SQLCoreOperations[_T], + expression: Union[_HasClauseElement, ColumnElement[_T]], hybrid: hybrid_property[_T], ): self.cls = cls @@ -1314,7 +1328,7 @@ class ExprComparator(Comparator[_T]): def _bulk_update_tuples( self, value: Any - ) -> List[Tuple[SQLCoreOperations[_T], Any]]: + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: if isinstance(self.expression, attributes.QueryableAttribute): return self.expression._bulk_update_tuples(value) elif self.hybrid.update_expr is not None: diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 2b6ca400e9..3d34927105 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -68,6 +68,7 @@ from ..sql import traversals from ..sql import visitors if typing.TYPE_CHECKING: + from ..sql.dml import _DMLColumnElement from ..sql.elements import ColumnElement from ..sql.elements import SQLCoreOperations @@ -281,7 +282,7 @@ class QueryableAttribute( def _bulk_update_tuples( self, value: Any - ) -> List[Tuple[SQLCoreOperations[_T], Any]]: + ) -> List[Tuple[_DMLColumnElement, Any]]: """Return setter tuples for a bulk UPDATE.""" return self.comparator._bulk_update_tuples(value) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index d797741873..b4228323b4 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -24,6 +24,7 @@ from typing import Any from typing import cast from typing import List from typing import Optional +from typing import Sequence from typing import Tuple from typing import Type from typing import TypeVar @@ -50,7 +51,6 @@ from .. import util from ..sql import operators from ..sql import roles from ..sql import visitors -from ..sql._typing import _ColumnsClauseElement from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey from ..sql.elements import SQLCoreOperations @@ -60,6 +60,8 @@ from ..util.typing import TypedDict if typing.TYPE_CHECKING: from .decl_api import RegistryType + from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _DMLColumnArgument _T = TypeVar("_T", bound=Any) @@ -90,8 +92,8 @@ class ORMColumnDescription(TypedDict): name: str type: Union[Type, TypeEngine] aliased: bool - expr: _ColumnsClauseElement - entity: Optional[_ColumnsClauseElement] + expr: _ColumnsClauseArgument + entity: Optional[_ColumnsClauseArgument] class _IntrospectsAnnotations: @@ -468,7 +470,7 @@ class PropComparator(SQLORMOperations[_T]): def _bulk_update_tuples( self, value: Any - ) -> List[Tuple[SQLCoreOperations[_T], Any]]: + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: """Receive a SQL expression that represents a value in the SET clause of an UPDATE statement. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 7d1fc76436..e463dcdb57 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1749,6 +1749,7 @@ class Mapper( col.key = col._tq_key_label = key self.columns.add(col, key) + for col in prop.columns + prop._orig_columns: for col in col.proxy_set: self._columntoproperty[col] = prop @@ -2381,7 +2382,7 @@ class Mapper( yield c @HasMemoized.memoized_attribute - def attrs(self) -> util.ImmutableProperties["MapperProperty"]: + def attrs(self) -> util.ReadOnlyProperties["MapperProperty"]: """A namespace of all :class:`.MapperProperty` objects associated this mapper. @@ -2416,7 +2417,7 @@ class Mapper( """ self._check_configure() - return util.ImmutableProperties(self._props) + return util.ReadOnlyProperties(self._props) @HasMemoized.memoized_attribute def all_orm_descriptors(self): @@ -2484,7 +2485,7 @@ class Mapper( :attr:`_orm.Mapper.attrs` """ - return util.ImmutableProperties( + return util.ReadOnlyProperties( dict(self.class_manager._all_sqla_attributes()) ) @@ -2571,7 +2572,7 @@ class Mapper( def _filter_properties(self, type_): self._check_configure() - return util.ImmutableProperties( + return util.ReadOnlyProperties( util.OrderedDict( (k, v) for k, v in self._props.items() if isinstance(v, type_) ) diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 6478aac15b..f2cddad53b 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -2206,7 +2206,7 @@ class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState): if opt._is_criteria_option: opt.get_global_criteria(extra_criteria_attributes) - if not statement._preserve_parameter_order and statement._values: + if statement._values: self._resolved_values = dict(self._resolved_values) new_stmt = sql.Update.__new__(sql.Update) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index ea5d5406ef..18a14012f8 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -60,7 +60,7 @@ from ..sql import roles from ..sql import Select from ..sql import util as sql_util from ..sql import visitors -from ..sql._typing import _FromClauseElement +from ..sql._typing import _FromClauseArgument from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _entity_namespace_key from ..sql.base import _generative @@ -2018,7 +2018,7 @@ class Query( @_generative @_assertions(_no_clauseelement_condition) def select_from( - self: SelfQuery, *from_obj: _FromClauseElement + self: SelfQuery, *from_obj: _FromClauseArgument ) -> SelfQuery: r"""Set the FROM clause of this :class:`.Query` explicitly. diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 4140d52c5c..58820fef62 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -56,7 +56,7 @@ from ..sql import coercions from ..sql import dml from ..sql import roles from ..sql import visitors -from ..sql._typing import _ColumnsClauseElement +from ..sql._typing import _ColumnsClauseArgument from ..sql.base import CompileState from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util.typing import Literal @@ -2040,7 +2040,7 @@ class Session(_SessionClassMethods): ) def query( - self, *entities: "_ColumnsClauseElement", **kwargs: Any + self, *entities: "_ColumnsClauseArgument", **kwargs: Any ) -> "Query": """Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 58fa3e41a1..c3e4e299ab 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -122,7 +122,7 @@ class InstanceState(interfaces.InspectionAttrInfo): since the last flush. """ - return util.ImmutableProperties( + return util.ReadOnlyProperties( dict((key, AttributeState(self, key)) for key in self.manager) ) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 2b49d44004..baca8f5476 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -632,7 +632,6 @@ class AliasedInsp( ORMEntityColumnsClauseRole, ORMFromClauseRole, sql_base.HasCacheKey, - roles.HasFromClauseElement, InspectionAttr, MemoizedSlots, ): diff --git a/lib/sqlalchemy/sql/_dml_constructors.py b/lib/sqlalchemy/sql/_dml_constructors.py index a8c24413fc..835819bacb 100644 --- a/lib/sqlalchemy/sql/_dml_constructors.py +++ b/lib/sqlalchemy/sql/_dml_constructors.py @@ -112,82 +112,10 @@ def update(table): object representing the database table to be updated. - :param whereclause: Optional SQL expression describing the ``WHERE`` - condition of the ``UPDATE`` statement; is equivalent to using the - more modern :meth:`~Update.where()` method to specify the ``WHERE`` - clause. - - :param values: - Optional dictionary which specifies the ``SET`` conditions of the - ``UPDATE``. If left as ``None``, the ``SET`` - conditions are determined from those parameters passed to the - statement during the execution and/or compilation of the - statement. When compiled standalone without any parameters, - the ``SET`` clause generates for all columns. - - Modern applications may prefer to use the generative - :meth:`_expression.Update.values` method to set the values of the - UPDATE statement. - - :param inline: - if True, SQL defaults present on :class:`_schema.Column` objects via - the ``default`` keyword will be compiled 'inline' into the statement - and not pre-executed. This means that their values will not - be available in the dictionary returned from - :meth:`_engine.CursorResult.last_updated_params`. - - :param preserve_parameter_order: if True, the update statement is - expected to receive parameters **only** via the - :meth:`_expression.Update.values` method, - and they must be passed as a Python - ``list`` of 2-tuples. The rendered UPDATE statement will emit the SET - clause for each referenced column maintaining this order. - - .. versionadded:: 1.0.10 - - .. seealso:: - - :ref:`updates_order_parameters` - illustrates the - :meth:`_expression.Update.ordered_values` method. - - If both ``values`` and compile-time bind parameters are present, the - compile-time bind parameters override the information specified - within ``values`` on a per-key basis. - - The keys within ``values`` can be either :class:`_schema.Column` - objects or their string identifiers (specifically the "key" of the - :class:`_schema.Column`, normally but not necessarily equivalent to - its "name"). Normally, the - :class:`_schema.Column` objects used here are expected to be - part of the target :class:`_schema.Table` that is the table - to be updated. However when using MySQL, a multiple-table - UPDATE statement can refer to columns from any of - the tables referred to in the WHERE clause. - - The values referred to in ``values`` are typically: - - * a literal data value (i.e. string, number, etc.) - * a SQL expression, such as a related :class:`_schema.Column`, - a scalar-returning :func:`_expression.select` construct, - etc. - - When combining :func:`_expression.select` constructs within the - values clause of an :func:`_expression.update` - construct, the subquery represented - by the :func:`_expression.select` should be *correlated* to the - parent table, that is, providing criterion which links the table inside - the subquery to the outer table being updated:: - - users.update().values( - name=select(addresses.c.email_address).\ - where(addresses.c.user_id==users.c.id).\ - scalar_subquery() - ) .. seealso:: - :ref:`inserts_and_updates` - SQL Expression - Language Tutorial + :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial` """ @@ -210,24 +138,12 @@ def delete(table): :meth:`_expression.TableClause.delete` method on :class:`_schema.Table`. - .. seealso:: - - :ref:`inserts_and_updates` - in the - :ref:`1.x tutorial ` - - :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial` - - :param table: The table to delete rows from. - :param whereclause: Optional SQL expression describing the ``WHERE`` - condition of the ``DELETE`` statement; is equivalent to using the - more modern :meth:`~Delete.where()` method to specify the ``WHERE`` - clause. - .. seealso:: - :ref:`deletes` - SQL Expression Tutorial + :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial` + """ return Delete(table) diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index aabd3871e1..f647ae927a 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -48,7 +48,7 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from . import sqltypes - from ._typing import _ColumnExpression + from ._typing import _ColumnExpressionArgument from ._typing import _TypeEngineArgument from .elements import BinaryExpression from .functions import FunctionElement @@ -58,7 +58,7 @@ if typing.TYPE_CHECKING: _T = TypeVar("_T") -def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]: +def all_(expr: _ColumnExpressionArgument[_T]) -> CollectionAggregate[bool]: """Produce an ALL expression. For dialects such as that of PostgreSQL, this operator applies @@ -112,7 +112,7 @@ def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]: return CollectionAggregate._create_all(expr) -def and_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList: +def and_(*clauses: _ColumnExpressionArgument[bool]) -> ColumnElement[bool]: r"""Produce a conjunction of expressions joined by ``AND``. E.g.:: @@ -173,7 +173,7 @@ def and_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList: return BooleanClauseList.and_(*clauses) -def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]: +def any_(expr: _ColumnExpressionArgument[_T]) -> CollectionAggregate[bool]: """Produce an ANY expression. For dialects such as that of PostgreSQL, this operator applies @@ -227,7 +227,7 @@ def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]: return CollectionAggregate._create_any(expr) -def asc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: +def asc(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: """Produce an ascending ``ORDER BY`` clause element. e.g.:: @@ -266,7 +266,7 @@ def asc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: def collate( - expression: _ColumnExpression[str], collation: str + expression: _ColumnExpressionArgument[str], collation: str ) -> BinaryExpression[str]: """Return the clause ``expression COLLATE collation``. @@ -289,7 +289,7 @@ def collate( def between( - expr: _ColumnExpression[_T], + expr: _ColumnExpressionArgument[_T], lower_bound: Any, upper_bound: Any, symmetric: bool = False, @@ -364,17 +364,19 @@ def outparam( return BindParameter(key, None, type_=type_, unique=False, isoutparam=True) +# mypy insists that BinaryExpression and _HasClauseElement protocol overlap. +# they do not. at all. bug in mypy? @overload -def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: +def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: # type: ignore ... @overload -def not_(clause: _ColumnExpression[_T]) -> ColumnElement[_T]: +def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: ... -def not_(clause: _ColumnExpression[_T]) -> ColumnElement[_T]: +def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: """Return a negation of the given clause, i.e. ``NOT(clause)``. The ``~`` operator is also overloaded on all @@ -646,7 +648,7 @@ def bindparam( def case( *whens: Union[ - typing_Tuple[_ColumnExpression[bool], Any], Mapping[Any, Any] + typing_Tuple[_ColumnExpressionArgument[bool], Any], Mapping[Any, Any] ], value: Optional[Any] = None, else_: Optional[Any] = None, @@ -775,7 +777,7 @@ def case( def cast( - expression: _ColumnExpression[Any], + expression: _ColumnExpressionArgument[Any], type_: _TypeEngineArgument[_T], ) -> Cast[_T]: r"""Produce a ``CAST`` expression. @@ -932,7 +934,7 @@ def column( return ColumnClause(text, type_, is_literal, _selectable) -def desc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: +def desc(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: """Produce a descending ``ORDER BY`` clause element. e.g.:: @@ -971,7 +973,7 @@ def desc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: return UnaryExpression._create_desc(column) -def distinct(expr: _ColumnExpression[_T]) -> UnaryExpression[_T]: +def distinct(expr: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: """Produce an column-expression-level unary ``DISTINCT`` clause. This applies the ``DISTINCT`` keyword to an individual column @@ -1010,7 +1012,7 @@ def distinct(expr: _ColumnExpression[_T]) -> UnaryExpression[_T]: return UnaryExpression._create_distinct(expr) -def extract(field: str, expr: _ColumnExpression[Any]) -> Extract: +def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract: """Return a :class:`.Extract` construct. This is typically available as :func:`.extract` @@ -1090,7 +1092,7 @@ def false() -> False_: def funcfilter( - func: FunctionElement[_T], *criterion: _ColumnExpression[bool] + func: FunctionElement[_T], *criterion: _ColumnExpressionArgument[bool] ) -> FunctionFilter[_T]: """Produce a :class:`.FunctionFilter` object against a function. @@ -1122,7 +1124,7 @@ def funcfilter( def label( name: str, - element: _ColumnExpression[_T], + element: _ColumnExpressionArgument[_T], type_: Optional[_TypeEngineArgument[_T]] = None, ) -> "Label[_T]": """Return a :class:`Label` object for the @@ -1149,7 +1151,7 @@ def null() -> Null: return Null._instance() -def nulls_first(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: +def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: """Produce the ``NULLS FIRST`` modifier for an ``ORDER BY`` expression. :func:`.nulls_first` is intended to modify the expression produced @@ -1193,7 +1195,7 @@ def nulls_first(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: return UnaryExpression._create_nulls_first(column) -def nulls_last(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: +def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: """Produce the ``NULLS LAST`` modifier for an ``ORDER BY`` expression. :func:`.nulls_last` is intended to modify the expression produced @@ -1237,7 +1239,7 @@ def nulls_last(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: return UnaryExpression._create_nulls_last(column) -def or_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList: +def or_(*clauses: _ColumnExpressionArgument[bool]) -> ColumnElement[bool]: """Produce a conjunction of expressions joined by ``OR``. E.g.:: @@ -1291,10 +1293,16 @@ def or_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList: def over( element: FunctionElement[_T], partition_by: Optional[ - Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + Union[ + Iterable[_ColumnExpressionArgument[Any]], + _ColumnExpressionArgument[Any], + ] ] = None, order_by: Optional[ - Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + Union[ + Iterable[_ColumnExpressionArgument[Any]], + _ColumnExpressionArgument[Any], + ] ] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, @@ -1502,7 +1510,7 @@ def true() -> True_: def tuple_( - *clauses: _ColumnExpression[Any], + *clauses: _ColumnExpressionArgument[Any], types: Optional[Sequence[_TypeEngineArgument[Any]]] = None, ) -> Tuple: """Return a :class:`.Tuple`. @@ -1531,7 +1539,7 @@ def tuple_( def type_coerce( - expression: _ColumnExpression[Any], + expression: _ColumnExpressionArgument[Any], type_: _TypeEngineArgument[_T], ) -> TypeCoerce[_T]: r"""Associate a SQL expression with a particular type, without rendering @@ -1612,7 +1620,7 @@ def type_coerce( def within_group( - element: FunctionElement[_T], *order_by: _ColumnExpression[Any] + element: FunctionElement[_T], *order_by: _ColumnExpressionArgument[Any] ) -> WithinGroup[_T]: r"""Produce a :class:`.WithinGroup` object against a function. diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index e9acc7e6dc..a17ee4ce86 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -12,7 +12,7 @@ from typing import Optional from . import coercions from . import roles -from ._typing import _ColumnsClauseElement +from ._typing import _ColumnsClauseArgument from .elements import ColumnClause from .selectable import Alias from .selectable import CompoundSelect @@ -281,7 +281,7 @@ def outerjoin(left, right, onclause=None, full=False): return Join(left, right, onclause, isouter=True, full=full) -def select(*entities: _ColumnsClauseElement) -> Select: +def select(*entities: _ColumnsClauseArgument) -> Select: r"""Construct a new :class:`_expression.Select`. diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index b50a7bf6a1..a5da878027 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Any +from typing import Iterable from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -10,9 +11,17 @@ from . import roles from .. import util from ..inspection import Inspectable from ..util.typing import Literal +from ..util.typing import Protocol if TYPE_CHECKING: + from .elements import ClauseElement + from .elements import ColumnClause + from .elements import ColumnElement from .elements import quoted_name + from .elements import SQLCoreOperations + from .elements import TextClause + from .roles import ColumnsClauseRole + from .roles import FromClauseRole from .schema import DefaultGenerator from .schema import Sequence from .selectable import FromClause @@ -24,31 +33,61 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) -_ColumnsClauseElement = Union[ + +class _HasClauseElement(Protocol): + """indicates a class that has a __clause_element__() method""" + + def __clause_element__(self) -> ColumnsClauseRole: + ... + + +# convention: +# XYZArgument - something that the end user is passing to a public API method +# XYZElement - the internal representation that we use for the thing. +# the coercions system is responsible for converting from XYZArgument to +# XYZElement. + +_ColumnsClauseArgument = Union[ Literal["*", 1], roles.ColumnsClauseRole, Type[Any], - Inspectable[roles.HasColumnElementClauseElement], + Inspectable[_HasClauseElement], + _HasClauseElement, ] -_FromClauseElement = Union[ - roles.FromClauseRole, Type[Any], Inspectable[roles.HasFromClauseElement] + +_SelectIterable = Iterable[Union["ColumnElement[Any]", "TextClause"]] + +_FromClauseArgument = Union[ + roles.FromClauseRole, + Type[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, ] -_ColumnExpression = Union[ - roles.ExpressionElementRole[_T], - Inspectable[roles.HasColumnElementClauseElement], +_ColumnExpressionArgument = Union[ + "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T] ] +_DMLColumnArgument = Union[str, "ColumnClause[Any]", _HasClauseElement] + _PropagateAttrsType = util.immutabledict[str, Any] _TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] -def is_named_from_clause(t: FromClause) -> TypeGuard[NamedFromClause]: +def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]: return t.named_with_column -def has_schema_attr(t: FromClause) -> TypeGuard[TableClause]: +def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: + return c._is_column_element + + +def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: + return c._is_text_clause + + +def has_schema_attr(t: FromClauseRole) -> TypeGuard[TableClause]: return hasattr(t, "schema") @@ -60,11 +99,5 @@ def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: return t._is_tuple_type -def is_has_clause_element(s: object) -> TypeGuard[roles.HasClauseElement]: - return hasattr(s, "__clause_element__") - - -def is_has_column_element_clause_element( - s: object, -) -> TypeGuard[roles.HasColumnElementClauseElement]: +def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]: return hasattr(s, "__clause_element__") diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index f37ae9a60d..f1919d1d39 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -25,6 +25,7 @@ from typing import overload from typing import Sequence from typing import Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from . import operators @@ -35,9 +36,9 @@ from .visitors import InternalTraversal from .. import util from ..util.typing import Literal -if typing.TYPE_CHECKING: +if TYPE_CHECKING: + from .base import _EntityNamespace from .visitors import _TraverseInternalsType - from ..util.typing import Self _AnnotationDict = Mapping[str, Any] @@ -192,7 +193,12 @@ class SupportsWrappingAnnotations(SupportsAnnotations): __slots__ = () _constructor: Callable[..., SupportsWrappingAnnotations] - entity_namespace: Mapping[str, Any] + + if TYPE_CHECKING: + + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: + ... def _annotate(self, values: _AnnotationDict) -> Annotated: """return a copy of this ClauseElement with annotations @@ -380,8 +386,8 @@ class Annotated(SupportsAnnotations): else: return hash(other) == hash(self) - @property - def entity_namespace(self) -> Mapping[str, Any]: + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: if "entity_namespace" in self._annotations: return cast( SupportsWrappingAnnotations, diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 8f51359155..19e4c13d22 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -62,10 +62,14 @@ if TYPE_CHECKING: from . import coercions from . import elements from . import type_api + from ._typing import _ColumnsClauseArgument + from ._typing import _SelectIterable from .elements import BindParameter from .elements import ColumnClause from .elements import ColumnElement + from .elements import NamedColumn from .elements import SQLCoreOperations + from .selectable import FromClause from ..engine import Connection from ..engine import Result from ..engine.base import _CompiledCacheType @@ -91,6 +95,8 @@ class _NoArg(Enum): NO_ARG = _NoArg.NO_ARG +_T = TypeVar("_T", bound=Any) + _Fn = TypeVar("_Fn", bound=Callable[..., Any]) _AmbiguousTableNameMap = MutableMapping[str, str] @@ -102,7 +108,9 @@ class _EntityNamespace(Protocol): class _HasEntityNamespace(Protocol): - entity_namespace: _EntityNamespace + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: + ... def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]: @@ -136,8 +144,8 @@ class SingletonConstant(Immutable): _singleton: SingletonConstant - def __new__(cls, *arg, **kw): - return cls._singleton + def __new__(cls: _T, *arg: Any, **kw: Any) -> _T: + return cast(_T, cls._singleton) @util.non_memoized_property def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: @@ -159,13 +167,15 @@ class SingletonConstant(Immutable): cls._singleton = obj -def _from_objects(*elements): +def _from_objects(*elements: ColumnElement[Any]) -> Iterator[FromClause]: return itertools.chain.from_iterable( [element._from_objects for element in elements] ) -def _select_iterables(elements): +def _select_iterables( + elements: Iterable[roles.ColumnsClauseRole], +) -> _SelectIterable: """expand tables into individual columns in the given list of column expressions. @@ -207,7 +217,7 @@ def _generative(fn: _Fn) -> _Fn: return decorated # type: ignore -def _exclusive_against(*names, **kw): +def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]: msgs = kw.pop("msgs", {}) defaults = kw.pop("defaults", {}) @@ -502,7 +512,7 @@ class DialectKWArgs: util.portable_instancemethod(self._kw_reg_for_dialect_cls) ) - def _validate_dialect_kwargs(self, kwargs): + def _validate_dialect_kwargs(self, kwargs: Any) -> None: # validate remaining kwargs that they all specify DB prefixes if not kwargs: @@ -605,7 +615,9 @@ class CompileState: self.statement = statement @classmethod - def get_plugin_class(cls, statement): + def get_plugin_class( + cls, statement: Executable + ) -> Optional[Type[CompileState]]: plugin_name = statement._propagate_attrs.get( "compile_state_plugin", None ) @@ -634,7 +646,9 @@ class CompileState: return None @classmethod - def plugin_for(cls, plugin_name, visit_name): + def plugin_for( + cls, plugin_name: str, visit_name: str + ) -> Callable[[_Fn], _Fn]: def decorate(cls_to_decorate): cls.plugins[(plugin_name, visit_name)] = cls_to_decorate return cls_to_decorate @@ -957,7 +971,7 @@ class Executable(roles.StatementRole, Generative): ) -> Result: ... - @property + @util.non_memoized_property def _all_selected_columns(self): raise NotImplementedError() @@ -1202,10 +1216,11 @@ class SchemaVisitor(ClauseVisitor): __traverse_options__ = {"schema_visitor": True} -_COL = TypeVar("_COL", bound="ColumnClause[Any]") +_COLKEY = TypeVar("_COLKEY", Union[None, str], str) +_COL = TypeVar("_COL", bound="ColumnElement[Any]") -class ColumnCollection(Generic[_COL]): +class ColumnCollection(Generic[_COLKEY, _COL]): """Collection of :class:`_expression.ColumnElement` instances, typically for :class:`_sql.FromClause` objects. @@ -1316,25 +1331,27 @@ class ColumnCollection(Generic[_COL]): __slots__ = "_collection", "_index", "_colset" - _collection: List[Tuple[str, _COL]] - _index: Dict[Union[str, int], _COL] + _collection: List[Tuple[_COLKEY, _COL]] + _index: Dict[Union[None, str, int], _COL] _colset: Set[_COL] - def __init__(self, columns: Optional[Iterable[Tuple[str, _COL]]] = None): + def __init__( + self, columns: Optional[Iterable[Tuple[_COLKEY, _COL]]] = None + ): object.__setattr__(self, "_colset", set()) object.__setattr__(self, "_index", {}) object.__setattr__(self, "_collection", []) if columns: self._initial_populate(columns) - def _initial_populate(self, iter_: Iterable[Tuple[str, _COL]]) -> None: + def _initial_populate(self, iter_: Iterable[Tuple[_COLKEY, _COL]]) -> None: self._populate_separate_keys(iter_) @property def _all_columns(self) -> List[_COL]: return [col for (k, col) in self._collection] - def keys(self) -> List[str]: + def keys(self) -> List[_COLKEY]: """Return a sequence of string key names for all columns in this collection.""" return [k for (k, col) in self._collection] @@ -1345,7 +1362,7 @@ class ColumnCollection(Generic[_COL]): collection.""" return [col for (k, col) in self._collection] - def items(self) -> List[Tuple[str, _COL]]: + def items(self) -> List[Tuple[_COLKEY, _COL]]: """Return a sequence of (key, column) tuples for all columns in this collection each consisting of a string key name and a :class:`_sql.ColumnClause` or @@ -1389,7 +1406,7 @@ class ColumnCollection(Generic[_COL]): else: return True - def compare(self, other: ColumnCollection[Any]) -> bool: + def compare(self, other: ColumnCollection[Any, Any]) -> bool: """Compare this :class:`_expression.ColumnCollection` to another based on the names of the keys""" @@ -1444,7 +1461,7 @@ class ColumnCollection(Generic[_COL]): __hash__ = None # type: ignore def _populate_separate_keys( - self, iter_: Iterable[Tuple[str, _COL]] + self, iter_: Iterable[Tuple[_COLKEY, _COL]] ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) @@ -1455,7 +1472,7 @@ class ColumnCollection(Generic[_COL]): ) self._index.update({k: col for k, col in reversed(self._collection)}) - def add(self, column: _COL, key: Optional[str] = None) -> None: + def add(self, column: _COL, key: Optional[_COLKEY] = None) -> None: """Add a column to this :class:`_sql.ColumnCollection`. .. note:: @@ -1467,15 +1484,19 @@ class ColumnCollection(Generic[_COL]): object, use the :meth:`_schema.Table.append_column` method. """ + colkey: _COLKEY + if key is None: - key = column.key + colkey = column.key # type: ignore + else: + colkey = key l = len(self._collection) - self._collection.append((key, column)) + self._collection.append((colkey, column)) self._colset.add(column) self._index[l] = column - if key not in self._index: - self._index[key] = column + if colkey not in self._index: + self._index[colkey] = column def __getstate__(self) -> Dict[str, Any]: return {"_collection": self._collection, "_index": self._index} @@ -1499,11 +1520,11 @@ class ColumnCollection(Generic[_COL]): else: return True - def as_immutable(self) -> ImmutableColumnCollection[_COL]: - """Return an "immutable" form of this + def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL]: + """Return a "read only" form of this :class:`_sql.ColumnCollection`.""" - return ImmutableColumnCollection(self) + return ReadOnlyColumnCollection(self) def corresponding_column( self, column: _COL, require_embedded: bool = False @@ -1605,7 +1626,10 @@ class ColumnCollection(Generic[_COL]): return col -class DedupeColumnCollection(ColumnCollection[_COL]): +_NAMEDCOL = TypeVar("_NAMEDCOL", bound="NamedColumn[Any]") + + +class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): """A :class:`_expression.ColumnCollection` that maintains deduplicating behavior. @@ -1618,7 +1642,7 @@ class DedupeColumnCollection(ColumnCollection[_COL]): """ - def add(self, column: _COL, key: Optional[str] = None) -> None: + def add(self, column: _NAMEDCOL, key: Optional[str] = None) -> None: if key is not None and column.key != key: raise exc.ArgumentError( @@ -1653,7 +1677,7 @@ class DedupeColumnCollection(ColumnCollection[_COL]): self._index[key] = column def _populate_separate_keys( - self, iter_: Iterable[Tuple[str, _COL]] + self, iter_: Iterable[Tuple[str, _NAMEDCOL]] ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) @@ -1679,10 +1703,10 @@ class DedupeColumnCollection(ColumnCollection[_COL]): for col in replace_col: self.replace(col) - def extend(self, iter_: Iterable[_COL]) -> None: - self._populate_separate_keys((col.key, col) for col in iter_) + def extend(self, iter_: Iterable[_NAMEDCOL]) -> None: + self._populate_separate_keys((col.key, col) for col in iter_) # type: ignore # noqa: E501 - def remove(self, column: _COL) -> None: + def remove(self, column: _NAMEDCOL) -> None: if column not in self._colset: raise ValueError( "Can't remove column %r; column is not in this collection" @@ -1699,7 +1723,7 @@ class DedupeColumnCollection(ColumnCollection[_COL]): # delete higher index del self._index[len(self._collection)] - def replace(self, column: _COL) -> None: + def replace(self, column: _NAMEDCOL) -> None: """add the given column to this collection, removing unaliased versions of this column as well as existing columns with the same key. @@ -1726,7 +1750,7 @@ class DedupeColumnCollection(ColumnCollection[_COL]): if column.key in self._index: remove_col.add(self._index[column.key]) - new_cols = [] + new_cols: List[Tuple[str, _NAMEDCOL]] = [] replaced = False for k, col in self._collection: if col in remove_col: @@ -1752,8 +1776,8 @@ class DedupeColumnCollection(ColumnCollection[_COL]): self._index.update(self._collection) -class ImmutableColumnCollection( - util.ImmutableContainer, ColumnCollection[_COL] +class ReadOnlyColumnCollection( + util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL] ): __slots__ = ("_parent",) @@ -1771,13 +1795,13 @@ class ImmutableColumnCollection( self.__init__(parent) # type: ignore def add(self, column: Any, key: Any = ...) -> Any: - self._immutable() + self._readonly() - def extend(self, elements: Any) -> None: - self._immutable() + def extend(self, elements: Any) -> NoReturn: + self._readonly() - def remove(self, item: Any) -> None: - self._immutable() + def remove(self, item: Any) -> NoReturn: + self._readonly() class ColumnSet(util.OrderedSet["ColumnClause[Any]"]): diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 35cd33a186..ccc8fba8d2 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -46,6 +46,7 @@ if typing.TYPE_CHECKING: from . import schema from . import selectable from . import traversals + from ._typing import _ColumnsClauseArgument from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement @@ -164,6 +165,32 @@ def expect( ... +@overload +def expect( + role: Type[roles.DMLTableRole], + element: Any, + *, + apply_propagate_attrs: Optional[ClauseElement] = None, + argname: Optional[str] = None, + post_inspect: bool = False, + **kw: Any, +) -> roles.DMLTableRole: + ... + + +@overload +def expect( + role: Type[roles.ColumnsClauseRole], + element: Any, + *, + apply_propagate_attrs: Optional[ClauseElement] = None, + argname: Optional[str] = None, + post_inspect: bool = False, + **kw: Any, +) -> roles.ColumnsClauseRole: + ... + + @overload def expect( role: Type[_SR], diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 7fd37e9b1f..a2f731ac91 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -52,7 +52,6 @@ from typing import Type from typing import TYPE_CHECKING from typing import Union -from sqlalchemy.sql.ddl import DDLElement from . import base from . import coercions from . import crud @@ -79,10 +78,12 @@ from ..util.typing import Protocol from ..util.typing import TypedDict if typing.TYPE_CHECKING: + from . import roles from .annotation import _AnnotationDict from .base import _AmbiguousTableNameMap from .base import CompileState from .cache_key import CacheKey + from .ddl import DDLElement from .dml import Insert from .dml import UpdateBase from .dml import ValuesBase @@ -724,7 +725,7 @@ class SQLCompiler(Compiled): """list of columns for which onupdate default values should be evaluated before an UPDATE takes place""" - returning: Optional[List[ColumnClause[Any]]] + returning: Optional[Sequence[roles.ColumnsClauseRole]] """list of columns that will be delivered to cursor.description or dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE @@ -4099,7 +4100,9 @@ class SQLCompiler(Compiled): return " FOR UPDATE" def returning_clause( - self, stmt: UpdateBase, returning_cols: List[ColumnClause[Any]] + self, + stmt: UpdateBase, + returning_cols: Sequence[roles.ColumnsClauseRole], ) -> str: raise exc.CompileError( "RETURNING is not supported by this " diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 533a2f6cd9..91a3f70c91 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -39,6 +39,7 @@ from ..util.typing import Literal if TYPE_CHECKING: from .compiler import _BindNameForColProtocol from .compiler import SQLCompiler + from .dml import _DMLColumnElement from .dml import DMLState from .dml import Insert from .dml import Update @@ -129,8 +130,10 @@ def _get_crud_params( [], ) - stmt_parameter_tuples: Optional[List[Any]] - spd: Optional[MutableMapping[str, Any]] + stmt_parameter_tuples: Optional[ + List[Tuple[Union[str, ColumnClause[Any]], Any]] + ] + spd: Optional[MutableMapping[_DMLColumnElement, Any]] if compile_state._has_multi_parameters: mp = compile_state._multi_parameters @@ -355,8 +358,8 @@ def _handle_values_anonymous_param(compiler, col, value, name, **kw): def _key_getters_for_crud_column( compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState ) -> Tuple[ - Callable[[Union[str, Column[Any]]], Union[str, Tuple[str, str]]], - Callable[[Column[Any]], Union[str, Tuple[str, str]]], + Callable[[Union[str, ColumnClause[Any]]], Union[str, Tuple[str, str]]], + Callable[[ColumnClause[Any]], Union[str, Tuple[str, str]]], _BindNameForColProtocol, ]: if dml.isupdate(compile_state) and compile_state._extra_froms: diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index f5fb6b2f34..0c9056aeeb 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -15,18 +15,29 @@ import collections.abc as collections_abc import operator import typing from typing import Any +from typing import cast +from typing import Dict +from typing import Iterable from typing import List from typing import MutableMapping +from typing import NoReturn from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type from typing import TYPE_CHECKING +from typing import Union from . import coercions from . import roles from . import util as sql_util +from ._typing import is_column_element +from ._typing import is_named_from_clause from .base import _entity_namespace_key from .base import _exclusive_against from .base import _from_objects from .base import _generative +from .base import _select_iterables from .base import ColumnCollection from .base import CompileState from .base import DialectKWArgs @@ -34,7 +45,9 @@ from .base import Executable from .base import HasCompileState from .elements import BooleanClauseList from .elements import ClauseElement +from .elements import ColumnElement from .elements import Null +from .selectable import FromClause from .selectable import HasCTE from .selectable import HasPrefixes from .selectable import ReturnsRows @@ -45,16 +58,25 @@ from .. import exc from .. import util from ..util.typing import TypeGuard - if TYPE_CHECKING: - def isupdate(dml) -> TypeGuard[UpdateDMLState]: + from ._typing import _ColumnsClauseArgument + from ._typing import _DMLColumnArgument + from ._typing import _FromClauseArgument + from ._typing import _HasClauseElement + from ._typing import _SelectIterable + from .base import ReadOnlyColumnCollection + from .compiler import SQLCompiler + from .elements import ColumnClause + from .selectable import Select + + def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]: ... - def isdelete(dml) -> TypeGuard[DeleteDMLState]: + def isdelete(dml: DMLState) -> TypeGuard[DeleteDMLState]: ... - def isinsert(dml) -> TypeGuard[InsertDMLState]: + def isinsert(dml: DMLState) -> TypeGuard[InsertDMLState]: ... else: @@ -63,27 +85,43 @@ else: isinsert = operator.attrgetter("isinsert") +_DMLColumnElement = Union[str, "ColumnClause[Any]"] + + class DMLState(CompileState): _no_parameters = True - _dict_parameters: Optional[MutableMapping[str, Any]] = None - _multi_parameters: Optional[List[MutableMapping[str, Any]]] = None - _ordered_values = None - _parameter_ordering = None + _dict_parameters: Optional[MutableMapping[_DMLColumnElement, Any]] = None + _multi_parameters: Optional[ + List[MutableMapping[_DMLColumnElement, Any]] + ] = None + _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None + _parameter_ordering: Optional[List[_DMLColumnElement]] = None _has_multi_parameters = False isupdate = False isdelete = False isinsert = False - def __init__(self, statement, compiler, **kw): + statement: UpdateBase + + def __init__( + self, statement: UpdateBase, compiler: SQLCompiler, **kw: Any + ): raise NotImplementedError() @classmethod - def get_entity_description(cls, statement): - return {"name": statement.table.name, "table": statement.table} + def get_entity_description(cls, statement: UpdateBase) -> Dict[str, Any]: + return { + "name": statement.table.name + if is_named_from_clause(statement.table) + else None, + "table": statement.table, + } @classmethod - def get_returning_column_descriptions(cls, statement): + def get_returning_column_descriptions( + cls, statement: UpdateBase + ) -> List[Dict[str, Any]]: return [ { "name": c.key, @@ -94,11 +132,21 @@ class DMLState(CompileState): ] @property - def dml_table(self): + def dml_table(self) -> roles.DMLTableRole: return self.statement.table + if TYPE_CHECKING: + + @classmethod + def get_plugin_class(cls, statement: Executable) -> Type[DMLState]: + ... + @classmethod - def _get_crud_kv_pairs(cls, statement, kv_iterator): + def _get_crud_kv_pairs( + cls, + statement: UpdateBase, + kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]], + ) -> List[Tuple[_DMLColumnElement, Any]]: return [ ( coercions.expect(roles.DMLColumnRole, k), @@ -112,8 +160,8 @@ class DMLState(CompileState): for k, v in kv_iterator ] - def _make_extra_froms(self, statement): - froms = [] + def _make_extra_froms(self, statement: DMLWhereBase) -> List[FromClause]: + froms: List[FromClause] = [] all_tables = list(sql_util.tables_from_leftmost(statement.table)) seen = {all_tables[0]} @@ -127,7 +175,7 @@ class DMLState(CompileState): froms.extend(all_tables[1:]) return froms - def _process_multi_values(self, statement): + def _process_multi_values(self, statement: ValuesBase) -> None: if not statement._supports_multi_parameters: raise exc.InvalidRequestError( "%s construct does not support " @@ -135,7 +183,7 @@ class DMLState(CompileState): ) for parameters in statement._multi_values: - multi_parameters = [ + multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [ { c.key: value for c, value in zip(statement.table.c, parameter_set) @@ -153,9 +201,10 @@ class DMLState(CompileState): elif not self._has_multi_parameters: self._cant_mix_formats_error() else: + assert self._multi_parameters self._multi_parameters.extend(multi_parameters) - def _process_values(self, statement): + def _process_values(self, statement: ValuesBase) -> None: if self._no_parameters: self._has_multi_parameters = False self._dict_parameters = statement._values @@ -163,11 +212,12 @@ class DMLState(CompileState): elif self._has_multi_parameters: self._cant_mix_formats_error() - def _process_ordered_values(self, statement): + def _process_ordered_values(self, statement: ValuesBase) -> None: parameters = statement._ordered_values if self._no_parameters: self._no_parameters = False + assert parameters is not None self._dict_parameters = dict(parameters) self._ordered_values = parameters self._parameter_ordering = [key for key, value in parameters] @@ -179,7 +229,8 @@ class DMLState(CompileState): "with any other values() call" ) - def _process_select_values(self, statement): + def _process_select_values(self, statement: ValuesBase) -> None: + assert statement._select_names is not None parameters = { coercions.expect(roles.DMLColumnRole, name, as_key=True): Null() for name in statement._select_names @@ -193,7 +244,7 @@ class DMLState(CompileState): # does not allow this construction to occur assert False, "This statement already has parameters" - def _cant_mix_formats_error(self): + def _cant_mix_formats_error(self) -> NoReturn: raise exc.InvalidRequestError( "Can't mix single and multiple VALUES " "formats in one INSERT statement; one style appends to a " @@ -208,7 +259,7 @@ class InsertDMLState(DMLState): include_table_with_column_exprs = False - def __init__(self, statement, compiler, **kw): + def __init__(self, statement: Insert, compiler: SQLCompiler, **kw: Any): self.statement = statement self.isinsert = True @@ -226,10 +277,9 @@ class UpdateDMLState(DMLState): include_table_with_column_exprs = False - def __init__(self, statement, compiler, **kw): + def __init__(self, statement: Update, compiler: SQLCompiler, **kw: Any): self.statement = statement self.isupdate = True - self._preserve_parameter_order = statement._preserve_parameter_order if statement._ordered_values is not None: self._process_ordered_values(statement) elif statement._values is not None: @@ -238,7 +288,7 @@ class UpdateDMLState(DMLState): self._process_multi_values(statement) self._extra_froms = ef = self._make_extra_froms(statement) self.is_multitable = mt = ef and self._dict_parameters - self.include_table_with_column_exprs = ( + self.include_table_with_column_exprs = bool( mt and compiler.render_table_with_column_in_update_from ) @@ -247,7 +297,7 @@ class UpdateDMLState(DMLState): class DeleteDMLState(DMLState): isdelete = True - def __init__(self, statement, compiler, **kw): + def __init__(self, statement: Delete, compiler: SQLCompiler, **kw: Any): self.statement = statement self.isdelete = True @@ -271,23 +321,31 @@ class UpdateBase( __visit_name__ = "update_base" - _hints = util.immutabledict() + _hints: util.immutabledict[ + Tuple[roles.DMLTableRole, str], str + ] = util.EMPTY_DICT named_with_column = False - table: TableClause + table: roles.DMLTableRole _return_defaults = False - _return_defaults_columns = None - _returning = () + _return_defaults_columns: Optional[ + Tuple[roles.ColumnsClauseRole, ...] + ] = None + _returning: Tuple[roles.ColumnsClauseRole, ...] = () is_dml = True - def _generate_fromclause_column_proxies(self, fromclause): + def _generate_fromclause_column_proxies( + self, fromclause: FromClause + ) -> None: fromclause._columns._populate_separate_keys( - col._make_proxy(fromclause) for col in self._returning + col._make_proxy(fromclause) + for col in self._all_selected_columns + if is_column_element(col) ) - def params(self, *arg, **kw): + def params(self, *arg: Any, **kw: Any) -> NoReturn: """Set the parameters for the statement. This method raises ``NotImplementedError`` on the base class, @@ -302,7 +360,9 @@ class UpdateBase( ) @_generative - def with_dialect_options(self: SelfUpdateBase, **opt) -> SelfUpdateBase: + def with_dialect_options( + self: SelfUpdateBase, **opt: Any + ) -> SelfUpdateBase: """Add dialect options to this INSERT/UPDATE/DELETE object. e.g.:: @@ -318,7 +378,9 @@ class UpdateBase( return self @_generative - def returning(self: SelfUpdateBase, *cols) -> SelfUpdateBase: + def returning( + self: SelfUpdateBase, *cols: _ColumnsClauseArgument + ) -> SelfUpdateBase: r"""Add a :term:`RETURNING` or equivalent clause to this statement. e.g.: @@ -397,26 +459,32 @@ class UpdateBase( ) return self - @property - def _all_selected_columns(self): - return self._returning + @util.non_memoized_property + def _all_selected_columns(self) -> _SelectIterable: + return [c for c in _select_iterables(self._returning)] @property - def exported_columns(self): + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[Optional[str], ColumnElement[Any]]: """Return the RETURNING columns as a column collection for this statement. .. versionadded:: 1.4 """ - # TODO: no coverage here return ColumnCollection( - (c.key, c) for c in self._all_selected_columns - ).as_immutable() + (c.key, c) + for c in self._all_selected_columns + if is_column_element(c) + ).as_readonly() @_generative def with_hint( - self: SelfUpdateBase, text, selectable=None, dialect_name="*" + self: SelfUpdateBase, + text: str, + selectable: Optional[roles.DMLTableRole] = None, + dialect_name: str = "*", ) -> SelfUpdateBase: """Add a table hint for a single table to this INSERT/UPDATE/DELETE statement. @@ -454,7 +522,7 @@ class UpdateBase( return self @property - def entity_description(self): + def entity_description(self) -> Dict[str, Any]: """Return a :term:`plugin-enabled` description of the table and/or entity which this DML construct is operating against. @@ -490,7 +558,7 @@ class UpdateBase( return meth(self) @property - def returning_column_descriptions(self): + def returning_column_descriptions(self) -> List[Dict[str, Any]]: """Return a :term:`plugin-enabled` description of the columns which this DML construct is RETURNING against, in other words the expressions established as part of :meth:`.UpdateBase.returning`. @@ -547,18 +615,30 @@ class ValuesBase(UpdateBase): __visit_name__ = "values_base" _supports_multi_parameters = False - _preserve_parameter_order = False - select = None - _post_values_clause = None - _values = None - _multi_values = () - _ordered_values = None - _select_names = None + select: Optional[Select] = None + """SELECT statement for INSERT .. FROM SELECT""" + + _post_values_clause: Optional[ClauseElement] = None + """used by extensions to Insert etc. to add additional syntacitcal + constructs, e.g. ON CONFLICT etc.""" + + _values: Optional[util.immutabledict[_DMLColumnElement, Any]] = None + _multi_values: Tuple[ + Union[ + Sequence[Dict[_DMLColumnElement, Any]], + Sequence[Sequence[Any]], + ], + ..., + ] = () + + _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None + + _select_names: Optional[List[str]] = None _inline: bool = False - _returning = () + _returning: Tuple[roles.ColumnsClauseRole, ...] = () - def __init__(self, table): + def __init__(self, table: _FromClauseArgument): self.table = coercions.expect( roles.DMLTableRole, table, apply_propagate_attrs=self ) @@ -573,7 +653,14 @@ class ValuesBase(UpdateBase): "values present", }, ) - def values(self: SelfValuesBase, *args, **kwargs) -> SelfValuesBase: + def values( + self: SelfValuesBase, + *args: Union[ + Dict[_DMLColumnArgument, Any], + Sequence[Any], + ], + **kwargs: Any, + ) -> SelfValuesBase: r"""Specify a fixed VALUES clause for an INSERT statement, or the SET clause for an UPDATE. @@ -704,9 +791,7 @@ class ValuesBase(UpdateBase): "dictionaries/tuples is accepted positionally." ) - elif not self._preserve_parameter_order and isinstance( - arg, collections_abc.Sequence - ): + elif isinstance(arg, collections_abc.Sequence): if arg and isinstance(arg[0], (list, dict, tuple)): self._multi_values += (arg,) @@ -714,18 +799,11 @@ class ValuesBase(UpdateBase): # tuple values arg = {c.key: value for c, value in zip(self.table.c, arg)} - elif self._preserve_parameter_order and not isinstance( - arg, collections_abc.Sequence - ): - raise ValueError( - "When preserve_parameter_order is True, " - "values() only accepts a list of 2-tuples" - ) else: # kwarg path. this is the most common path for non-multi-params # so this is fairly quick. - arg = kwargs + arg = cast("Dict[_DMLColumnArgument, Any]", kwargs) if args: raise exc.ArgumentError( "Only a single dictionary/tuple or list of " @@ -739,15 +817,11 @@ class ValuesBase(UpdateBase): # and ensures they get the "crud"-style name when rendered. kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs - - if self._preserve_parameter_order: - self._ordered_values = kv_generator(self, arg) + coerced_arg = {k: v for k, v in kv_generator(self, arg.items())} + if self._values: + self._values = self._values.union(coerced_arg) else: - arg = {k: v for k, v in kv_generator(self, arg.items())} - if self._values: - self._values = self._values.union(arg) - else: - self._values = util.immutabledict(arg) + self._values = util.immutabledict(coerced_arg) return self @_generative @@ -758,7 +832,9 @@ class ValuesBase(UpdateBase): }, defaults={"_returning": _returning}, ) - def return_defaults(self: SelfValuesBase, *cols) -> SelfValuesBase: + def return_defaults( + self: SelfValuesBase, *cols: _DMLColumnArgument + ) -> SelfValuesBase: """Make use of a :term:`RETURNING` clause for the purpose of fetching server-side expressions and defaults. @@ -843,7 +919,9 @@ class ValuesBase(UpdateBase): """ self._return_defaults = True - self._return_defaults_columns = cols + self._return_defaults_columns = tuple( + coercions.expect(roles.ColumnsClauseRole, c) for c in cols + ) return self @@ -867,6 +945,8 @@ class Insert(ValuesBase): is_insert = True + table: TableClause + _traverse_internals = ( [ ("table", InternalTraversal.dp_clauseelement), @@ -890,7 +970,7 @@ class Insert(ValuesBase): + HasCTE._has_ctes_traverse_internals ) - def __init__(self, table): + def __init__(self, table: roles.FromClauseRole): super(Insert, self).__init__(table) @_generative @@ -916,7 +996,10 @@ class Insert(ValuesBase): @_generative def from_select( - self: SelfInsert, names, select, include_defaults=True + self: SelfInsert, + names: List[str], + select: Select, + include_defaults: bool = True, ) -> SelfInsert: """Return a new :class:`_expression.Insert` construct which represents an ``INSERT...FROM SELECT`` statement. @@ -983,10 +1066,13 @@ SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase") class DMLWhereBase: - _where_criteria = () + table: roles.DMLTableRole + _where_criteria: Tuple[ColumnElement[Any], ...] = () @_generative - def where(self: SelfDMLWhereBase, *whereclause) -> SelfDMLWhereBase: + def where( + self: SelfDMLWhereBase, *whereclause: roles.ExpressionElementRole[Any] + ) -> SelfDMLWhereBase: """Return a new construct with the given expression(s) added to its WHERE clause, joined to the existing clause via AND, if any. @@ -1022,7 +1108,9 @@ class DMLWhereBase: self._where_criteria += (where_criteria,) return self - def filter(self: SelfDMLWhereBase, *criteria) -> SelfDMLWhereBase: + def filter( + self: SelfDMLWhereBase, *criteria: roles.ExpressionElementRole[Any] + ) -> SelfDMLWhereBase: """A synonym for the :meth:`_dml.DMLWhereBase.where` method. .. versionadded:: 1.4 @@ -1031,10 +1119,10 @@ class DMLWhereBase: return self.where(*criteria) - def _filter_by_zero(self): + def _filter_by_zero(self) -> roles.DMLTableRole: return self.table - def filter_by(self: SelfDMLWhereBase, **kwargs) -> SelfDMLWhereBase: + def filter_by(self: SelfDMLWhereBase, **kwargs: Any) -> SelfDMLWhereBase: r"""apply the given filtering criterion as a WHERE clause to this select. @@ -1048,7 +1136,7 @@ class DMLWhereBase: return self.filter(*clauses) @property - def whereclause(self): + def whereclause(self) -> Optional[ColumnElement[Any]]: """Return the completed WHERE clause for this :class:`.DMLWhereBase` statement. @@ -1079,7 +1167,6 @@ class Update(DMLWhereBase, ValuesBase): __visit_name__ = "update" is_update = True - _preserve_parameter_order = False _traverse_internals = ( [ @@ -1102,11 +1189,13 @@ class Update(DMLWhereBase, ValuesBase): + HasCTE._has_ctes_traverse_internals ) - def __init__(self, table): + def __init__(self, table: roles.FromClauseRole): super(Update, self).__init__(table) @_generative - def ordered_values(self: SelfUpdate, *args) -> SelfUpdate: + def ordered_values( + self: SelfUpdate, *args: Tuple[_DMLColumnArgument, Any] + ) -> SelfUpdate: """Specify the VALUES clause of this UPDATE statement with an explicit parameter ordering that will be maintained in the SET clause of the resulting UPDATE statement. @@ -1190,7 +1279,7 @@ class Delete(DMLWhereBase, UpdateBase): + HasCTE._has_ctes_traverse_internals ) - def __init__(self, table): + def __init__(self, table: roles.FromClauseRole): self.table = coercions.expect( roles.DMLTableRole, table, apply_propagate_attrs=self ) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 691eb10ec4..da1d50a535 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -70,12 +70,14 @@ from .visitors import Visitable from .. import exc from .. import inspection from .. import util -from ..util.langhelpers import TypingOnly +from ..util import HasMemoized_ro_memoized_attribute +from ..util import TypingOnly from ..util.typing import Literal if typing.TYPE_CHECKING: - from ._typing import _ColumnExpression + from ._typing import _ColumnExpressionArgument from ._typing import _PropagateAttrsType + from ._typing import _SelectIterable from ._typing import _TypeEngineArgument from .cache_key import CacheKey from .compiler import Compiled @@ -300,7 +302,7 @@ class ClauseElement( is_clause_element = True is_selectable = False - + _is_column_element = False _is_table = False _is_textual = False _is_from_clause = False @@ -330,7 +332,7 @@ class ClauseElement( ) -> Iterable[ClauseElement]: ... - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return [] @@ -696,6 +698,9 @@ class CompilerColumnElement( __slots__ = () +# SQLCoreOperations should be suiting the ExpressionElementRole +# and ColumnsClauseRole. however the MRO issues become too elaborate +# at the moment. class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): __slots__ = () @@ -1154,6 +1159,7 @@ class ColumnElement( primary_key: bool = False _is_clone_of: Optional[ColumnElement[_T]] + _is_column_element = True foreign_keys: AbstractSet[ForeignKey] = frozenset() @@ -1396,7 +1402,7 @@ class ColumnElement( return self @property - def _select_iterable(self) -> Iterable[ColumnElement[Any]]: + def _select_iterable(self) -> _SelectIterable: return (self,) @util.memoized_property @@ -2075,7 +2081,7 @@ class TextClause( return and_(self, other) @property - def _select_iterable(self): + def _select_iterable(self) -> _SelectIterable: return (self,) # help in those cases where text() is @@ -2491,9 +2497,11 @@ class ClauseList( ("operator", InternalTraversal.dp_operator), ] + clauses: List[ColumnElement[Any]] + def __init__( self, - *clauses: _ColumnExpression[Any], + *clauses: _ColumnExpressionArgument[Any], operator: OperatorType = operators.comma_op, group: bool = True, group_contents: bool = True, @@ -2541,7 +2549,7 @@ class ClauseList( return len(self.clauses) @property - def _select_iterable(self): + def _select_iterable(self) -> _SelectIterable: return itertools.chain.from_iterable( [elem._select_iterable for elem in self.clauses] ) @@ -2558,7 +2566,7 @@ class ClauseList( coercions.expect(self._text_converter_role, clause) ) - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return list(itertools.chain(*[c._from_objects for c in self.clauses])) @@ -2580,8 +2588,12 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): @classmethod def _process_clauses_for_boolean( - cls, operator, continue_on, skip_on, clauses - ): + cls, + operator: OperatorType, + continue_on: Any, + skip_on: Any, + clauses: Iterable[ColumnElement[Any]], + ) -> typing_Tuple[int, List[ColumnElement[Any]]]: has_continue_on = None convert_clauses = [] @@ -2623,9 +2635,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): operator: OperatorType, continue_on: Any, skip_on: Any, - *clauses: _ColumnExpression[Any], + *clauses: _ColumnExpressionArgument[Any], **kw: Any, - ) -> BooleanClauseList: + ) -> ColumnElement[Any]: lcc, convert_clauses = cls._process_clauses_for_boolean( operator, continue_on, @@ -2639,7 +2651,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): if lcc > 1: # multiple elements. Return regular BooleanClauseList # which will link elements against the operator. - return cls._construct_raw(operator, convert_clauses) # type: ignore[no-any-return] # noqa E501 + return cls._construct_raw(operator, convert_clauses) # type: ignore # noqa E501 elif lcc == 1: # just one element. return it as a single boolean element, # not a list and discard the operator. @@ -2663,7 +2675,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): return cls._construct_raw(operator) # type: ignore[no-any-return] # noqa E501 @classmethod - def _construct_for_whereclause(cls, clauses): + def _construct_for_whereclause( + cls, clauses: Iterable[ColumnElement[Any]] + ) -> Optional[ColumnElement[bool]]: operator, continue_on, skip_on = ( operators.and_, True_._singleton, @@ -2689,7 +2703,11 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): return None @classmethod - def _construct_raw(cls, operator, clauses=None): + def _construct_raw( + cls, + operator: OperatorType, + clauses: Optional[List[ColumnElement[Any]]] = None, + ) -> BooleanClauseList: self = cls.__new__(cls) self.clauses = clauses if clauses else [] self.group = True @@ -2700,7 +2718,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): return self @classmethod - def and_(cls, *clauses: _ColumnExpression[bool]) -> BooleanClauseList: + def and_( + cls, *clauses: _ColumnExpressionArgument[bool] + ) -> ColumnElement[bool]: r"""Produce a conjunction of expressions joined by ``AND``. See :func:`_sql.and_` for full documentation. @@ -2710,7 +2730,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): ) @classmethod - def or_(cls, *clauses: _ColumnExpression[bool]) -> BooleanClauseList: + def or_( + cls, *clauses: _ColumnExpressionArgument[bool] + ) -> ColumnElement[bool]: """Produce a conjunction of expressions joined by ``OR``. See :func:`_sql.or_` for full documentation. @@ -2720,7 +2742,7 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): ) @property - def _select_iterable(self): + def _select_iterable(self) -> _SelectIterable: return (self,) def self_group(self, against=None): @@ -2751,7 +2773,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): @util.preload_module("sqlalchemy.sql.sqltypes") def __init__( self, - *clauses: _ColumnExpression[Any], + *clauses: _ColumnExpressionArgument[Any], types: Optional[Sequence[_TypeEngineArgument[Any]]] = None, ): sqltypes = util.preloaded.sql_sqltypes @@ -2780,7 +2802,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): super(Tuple, self).__init__(*init_clauses) @property - def _select_iterable(self): + def _select_iterable(self) -> _SelectIterable: return (self,) def _bind_param(self, operator, obj, type_=None, expanding=False): @@ -2856,7 +2878,8 @@ class Case(ColumnElement[_T]): def __init__( self, *whens: Union[ - typing_Tuple[_ColumnExpression[bool], Any], Mapping[Any, Any] + typing_Tuple[_ColumnExpressionArgument[bool], Any], + Mapping[Any, Any], ], value: Optional[Any] = None, else_: Optional[Any] = None, @@ -2900,7 +2923,7 @@ class Case(ColumnElement[_T]): else: self.else_ = None - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return list( itertools.chain(*[x._from_objects for x in self.get_children()]) @@ -2944,7 +2967,7 @@ class Cast(WrapsColumnExpression[_T]): def __init__( self, - expression: _ColumnExpression[Any], + expression: _ColumnExpressionArgument[Any], type_: _TypeEngineArgument[_T], ): self.type = type_api.to_instance(type_) @@ -2956,7 +2979,7 @@ class Cast(WrapsColumnExpression[_T]): ) self.typeclause = TypeClause(self.type) - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return self.clause._from_objects @@ -2995,7 +3018,7 @@ class TypeCoerce(WrapsColumnExpression[_T]): def __init__( self, - expression: _ColumnExpression[Any], + expression: _ColumnExpressionArgument[Any], type_: _TypeEngineArgument[_T], ): self.type = type_api.to_instance(type_) @@ -3006,7 +3029,7 @@ class TypeCoerce(WrapsColumnExpression[_T]): apply_propagate_attrs=self, ) - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return self.clause._from_objects @@ -3044,12 +3067,12 @@ class Extract(ColumnElement[int]): expr: ColumnElement[Any] field: str - def __init__(self, field: str, expr: _ColumnExpression[Any]): + def __init__(self, field: str, expr: _ColumnExpressionArgument[Any]): self.type = type_api.INTEGERTYPE self.field = field self.expr = coercions.expect(roles.ExpressionElementRole, expr) - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return self.expr._from_objects @@ -3076,7 +3099,7 @@ class _label_reference(ColumnElement[_T]): def __init__(self, element: ColumnElement[_T]): self.element = element - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return [] @@ -3142,7 +3165,7 @@ class UnaryExpression(ColumnElement[_T]): @classmethod def _create_nulls_first( cls, - column: _ColumnExpression[_T], + column: _ColumnExpressionArgument[_T], ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), @@ -3153,7 +3176,7 @@ class UnaryExpression(ColumnElement[_T]): @classmethod def _create_nulls_last( cls, - column: _ColumnExpression[_T], + column: _ColumnExpressionArgument[_T], ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), @@ -3163,7 +3186,7 @@ class UnaryExpression(ColumnElement[_T]): @classmethod def _create_desc( - cls, column: _ColumnExpression[_T] + cls, column: _ColumnExpressionArgument[_T] ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), @@ -3174,7 +3197,7 @@ class UnaryExpression(ColumnElement[_T]): @classmethod def _create_asc( cls, - column: _ColumnExpression[_T], + column: _ColumnExpressionArgument[_T], ) -> UnaryExpression[_T]: return UnaryExpression( coercions.expect(roles.ByOfRole, column), @@ -3185,7 +3208,7 @@ class UnaryExpression(ColumnElement[_T]): @classmethod def _create_distinct( cls, - expr: _ColumnExpression[_T], + expr: _ColumnExpressionArgument[_T], ) -> UnaryExpression[_T]: col_expr = coercions.expect(roles.ExpressionElementRole, expr) return UnaryExpression( @@ -3202,7 +3225,7 @@ class UnaryExpression(ColumnElement[_T]): else: return None - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return self.element._from_objects @@ -3238,7 +3261,7 @@ class CollectionAggregate(UnaryExpression[_T]): @classmethod def _create_any( - cls, expr: _ColumnExpression[_T] + cls, expr: _ColumnExpressionArgument[_T] ) -> CollectionAggregate[bool]: col_expr = coercions.expect( roles.ExpressionElementRole, @@ -3254,7 +3277,7 @@ class CollectionAggregate(UnaryExpression[_T]): @classmethod def _create_all( - cls, expr: _ColumnExpression[_T] + cls, expr: _ColumnExpressionArgument[_T] ) -> CollectionAggregate[bool]: col_expr = coercions.expect( roles.ExpressionElementRole, @@ -3431,7 +3454,7 @@ class BinaryExpression(ColumnElement[_T]): def is_comparison(self): return operators.is_comparison(self.operator) - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return self.left._from_objects + self.right._from_objects @@ -3557,7 +3580,7 @@ class Grouping(GroupedElement, ColumnElement[_T]): else: return [] - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return self.element._from_objects @@ -3614,10 +3637,16 @@ class Over(ColumnElement[_T]): self, element: ColumnElement[_T], partition_by: Optional[ - Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + Union[ + Iterable[_ColumnExpressionArgument[Any]], + _ColumnExpressionArgument[Any], + ] ] = None, order_by: Optional[ - Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + Union[ + Iterable[_ColumnExpressionArgument[Any]], + _ColumnExpressionArgument[Any], + ] ] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, @@ -3697,7 +3726,7 @@ class Over(ColumnElement[_T]): def type(self): return self.element.type - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return list( itertools.chain( @@ -3737,7 +3766,9 @@ class WithinGroup(ColumnElement[_T]): order_by: Optional[ClauseList] = None def __init__( - self, element: FunctionElement[_T], *order_by: _ColumnExpression[Any] + self, + element: FunctionElement[_T], + *order_by: _ColumnExpressionArgument[Any], ): self.element = element if order_by is not None: @@ -3774,7 +3805,7 @@ class WithinGroup(ColumnElement[_T]): else: return self.element.type - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return list( itertools.chain( @@ -3817,7 +3848,9 @@ class FunctionFilter(ColumnElement[_T]): criterion: Optional[ColumnElement[bool]] = None def __init__( - self, func: FunctionElement[_T], *criterion: _ColumnExpression[bool] + self, + func: FunctionElement[_T], + *criterion: _ColumnExpressionArgument[bool], ): self.func = func self.filter(*criterion) @@ -3847,10 +3880,16 @@ class FunctionFilter(ColumnElement[_T]): def over( self, partition_by: Optional[ - Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + Union[ + Iterable[_ColumnExpressionArgument[Any]], + _ColumnExpressionArgument[Any], + ] ] = None, order_by: Optional[ - Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + Union[ + Iterable[_ColumnExpressionArgument[Any]], + _ColumnExpressionArgument[Any], + ] ] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, @@ -3890,7 +3929,7 @@ class FunctionFilter(ColumnElement[_T]): def type(self): return self.func.type - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return list( itertools.chain( @@ -3903,7 +3942,97 @@ class FunctionFilter(ColumnElement[_T]): ) -class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]): +class NamedColumn(ColumnElement[_T]): + is_literal = False + table: Optional[FromClause] = None + name: str + key: str + + def _compare_name_for_result(self, other): + return (hasattr(other, "name") and self.name == other.name) or ( + hasattr(other, "_label") and self._label == other._label + ) + + @util.ro_memoized_property + def description(self) -> str: + return self.name + + @HasMemoized.memoized_attribute + def _tq_key_label(self): + """table qualified label based on column key. + + for table-bound columns this is _; + + all other expressions it resolves to key/proxy key. + + """ + proxy_key = self._proxy_key + if proxy_key and proxy_key != self.name: + return self._gen_tq_label(proxy_key) + else: + return self._tq_label + + @HasMemoized.memoized_attribute + def _tq_label(self) -> Optional[str]: + """table qualified label based on column name. + + for table-bound columns this is _; all other + expressions it resolves to .name. + + """ + return self._gen_tq_label(self.name) + + @HasMemoized.memoized_attribute + def _render_label_in_columns_clause(self): + return True + + @HasMemoized.memoized_attribute + def _non_anon_label(self): + return self.name + + def _gen_tq_label( + self, name: str, dedupe_on_key: bool = True + ) -> Optional[str]: + return name + + def _bind_param(self, operator, obj, type_=None, expanding=False): + return BindParameter( + self.key, + obj, + _compared_to_operator=operator, + _compared_to_type=self.type, + type_=type_, + unique=True, + expanding=expanding, + ) + + def _make_proxy( + self, + selectable, + name=None, + name_is_truncatable=False, + disallow_is_literal=False, + **kw, + ): + c = ColumnClause( + coercions.expect(roles.TruncatedLabelRole, name or self.name) + if name_is_truncatable + else (name or self.name), + type_=self.type, + _selectable=selectable, + is_literal=False, + ) + + c._propagate_attrs = selectable._propagate_attrs + if name is None: + c.key = self.key + c._proxies = [self] + if selectable._is_clone_of is not None: + c._is_clone_of = selectable._is_clone_of.columns.get(c.key) + return c.key, c + + +class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]): """Represents a column label (AS). Represent a label, as typically applied to any column-level @@ -3925,7 +4054,7 @@ class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]): def __init__( self, name: Optional[str], - element: _ColumnExpression[_T], + element: _ColumnExpressionArgument[_T], type_: Optional[_TypeEngineArgument[_T]] = None, ): orig_element = element @@ -3964,6 +4093,21 @@ class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]): def __reduce__(self): return self.__class__, (self.name, self._element, self.type) + @HasMemoized.memoized_attribute + def _render_label_in_columns_clause(self): + return True + + def _bind_param(self, operator, obj, type_=None, expanding=False): + return BindParameter( + None, + obj, + _compared_to_operator=operator, + type_=type_, + _compared_to_type=self.type, + unique=True, + expanding=expanding, + ) + @util.memoized_property def _is_implicitly_boolean(self): return self.element._is_implicitly_boolean @@ -4010,7 +4154,7 @@ class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]): ) self.key = self._tq_label = self._tq_key_label = self.name - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return self.element._from_objects @@ -4047,96 +4191,6 @@ class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]): return self.key, e -class NamedColumn(ColumnElement[_T]): - is_literal = False - table: Optional[FromClause] = None - name: str - key: str - - def _compare_name_for_result(self, other): - return (hasattr(other, "name") and self.name == other.name) or ( - hasattr(other, "_label") and self._label == other._label - ) - - @util.ro_memoized_property - def description(self) -> str: - return self.name - - @HasMemoized.memoized_attribute - def _tq_key_label(self): - """table qualified label based on column key. - - for table-bound columns this is _; - - all other expressions it resolves to key/proxy key. - - """ - proxy_key = self._proxy_key - if proxy_key and proxy_key != self.name: - return self._gen_tq_label(proxy_key) - else: - return self._tq_label - - @HasMemoized.memoized_attribute - def _tq_label(self) -> Optional[str]: - """table qualified label based on column name. - - for table-bound columns this is _; all other - expressions it resolves to .name. - - """ - return self._gen_tq_label(self.name) - - @HasMemoized.memoized_attribute - def _render_label_in_columns_clause(self): - return True - - @HasMemoized.memoized_attribute - def _non_anon_label(self): - return self.name - - def _gen_tq_label( - self, name: str, dedupe_on_key: bool = True - ) -> Optional[str]: - return name - - def _bind_param(self, operator, obj, type_=None, expanding=False): - return BindParameter( - self.key, - obj, - _compared_to_operator=operator, - _compared_to_type=self.type, - type_=type_, - unique=True, - expanding=expanding, - ) - - def _make_proxy( - self, - selectable, - name=None, - name_is_truncatable=False, - disallow_is_literal=False, - **kw, - ): - c = ColumnClause( - coercions.expect(roles.TruncatedLabelRole, name or self.name) - if name_is_truncatable - else (name or self.name), - type_=self.type, - _selectable=selectable, - is_literal=False, - ) - - c._propagate_attrs = selectable._propagate_attrs - if name is None: - c.key = self.key - c._proxies = [self] - if selectable._is_clone_of is not None: - c._is_clone_of = selectable._is_clone_of.columns.get(c.key) - return c.key, c - - class ColumnClause( roles.DDLReferredColumnRole, roles.LabeledColumnExprRole[_T], @@ -4242,7 +4296,7 @@ class ColumnClause( return super(ColumnClause, self)._clone(**kw) - @HasMemoized.memoized_attribute + @HasMemoized_ro_memoized_attribute def _from_objects(self) -> List[FromClause]: t = self.table if t is not None: @@ -4395,7 +4449,7 @@ class TableValuedColumn(NamedColumn[_T]): self.scalar_alias = clone(self.scalar_alias, **kw) self.key = self.name = self.scalar_alias.name - @util.non_memoized_property + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return [self.scalar_alias] @@ -4409,7 +4463,7 @@ class CollationClause(ColumnElement[str]): @classmethod def _create_collation_expression( - cls, expression: _ColumnExpression[str], collation: str + cls, expression: _ColumnExpressionArgument[str], collation: str ) -> BinaryExpression[str]: expr = coercions.expect(roles.ExpressionElementRole, expression) return BinaryExpression( diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 9e801a99f3..3bca8b502f 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -298,7 +298,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return self.alias(name=name).column - @property + @util.ro_non_memoized_property def columns(self): r"""The set of columns exported by this :class:`.FunctionElement`. @@ -320,6 +320,11 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): SQL function expressions. """ # noqa E501 + return self.c + + @util.ro_memoized_property + def c(self): + """synonym for :attr:`.FunctionElement.columns`.""" return ColumnCollection( columns=[(col.key, col) for col in self._all_selected_columns] diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index beb73c1b50..86725f86f5 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -6,27 +6,31 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php from __future__ import annotations -import typing from typing import Any from typing import Generic from typing import Iterable +from typing import List from typing import Optional +from typing import TYPE_CHECKING from typing import TypeVar from .. import util -from ..util import TypingOnly from ..util.typing import Literal -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from ._typing import _PropagateAttrsType + from ._typing import _SelectIterable + from .base import _EntityNamespace from .base import ColumnCollection + from .base import ReadOnlyColumnCollection from .elements import ClauseElement + from .elements import ColumnClause from .elements import ColumnElement from .elements import Label + from .elements import NamedColumn from .selectable import FromClause from .selectable import Subquery - _T = TypeVar("_T", bound=Any) @@ -109,7 +113,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole): _role_name = "Column expression or FROM clause" @property - def _select_iterable(self) -> Iterable[ColumnsClauseRole]: + def _select_iterable(self) -> _SelectIterable: raise NotImplementedError() @@ -202,32 +206,51 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole): _is_subquery = False - @property - def _hide_froms(self) -> Iterable[FromClause]: - raise NotImplementedError() + named_with_column: bool + + if TYPE_CHECKING: + + @util.ro_non_memoized_property + def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: + ... + + @util.ro_non_memoized_property + def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: + ... + + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: + ... + + @util.ro_non_memoized_property + def _hide_froms(self) -> Iterable[FromClause]: + ... + + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + ... class StrictFromClauseRole(FromClauseRole): __slots__ = () # does not allow text() or select() objects - c: ColumnCollection[Any] + if TYPE_CHECKING: - # this should be ->str , however, working around: - # https://github.com/python/mypy/issues/12440 - @util.ro_non_memoized_property - def description(self) -> str: - raise NotImplementedError() + @util.ro_non_memoized_property + def description(self) -> str: + ... class AnonymizedFromClauseRole(StrictFromClauseRole): __slots__ = () - # calls .alias() as a post processor - def _anonymous_fromclause( - self, name: Optional[str] = None, flat: bool = False - ) -> FromClause: - raise NotImplementedError() + if TYPE_CHECKING: + + def _anonymous_fromclause( + self, name: Optional[str] = None, flat: bool = False + ) -> FromClause: + ... class ReturnsRowsRole(SQLRole): @@ -283,6 +306,16 @@ class DMLTableRole(FromClauseRole): __slots__ = () _role_name = "subject table for an INSERT, UPDATE or DELETE" + if TYPE_CHECKING: + + @util.ro_non_memoized_property + def primary_key(self) -> Iterable[NamedColumn[Any]]: + ... + + @util.ro_non_memoized_property + def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: + ... + class DMLColumnRole(SQLRole): __slots__ = () @@ -315,36 +348,3 @@ class DDLReferredColumnRole(DDLConstraintColumnRole): _role_name = ( "String column name or Column object for DDL foreign key constraint" ) - - -class HasClauseElement(TypingOnly): - """indicates a class that has a __clause_element__() method""" - - __slots__ = () - - if typing.TYPE_CHECKING: - - def __clause_element__(self) -> ClauseElement: - ... - - -class HasColumnElementClauseElement(TypingOnly): - """indicates a class that has a __clause_element__() method""" - - __slots__ = () - - if typing.TYPE_CHECKING: - - def __clause_element__(self) -> ColumnElement[Any]: - ... - - -class HasFromClauseElement(HasClauseElement, TypingOnly): - """indicates a class that has a __clause_element__() method""" - - __slots__ = () - - if typing.TYPE_CHECKING: - - def __clause_element__(self) -> FromClause: - ... diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 7206cfdbab..0e3e24a14c 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -37,6 +37,7 @@ import typing from typing import Any from typing import Callable from typing import Dict +from typing import Iterator from typing import List from typing import MutableMapping from typing import Optional @@ -54,7 +55,6 @@ from . import ddl from . import roles from . import type_api from . import visitors -from .base import ColumnCollection from .base import DedupeColumnCollection from .base import DialectKWArgs from .base import Executable @@ -78,6 +78,7 @@ from ..util.typing import Protocol from ..util.typing import TypeGuard if typing.TYPE_CHECKING: + from .base import ReadOnlyColumnCollection from .type_api import TypeEngine from ..engine import Connection from ..engine import Engine @@ -273,6 +274,16 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): __visit_name__ = "table" + if TYPE_CHECKING: + + @util.ro_non_memoized_property + def primary_key(self) -> PrimaryKeyConstraint: + ... + + @util.ro_non_memoized_property + def foreign_keys(self) -> Set[ForeignKey]: + ... + constraints: Set[Constraint] """A collection of all :class:`_schema.Constraint` objects associated with this :class:`_schema.Table`. @@ -316,12 +327,18 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): ] if TYPE_CHECKING: - - @util.non_memoized_property - def columns(self) -> ColumnCollection[Column[Any]]: + # we are upgrading .c and .columns to return Column, not + # ColumnClause. mypy typically sees this as incompatible because + # the contract of TableClause is that we can put a ColumnClause + # into this collection. does not recognize its immutability + # for the moment. + @util.ro_non_memoized_property + def columns(self) -> ReadOnlyColumnCollection[str, Column[Any]]: # type: ignore # noqa: E501 ... - c: ColumnCollection[Column[Any]] + @util.ro_non_memoized_property + def c(self) -> ReadOnlyColumnCollection[str, Column[Any]]: # type: ignore # noqa: E501 + ... def _gen_cache_key(self, anon_map, bindparams): if self._annotations: @@ -737,7 +754,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): PrimaryKeyConstraint( _implicit_generated=True )._set_parent_with_dispatch(self) - self.foreign_keys = set() + self.foreign_keys = set() # type: ignore self._extra_dependencies = set() if self.schema is not None: self.fullname = "%s.%s" % (self.schema, self.name) @@ -3537,7 +3554,7 @@ class ColumnCollectionMixin: """ - columns: ColumnCollection[Column[Any]] + _columns: DedupeColumnCollection[Column[Any]] _allow_multiple_tables = False @@ -3551,7 +3568,7 @@ class ColumnCollectionMixin: def __init__(self, *columns, **kw): _autoattach = kw.pop("_autoattach", True) self._column_flag = kw.pop("_column_flag", False) - self.columns = DedupeColumnCollection() + self._columns = DedupeColumnCollection() processed_expressions = kw.pop("_gather_expressions", None) if processed_expressions is not None: @@ -3624,6 +3641,14 @@ class ColumnCollectionMixin: ) ) + @util.ro_memoized_property + def columns(self) -> ReadOnlyColumnCollection[str, Column[Any]]: + return self._columns.as_readonly() + + @util.ro_memoized_property + def c(self) -> ReadOnlyColumnCollection[str, Column[Any]]: + return self._columns.as_readonly() + def _col_expressions(self, table: Table) -> List[Column[Any]]: return [ table.c[col] if isinstance(col, str) else col @@ -3635,7 +3660,7 @@ class ColumnCollectionMixin: assert isinstance(parent, Table) for col in self._col_expressions(parent): if col is not None: - self.columns.add(col) + self._columns.add(col) class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): @@ -3668,7 +3693,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): self, *columns, _autoattach=_autoattach, _column_flag=_column_flag ) - columns: DedupeColumnCollection[Column[Any]] + columns: ReadOnlyColumnCollection[str, Column[Any]] """A :class:`_expression.ColumnCollection` representing the set of columns for this constraint. @@ -3679,7 +3704,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): ColumnCollectionMixin._set_parent(self, table) def __contains__(self, x): - return x in self.columns + return x in self._columns @util.deprecated( "1.4", @@ -3708,7 +3733,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): initially=self.initially, *[ _copy_expression(expr, self.parent, target_table) - for expr in self.columns + for expr in self._columns ], **constraint_kwargs, ) @@ -3723,13 +3748,13 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): """ - return self.columns.contains_column(col) + return self._columns.contains_column(col) - def __iter__(self): - return iter(self.columns) + def __iter__(self) -> Iterator[Column[Any]]: + return iter(self._columns) - def __len__(self): - return len(self.columns) + def __len__(self) -> int: + return len(self._columns) class CheckConstraint(ColumnCollectionConstraint): @@ -4002,10 +4027,10 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): self._set_parent_with_dispatch(table) def _append_element(self, column: Column[Any], fk: ForeignKey) -> None: - self.columns.add(column) + self._columns.add(column) self.elements.append(fk) - columns: DedupeColumnCollection[Column[Any]] + columns: ReadOnlyColumnCollection[str, Column[Any]] """A :class:`_expression.ColumnCollection` representing the set of columns for this constraint. @@ -4072,7 +4097,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): """ if hasattr(self, "parent"): - return self.columns.keys() + return self._columns.keys() else: return [ col.key if isinstance(col, ColumnElement) else str(col) @@ -4095,7 +4120,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): "named '%s' is present." % (table.description, ke.args[0]) ) from ke - for col, fk in zip(self.columns, self.elements): + for col, fk in zip(self._columns, self.elements): if not hasattr(fk, "parent") or fk.parent is not col: fk._set_parent_with_dispatch(col) @@ -4226,7 +4251,11 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): table.constraints.add(self) table_pks = [c for c in table.c if c.primary_key] - if self.columns and table_pks and set(table_pks) != set(self.columns): + if ( + self._columns + and table_pks + and set(table_pks) != set(self._columns) + ): util.warn( "Table '%s' specifies columns %s as primary_key=True, " "not matching locally specified columns %s; setting the " @@ -4235,18 +4264,18 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): % ( table.name, ", ".join("'%s'" % c.name for c in table_pks), - ", ".join("'%s'" % c.name for c in self.columns), - ", ".join("'%s'" % c.name for c in self.columns), + ", ".join("'%s'" % c.name for c in self._columns), + ", ".join("'%s'" % c.name for c in self._columns), ) ) table_pks[:] = [] - for c in self.columns: + for c in self._columns: c.primary_key = True if c._user_defined_nullable is NULL_UNSPECIFIED: c.nullable = False if table_pks: - self.columns.extend(table_pks) + self._columns.extend(table_pks) def _reload(self, columns): """repopulate this :class:`.PrimaryKeyConstraint` given @@ -4272,14 +4301,14 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): for col in columns: col.primary_key = True - self.columns.extend(columns) + self._columns.extend(columns) PrimaryKeyConstraint._autoincrement_column._reset(self) self._set_parent_with_dispatch(self.table) def _replace(self, col): PrimaryKeyConstraint._autoincrement_column._reset(self) - self.columns.replace(col) + self._columns.replace(col) self.dispatch._sa_event_column_added_to_pk_constraint(self, col) @@ -4288,9 +4317,9 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): autoinc = self._autoincrement_column if autoinc is not None: - return [autoinc] + [c for c in self.columns if c is not autoinc] + return [autoinc] + [c for c in self._columns if c is not autoinc] else: - return list(self.columns) + return list(self._columns) @util.memoized_property def _autoincrement_column(self): @@ -4323,8 +4352,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): return False return True - if len(self.columns) == 1: - col = list(self.columns)[0] + if len(self._columns) == 1: + col = list(self._columns)[0] if col.autoincrement is True: _validate_autoinc(col, True) @@ -4337,7 +4366,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): else: autoinc = None - for col in self.columns: + for col in self._columns: if col.autoincrement is True: _validate_autoinc(col, True) if autoinc is not None: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 2f37317f26..24edc1caef 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -16,13 +16,14 @@ from __future__ import annotations import collections from enum import Enum import itertools -from operator import attrgetter import typing from typing import Any as TODO_Any from typing import Any from typing import Iterable +from typing import List from typing import NamedTuple from typing import Optional +from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar @@ -34,13 +35,15 @@ from . import roles from . import traversals from . import type_api from . import visitors -from ._typing import _ColumnsClauseElement +from ._typing import _ColumnsClauseArgument +from ._typing import is_column_element from .annotation import Annotated from .annotation import SupportsCloneAnnotations from .base import _clone from .base import _cloned_difference from .base import _cloned_intersection from .base import _entity_namespace_key +from .base import _EntityNamespace from .base import _expand_cloned from .base import _from_objects from .base import _generative @@ -78,6 +81,13 @@ and_ = BooleanClauseList.and_ _T = TypeVar("_T", bound=Any) +if TYPE_CHECKING: + from ._typing import _SelectIterable + from .base import ReadOnlyColumnCollection + from .elements import NamedColumn + from .schema import ForeignKey + from .schema import PrimaryKeyConstraint + class _OffsetLimitParam(BindParameter): inherit_cache = True @@ -111,8 +121,8 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): def selectable(self): return self - @property - def _all_selected_columns(self): + @util.non_memoized_property + def _all_selected_columns(self) -> _SelectIterable: """A sequence of column expression objects that represents the "selected" columns of this :class:`_expression.ReturnsRows`. @@ -457,7 +467,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): __visit_name__ = "fromclause" named_with_column = False - @property + @util.ro_non_memoized_property def _hide_froms(self) -> Iterable[FromClause]: return () @@ -707,10 +717,10 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ - return self.columns + return self.c - @util.memoized_property - def columns(self) -> ColumnCollection[Any]: + @util.ro_non_memoized_property + def columns(self) -> ReadOnlyColumnCollection[str, Any]: """A named-based collection of :class:`_expression.ColumnElement` objects maintained by this :class:`_expression.FromClause`. @@ -723,14 +733,23 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): :return: a :class:`.ColumnCollection` object. """ + return self.c + + @util.ro_memoized_property + def c(self) -> ReadOnlyColumnCollection[str, Any]: + """ + A synonym for :attr:`.FromClause.columns` + + :return: a :class:`.ColumnCollection` + """ if "_columns" not in self.__dict__: self._init_collections() self._populate_column_collection() - return self._columns.as_immutable() + return self._columns.as_readonly() - @property - def entity_namespace(self): + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: """Return a namespace used for name-based access in SQL expressions. This is the namespace that is used to resolve "filter_by()" type @@ -743,10 +762,10 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): alternative results. """ - return self.columns + return self.c - @util.memoized_property - def primary_key(self): + @util.ro_memoized_property + def primary_key(self) -> Iterable[NamedColumn[Any]]: """Return the iterable collection of :class:`_schema.Column` objects which comprise the primary key of this :class:`_selectable.FromClause`. @@ -759,8 +778,8 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self._populate_column_collection() return self.primary_key - @util.memoized_property - def foreign_keys(self): + @util.ro_memoized_property + def foreign_keys(self) -> Iterable[ForeignKey]: """Return the collection of :class:`_schema.ForeignKey` marker objects which this FromClause references. @@ -791,28 +810,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ - for key in ["_columns", "columns", "primary_key", "foreign_keys"]: + for key in ["_columns", "columns", "c", "primary_key", "foreign_keys"]: self.__dict__.pop(key, None) - # this is awkward. maybe there's a better way - if TYPE_CHECKING: - c: ColumnCollection[Any] - else: - c = property( - attrgetter("columns"), - doc=""" - A named-based collection of :class:`_expression.ColumnElement` - objects maintained by this :class:`_expression.FromClause`. - - The :attr:`_sql.FromClause.c` attribute is an alias for the - :attr:`_sql.FromClause.columns` attribute. - - :return: a :class:`.ColumnCollection` - - """, - ) - - _select_iterable = property(attrgetter("columns")) + @util.ro_non_memoized_property + def _select_iterable(self) -> _SelectIterable: + return self.c def _init_collections(self): assert "_columns" not in self.__dict__ @@ -820,8 +823,8 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): assert "foreign_keys" not in self.__dict__ self._columns = ColumnCollection() - self.primary_key = ColumnSet() - self.foreign_keys = set() + self.primary_key = ColumnSet() # type: ignore + self.foreign_keys = set() # type: ignore @property def _cols_populated(self): @@ -1050,9 +1053,7 @@ class Join(roles.DMLTableRole, FromClause): @util.preload_module("sqlalchemy.sql.util") def _populate_column_collection(self): sqlutil = util.preloaded.sql_util - columns = [c for c in self.left.columns] + [ - c for c in self.right.columns - ] + columns = [c for c in self.left.c] + [c for c in self.right.c] self.primary_key.extend( sqlutil.reduce_columns( @@ -1300,14 +1301,14 @@ class Join(roles.DMLTableRole, FromClause): .alias(name) ) - @property + @util.ro_non_memoized_property def _hide_froms(self) -> Iterable[FromClause]: return itertools.chain( *[_from_objects(x.left, x.right) for x in self._cloned_set] ) - @property - def _from_objects(self): + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: return [self] + self.left._from_objects + self.right._from_objects @@ -1415,7 +1416,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): self._reset_column_collection() @property - def _from_objects(self): + def _from_objects(self) -> List[FromClause]: return [self] @@ -2329,10 +2330,14 @@ class FromGrouping(GroupedElement, FromClause): def _init_collections(self): pass - @property + @util.ro_non_memoized_property def columns(self): return self.element.columns + @util.ro_non_memoized_property + def c(self): + return self.element.columns + @property def primary_key(self): return self.element.primary_key @@ -2350,12 +2355,12 @@ class FromGrouping(GroupedElement, FromClause): def _anonymous_fromclause(self, **kw): return FromGrouping(self.element._anonymous_fromclause(**kw)) - @property + @util.ro_non_memoized_property def _hide_froms(self) -> Iterable[FromClause]: return self.element._hide_froms - @property - def _from_objects(self): + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.element._from_objects def __getstate__(self): @@ -2436,6 +2441,16 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): if kw: raise exc.ArgumentError("Unsupported argument(s): %s" % list(kw)) + if TYPE_CHECKING: + + @util.ro_non_memoized_property + def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: + ... + + @util.ro_non_memoized_property + def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: + ... + def __str__(self): if self.schema is not None: return self.schema + "." + self.name @@ -2507,8 +2522,8 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): """ return util.preloaded.sql_dml.Delete(self) - @property - def _from_objects(self): + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: return [self] @@ -2669,11 +2684,14 @@ class Values(Generative, NamedFromClause): self._columns.add(c) c.table = self - @property - def _from_objects(self): + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: return [self] +SelfSelectBase = TypeVar("SelfSelectBase", bound=Any) + + class SelectBase( roles.SelectStatementRole, roles.DMLSelectRole, @@ -2697,12 +2715,27 @@ class SelectBase( _is_select_statement = True is_select = True - def _generate_fromclause_column_proxies(self, fromclause): + def _generate_fromclause_column_proxies( + self, fromclause: FromClause + ) -> None: raise NotImplementedError() - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: self._reset_memoizations() + def _generate_columns_plus_names( + self, anon_for_dupe_key: bool + ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: + raise NotImplementedError() + + def set_label_style( + self: SelfSelectBase, label_style: SelectLabelStyle + ) -> SelfSelectBase: + raise NotImplementedError() + + def get_label_style(self) -> SelectLabelStyle: + raise NotImplementedError() + @property def selected_columns(self): """A :class:`_expression.ColumnCollection` @@ -2733,8 +2766,8 @@ class SelectBase( """ raise NotImplementedError() - @property - def _all_selected_columns(self): + @util.non_memoized_property + def _all_selected_columns(self) -> _SelectIterable: """A sequence of expressions that correspond to what is rendered in the columns clause, including :class:`_sql.TextClause` constructs. @@ -2893,8 +2926,8 @@ class SelectBase( """ return Lateral._factory(self, name) - @property - def _from_objects(self): + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: return [self] def subquery(self, name=None): @@ -2979,6 +3012,8 @@ class SelectStatementGrouping(GroupedElement, SelectBase): _is_select_container = True + element: SelectBase + def __init__(self, element): self.element = coercions.expect(roles.SelectStatementRole, element) @@ -2990,19 +3025,15 @@ class SelectStatementGrouping(GroupedElement, SelectBase): return self def get_label_style(self) -> SelectLabelStyle: - return self._label_style + return self.element.get_label_style() def set_label_style( self, label_style: SelectLabelStyle - ) -> "SelectStatementGrouping": + ) -> SelectStatementGrouping: return SelectStatementGrouping( self.element.set_label_style(label_style) ) - @property - def _label_style(self): - return self.element._label_style - @property def select_statement(self): return self.element @@ -3010,17 +3041,18 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def self_group(self, against=None): return self - def _generate_columns_plus_names(self, anon_for_dupe_key): + def _generate_columns_plus_names( + self, anon_for_dupe_key: bool + ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: return self.element._generate_columns_plus_names(anon_for_dupe_key) - def _generate_fromclause_column_proxies(self, subquery): + def _generate_fromclause_column_proxies( + self, subquery: FromClause + ) -> None: self.element._generate_fromclause_column_proxies(subquery) - def _generate_proxy_for_new_column(self, column, subquery): - return self.element._generate_proxy_for_new_column(subquery) - - @property - def _all_selected_columns(self): + @util.non_memoized_property + def _all_selected_columns(self) -> _SelectIterable: return self.element._all_selected_columns @property @@ -3039,8 +3071,8 @@ class SelectStatementGrouping(GroupedElement, SelectBase): """ return self.element.selected_columns - @property - def _from_objects(self): + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: return self.element._from_objects @@ -3612,10 +3644,10 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return True return False - def _set_label_style(self, style): + def set_label_style(self, style): if self._label_style is not style: self = self._generate() - select_0 = self.selects[0]._set_label_style(style) + select_0 = self.selects[0].set_label_style(style) self.selects = [select_0] + self.selects[1:] return self @@ -3665,8 +3697,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect): for select in self.selects: select._refresh_for_new_column(column) - @property - def _all_selected_columns(self): + @util.non_memoized_property + def _all_selected_columns(self) -> _SelectIterable: return self.selects[0]._all_selected_columns @property @@ -3701,8 +3733,18 @@ class SelectState(util.MemoizedSlots, CompileState): "_label_resolve_dict", ) - class default_select_compile_options(CacheableOptions): - _cache_key_traversal = [] + if TYPE_CHECKING: + default_select_compile_options: CacheableOptions + else: + + class default_select_compile_options(CacheableOptions): + _cache_key_traversal = [] + + if TYPE_CHECKING: + + @classmethod + def get_plugin_class(cls, statement: Select) -> SelectState: + ... def __init__(self, statement, compiler, **kw): self.statement = statement @@ -3966,7 +4008,7 @@ class SelectState(util.MemoizedSlots, CompileState): return None @classmethod - def all_selected_columns(cls, statement): + def all_selected_columns(cls, statement: Select) -> _SelectIterable: return [c for c in _select_iterables(statement._raw_columns)] def _setup_joins(self, args, raw_columns): @@ -4205,15 +4247,17 @@ class Select( _memoized_select_entities: Tuple[TODO_Any, ...] = () _distinct = False - _distinct_on: Tuple[ColumnElement, ...] = () + _distinct_on: Tuple[ColumnElement[Any], ...] = () _correlate: Tuple[FromClause, ...] = () _correlate_except: Optional[Tuple[FromClause, ...]] = None - _where_criteria: Tuple[ColumnElement, ...] = () - _having_criteria: Tuple[ColumnElement, ...] = () + _where_criteria: Tuple[ColumnElement[Any], ...] = () + _having_criteria: Tuple[ColumnElement[Any], ...] = () _from_obj: Tuple[FromClause, ...] = () _auto_correlate = True - _compile_options = SelectState.default_select_compile_options + _compile_options: CacheableOptions = ( + SelectState.default_select_compile_options + ) _traverse_internals = ( [ @@ -4264,7 +4308,7 @@ class Select( stmt.__dict__.update(kw) return stmt - def __init__(self, *entities: _ColumnsClauseElement): + def __init__(self, *entities: _ColumnsClauseArgument): r"""Construct a new :class:`_expression.Select`. The public constructor for :class:`_expression.Select` is the @@ -4286,7 +4330,7 @@ class Select( cols = list(elem._select_iterable) return cols[0].type - def filter(self, *criteria): + def filter(self: SelfSelect, *criteria: ColumnElement[Any]) -> SelfSelect: """A synonym for the :meth:`_future.Select.where` method.""" return self.where(*criteria) @@ -4896,7 +4940,7 @@ class Select( return self @property - def whereclause(self): + def whereclause(self) -> Optional[ColumnElement[Any]]: """Return the completed WHERE clause for this :class:`_expression.Select` statement. @@ -5161,12 +5205,12 @@ class Select( [ (conv(c), c) for c in self._all_selected_columns - if not c._is_text_clause + if is_column_element(c) ] - ).as_immutable() + ).as_readonly() @HasMemoized.memoized_attribute - def _all_selected_columns(self): + def _all_selected_columns(self) -> Sequence[ColumnElement[Any]]: meth = SelectState.get_plugin_class(self).all_selected_columns return list(meth(self)) @@ -5175,7 +5219,9 @@ class Select( self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY) return self - def _generate_columns_plus_names(self, anon_for_dupe_key): + def _generate_columns_plus_names( + self, anon_for_dupe_key: bool + ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: """Generate column names as rendered in a SELECT statement by the compiler. @@ -5805,13 +5851,13 @@ class TextualSelect(SelectBase): """ return ColumnCollection( (c.key, c) for c in self.column_args - ).as_immutable() + ).as_readonly() - @property - def _all_selected_columns(self): + @util.non_memoized_property + def _all_selected_columns(self) -> _SelectIterable: return self.column_args - def _set_label_style(self, style): + def set_label_style(self, style): return self def _ensure_disambiguated_names(self): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 5114a2431d..cdce49f7bc 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -15,6 +15,7 @@ from itertools import chain import typing from typing import Any from typing import cast +from typing import Iterator from typing import Optional from . import coercions @@ -33,6 +34,7 @@ from .elements import _find_columns # noqa from .elements import _label_reference from .elements import _textual_label_reference from .elements import BindParameter +from .elements import ClauseElement # noqa from .elements import ColumnClause from .elements import ColumnElement from .elements import Grouping @@ -51,6 +53,7 @@ from .. import exc from .. import util if typing.TYPE_CHECKING: + from .roles import FromClauseRole from ..engine.interfaces import _AnyExecuteParams from ..engine.interfaces import _AnyMultiExecuteParams from ..engine.interfaces import _AnySingleExecuteParams @@ -404,7 +407,7 @@ def clause_is_present(clause, search): return False -def tables_from_leftmost(clause): +def tables_from_leftmost(clause: FromClauseRole) -> Iterator[FromClause]: if isinstance(clause, Join): for t in tables_from_leftmost(clause.left): yield t diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 7e616cd74f..406c8af248 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -21,9 +21,7 @@ from ._collections import flatten_iterator as flatten_iterator from ._collections import has_dupes as has_dupes from ._collections import has_intersection as has_intersection from ._collections import IdentitySet as IdentitySet -from ._collections import ImmutableContainer as ImmutableContainer from ._collections import immutabledict as immutabledict -from ._collections import ImmutableProperties as ImmutableProperties from ._collections import LRUCache as LRUCache from ._collections import merge_lists_w_ordering as merge_lists_w_ordering from ._collections import ordered_column_set as ordered_column_set @@ -33,6 +31,8 @@ from ._collections import OrderedProperties as OrderedProperties from ._collections import OrderedSet as OrderedSet from ._collections import PopulateDict as PopulateDict from ._collections import Properties as Properties +from ._collections import ReadOnlyContainer as ReadOnlyContainer +from ._collections import ReadOnlyProperties as ReadOnlyProperties from ._collections import ScopedRegistry as ScopedRegistry from ._collections import sort_dictionary as sort_dictionary from ._collections import ThreadLocalRegistry as ThreadLocalRegistry @@ -107,6 +107,9 @@ from .langhelpers import get_func_kwargs as get_func_kwargs from .langhelpers import getargspec_init as getargspec_init from .langhelpers import has_compiled_ext as has_compiled_ext from .langhelpers import HasMemoized as HasMemoized +from .langhelpers import ( + HasMemoized_ro_memoized_attribute as HasMemoized_ro_memoized_attribute, +) from .langhelpers import hybridmethod as hybridmethod from .langhelpers import hybridproperty as hybridproperty from .langhelpers import inject_docstring_text as inject_docstring_text diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 2d974b7372..bd73bf7140 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -38,13 +38,13 @@ from .typing import Protocol if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_collections import immutabledict as immutabledict from ._py_collections import IdentitySet as IdentitySet - from ._py_collections import ImmutableContainer as ImmutableContainer + from ._py_collections import ReadOnlyContainer as ReadOnlyContainer from ._py_collections import ImmutableDictBase as ImmutableDictBase from ._py_collections import OrderedSet as OrderedSet from ._py_collections import unique_list as unique_list else: from sqlalchemy.cyextension.immutabledict import ( - ImmutableContainer as ImmutableContainer, + ReadOnlyContainer as ReadOnlyContainer, ) from sqlalchemy.cyextension.immutabledict import ( ImmutableDictBase as ImmutableDictBase, @@ -213,10 +213,10 @@ class Properties(Generic[_T]): def __contains__(self, key: str) -> bool: return key in self._data - def as_immutable(self) -> "ImmutableProperties[_T]": + def as_readonly(self) -> "ReadOnlyProperties[_T]": """Return an immutable proxy for this :class:`.Properties`.""" - return ImmutableProperties(self._data) + return ReadOnlyProperties(self._data) def update(self, value): self._data.update(value) @@ -263,7 +263,7 @@ class OrderedProperties(Properties[_T]): Properties.__init__(self, OrderedDict()) -class ImmutableProperties(ImmutableContainer, Properties[_T]): +class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]): """Provide immutable dict/object attribute to an underlying dictionary.""" __slots__ = () diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index d503529303..1016871aa7 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -29,37 +29,45 @@ _KT = TypeVar("_KT", bound=Any) _VT = TypeVar("_VT", bound=Any) -class ImmutableContainer: +class ReadOnlyContainer: __slots__ = () + def _readonly(self, *arg: Any, **kw: Any) -> NoReturn: + raise TypeError( + "%s object is immutable and/or readonly" % self.__class__.__name__ + ) + def _immutable(self, *arg: Any, **kw: Any) -> NoReturn: raise TypeError("%s object is immutable" % self.__class__.__name__) def __delitem__(self, key: Any) -> NoReturn: - self._immutable() + self._readonly() def __setitem__(self, key: Any, value: Any) -> NoReturn: - self._immutable() + self._readonly() def __setattr__(self, key: str, value: Any) -> NoReturn: - self._immutable() + self._readonly() -class ImmutableDictBase(ImmutableContainer, Dict[_KT, _VT]): - def clear(self) -> NoReturn: +class ImmutableDictBase(ReadOnlyContainer, Dict[_KT, _VT]): + def _readonly(self, *arg: Any, **kw: Any) -> NoReturn: self._immutable() + def clear(self) -> NoReturn: + self._readonly() + def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn: - self._immutable() + self._readonly() def popitem(self) -> NoReturn: - self._immutable() + self._readonly() def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn: - self._immutable() + self._readonly() def update(self, *arg: Any, **kw: Any) -> NoReturn: - self._immutable() + self._readonly() class immutabledict(ImmutableDictBase[_KT, _VT]): diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 8cf50c724c..9e1194e231 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1248,6 +1248,7 @@ if TYPE_CHECKING: # of a property, meaning assignment needs to be disallowed ro_memoized_property = property ro_non_memoized_property = property + else: memoized_property = ro_memoized_property = _memoized_property non_memoized_property = ro_non_memoized_property = _non_memoized_property @@ -1348,6 +1349,12 @@ class HasMemoized: return update_wrapper(oneshot, fn) +if TYPE_CHECKING: + HasMemoized_ro_memoized_attribute = property +else: + HasMemoized_ro_memoized_attribute = HasMemoized.memoized_attribute + + class MemoizedSlots: """Apply memoized items to an object using a __getattr__ scheme. diff --git a/pyproject.toml b/pyproject.toml index aa2790b049..cc79e86469 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,6 @@ module = [ "sqlalchemy.sql.selectable", # would be nice as strict "sqlalchemy.sql.functions", # would be nice as strict "sqlalchemy.sql.lambdas", - "sqlalchemy.sql.dml", # would be nice as strict "sqlalchemy.sql.util", # not yet classified: diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 67fcc88705..fc61e39b65 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -357,8 +357,8 @@ class ImmutableTest(fixtures.TestBase): with expect_raises_message(TypeError, "object is immutable"): m() - def test_immutable_properties(self): - d = util.ImmutableProperties({3: 4}) + def test_readonly_properties(self): + d = util.ReadOnlyProperties({3: 4}) calls = ( lambda: d.__delitem__(1), lambda: d.__setitem__(2, 3), @@ -563,7 +563,7 @@ class ColumnCollectionCommon(testing.AssertsCompiledSQL): eq_(keys, ["c1", "foo", "c3"]) ne_(id(keys), id(cc.keys())) - ci = cc.as_immutable() + ci = cc.as_readonly() eq_(ci.keys(), ["c1", "foo", "c3"]) def test_values(self): @@ -576,7 +576,7 @@ class ColumnCollectionCommon(testing.AssertsCompiledSQL): eq_(val, [c1, c2, c3]) ne_(id(val), id(cc.values())) - ci = cc.as_immutable() + ci = cc.as_readonly() eq_(ci.values(), [c1, c2, c3]) def test_items(self): @@ -589,7 +589,7 @@ class ColumnCollectionCommon(testing.AssertsCompiledSQL): eq_(items, [("c1", c1), ("foo", c2), ("c3", c3)]) ne_(id(items), id(cc.items())) - ci = cc.as_immutable() + ci = cc.as_readonly() eq_(ci.items(), [("c1", c1), ("foo", c2), ("c3", c3)]) def test_key_index_error(self): @@ -732,7 +732,7 @@ class ColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): self._assert_collection_integrity(cc) - ci = cc.as_immutable() + ci = cc.as_readonly() eq_(ci._all_columns, [c1, c2a, c3, c2b]) eq_(list(ci), [c1, c2a, c3, c2b]) eq_(ci.keys(), ["c1", "c2", "c3", "c2"]) @@ -763,7 +763,7 @@ class ColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): self._assert_collection_integrity(cc) - ci = cc.as_immutable() + ci = cc.as_readonly() eq_(ci._all_columns, [c1, c2a, c3, c2b]) eq_(list(ci), [c1, c2a, c3, c2b]) eq_(ci.keys(), ["c1", "c2", "c3", "c2"]) @@ -786,7 +786,7 @@ class ColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): assert cc.contains_column(c2) self._assert_collection_integrity(cc) - ci = cc.as_immutable() + ci = cc.as_readonly() eq_(ci._all_columns, [c1, c2, c3, c2]) eq_(list(ci), [c1, c2, c3, c2]) @@ -821,7 +821,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): c2.key = "foo" cc = self._column_collection(columns=[("c1", c1), ("foo", c2)]) - ci = cc.as_immutable() + ci = cc.as_readonly() d = {"cc": cc, "ci": ci} @@ -922,7 +922,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): assert cc.contains_column(c2) self._assert_collection_integrity(cc) - ci = cc.as_immutable() + ci = cc.as_readonly() eq_(ci._all_columns, [c1, c2, c3]) eq_(list(ci), [c1, c2, c3]) @@ -944,13 +944,13 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): assert cc.contains_column(c2) self._assert_collection_integrity(cc) - ci = cc.as_immutable() + ci = cc.as_readonly() eq_(ci._all_columns, [c1, c2, c3]) eq_(list(ci), [c1, c2, c3]) def test_replace(self): cc = DedupeColumnCollection() - ci = cc.as_immutable() + ci = cc.as_readonly() c1, c2a, c3, c2b = ( column("c1"), @@ -979,7 +979,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): def test_replace_key_matches_name_of_another(self): cc = DedupeColumnCollection() - ci = cc.as_immutable() + ci = cc.as_readonly() c1, c2a, c3, c2b = ( column("c1"), @@ -1009,7 +1009,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): def test_replace_key_matches(self): cc = DedupeColumnCollection() - ci = cc.as_immutable() + ci = cc.as_readonly() c1, c2a, c3, c2b = ( column("c1"), @@ -1041,7 +1041,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): def test_replace_name_matches(self): cc = DedupeColumnCollection() - ci = cc.as_immutable() + ci = cc.as_readonly() c1, c2a, c3, c2b = ( column("c1"), @@ -1073,7 +1073,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): def test_replace_no_match(self): cc = DedupeColumnCollection() - ci = cc.as_immutable() + ci = cc.as_readonly() c1, c2, c3, c4 = column("c1"), column("c2"), column("c3"), column("c4") c4.key = "X" @@ -1123,7 +1123,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): cc = DedupeColumnCollection( columns=[("c1", c1), ("c2", c2), ("c3", c3)] ) - ci = cc.as_immutable() + ci = cc.as_readonly() eq_(cc._all_columns, [c1, c2, c3]) eq_(list(cc), [c1, c2, c3]) @@ -1184,7 +1184,7 @@ class DedupeColumnCollectionTest(ColumnCollectionCommon, fixtures.TestBase): def test_dupes_extend(self): cc = DedupeColumnCollection() - ci = cc.as_immutable() + ci = cc.as_readonly() c1, c2a, c3, c2b = ( column("c1"), @@ -3044,7 +3044,7 @@ class TestProperties(fixtures.TestBase): def test_pickle_immuatbleprops(self): data = {"hello": "bla"} - props = util.Properties(data).as_immutable() + props = util.Properties(data).as_readonly() for loader, dumper in picklers(): s = dumper(props) diff --git a/test/profiles.txt b/test/profiles.txt index 074b649f2e..31f72bd166 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -69,17 +69,17 @@ test.aaa_profiling.test_compiler.CompileTest.test_update x86_64_linux_cpython_3. # TEST: test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 174 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 174 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 174 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 174 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 174 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 174 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 174 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 174 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 174 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 170 -test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 173 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 180 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 180 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 180 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 180 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 180 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 180 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 180 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 180 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 180 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 180 +test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 180 # TEST: test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_cached diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py index 9812f84c15..7d90bc67b1 100644 --- a/test/sql/test_quote.py +++ b/test/sql/test_quote.py @@ -252,6 +252,50 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): eq_(repr(name), repr("姓名")) + def test_literal_column_label_embedded_select_samename_explicit_quote( + self, + ): + col = sql.literal_column("NEEDS QUOTES").label( + quoted_name("NEEDS QUOTES", True) + ) + + self.assert_compile( + select(col).subquery().select(), + 'SELECT anon_1."NEEDS QUOTES" FROM ' + '(SELECT NEEDS QUOTES AS "NEEDS QUOTES") AS anon_1', + ) + + def test_literal_column_label_embedded_select_diffname_explicit_quote( + self, + ): + col = sql.literal_column("NEEDS QUOTES").label( + quoted_name("NEEDS QUOTES_", True) + ) + + self.assert_compile( + select(col).subquery().select(), + 'SELECT anon_1."NEEDS QUOTES_" FROM ' + '(SELECT NEEDS QUOTES AS "NEEDS QUOTES_") AS anon_1', + ) + + def test_literal_column_label_embedded_select_diffname(self): + col = sql.literal_column("NEEDS QUOTES").label("NEEDS QUOTES_") + + self.assert_compile( + select(col).subquery().select(), + 'SELECT anon_1."NEEDS QUOTES_" FROM (SELECT NEEDS QUOTES AS ' + '"NEEDS QUOTES_") AS anon_1', + ) + + def test_literal_column_label_embedded_select_samename(self): + col = sql.literal_column("NEEDS QUOTES").label("NEEDS QUOTES") + + self.assert_compile( + select(col).subquery().select(), + 'SELECT anon_1."NEEDS QUOTES" FROM (SELECT NEEDS QUOTES AS ' + '"NEEDS QUOTES") AS anon_1', + ) + def test_lower_case_names(self): # Create table with quote defaults metadata = MetaData() diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 138e7a4c6f..ffbab32237 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -1,6 +1,7 @@ import itertools from sqlalchemy import Boolean +from sqlalchemy import column from sqlalchemy import delete from sqlalchemy import exc as sa_exc from sqlalchemy import func @@ -10,9 +11,11 @@ from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import Sequence from sqlalchemy import String +from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import type_coerce from sqlalchemy import update +from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import AssertsExecutionResults @@ -88,6 +91,113 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL): t.c.x, ) + def test_named_expressions_selected_columns(self, table_fixture): + table = table_fixture + stmt = ( + table.insert() + .values(goofy="someOTHERgoofy") + .returning(func.lower(table.c.x).label("goof")) + ) + self.assert_compile( + select(stmt.exported_columns.goof), + "SELECT lower(foo.x) AS goof FROM foo", + ) + + def test_anon_expressions_selected_columns(self, table_fixture): + table = table_fixture + stmt = ( + table.insert() + .values(goofy="someOTHERgoofy") + .returning(func.lower(table.c.x)) + ) + self.assert_compile( + select(stmt.exported_columns[0]), + "SELECT lower(foo.x) AS lower_1 FROM foo", + ) + + def test_returning_fromclause(self): + t = table("t", column("x"), column("y"), column("z")) + stmt = t.update().returning(t) + + self.assert_compile( + stmt, + "UPDATE t SET x=%(x)s, y=%(y)s, z=%(z)s RETURNING t.x, t.y, t.z", + ) + + eq_( + stmt.returning_column_descriptions, + [ + { + "name": "x", + "type": testing.eq_type_affinity(NullType), + "expr": t.c.x, + }, + { + "name": "y", + "type": testing.eq_type_affinity(NullType), + "expr": t.c.y, + }, + { + "name": "z", + "type": testing.eq_type_affinity(NullType), + "expr": t.c.z, + }, + ], + ) + + cte = stmt.cte("c") + + stmt = select(cte.c.z) + self.assert_compile( + stmt, + "WITH c AS (UPDATE t SET x=%(x)s, y=%(y)s, z=%(z)s " + "RETURNING t.x, t.y, t.z) SELECT c.z FROM c", + ) + + def test_returning_inspectable(self): + t = table("t", column("x"), column("y"), column("z")) + + class HasClauseElement: + def __clause_element__(self): + return t + + stmt = update(HasClauseElement()).returning(HasClauseElement()) + + eq_( + stmt.returning_column_descriptions, + [ + { + "name": "x", + "type": testing.eq_type_affinity(NullType), + "expr": t.c.x, + }, + { + "name": "y", + "type": testing.eq_type_affinity(NullType), + "expr": t.c.y, + }, + { + "name": "z", + "type": testing.eq_type_affinity(NullType), + "expr": t.c.z, + }, + ], + ) + + self.assert_compile( + stmt, + "UPDATE t SET x=%(x)s, y=%(y)s, z=%(z)s " + "RETURNING t.x, t.y, t.z", + ) + cte = stmt.cte("c") + + stmt = select(cte.c.z) + self.assert_compile( + stmt, + "WITH c AS (UPDATE t SET x=%(x)s, y=%(y)s, z=%(z)s " + "RETURNING t.x, t.y, t.z) SELECT c.z FROM c", + ) + class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): __requires__ = ("returning",) diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 4944f2d57c..ca5f43bb6e 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -208,6 +208,24 @@ class SelectableTest( {"name": "table1", "table": table1}, [], ), + ( + table1.alias("some_alias"), + None, + { + "name": "some_alias", + "table": testing.eq_clause_element(table1.alias("some_alias")), + }, + [], + ), + ( + table1.join(table2), + None, + { + "name": None, + "table": testing.eq_clause_element(table1.join(table2)), + }, + [], + ), argnames="entity, cols, expected_entity, expected_returning", ) def test_dml_descriptions(