From: Mike Bayer Date: Sun, 13 Mar 2022 17:37:11 +0000 (-0400) Subject: pep-484 - SQL column operations X-Git-Tag: rel_2_0_0b1~424 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6acf5d2fca4a988a77481b82662174e8015a6b37;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep-484 - SQL column operations note we are taking out the ColumnOperartors[SQLCoreOperations] thing; not really clear why that was needed and at the moment it seems I was likely confused. Change-Id: I834b75f9b44f91b97e29f2e1a7b1029bd910e0a1 --- diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index d8009e26c6..5b66c537aa 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -39,7 +39,7 @@ from .. import util from ..sql import compiler from ..sql import util as sql_util -_CompiledCacheType = MutableMapping[Any, Any] +_CompiledCacheType = MutableMapping[Any, "Compiled"] if typing.TYPE_CHECKING: from . import Result @@ -1410,7 +1410,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): "schema_translate_map", None ) - compiled_cache: _CompiledCacheType = execution_options.get( + compiled_cache: Optional[_CompiledCacheType] = execution_options.get( "compiled_cache", self.engine._compiled_cache ) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index c9fb1ebf2c..ba34a0d421 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -33,6 +33,7 @@ from typing import Sequence from typing import Set from typing import Tuple from typing import Type +from typing import TYPE_CHECKING import weakref from . import characteristics @@ -46,11 +47,11 @@ from .interfaces import ExecutionContext from .. import event from .. import exc from .. import pool -from .. import TupleType from .. import types as sqltypes from .. import util from ..sql import compiler from ..sql import expression +from ..sql._typing import is_tuple_type from ..sql.compiler import DDLCompiler from ..sql.compiler import SQLCompiler from ..sql.elements import quoted_name @@ -80,6 +81,7 @@ if typing.TYPE_CHECKING: from ..sql.dml import DMLState from ..sql.elements import BindParameter from ..sql.schema import Column + from ..sql.schema import ColumnDefault from ..sql.type_api import TypeEngine # When we're handed literal SQL, ensure it's a SELECT query @@ -225,12 +227,6 @@ class DefaultDialect(Dialect): is_async = False - CACHE_HIT = CACHE_HIT - CACHE_MISS = CACHE_MISS - CACHING_DISABLED = CACHING_DISABLED - NO_CACHE_KEY = NO_CACHE_KEY - NO_DIALECT_SUPPORT = NO_DIALECT_SUPPORT - # TODO: this is not to be part of 2.0. implement rudimentary binary # literals for SQLite, PostgreSQL, MySQL only within # _Binary.literal_processor @@ -1128,13 +1124,15 @@ class DefaultExecutionContext(ExecutionContext): return self.root_connection.engine @util.memoized_property - def postfetch_cols(self) -> Optional[Sequence[Column[Any]]]: # type: ignore[override] # mypy#4125 # noqa E501 - assert isinstance(self.compiled, SQLCompiler) + def postfetch_cols(self) -> Optional[Sequence[Column[Any]]]: + if TYPE_CHECKING: + assert isinstance(self.compiled, SQLCompiler) return self.compiled.postfetch @util.memoized_property - def prefetch_cols(self) -> Optional[Sequence[Column[Any]]]: # type: ignore[override] # mypy#4125 # noqa E501 - assert isinstance(self.compiled, SQLCompiler) + def prefetch_cols(self) -> Optional[Sequence[Column[Any]]]: + if TYPE_CHECKING: + assert isinstance(self.compiled, SQLCompiler) if self.isinsert: return self.compiled.insert_prefetch elif self.isupdate: @@ -1144,7 +1142,8 @@ class DefaultExecutionContext(ExecutionContext): @util.memoized_property def returning_cols(self) -> Optional[Sequence[Column[Any]]]: - assert isinstance(self.compiled, SQLCompiler) + if TYPE_CHECKING: + assert isinstance(self.compiled, SQLCompiler) return self.compiled.returning @util.memoized_property @@ -1538,9 +1537,8 @@ class DefaultExecutionContext(ExecutionContext): continue if key in self._expanded_parameters: - if bindparam.type._is_tuple_type: - tup_type = cast(TupleType, bindparam.type) - num = len(tup_type.types) + if is_tuple_type(bindparam.type): + num = len(bindparam.type.types) dbtypes = inputsizes[bindparam] generic_inputsizes.extend( ( @@ -1550,7 +1548,7 @@ class DefaultExecutionContext(ExecutionContext): else paramname ), dbtypes[idx % num], - tup_type.types[idx % num], + bindparam.type.types[idx % num], ) for idx, paramname in enumerate( self._expanded_parameters[key] @@ -1758,10 +1756,14 @@ class DefaultExecutionContext(ExecutionContext): # get_update_default() for c in insert_prefetch: if c.default and not c.default.is_sequence and c.default.is_scalar: + if TYPE_CHECKING: + assert isinstance(c.default, ColumnDefault) scalar_defaults[c] = c.default.arg for c in update_prefetch: if c.onupdate and c.onupdate.is_scalar: + if TYPE_CHECKING: + assert isinstance(c.onupdate, ColumnDefault) scalar_defaults[c] = c.onupdate.arg for param in self.compiled_parameters: @@ -1793,6 +1795,8 @@ class DefaultExecutionContext(ExecutionContext): for c in compiled.insert_prefetch: if c.default and not c.default.is_sequence and c.default.is_scalar: + if TYPE_CHECKING: + assert isinstance(c.default, ColumnDefault) val = c.default.arg else: val = self.get_insert_default(c) diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index e13295d6d6..3ca30d1bc9 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -561,6 +561,12 @@ class Dialect(EventTarget): """ + CACHE_HIT = CacheStats.CACHE_HIT + CACHE_MISS = CacheStats.CACHE_MISS + CACHING_DISABLED = CacheStats.CACHING_DISABLED + NO_CACHE_KEY = CacheStats.NO_CACHE_KEY + NO_DIALECT_SUPPORT = CacheStats.NO_DIALECT_SUPPORT + dispatch: dispatcher[Dialect] name: str @@ -2243,11 +2249,11 @@ class ExecutionContext: executemany: bool """True if the parameters have determined this to be an executemany""" - prefetch_cols: Optional[Sequence[Column[Any]]] + prefetch_cols: util.generic_fn_descriptor[Optional[Sequence[Column[Any]]]] """a list of Column objects for which a client-side default was fired off. Applies to inserts and updates.""" - postfetch_cols: Optional[Sequence[Column[Any]]] + postfetch_cols: util.generic_fn_descriptor[Optional[Sequence[Column[Any]]]] """a list of Column objects for which a server-side default or inline SQL expression value was fired off. Applies to inserts and updates.""" diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index e490a4f03d..4d2b1d8b69 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -54,7 +54,6 @@ from ..orm.base import SQLORMOperations from ..sql import operators from ..sql import or_ from ..sql.elements import SQLCoreOperations -from ..sql.operators import ColumnOperators from ..util.typing import Literal from ..util.typing import Protocol from ..util.typing import Self @@ -492,9 +491,7 @@ _SelfAssociationProxyInstance = TypeVar( ) -class AssociationProxyInstance( - SQLORMOperations[_T], ColumnOperators[SQLORMOperations[_T]] -): +class AssociationProxyInstance(SQLORMOperations[_T]): """A per-class object that serves class- and object-specific results. This is used by :class:`.AssociationProxy` when it is invoked diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 8d5fb91d08..c5b0affd2e 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -10,8 +10,8 @@ from __future__ import annotations import typing from typing import Any from typing import Collection +from typing import Dict from typing import List -from typing import Mapping from typing import Optional from typing import overload from typing import Set @@ -711,16 +711,61 @@ def with_loader_criteria( ) +@overload +def relationship( + argument: str, + secondary=..., + *, + uselist: bool = ..., + collection_class: Literal[None] = ..., + primaryjoin=..., + secondaryjoin=..., + back_populates=..., + **kw: Any, +) -> Relationship[Any]: + ... + + +@overload +def relationship( + argument: str, + secondary=..., + *, + uselist: bool = ..., + collection_class: Type[Set] = ..., + primaryjoin=..., + secondaryjoin=..., + back_populates=..., + **kw: Any, +) -> Relationship[Set[Any]]: + ... + + +@overload +def relationship( + argument: str, + secondary=..., + *, + uselist: bool = ..., + collection_class: Type[List] = ..., + primaryjoin=..., + secondaryjoin=..., + back_populates=..., + **kw: Any, +) -> Relationship[List[Any]]: + ... + + @overload def relationship( argument: Optional[_RelationshipArgumentType[_T]], - secondary=None, + secondary=..., *, - uselist: Literal[False] = None, - collection_class: Literal[None] = None, - primaryjoin=None, - secondaryjoin=None, - back_populates=None, + uselist: Literal[False] = ..., + collection_class: Literal[None] = ..., + primaryjoin=..., + secondaryjoin=..., + back_populates=..., **kw: Any, ) -> Relationship[_T]: ... @@ -729,13 +774,13 @@ def relationship( @overload def relationship( argument: Optional[_RelationshipArgumentType[_T]], - secondary=None, + secondary=..., *, - uselist: Literal[True] = None, - collection_class: Literal[None] = None, - primaryjoin=None, - secondaryjoin=None, - back_populates=None, + uselist: Literal[True] = ..., + collection_class: Literal[None] = ..., + primaryjoin=..., + secondaryjoin=..., + back_populates=..., **kw: Any, ) -> Relationship[List[_T]]: ... @@ -744,13 +789,13 @@ def relationship( @overload def relationship( argument: Optional[_RelationshipArgumentType[_T]], - secondary=None, + secondary=..., *, - uselist: Union[Literal[None], Literal[True]] = None, - collection_class: Type[List] = None, - primaryjoin=None, - secondaryjoin=None, - back_populates=None, + uselist: Union[Literal[None], Literal[True]] = ..., + collection_class: Type[List] = ..., + primaryjoin=..., + secondaryjoin=..., + back_populates=..., **kw: Any, ) -> Relationship[List[_T]]: ... @@ -759,13 +804,13 @@ def relationship( @overload def relationship( argument: Optional[_RelationshipArgumentType[_T]], - secondary=None, + secondary=..., *, - uselist: Union[Literal[None], Literal[True]] = None, - collection_class: Type[Set] = None, - primaryjoin=None, - secondaryjoin=None, - back_populates=None, + uselist: Union[Literal[None], Literal[True]] = ..., + collection_class: Type[Set] = ..., + primaryjoin=..., + secondaryjoin=..., + back_populates=..., **kw: Any, ) -> Relationship[Set[_T]]: ... @@ -774,26 +819,26 @@ def relationship( @overload def relationship( argument: Optional[_RelationshipArgumentType[_T]], - secondary=None, + secondary=..., *, - uselist: Union[Literal[None], Literal[True]] = None, - collection_class: Type[Mapping[Any, Any]] = None, - primaryjoin=None, - secondaryjoin=None, - back_populates=None, + uselist: Union[Literal[None], Literal[True]] = ..., + collection_class: Type[Dict[Any, Any]] = ..., + primaryjoin=..., + secondaryjoin=..., + back_populates=..., **kw: Any, -) -> Relationship[Mapping[Any, _T]]: +) -> Relationship[Dict[Any, _T]]: ... @overload def relationship( argument: _RelationshipArgumentType[_T], - secondary=None, + secondary=..., *, - uselist: Literal[None] = None, - collection_class: Literal[None] = None, - primaryjoin=None, + uselist: Literal[None] = ..., + collection_class: Literal[None] = ..., + primaryjoin=..., secondaryjoin=None, back_populates=None, **kw: Any, @@ -803,14 +848,14 @@ def relationship( @overload def relationship( - argument: Optional[_RelationshipArgumentType[_T]] = None, - secondary=None, + argument: Optional[_RelationshipArgumentType[_T]] = ..., + secondary=..., *, - uselist: Literal[True] = None, - collection_class: Any = None, - primaryjoin=None, - secondaryjoin=None, - back_populates=None, + uselist: Literal[True] = ..., + collection_class: Any = ..., + primaryjoin=..., + secondaryjoin=..., + back_populates=..., **kw: Any, ) -> Relationship[Any]: ... @@ -818,14 +863,14 @@ def relationship( @overload def relationship( - argument: Literal[None] = None, - secondary=None, + argument: Literal[None] = ..., + secondary=..., *, - uselist: Optional[bool] = None, - collection_class: Any = None, - primaryjoin=None, - secondaryjoin=None, - back_populates=None, + uselist: Optional[bool] = ..., + collection_class: Any = ..., + primaryjoin=..., + secondaryjoin=..., + back_populates=..., **kw: Any, ) -> Relationship[Any]: ... diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index c4afdb3a9e..ce3a645adb 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -18,6 +18,7 @@ from __future__ import annotations from collections import namedtuple import operator +import typing from typing import Any from typing import List from typing import NamedTuple @@ -65,6 +66,9 @@ from ..sql import roles from ..sql import traversals from ..sql import visitors +if typing.TYPE_CHECKING: + from ..sql.elements import ColumnElement + _T = TypeVar("_T") @@ -84,6 +88,7 @@ class QueryableAttribute( roles.JoinTargetRole, roles.OnClauseRole, roles.ColumnsClauseRole, + roles.ExpressionElementRole[_T], sql_base.Immutable, sql_base.MemoizedHasCacheKey, ): @@ -265,7 +270,7 @@ class QueryableAttribute( def _annotations(self): return self.__clause_element__()._annotations - def __clause_element__(self): + def __clause_element__(self) -> ColumnElement[_T]: return self.expression @property diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 04fc07f61b..00ddbcca72 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -367,9 +367,7 @@ class MapperProperty( @inspection._self_inspects -class PropComparator( - SQLORMOperations[_T], operators.ColumnOperators[SQLORMOperations] -): +class PropComparator(SQLORMOperations[_T]): r"""Defines SQL operations for ORM mapped attributes. SQLAlchemy allows for operators to @@ -519,7 +517,7 @@ class PropComparator( else: return self._adapt_to_entity._adapt_element - @property + @util.non_memoized_property def info(self): return self.property.info diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 9f9ca90cb4..c01825b6d6 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -41,7 +41,6 @@ from .. import log from .. import sql from .. import util from ..sql import coercions -from ..sql import operators from ..sql import roles from ..sql import sqltypes from ..sql.schema import Column @@ -413,7 +412,6 @@ class ColumnProperty( class MappedColumn( SQLCoreOperations[_T], - operators.ColumnOperators[SQLCoreOperations], _IntrospectsAnnotations, _MapsColumns[_T], ): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index d4faf10e33..980002776e 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -64,6 +64,7 @@ from ..util.typing import is_origin_of if typing.TYPE_CHECKING: from .mapper import Mapper from ..engine import Row + from ..sql._typing import _PropagateAttrsType from ..sql.selectable import Alias _T = TypeVar("_T", bound=Any) @@ -1238,7 +1239,7 @@ class Bundle( is_bundle = True - _propagate_attrs = util.immutabledict() + _propagate_attrs: _PropagateAttrsType = util.immutabledict() def __init__(self, name, *exprs, **kw): r"""Construct a new :class:`.Bundle`. diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 9d15cdcc3a..770fbe40c9 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -9,17 +9,19 @@ from __future__ import annotations import typing from typing import Any -from typing import cast as _typing_cast +from typing import Callable +from typing import Iterable +from typing import Mapping from typing import Optional from typing import overload -from typing import Type +from typing import Sequence +from typing import Tuple as typing_Tuple from typing import TypeVar from typing import Union from . import coercions -from . import operators from . import roles -from .base import NO_ARG +from .base import _NoArg from .coercions import _document_text_coercion from .elements import BindParameter from .elements import BooleanClauseList @@ -35,18 +37,20 @@ from .elements import FunctionFilter from .elements import Label from .elements import Null from .elements import Over -from .elements import SQLCoreOperations from .elements import TextClause from .elements import True_ from .elements import Tuple from .elements import TypeCoerce from .elements import UnaryExpression from .elements import WithinGroup +from .functions import FunctionElement +from ..util.typing import Literal if typing.TYPE_CHECKING: - from elements import BinaryExpression - from . import sqltypes + from ._typing import _ColumnExpression + from ._typing import _TypeEngineArgument + from .elements import BinaryExpression from .functions import FunctionElement from .selectable import FromClause from .type_api import TypeEngine @@ -54,7 +58,7 @@ if typing.TYPE_CHECKING: _T = TypeVar("_T") -def all_(expr): +def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[_T]: """Produce an ALL expression. For dialects such as that of PostgreSQL, this operator applies @@ -108,7 +112,7 @@ def all_(expr): return CollectionAggregate._create_all(expr) -def and_(*clauses): +def and_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList: r"""Produce a conjunction of expressions joined by ``AND``. E.g.:: @@ -169,7 +173,7 @@ def and_(*clauses): return BooleanClauseList.and_(*clauses) -def any_(expr): +def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[_T]: """Produce an ANY expression. For dialects such as that of PostgreSQL, this operator applies @@ -223,7 +227,7 @@ def any_(expr): return CollectionAggregate._create_any(expr) -def asc(column): +def asc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: """Produce an ascending ``ORDER BY`` clause element. e.g.:: @@ -261,7 +265,9 @@ def asc(column): return UnaryExpression._create_asc(column) -def collate(expression, collation): +def collate( + expression: _ColumnExpression[str], collation: str +) -> BinaryExpression[str]: """Return the clause ``expression COLLATE collation``. e.g.:: @@ -282,7 +288,12 @@ def collate(expression, collation): return CollationClause._create_collation_expression(expression, collation) -def between(expr, lower_bound, upper_bound, symmetric=False): +def between( + expr: _ColumnExpression[_T], + lower_bound: Any, + upper_bound: Any, + symmetric: bool = False, +) -> BinaryExpression[bool]: """Produce a ``BETWEEN`` predicate clause. E.g.:: @@ -338,7 +349,9 @@ def between(expr, lower_bound, upper_bound, symmetric=False): return expr.between(lower_bound, upper_bound, symmetric=symmetric) -def outparam(key, type_=None): +def outparam( + key: str, type_: Optional[TypeEngine[_T]] = None +) -> BindParameter[_T]: """Create an 'OUT' parameter for usage in functions (stored procedures), for databases which support them. @@ -352,16 +365,16 @@ def outparam(key, type_=None): @overload -def not_(clause: "BinaryExpression[_T]") -> "BinaryExpression[_T]": +def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: ... @overload -def not_(clause: "ColumnElement[_T]") -> "UnaryExpression[_T]": +def not_(clause: _ColumnExpression[_T]) -> ColumnElement[_T]: ... -def not_(clause: "ColumnElement[_T]") -> "ColumnElement[_T]": +def not_(clause: _ColumnExpression[_T]) -> ColumnElement[_T]: """Return a negation of the given clause, i.e. ``NOT(clause)``. The ``~`` operator is also overloaded on all @@ -370,29 +383,21 @@ def not_(clause: "ColumnElement[_T]") -> "ColumnElement[_T]": """ - return operators.inv( - _typing_cast( - "ColumnElement[_T]", - coercions.expect(roles.ExpressionElementRole, clause), - ) - ) + return coercions.expect(roles.ExpressionElementRole, clause).__invert__() def bindparam( - key, - value=NO_ARG, - type_: Optional[Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]] = None, - unique=False, - required=NO_ARG, - quote=None, - callable_=None, - expanding=False, - isoutparam=False, - literal_execute=False, - _compared_to_operator=None, - _compared_to_type=None, - _is_crud=False, -) -> "BindParameter[_T]": + key: str, + value: Any = _NoArg.NO_ARG, + type_: Optional[TypeEngine[_T]] = None, + unique: bool = False, + required: Union[bool, Literal[_NoArg.NO_ARG]] = _NoArg.NO_ARG, + quote: Optional[bool] = None, + callable_: Optional[Callable[[], Any]] = None, + expanding: bool = False, + isoutparam: bool = False, + literal_execute: bool = False, +) -> BindParameter[_T]: r"""Produce a "bound expression". The return value is an instance of :class:`.BindParameter`; this @@ -636,13 +641,16 @@ def bindparam( expanding, isoutparam, literal_execute, - _compared_to_operator, - _compared_to_type, - _is_crud, ) -def case(*whens, value=None, else_=None) -> "Case[Any]": +def case( + *whens: Union[ + typing_Tuple[_ColumnExpression[bool], Any], Mapping[Any, Any] + ], + value: Optional[Any] = None, + else_: Optional[Any] = None, +) -> Case[Any]: r"""Produce a ``CASE`` expression. The ``CASE`` construct in SQL is a conditional object that @@ -767,9 +775,9 @@ def case(*whens, value=None, else_=None) -> "Case[Any]": def cast( - expression: ColumnElement, - type_: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], -) -> "Cast[_T]": + expression: _ColumnExpression[Any], + type_: _TypeEngineArgument[_T], +) -> Cast[_T]: r"""Produce a ``CAST`` expression. :func:`.cast` returns an instance of :class:`.Cast`. @@ -826,10 +834,10 @@ def cast( def column( text: str, - type_: Optional[Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]] = None, + type_: Optional[_TypeEngineArgument[_T]] = None, is_literal: bool = False, - _selectable: Optional["FromClause"] = None, -) -> "ColumnClause[_T]": + _selectable: Optional[FromClause] = None, +) -> ColumnClause[_T]: """Produce a :class:`.ColumnClause` object. The :class:`.ColumnClause` is a lightweight analogue to the @@ -921,12 +929,10 @@ def column( :ref:`sqlexpression_literal_column` """ - self = ColumnClause.__new__(ColumnClause) - self.__init__(text, type_, is_literal, _selectable) - return self + return ColumnClause(text, type_, is_literal, _selectable) -def desc(column): +def desc(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: """Produce a descending ``ORDER BY`` clause element. e.g.:: @@ -965,7 +971,7 @@ def desc(column): return UnaryExpression._create_desc(column) -def distinct(expr): +def distinct(expr: _ColumnExpression[_T]) -> UnaryExpression[_T]: """Produce an column-expression-level unary ``DISTINCT`` clause. This applies the ``DISTINCT`` keyword to an individual column @@ -1004,7 +1010,7 @@ def distinct(expr): return UnaryExpression._create_distinct(expr) -def extract(field: str, expr: ColumnElement) -> "Extract[sqltypes.Integer]": +def extract(field: str, expr: _ColumnExpression[Any]) -> Extract: """Return a :class:`.Extract` construct. This is typically available as :func:`.extract` @@ -1045,7 +1051,7 @@ def extract(field: str, expr: ColumnElement) -> "Extract[sqltypes.Integer]": return Extract(field, expr) -def false(): +def false() -> False_: """Return a :class:`.False_` construct. E.g.:: @@ -1083,7 +1089,9 @@ def false(): return False_._instance() -def funcfilter(func, *criterion) -> "FunctionFilter": +def funcfilter( + func: FunctionElement[_T], *criterion: _ColumnExpression[bool] +) -> FunctionFilter[_T]: """Produce a :class:`.FunctionFilter` object against a function. Used against aggregate and window functions, @@ -1114,8 +1122,8 @@ def funcfilter(func, *criterion) -> "FunctionFilter": def label( name: str, - element: ColumnElement[_T], - type_: Optional[Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]] = None, + element: _ColumnExpression[_T], + type_: Optional[_TypeEngineArgument[_T]] = None, ) -> "Label[_T]": """Return a :class:`Label` object for the given :class:`_expression.ColumnElement`. @@ -1135,13 +1143,13 @@ def label( return Label(name, element, type_) -def null(): +def null() -> Null: """Return a constant :class:`.Null` construct.""" return Null._instance() -def nulls_first(column): +def nulls_first(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: """Produce the ``NULLS FIRST`` modifier for an ``ORDER BY`` expression. :func:`.nulls_first` is intended to modify the expression produced @@ -1185,7 +1193,7 @@ def nulls_first(column): return UnaryExpression._create_nulls_first(column) -def nulls_last(column): +def nulls_last(column: _ColumnExpression[_T]) -> UnaryExpression[_T]: """Produce the ``NULLS LAST`` modifier for an ``ORDER BY`` expression. :func:`.nulls_last` is intended to modify the expression produced @@ -1229,7 +1237,7 @@ def nulls_last(column): return UnaryExpression._create_nulls_last(column) -def or_(*clauses: SQLCoreOperations) -> BooleanClauseList: +def or_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList: """Produce a conjunction of expressions joined by ``OR``. E.g.:: @@ -1281,12 +1289,16 @@ def or_(*clauses: SQLCoreOperations) -> BooleanClauseList: def over( - element: "FunctionElement[_T]", - partition_by=None, - order_by=None, - range_=None, - rows=None, -) -> "Over[_T]": + element: FunctionElement[_T], + partition_by: Optional[ + Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + ] = None, + order_by: Optional[ + Union[Iterable[_ColumnExpression[Any]], _ColumnExpression[Any]] + ] = None, + range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, +) -> Over[_T]: r"""Produce an :class:`.Over` object against a function. Used against aggregate or so-called "window" functions, @@ -1373,7 +1385,7 @@ def over( @_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`") -def text(text): +def text(text: str) -> TextClause: r"""Construct a new :class:`_expression.TextClause` clause, representing a textual SQL string directly. @@ -1451,7 +1463,7 @@ def text(text): return TextClause(text) -def true(): +def true() -> True_: """Return a constant :class:`.True_` construct. E.g.:: @@ -1489,7 +1501,10 @@ def true(): return True_._instance() -def tuple_(*clauses: roles.ExpressionElementRole, types=None) -> "Tuple": +def tuple_( + *clauses: _ColumnExpression[Any], + types: Optional[Sequence[_TypeEngineArgument[Any]]] = None, +) -> Tuple: """Return a :class:`.Tuple`. Main usage is to produce a composite IN construct using @@ -1516,9 +1531,9 @@ def tuple_(*clauses: roles.ExpressionElementRole, types=None) -> "Tuple": def type_coerce( - expression: "ColumnElement", - type_: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], -) -> "TypeCoerce[_T]": + expression: _ColumnExpression[Any], + type_: _TypeEngineArgument[_T], +) -> TypeCoerce[_T]: r"""Associate a SQL expression with a particular type, without rendering ``CAST``. @@ -1597,8 +1612,8 @@ def type_coerce( def within_group( - element: "FunctionElement[_T]", *order_by: roles.OrderByRole -) -> "WithinGroup[_T]": + element: FunctionElement[_T], *order_by: _ColumnExpression[Any] +) -> WithinGroup[_T]: r"""Produce a :class:`.WithinGroup` object against a function. Used against so-called "ordered set aggregate" and "hypothetical diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 69e4645fa6..389f7e8d00 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,14 +1,58 @@ from __future__ import annotations +from typing import Any from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import roles +from .. import util from ..inspection import Inspectable +if TYPE_CHECKING: + from .elements import quoted_name + from .schema import DefaultGenerator + from .schema import Sequence + from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import TableClause + from .sqltypes import TupleType + from .type_api import TypeEngine + from ..util.typing import TypeGuard + +_T = TypeVar("_T", bound=Any) + _ColumnsClauseElement = Union[ - roles.ColumnsClauseRole, Type, Inspectable[roles.HasClauseElement] + roles.ColumnsClauseRole, + Type, + Inspectable[roles.HasColumnElementClauseElement], ] _FromClauseElement = Union[ roles.FromClauseRole, Type, Inspectable[roles.HasFromClauseElement] ] + +_ColumnExpression = Union[ + roles.ExpressionElementRole[_T], + Inspectable[roles.HasColumnElementClauseElement], +] + +_PropagateAttrsType = util.immutabledict[str, Any] + +_TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] + + +def is_named_from_clause(t: FromClause) -> TypeGuard[NamedFromClause]: + return t.named_with_column + + +def has_schema_attr(t: FromClause) -> TypeGuard[TableClause]: + return hasattr(t, "schema") + + +def is_quoted_name(s: str) -> TypeGuard[quoted_name]: + return hasattr(s, "quote") + + +def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: + return t._is_tuple_type diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 7afc2de977..f37ae9a60d 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -18,11 +18,11 @@ from typing import Any from typing import Callable from typing import cast from typing import Dict +from typing import FrozenSet from typing import Mapping from typing import Optional from typing import overload from typing import Sequence -from typing import Set from typing import Tuple from typing import Type from typing import TypeVar @@ -53,7 +53,9 @@ class SupportsAnnotations(ExternallyTraversible): __slots__ = () _annotations: util.immutabledict[str, Any] = EMPTY_ANNOTATIONS - proxy_set: Set[SupportsAnnotations] + + proxy_set: util.generic_fn_descriptor[FrozenSet[Any]] + _is_immutable: bool def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations: diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a408a010a0..29f9028c8b 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -13,16 +13,20 @@ from __future__ import annotations import collections.abc as collections_abc +from enum import Enum from functools import reduce import itertools from itertools import zip_longest import operator import re import typing +from typing import Any +from typing import Iterable +from typing import List from typing import MutableMapping from typing import Optional from typing import Sequence -from typing import Set +from typing import Tuple from typing import TypeVar from . import roles @@ -38,23 +42,33 @@ from .. import util from ..util import HasMemoized as HasMemoized from ..util import hybridmethod from ..util import typing as compat_typing +from ..util.typing import Self if typing.TYPE_CHECKING: + from .elements import BindParameter from .elements import ColumnElement from ..engine import Connection from ..engine import Result + from ..engine.base import _CompiledCacheType from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import _ExecuteOptions from ..engine.interfaces import _ExecuteOptionsParameter from ..engine.interfaces import _ImmutableExecuteOptions + from ..engine.interfaces import _SchemaTranslateMapType from ..engine.interfaces import CacheStats - + from ..engine.interfaces import Compiled + from ..engine.interfaces import Dialect coercions = None elements = None type_api = None -NO_ARG = util.symbol("NO_ARG") + +class _NoArg(Enum): + NO_ARG = 0 + + +NO_ARG = _NoArg.NO_ARG # if I use sqlalchemy.util.typing, which has the exact same # symbols, mypy reports: "error: _Fn? not callable" @@ -74,10 +88,12 @@ class Immutable: def params(self, *optionaldict, **kwargs): raise NotImplementedError("Immutable objects do not support copying") - def _clone(self, **kw): + def _clone(self: Self, **kw: Any) -> Self: return self - def _copy_internals(self, **kw): + def _copy_internals( + self, omit_attrs: Iterable[str] = (), **kw: Any + ) -> None: pass @@ -88,8 +104,6 @@ class SingletonConstant(Immutable): _singleton: SingletonConstant - proxy_set: Set[ColumnElement] - def __new__(cls, *arg, **kw): return cls._singleton @@ -877,12 +891,15 @@ class Executable(roles.StatementRole, Generative): def _compile_w_cache( self, dialect: Dialect, - compiled_cache: Optional[_CompiledCacheType] = None, - column_keys: Optional[Sequence[str]] = None, + *, + compiled_cache: Optional[_CompiledCacheType], + column_keys: List[str], for_executemany: bool = False, schema_translate_map: Optional[_SchemaTranslateMapType] = None, **kw: Any, - ) -> Tuple[Compiled, _SingleExecuteParams, CacheStats]: + ) -> Tuple[ + Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats + ]: ... def _execute_on_connection( diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index fca58f98e5..19a232c563 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -11,15 +11,14 @@ import enum from itertools import zip_longest import typing from typing import Any -from typing import cast from typing import Dict from typing import Iterator from typing import List +from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import Sequence from typing import Tuple -from typing import Type from typing import Union from .visitors import anon_map @@ -91,7 +90,7 @@ class HasCacheKey: __slots__ = () _cache_key_traversal: Union[ - _TraverseInternalsType, Literal[CacheConst.NO_CACHE] + _TraverseInternalsType, Literal[CacheConst.NO_CACHE], Literal[None] ] = NO_CACHE _is_has_cache_key = True @@ -147,11 +146,8 @@ class HasCacheKey: _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) if _cache_key_traversal is None: try: - # check for _traverse_internals, which is part of - # HasTraverseInternals - _cache_key_traversal = cast( - "Type[HasTraverseInternals]", cls - )._traverse_internals + assert issubclass(cls, HasTraverseInternals) + _cache_key_traversal = cls._traverse_internals except AttributeError: cls._generated_cache_key_traversal = NO_CACHE return NO_CACHE @@ -417,7 +413,7 @@ class CacheKey(NamedTuple): def to_offline_string( self, - statement_cache: _CompiledCacheType, + statement_cache: MutableMapping[Any, str], statement: ClauseElement, parameters: _CoreSingleExecuteParams, ) -> str: diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 834bfb75d4..ea17b8e037 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -13,10 +13,12 @@ import re import typing from typing import Any from typing import Any as TODO_Any +from typing import Callable from typing import Dict from typing import List from typing import NoReturn from typing import Optional +from typing import overload from typing import Type from typing import TypeVar @@ -46,9 +48,14 @@ if typing.TYPE_CHECKING: from . import traversals from .elements import ClauseElement from .elements import ColumnClause + from .elements import ColumnElement + from .elements import SQLCoreOperations + _SR = TypeVar("_SR", bound=roles.SQLRole) +_F = TypeVar("_F", bound=Callable[..., Any]) _StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole) +_T = TypeVar("_T", bound=Any) def _is_literal(element): @@ -104,7 +111,9 @@ def _deep_is_literal(element): ) -def _document_text_coercion(paramname, meth_rst, param_rst): +def _document_text_coercion( + paramname: str, meth_rst: str, param_rst: str +) -> Callable[[_F], _F]: return util.add_parameter_text( paramname, ( @@ -132,15 +141,50 @@ def _expression_collection_was_a_list(attrname, fnname, args): return args -# TODO; would like to have overloads here, however mypy is being extremely -# pedantic about them. not sure why pylance is OK with them. +@overload +def expect( + role: Type[roles.TruncatedLabelRole], + element: Any, + *, + apply_propagate_attrs: Optional[ClauseElement] = None, + argname: Optional[str] = None, + post_inspect: bool = False, + **kw: Any, +) -> str: + ... + + +@overload +def expect( + role: Type[roles.ExpressionElementRole[_T]], + element: Any, + *, + apply_propagate_attrs: Optional[ClauseElement] = None, + argname: Optional[str] = None, + post_inspect: bool = False, + **kw: Any, +) -> ColumnElement[_T]: + ... +@overload def expect( role: Type[_SR], element: Any, *, - apply_propagate_attrs: Optional["ClauseElement"] = None, + apply_propagate_attrs: Optional[ClauseElement] = None, + argname: Optional[str] = None, + post_inspect: bool = False, + **kw: Any, +) -> TODO_Any: + ... + + +def expect( + role: Type[_SR], + element: Any, + *, + apply_propagate_attrs: Optional[ClauseElement] = None, argname: Optional[str] = None, post_inspect: bool = False, **kw: Any, @@ -220,12 +264,16 @@ def expect( resolved = element else: resolved = element - if ( - apply_propagate_attrs is not None - and not apply_propagate_attrs._propagate_attrs - and resolved._propagate_attrs - ): - apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs + + if apply_propagate_attrs is not None: + if typing.TYPE_CHECKING: + assert isinstance(resolved, (SQLCoreOperations, ClauseElement)) + + if ( + not apply_propagate_attrs._propagate_attrs + and resolved._propagate_attrs + ): + apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs if impl._role_class in resolved.__class__.__mro__: if impl._post_coercion: @@ -620,8 +668,8 @@ class InElementImpl(RoleImpl): element, str ): non_literal_expressions: Dict[ - Optional[operators.ColumnOperators[Any]], - operators.ColumnOperators[Any], + Optional[operators.ColumnOperators], + operators.ColumnOperators, ] = {} element = list(element) for o in element: diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 001710d7bb..91bb0a5c58 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -42,13 +42,14 @@ _T = typing.TypeVar("_T", bound=Any) if typing.TYPE_CHECKING: from .elements import ColumnElement + from .operators import custom_op from .sqltypes import TypeEngine def _boolean_compare( - expr: "ColumnElement", + expr: ColumnElement[Any], op: OperatorType, - obj: roles.BinaryElementRole, + obj: Any, *, negate_op: Optional[OperatorType] = None, reverse: bool = False, @@ -59,7 +60,6 @@ def _boolean_compare( ] = None, **kwargs: Any, ) -> BinaryExpression[bool]: - if result_type is None: result_type = type_api.BOOLEANTYPE @@ -143,7 +143,14 @@ def _boolean_compare( ) -def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw): +def _custom_op_operate( + expr: ColumnElement[Any], + op: custom_op[Any], + obj: Any, + reverse: bool = False, + result_type: Optional[TypeEngine[Any]] = None, + **kw: Any, +) -> ColumnElement[Any]: if result_type is None: if op.return_type: result_type = op.return_type @@ -156,11 +163,11 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw): def _binary_operate( - expr: "ColumnElement", + expr: ColumnElement[Any], op: OperatorType, obj: roles.BinaryElementRole, *, - reverse=False, + reverse: bool = False, result_type: Optional[ Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] ] = None, @@ -184,7 +191,9 @@ def _binary_operate( return BinaryExpression(left, right, op, type_=result_type, modifiers=kw) -def _conjunction_operate(expr, op, other, **kw) -> "ColumnElement": +def _conjunction_operate( + expr: ColumnElement[Any], op: OperatorType, other, **kw +) -> ColumnElement[Any]: if op is operators.and_: return and_(expr, other) elif op is operators.or_: @@ -193,11 +202,19 @@ def _conjunction_operate(expr, op, other, **kw) -> "ColumnElement": raise NotImplementedError() -def _scalar(expr, op, fn, **kw) -> "ColumnElement": +def _scalar( + expr: ColumnElement[Any], op: OperatorType, fn, **kw +) -> ColumnElement[Any]: return fn(expr) -def _in_impl(expr, op, seq_or_selectable, negate_op, **kw) -> "ColumnElement": +def _in_impl( + expr: ColumnElement[Any], + op: OperatorType, + seq_or_selectable, + negate_op: OperatorType, + **kw, +) -> ColumnElement[Any]: seq_or_selectable = coercions.expect( roles.InElementRole, seq_or_selectable, expr=expr, operator=op ) @@ -209,7 +226,9 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw) -> "ColumnElement": ) -def _getitem_impl(expr, op, other, **kw) -> "ColumnElement": +def _getitem_impl( + expr: ColumnElement[Any], op: OperatorType, other, **kw +) -> ColumnElement[Any]: if isinstance(expr.type, type_api.INDEXABLE): other = coercions.expect( roles.BinaryElementRole, other, expr=expr, operator=op @@ -219,13 +238,17 @@ def _getitem_impl(expr, op, other, **kw) -> "ColumnElement": _unsupported_impl(expr, op, other, **kw) -def _unsupported_impl(expr, op, *arg, **kw) -> NoReturn: +def _unsupported_impl( + expr: ColumnElement[Any], op: OperatorType, *arg, **kw +) -> NoReturn: raise NotImplementedError( "Operator '%s' is not supported on " "this expression" % op.__name__ ) -def _inv_impl(expr, op, **kw) -> "ColumnElement": +def _inv_impl( + expr: ColumnElement[Any], op: OperatorType, **kw +) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.__inv__`.""" # undocumented element currently used by the ORM for @@ -236,12 +259,16 @@ def _inv_impl(expr, op, **kw) -> "ColumnElement": return expr._negate() -def _neg_impl(expr, op, **kw) -> "ColumnElement": +def _neg_impl( + expr: ColumnElement[Any], op: OperatorType, **kw +) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.__neg__`.""" return UnaryExpression(expr, operator=operators.neg, type_=expr.type) -def _match_impl(expr, op, other, **kw) -> "ColumnElement": +def _match_impl( + expr: ColumnElement[Any], op: OperatorType, other, **kw +) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.match`.""" return _boolean_compare( @@ -261,14 +288,18 @@ def _match_impl(expr, op, other, **kw) -> "ColumnElement": ) -def _distinct_impl(expr, op, **kw) -> "ColumnElement": +def _distinct_impl( + expr: ColumnElement[Any], op: OperatorType, **kw +) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.distinct`.""" return UnaryExpression( expr, operator=operators.distinct_op, type_=expr.type ) -def _between_impl(expr, op, cleft, cright, **kw) -> "ColumnElement": +def _between_impl( + expr: ColumnElement[Any], op: OperatorType, cleft, cright, **kw +) -> ColumnElement[Any]: """See :meth:`.ColumnOperators.between`.""" return BinaryExpression( expr, @@ -297,11 +328,15 @@ def _between_impl(expr, op, cleft, cright, **kw) -> "ColumnElement": ) -def _collate_impl(expr, op, collation, **kw) -> "ColumnElement": +def _collate_impl( + expr: ColumnElement[Any], op: OperatorType, collation, **kw +) -> ColumnElement[Any]: return CollationClause._create_collation_expression(expr, collation) -def _regexp_match_impl(expr, op, pattern, flags, **kw) -> "ColumnElement": +def _regexp_match_impl( + expr: ColumnElement[Any], op: OperatorType, pattern, flags, **kw +) -> ColumnElement[Any]: if flags is not None: flags = coercions.expect( roles.BinaryElementRole, @@ -322,8 +357,13 @@ def _regexp_match_impl(expr, op, pattern, flags, **kw) -> "ColumnElement": def _regexp_replace_impl( - expr, op, pattern, replacement, flags, **kw -) -> "ColumnElement": + expr: ColumnElement[Any], + op: OperatorType, + pattern, + replacement, + flags, + **kw, +) -> ColumnElement[Any]: replacement = coercions.expect( roles.BinaryElementRole, replacement, @@ -345,7 +385,7 @@ def _regexp_replace_impl( # a mapping of operators with the method they use, along with # additional keyword arguments to be passed operator_lookup: Dict[ - str, Tuple[Callable[..., "ColumnElement"], util.immutabledict] + str, Tuple[Callable[..., ColumnElement[Any]], util.immutabledict] ] = { "and_": (_conjunction_operate, util.EMPTY_DICT), "or_": (_conjunction_operate, util.EMPTY_DICT), diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 08d632afd9..fdb3fc8bbf 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -12,20 +12,28 @@ from __future__ import annotations +from decimal import Decimal +from enum import IntEnum import itertools import operator import re import typing from typing import Any from typing import Callable +from typing import cast from typing import Dict +from typing import FrozenSet from typing import Generic +from typing import Iterable from typing import List +from typing import Mapping from typing import Optional from typing import overload from typing import Sequence -from typing import Text as typing_Text +from typing import Set +from typing import Tuple as typing_Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -34,10 +42,15 @@ from . import operators from . import roles from . import traversals from . import type_api +from ._typing import has_schema_attr +from ._typing import is_named_from_clause +from ._typing import is_quoted_name +from ._typing import is_tuple_type from .annotation import Annotated from .annotation import SupportsWrappingAnnotations from .base import _clone from .base import _generative +from .base import _NoArg from .base import Executable from .base import HasMemoized from .base import Immutable @@ -57,30 +70,47 @@ from .. import exc from .. import inspection from .. import util from ..util.langhelpers import TypingOnly +from ..util.typing import Literal if typing.TYPE_CHECKING: - from decimal import Decimal - + from ._typing import _ColumnExpression + from ._typing import _PropagateAttrsType + from ._typing import _TypeEngineArgument + from .cache_key import CacheKey from .compiler import Compiled from .compiler import SQLCompiler + from .functions import FunctionElement from .operators import OperatorType + from .schema import Column + from .schema import DefaultGenerator + from .schema import ForeignKey from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import ReturnsRows from .selectable import Select - from .sqltypes import Boolean # noqa + from .selectable import TableClause + from .sqltypes import Boolean + from .sqltypes import TupleType from .type_api import TypeEngine + from .visitors import _TraverseInternalsType from ..engine import Connection from ..engine import Dialect from ..engine import Engine from ..engine.base import _CompiledCacheType - from ..engine.base import _SchemaTranslateMapType - + from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _SchemaTranslateMapType + from ..engine.interfaces import CacheStats + from ..engine.result import Result -_NUMERIC = Union[complex, "Decimal"] +_NUMERIC = Union[complex, Decimal] +_NUMBER = Union[complex, int, Decimal] _T = TypeVar("_T", bound="Any") _OPT = TypeVar("_OPT", bound="Any") _NT = TypeVar("_NT", bound="_NUMERIC") -_ST = TypeVar("_ST", bound="typing_Text") + +_NMT = TypeVar("_NMT", bound="_NUMBER") def literal(value, type_=None): @@ -210,28 +240,27 @@ class CompilerElement(Visitable): """ - if not dialect: + if dialect is None: if bind: dialect = bind.dialect + elif self.stringify_dialect == "default": + default = util.preloaded.engine_default + dialect = default.StrCompileDialect() else: - if self.stringify_dialect == "default": - default = util.preloaded.engine_default - dialect = default.StrCompileDialect() - else: - url = util.preloaded.engine_url - dialect = url.URL.create( - self.stringify_dialect - ).get_dialect()() + url = util.preloaded.engine_url + dialect = url.URL.create( + self.stringify_dialect + ).get_dialect()() return self._compiler(dialect, **kw) - def _compiler(self, dialect, **kw): + def _compiler(self, dialect: Dialect, **kw: Any) -> Compiled: """Return a compiler appropriate for this ClauseElement, given a Dialect.""" return dialect.statement_compiler(dialect, self, **kw) - def __str__(self): + def __str__(self) -> str: return str(self.compile()) @@ -253,16 +282,17 @@ class ClauseElement( __visit_name__ = "clause" - _propagate_attrs = util.immutabledict() + _propagate_attrs: _PropagateAttrsType = util.immutabledict() """like annotations, however these propagate outwards liberally as SQL constructs are built, and are set up at construction time. """ - _from_objects = [] - bind = None - description = None - _is_clone_of = None + @util.memoized_property + def description(self) -> Optional[str]: + return None + + _is_clone_of: Optional[ClauseElement] = None is_clause_element = True is_selectable = False @@ -281,10 +311,25 @@ class ClauseElement( _is_singleton_constant = False _is_immutable = False - _order_by_label_element = None + @property + def _order_by_label_element(self) -> Optional[Label[Any]]: + return None _cache_key_traversal = None + negation_clause: ClauseElement + + if typing.TYPE_CHECKING: + + def get_children( + self, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any + ) -> Iterable[ClauseElement]: + ... + + @util.non_memoized_property + def _from_objects(self) -> List[FromClause]: + return [] + def _set_propagate_attrs(self, values): # usually, self._propagate_attrs is empty here. one case where it's # not is a subquery against ORM select, that is then pulled as a @@ -295,7 +340,7 @@ class ClauseElement( self._propagate_attrs = util.immutabledict(values) return self - def _clone(self: SelfClauseElement, **kw) -> SelfClauseElement: + def _clone(self: SelfClauseElement, **kw: Any) -> SelfClauseElement: """Create a shallow copy of this ClauseElement. This method may be used by a generative API. Its also used as @@ -357,7 +402,7 @@ class ClauseElement( """ s = util.column_set() - f = self + f: Optional[ClauseElement] = self # note this creates a cycle, asserted in test_memusage. however, # turning this into a plain @property adds tends of thousands of method @@ -383,16 +428,26 @@ class ClauseElement( return d def _execute_on_connection( - self, connection, distilled_params, execution_options, _force=False - ): + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + _force: bool = False, + ) -> Result: if _force or self.supports_execution: + if TYPE_CHECKING: + assert isinstance(self, Executable) return connection._execute_clauseelement( self, distilled_params, execution_options ) else: raise exc.ObjectNotExecutableError(self) - def unique_params(self, *optionaldict, **kwargs): + def unique_params( + self: SelfClauseElement, + __optionaldict: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> SelfClauseElement: """Return a copy with :func:`_expression.bindparam` elements replaced. @@ -402,11 +457,13 @@ class ClauseElement( used. """ - return self._replace_params(True, optionaldict, kwargs) + return self._replace_params(True, __optionaldict, kwargs) def params( - self, *optionaldict: Dict[str, Any], **kwargs: Any - ) -> ClauseElement: + self: SelfClauseElement, + __optionaldict: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> SelfClauseElement: """Return a copy with :func:`_expression.bindparam` elements replaced. @@ -421,33 +478,32 @@ class ClauseElement( {'foo':7} """ - return self._replace_params(False, optionaldict, kwargs) + return self._replace_params(False, __optionaldict, kwargs) def _replace_params( - self, + self: SelfClauseElement, unique: bool, optionaldict: Optional[Dict[str, Any]], kwargs: Dict[str, Any], - ) -> ClauseElement: + ) -> SelfClauseElement: - if len(optionaldict) == 1: - kwargs.update(optionaldict[0]) - elif len(optionaldict) > 1: - raise exc.ArgumentError( - "params() takes zero or one positional dictionary argument" - ) + if optionaldict: + kwargs.update(optionaldict) - def visit_bindparam(bind): + def visit_bindparam(bind: BindParameter[Any]) -> None: if bind.key in kwargs: bind.value = kwargs[bind.key] bind.required = False if unique: bind._convert_to_unique() - return cloned_traverse( - self, - {"maintain_key": True, "detect_subquery_cols": True}, - {"bindparam": visit_bindparam}, + return cast( + SelfClauseElement, + cloned_traverse( + self, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, + ), ) def compare(self, other, **kw): @@ -501,18 +557,26 @@ class ClauseElement( def _compile_w_cache( self, dialect: Dialect, - compiled_cache: Optional[_CompiledCacheType] = None, - column_keys: Optional[List[str]] = None, + *, + compiled_cache: Optional[_CompiledCacheType], + column_keys: List[str], for_executemany: bool = False, schema_translate_map: Optional[_SchemaTranslateMapType] = None, **kw: Any, - ): + ) -> typing_Tuple[ + Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats + ]: + elem_cache_key: Optional[CacheKey] + if compiled_cache is not None and dialect._supports_statement_cache: elem_cache_key = self._generate_cache_key() else: elem_cache_key = None - if elem_cache_key: + if elem_cache_key is not None: + if TYPE_CHECKING: + assert compiled_cache is not None + cache_key, extracted_params = elem_cache_key key = ( dialect, @@ -564,7 +628,7 @@ class ClauseElement( else: return self._negate() - def _negate(self): + def _negate(self) -> ClauseElement: return UnaryExpression( self.self_group(against=operators.inv), operator=operators.inv ) @@ -605,6 +669,9 @@ class DQLDMLClauseElement(ClauseElement): ) -> SQLCompiler: ... + def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler: + ... + class CompilerColumnElement( roles.DMLColumnRole, @@ -621,9 +688,7 @@ class CompilerColumnElement( __slots__ = () -class SQLCoreOperations( - Generic[_T], ColumnOperators["SQLCoreOperations"], TypingOnly -): +class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): __slots__ = () # annotations for comparison methods @@ -631,173 +696,186 @@ class SQLCoreOperations( # redefined with the specific types returned by ColumnElement hierarchies if typing.TYPE_CHECKING: + _propagate_attrs: _PropagateAttrsType + def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> ColumnElement: + ) -> ColumnElement[Any]: ... def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any - ) -> ColumnElement: + ) -> ColumnElement[Any]: ... def op( self, - opstring: Any, + opstring: str, precedence: int = 0, is_comparison: bool = False, - return_type: Optional[ - Union[Type["TypeEngine[_OPT]"], "TypeEngine[_OPT]"] - ] = None, - python_impl=None, - ) -> Callable[[Any], "BinaryExpression[_OPT]"]: + return_type: Optional[_TypeEngineArgument[_OPT]] = None, + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], BinaryExpression[_OPT]]: ... def bool_op( - self, opstring: Any, precedence: int = 0, python_impl=None - ) -> Callable[[Any], "BinaryExpression[bool]"]: + self, + opstring: str, + precedence: int = 0, + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], BinaryExpression[bool]]: ... - def __and__(self, other: Any) -> "BooleanClauseList": + def __and__(self, other: Any) -> BooleanClauseList: ... - def __or__(self, other: Any) -> "BooleanClauseList": + def __or__(self, other: Any) -> BooleanClauseList: ... - def __invert__(self) -> "UnaryExpression[_T]": + def __invert__(self) -> ColumnElement[_T]: ... - def __lt__(self, other: Any) -> "ColumnElement[bool]": + def __lt__(self, other: Any) -> ColumnElement[bool]: ... - def __le__(self, other: Any) -> "ColumnElement[bool]": + def __le__(self, other: Any) -> ColumnElement[bool]: ... - def __eq__(self, other: Any) -> "ColumnElement[bool]": # type: ignore[override] # noqa: E501 + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 ... - def __ne__(self, other: Any) -> "ColumnElement[bool]": # type: ignore[override] # noqa: E501 + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 ... - def is_distinct_from(self, other: Any) -> "ColumnElement[bool]": + def is_distinct_from(self, other: Any) -> ColumnElement[bool]: ... - def is_not_distinct_from(self, other: Any) -> "ColumnElement[bool]": + def is_not_distinct_from(self, other: Any) -> ColumnElement[bool]: ... - def __gt__(self, other: Any) -> "ColumnElement[bool]": + def __gt__(self, other: Any) -> ColumnElement[bool]: ... - def __ge__(self, other: Any) -> "ColumnElement[bool]": + def __ge__(self, other: Any) -> ColumnElement[bool]: ... - def __neg__(self) -> "UnaryExpression[_T]": + def __neg__(self) -> UnaryExpression[_T]: ... - def __contains__(self, other: Any) -> "ColumnElement[bool]": + def __contains__(self, other: Any) -> ColumnElement[bool]: ... - def __getitem__(self, index: Any) -> "ColumnElement": + def __getitem__(self, index: Any) -> ColumnElement[Any]: ... @overload - def concat( - self: "SQLCoreOperations[_ST]", other: Any - ) -> "ColumnElement[_ST]": + def concat(self: _SQO[str], other: Any) -> ColumnElement[str]: ... @overload - def concat(self, other: Any) -> "ColumnElement": + def concat(self, other: Any) -> ColumnElement[Any]: ... - def concat(self, other: Any) -> "ColumnElement": + def concat(self, other: Any) -> ColumnElement[Any]: ... - def like(self, other: Any, escape=None) -> "BinaryExpression[bool]": + def like( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... - def ilike(self, other: Any, escape=None) -> "BinaryExpression[bool]": + def ilike( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... def in_( self, - other: Union[Sequence[Any], "BindParameter", "Select"], - ) -> "BinaryExpression[bool]": + other: Union[Sequence[Any], BindParameter[Any], Select], + ) -> BinaryExpression[bool]: ... def not_in( self, - other: Union[Sequence[Any], "BindParameter", "Select"], - ) -> "BinaryExpression[bool]": + other: Union[Sequence[Any], BindParameter[Any], Select], + ) -> BinaryExpression[bool]: ... def not_like( - self, other: Any, escape=None - ) -> "BinaryExpression[bool]": + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... def not_ilike( - self, other: Any, escape=None - ) -> "BinaryExpression[bool]": + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... - def is_(self, other: Any) -> "BinaryExpression[bool]": + def is_(self, other: Any) -> BinaryExpression[bool]: ... - def is_not(self, other: Any) -> "BinaryExpression[bool]": + def is_not(self, other: Any) -> BinaryExpression[bool]: ... def startswith( - self, other: Any, escape=None, autoescape=False - ) -> "ColumnElement[bool]": + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnElement[bool]: ... def endswith( - self, other: Any, escape=None, autoescape=False - ) -> "ColumnElement[bool]": + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnElement[bool]: ... - def contains(self, other: Any, **kw: Any) -> "ColumnElement[bool]": + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: ... - def match(self, other: Any, **kwargs) -> "ColumnElement[bool]": + def match(self, other: Any, **kwargs: Any) -> ColumnElement[bool]: ... - def regexp_match(self, pattern, flags=None) -> "ColumnElement[bool]": + def regexp_match( + self, pattern: Any, flags: Optional[str] = None + ) -> ColumnElement[bool]: ... def regexp_replace( - self, pattern, replacement, flags=None - ) -> "ColumnElement": + self, pattern: Any, replacement: Any, flags: Optional[str] = None + ) -> ColumnElement[str]: ... - def desc(self) -> "UnaryExpression[_T]": + def desc(self) -> UnaryExpression[_T]: ... - def asc(self) -> "UnaryExpression[_T]": + def asc(self) -> UnaryExpression[_T]: ... - def nulls_first(self) -> "UnaryExpression[_T]": + def nulls_first(self) -> UnaryExpression[_T]: ... - def nulls_last(self) -> "UnaryExpression[_T]": + def nulls_last(self) -> UnaryExpression[_T]: ... - def collate(self, collation) -> "CollationClause": + def collate(self, collation: str) -> CollationClause: ... def between( - self, cleft, cright, symmetric=False - ) -> "ColumnElement[bool]": + self, cleft: Any, cright: Any, symmetric: bool = False + ) -> BinaryExpression[bool]: ... - def distinct(self: "SQLCoreOperations[_T]") -> "UnaryExpression[_T]": + def distinct(self: _SQO[_T]) -> UnaryExpression[_T]: ... - def any_(self) -> "CollectionAggregate": + def any_(self) -> CollectionAggregate[Any]: ... - def all_(self) -> "CollectionAggregate": + def all_(self) -> CollectionAggregate[Any]: ... # numeric overloads. These need more tweaking @@ -807,179 +885,173 @@ class SQLCoreOperations( @overload def __add__( - self: "Union[_SQO[_NT], _SQO[Optional[_NT]]]", - other: "Union[_SQO[Optional[_NT]], _SQO[_NT], _NT]", - ) -> "ColumnElement[_NT]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload def __add__( - self: "Union[_SQO[_NT], _SQO[Optional[_NT]]]", + self: _SQO[str], other: Any, - ) -> "ColumnElement[_NUMERIC]": + ) -> ColumnElement[str]: ... - @overload - def __add__( - self: "Union[_SQO[_ST], _SQO[Optional[_ST]]]", - other: Any, - ) -> "ColumnElement[_ST]": + def __add__(self, other: Any) -> ColumnElement[Any]: ... - def __add__(self, other: Any) -> "ColumnElement": + @overload + def __radd__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: ... @overload - def __radd__(self, other: Any) -> "ColumnElement[_NUMERIC]": + def __radd__(self: _SQO[int], other: Any) -> ColumnElement[int]: ... @overload - def __radd__(self, other: Any) -> "ColumnElement": + def __radd__(self: _SQO[str], other: Any) -> ColumnElement[str]: ... - def __radd__(self, other: Any) -> "ColumnElement": + def __radd__(self, other: Any) -> ColumnElement[Any]: ... @overload def __sub__( - self: "SQLCoreOperations[_NT]", - other: "Union[SQLCoreOperations[_NT], _NT]", - ) -> "ColumnElement[_NT]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload - def __sub__(self, other: Any) -> "ColumnElement": + def __sub__(self, other: Any) -> ColumnElement[Any]: ... - def __sub__(self, other: Any) -> "ColumnElement": + def __sub__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rsub__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload - def __rsub__(self, other: Any) -> "ColumnElement": + def __rsub__(self, other: Any) -> ColumnElement[Any]: ... - def __rsub__(self, other: Any) -> "ColumnElement": + def __rsub__(self, other: Any) -> ColumnElement[Any]: ... @overload def __mul__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload - def __mul__(self, other: Any) -> "ColumnElement": + def __mul__(self, other: Any) -> ColumnElement[Any]: ... - def __mul__(self, other: Any) -> "ColumnElement": + def __mul__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rmul__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... @overload - def __rmul__(self, other: Any) -> "ColumnElement": + def __rmul__(self, other: Any) -> ColumnElement[Any]: ... - def __rmul__(self, other: Any) -> "ColumnElement": + def __rmul__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __mod__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + def __mod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __mod__(self, other: Any) -> "ColumnElement": + def __mod__(self, other: Any) -> ColumnElement[Any]: ... - def __mod__(self, other: Any) -> "ColumnElement": + def __mod__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __rmod__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + def __rmod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __rmod__(self, other: Any) -> "ColumnElement": + def __rmod__(self, other: Any) -> ColumnElement[Any]: ... - def __rmod__(self, other: Any) -> "ColumnElement": + def __rmod__(self, other: Any) -> ColumnElement[Any]: ... @overload def __truediv__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NUMERIC]: ... @overload - def __truediv__(self, other: Any) -> "ColumnElement": + def __truediv__(self, other: Any) -> ColumnElement[Any]: ... - def __truediv__(self, other: Any) -> "ColumnElement": + def __truediv__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rtruediv__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NUMERIC]: ... @overload - def __rtruediv__(self, other: Any) -> "ColumnElement": + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: ... - def __rtruediv__(self, other: Any) -> "ColumnElement": + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __floordiv__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + def __floordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __floordiv__(self, other: Any) -> "ColumnElement": + def __floordiv__(self, other: Any) -> ColumnElement[Any]: ... - def __floordiv__(self, other: Any) -> "ColumnElement": + def __floordiv__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __rfloordiv__( - self: "SQLCoreOperations[_NT]", other: Any - ) -> "ColumnElement[_NUMERIC]": + def __rfloordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __rfloordiv__(self, other: Any) -> "ColumnElement": + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: ... - def __rfloordiv__(self, other: Any) -> "ColumnElement": + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: ... _SQO = SQLCoreOperations +SelfColumnElement = TypeVar("SelfColumnElement", bound="ColumnElement[Any]") + class ColumnElement( roles.ColumnArgumentOrKeyRole, roles.StatementOptionRole, roles.WhereHavingRole, - roles.BinaryElementRole, + roles.BinaryElementRole[_T], roles.OrderByRole, roles.ColumnsClauseRole, roles.LimitOffsetRole, @@ -987,7 +1059,6 @@ class ColumnElement( roles.DDLConstraintColumnRole, roles.DDLExpressionRole, SQLCoreOperations[_T], - operators.ColumnOperators[SQLCoreOperations], DQLDMLClauseElement, ): """Represent a column-oriented SQL expression suitable for usage in the @@ -1069,28 +1140,37 @@ class ColumnElement( __visit_name__ = "column_element" - primary_key = False - foreign_keys = [] - _proxies = () + primary_key: bool = False + _is_clone_of: Optional[ColumnElement[_T]] - _tq_label = None - """The named label that can be used to target - this column in a result set in a "table qualified" context. + @util.memoized_property + def foreign_keys(self) -> Iterable[ForeignKey]: + return [] - This label is almost always the label used when - rendering AS