From: Mike Bayer Date: Wed, 30 Mar 2022 22:01:58 +0000 (-0400) Subject: pep484 - sql.selectable X-Git-Tag: rel_2_0_0b1~376 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3b4d62f4f72e8dfad7f38db192a6a90a8551608c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep484 - sql.selectable the pep484 task becomes more intense as there is mounting pressure to come up with a consistency in how data moves from end-user to instance variable. current thinking is coming into: 1. there are _typing._XYZArgument objects that represent "what the user sent" 2. there's the roles, which represent a kind of "filter" for different kinds of objects. These are mostly important as the argument we pass to coerce(). 3. there's the thing that coerce() returns, which should be what the construct uses as its internal representation of the thing. This is _typing._XYZElement. but there's some controversy over whether or not we should pass actual ClauseElements around by their role or not. I think we shouldn't at the moment, but this makes the "role-ness" of something a little less portable. Like, we have to set DMLTableRole for TableClause, Join, and Alias, but then also we have to repeat those three types in order to set up _DMLTableElement. Other change introduced here, there was a deannotate=True for the left/right of a sql.join(). All tests pass without that. I'd rather not have that there as if we have a join(A, B) where A, B are mapped classes, we want them inside of the _annotations. The rationale seems to be performance, but this performance can be illustrated to be on the compile side which we hope is cached in the normal case. CTEs now accommodate for text selects including recursive. Get typing to accommodate "util.preloaded" cleanly; add "preloaded" as a real module. This seemed like we would have needed pep562 `__getattr__()` but we don't, just set names in globals() as we import them. References: #6810 Change-Id: I34d17f617de2fe2c086fc556bd55748dc782faf0 --- diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 8f4b963eba..9db6f3f527 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -23,8 +23,8 @@ from typing import Tuple from typing import Type from typing import Union -from .util import _preloaded from .util import compat +from .util import preloaded as _preloaded if typing.TYPE_CHECKING: from .engine.interfaces import _AnyExecuteParams @@ -345,6 +345,8 @@ class MultipleResultsFound(InvalidRequestError): class NoReferenceError(InvalidRequestError): """Raised by ``ForeignKey`` to indicate a reference cannot be resolved.""" + table_name: str + class AwaitRequired(InvalidRequestError): """Error raised by the async greenlet spawn if no async operation @@ -501,10 +503,7 @@ class StatementError(SQLAlchemyError): @_preloaded.preload_module("sqlalchemy.sql.util") def _sql_message(self) -> str: - if typing.TYPE_CHECKING: - from .sql import util - else: - util = _preloaded.preloaded.sql_util + util = _preloaded.sql_util details = [self._message()] if self.statement: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 18a14012f8..b9ced44d53 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -28,6 +28,8 @@ from typing import Generic from typing import Iterable from typing import List from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING from typing import TypeVar from . import exc as orm_exc @@ -78,6 +80,8 @@ from ..sql.selectable import SelectBase from ..sql.selectable import SelectStatementGrouping from ..sql.visitors import InternalTraversal +if TYPE_CHECKING: + from ..sql.selectable import _SetupJoinsElement __all__ = ["Query", "QueryContext"] @@ -134,7 +138,8 @@ class Query( _correlate = () _auto_correlate = True _from_obj = () - _setup_joins = () + _setup_joins: Tuple[_SetupJoinsElement, ...] = () + _label_style = LABEL_STYLE_LEGACY_ORM _memoized_select_entities = () diff --git a/lib/sqlalchemy/sql/_dml_constructors.py b/lib/sqlalchemy/sql/_dml_constructors.py index 835819bacb..926e5257ba 100644 --- a/lib/sqlalchemy/sql/_dml_constructors.py +++ b/lib/sqlalchemy/sql/_dml_constructors.py @@ -7,12 +7,17 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from .dml import Delete from .dml import Insert from .dml import Update +if TYPE_CHECKING: + from ._typing import _DMLTableArgument + -def insert(table): +def insert(table: _DMLTableArgument) -> Insert: """Construct an :class:`_expression.Insert` object. E.g.:: @@ -82,7 +87,7 @@ def insert(table): return Insert(table) -def update(table): +def update(table: _DMLTableArgument) -> Update: r"""Construct an :class:`_expression.Update` object. E.g.:: @@ -122,7 +127,7 @@ def update(table): return Update(table) -def delete(table): +def delete(table: _DMLTableArgument) -> Delete: r"""Construct :class:`_expression.Delete` object. E.g.:: diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index fc925a8b32..ea21e01c66 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -345,8 +345,8 @@ def between( :meth:`_expression.ColumnElement.between` """ - expr = coercions.expect(roles.ExpressionElementRole, expr) - return expr.between(lower_bound, upper_bound, symmetric=symmetric) + col_expr = coercions.expect(roles.ExpressionElementRole, expr) + return col_expr.between(lower_bound, upper_bound, symmetric=symmetric) def outparam( diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index a17ee4ce86..7896c02c24 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -9,6 +9,8 @@ from __future__ import annotations from typing import Any from typing import Optional +from typing import TYPE_CHECKING +from typing import Union from . import coercions from . import roles @@ -17,64 +19,65 @@ from .elements import ColumnClause from .selectable import Alias from .selectable import CompoundSelect from .selectable import Exists +from .selectable import FromClause from .selectable import Join from .selectable import Lateral +from .selectable import LateralFromClause +from .selectable import NamedFromClause from .selectable import Select from .selectable import TableClause from .selectable import TableSample from .selectable import Values +if TYPE_CHECKING: + from ._typing import _ColumnsClauseArgument + from ._typing import _FromClauseArgument + from ._typing import _OnClauseArgument + from ._typing import _SelectStatementForCompoundArgument + from .functions import Function + from .selectable import CTE + from .selectable import HasCTE + from .selectable import ScalarSelect + from .selectable import SelectBase -def alias(selectable, name=None, flat=False): - """Return an :class:`_expression.Alias` object. - An :class:`_expression.Alias` represents any - :class:`_expression.FromClause` - with an alternate name assigned within SQL, typically using the ``AS`` - clause when generated, e.g. ``SELECT * FROM table AS aliasname``. +def alias( + selectable: FromClause, name: Optional[str] = None, flat: bool = False +) -> NamedFromClause: + """Return a named alias of the given :class:`.FromClause`. + + For :class:`.Table` and :class:`.Join` objects, the return type is the + :class:`_expression.Alias` object. Other kinds of :class:`.NamedFromClause` + objects may be returned for other kinds of :class:`.FromClause` objects. + + The named alias represents any :class:`_expression.FromClause` with an + alternate name assigned within SQL, typically using the ``AS`` clause when + generated, e.g. ``SELECT * FROM table AS aliasname``. - Similar functionality is available via the + Equivalent functionality is available via the :meth:`_expression.FromClause.alias` - method available on all :class:`_expression.FromClause` subclasses. - In terms of - a SELECT object as generated from the :func:`_expression.select` - function, the :meth:`_expression.SelectBase.alias` method returns an - :class:`_expression.Alias` or similar object which represents a named, - parenthesized subquery. - - When an :class:`_expression.Alias` is created from a - :class:`_schema.Table` object, - this has the effect of the table being rendered - as ``tablename AS aliasname`` in a SELECT statement. - - For :func:`_expression.select` objects, the effect is that of - creating a named subquery, i.e. ``(select ...) AS aliasname``. - - The ``name`` parameter is optional, and provides the name - to use in the rendered SQL. If blank, an "anonymous" name - will be deterministically generated at compile time. - Deterministic means the name is guaranteed to be unique against - other constructs used in the same statement, and will also be the - same name for each successive compilation of the same statement - object. + method available on all :class:`_expression.FromClause` objects. :param selectable: any :class:`_expression.FromClause` subclass, such as a table, select statement, etc. :param name: string name to be assigned as the alias. - If ``None``, a name will be deterministically generated - at compile time. + If ``None``, a name will be deterministically generated at compile + time. Deterministic means the name is guaranteed to be unique against + other constructs used in the same statement, and will also be the same + name for each successive compilation of the same statement object. :param flat: Will be passed through to if the given selectable is an instance of :class:`_expression.Join` - see - :meth:`_expression.Join.alias` - for details. + :meth:`_expression.Join.alias` for details. """ return Alias._factory(selectable, name=name, flat=flat) -def cte(selectable, name=None, recursive=False): +def cte( + selectable: HasCTE, name: Optional[str] = None, recursive: bool = False +) -> CTE: r"""Return a new :class:`_expression.CTE`, or Common Table Expression instance. @@ -86,7 +89,7 @@ def cte(selectable, name=None, recursive=False): ) -def except_(*selects): +def except_(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: r"""Return an ``EXCEPT`` of multiple selectables. The returned object is an instance of @@ -99,7 +102,9 @@ def except_(*selects): return CompoundSelect._create_except(*selects) -def except_all(*selects): +def except_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``EXCEPT ALL`` of multiple selectables. The returned object is an instance of @@ -112,7 +117,11 @@ def except_all(*selects): return CompoundSelect._create_except_all(*selects) -def exists(__argument=None): +def exists( + __argument: Optional[ + Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + ] = None, +) -> Exists: """Construct a new :class:`_expression.Exists` construct. The :func:`_sql.exists` can be invoked by itself to produce an @@ -153,7 +162,7 @@ def exists(__argument=None): return Exists(__argument) -def intersect(*selects): +def intersect(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: r"""Return an ``INTERSECT`` of multiple selectables. The returned object is an instance of @@ -166,7 +175,9 @@ def intersect(*selects): return CompoundSelect._create_intersect(*selects) -def intersect_all(*selects): +def intersect_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``INTERSECT ALL`` of multiple selectables. The returned object is an instance of @@ -180,7 +191,13 @@ def intersect_all(*selects): return CompoundSelect._create_intersect_all(*selects) -def join(left, right, onclause=None, isouter=False, full=False): +def join( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, +) -> Join: """Produce a :class:`_expression.Join` object, given two :class:`_expression.FromClause` expressions. @@ -232,7 +249,10 @@ def join(left, right, onclause=None, isouter=False, full=False): return Join(left, right, onclause, isouter, full) -def lateral(selectable, name=None): +def lateral( + selectable: Union[SelectBase, _FromClauseArgument], + name: Optional[str] = None, +) -> LateralFromClause: """Return a :class:`_expression.Lateral` object. :class:`_expression.Lateral` is an :class:`_expression.Alias` @@ -255,7 +275,12 @@ def lateral(selectable, name=None): return Lateral._factory(selectable, name=name) -def outerjoin(left, right, onclause=None, full=False): +def outerjoin( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + full: bool = False, +) -> Join: """Return an ``OUTER JOIN`` clause element. The returned object is an instance of :class:`_expression.Join`. @@ -349,7 +374,12 @@ def table(name: str, *columns: ColumnClause[Any], **kw: Any) -> TableClause: return TableClause(name, *columns, **kw) -def tablesample(selectable, sampling, name=None, seed=None): +def tablesample( + selectable: _FromClauseArgument, + sampling: Union[float, Function[Any]], + name: Optional[str] = None, + seed: Optional[roles.ExpressionElementRole[Any]] = None, +) -> TableSample: """Return a :class:`_expression.TableSample` object. :class:`_expression.TableSample` is an :class:`_expression.Alias` @@ -395,7 +425,7 @@ def tablesample(selectable, sampling, name=None, seed=None): return TableSample._factory(selectable, sampling, name=name, seed=seed) -def union(*selects, **kwargs): +def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: r"""Return a ``UNION`` of multiple selectables. The returned object is an instance of @@ -412,10 +442,10 @@ def union(*selects, **kwargs): :func:`select`. """ - return CompoundSelect._create_union(*selects, **kwargs) + return CompoundSelect._create_union(*selects) -def union_all(*selects): +def union_all(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: r"""Return a ``UNION ALL`` of multiple selectables. The returned object is an instance of diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index a5da878027..0a72a93c5c 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,7 +1,7 @@ from __future__ import annotations +import operator from typing import Any -from typing import Iterable from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -24,9 +24,16 @@ if TYPE_CHECKING: from .roles import FromClauseRole from .schema import DefaultGenerator from .schema import Sequence + from .selectable import Alias from .selectable import FromClause + from .selectable import Join from .selectable import NamedFromClause + from .selectable import ReturnsRows + from .selectable import Select + from .selectable import SelectBase + from .selectable import Subquery from .selectable import TableClause + from .sqltypes import TableValueType from .sqltypes import TupleType from .type_api import TypeEngine from ..util.typing import TypeGuard @@ -47,6 +54,14 @@ class _HasClauseElement(Protocol): # the coercions system is responsible for converting from XYZArgument to # XYZElement. +_TextCoercedExpressionArgument = Union[ + str, + "TextClause", + "ColumnElement[_T]", + _HasClauseElement, + roles.ExpressionElementRole[_T], +] + _ColumnsClauseArgument = Union[ Literal["*", 1], roles.ColumnsClauseRole, @@ -54,8 +69,31 @@ _ColumnsClauseArgument = Union[ Inspectable[_HasClauseElement], _HasClauseElement, ] +"""open-ended SELECT columns clause argument. + +Includes column expressions, tables, ORM mapped entities, a few literal values. + +This type is used for lists of columns / entities to be returned in result +sets; select(...), insert().returning(...), etc. + + +""" + +_ColumnExpressionArgument = Union[ + "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T] +] +"""narrower "column expression" argument. + +This type is used for all the other "column" kinds of expressions that +typically represent a single SQL column expression, not a set of columns the +way a table or ORM entity does. + +This includes ColumnElement, or ORM-mapped attributes that will have a +`__clause_element__()` method, it also has the ExpressionElementRole +overall which brings in the TextClause object also. + +""" -_SelectIterable = Iterable[Union["ColumnElement[Any]", "TextClause"]] _FromClauseArgument = Union[ roles.FromClauseRole, @@ -63,28 +101,99 @@ _FromClauseArgument = Union[ Inspectable[_HasClauseElement], _HasClauseElement, ] +"""A FROM clause, like we would send to select().select_from(). -_ColumnExpressionArgument = Union[ - "ColumnElement[_T]", _HasClauseElement, roles.ExpressionElementRole[_T] +Also accommodates ORM entities and related constructs. + +""" + +_JoinTargetArgument = Union[_FromClauseArgument, roles.JoinTargetRole] +"""target for join() builds on _FromClauseArgument to include additional +join target roles such as those which come from the ORM. + +""" + +_OnClauseArgument = Union[_ColumnExpressionArgument[Any], roles.OnClauseRole] +"""target for an ON clause, includes additional roles such as those which +come from the ORM. + +""" + +_SelectStatementForCompoundArgument = Union[ + "SelectBase", roles.CompoundElementRole +] +"""SELECT statement acceptable by ``union()`` and other SQL set operations""" + +_DMLColumnArgument = Union[ + str, "ColumnClause[Any]", _HasClauseElement, roles.DMLColumnRole ] +"""A DML column expression. This is a "key" inside of insert().values(), +update().values(), and related. + +These are usually strings or SQL table columns. + +There's also edge cases like JSON expression assignment, which we would want +the DMLColumnRole to be able to accommodate. -_DMLColumnArgument = Union[str, "ColumnClause[Any]", _HasClauseElement] +""" + + +_DMLTableArgument = Union[ + "TableClause", + "Join", + "Alias", + Type[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, +] _PropagateAttrsType = util.immutabledict[str, Any] _TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] +if TYPE_CHECKING: -def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]: - return t.named_with_column + def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]: + ... + def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: + ... -def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: - return c._is_column_element + def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: + ... + def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]: + ... -def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: - return c._is_text_clause + def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: + ... + + def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]: + ... + + def is_select_base(t: ReturnsRows) -> TypeGuard[SelectBase]: + ... + + def is_select_statement(t: ReturnsRows) -> TypeGuard[Select]: + ... + + def is_table(t: FromClause) -> TypeGuard[TableClause]: + ... + + def is_subquery(t: FromClause) -> TypeGuard[Subquery]: + ... + +else: + is_named_from_clause = operator.attrgetter("named_with_column") + is_column_element = operator.attrgetter("_is_column_element") + is_text_clause = operator.attrgetter("_is_text_clause") + is_from_clause = operator.attrgetter("_is_from_clause") + is_tuple_type = operator.attrgetter("_is_tuple_type") + is_table_value_type = operator.attrgetter("_is_table_value") + is_select_base = operator.attrgetter("_is_select_base") + is_select_statement = operator.attrgetter("_is_select_statement") + is_table = operator.attrgetter("_is_table") + is_subquery = operator.attrgetter("_is_subquery") def has_schema_attr(t: FromClauseRole) -> TypeGuard[TableClause]: @@ -95,9 +204,5 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]: return hasattr(s, "quote") -def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: - return t._is_tuple_type - - def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]: return hasattr(s, "__clause_element__") diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index f1919d1d39..fa36c09fcf 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -59,7 +59,9 @@ class SupportsAnnotations(ExternallyTraversible): _is_immutable: bool - def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations: + def _annotate( + self: SelfSupportsAnnotations, values: _AnnotationDict + ) -> SelfSupportsAnnotations: raise NotImplementedError() @overload @@ -105,11 +107,6 @@ class SupportsAnnotations(ExternallyTraversible): ) -SelfSupportsCloneAnnotations = TypeVar( - "SelfSupportsCloneAnnotations", bound="SupportsCloneAnnotations" -) - - class SupportsCloneAnnotations(SupportsAnnotations): if not typing.TYPE_CHECKING: __slots__ = () @@ -119,8 +116,8 @@ class SupportsCloneAnnotations(SupportsAnnotations): ] def _annotate( - self: SelfSupportsCloneAnnotations, values: _AnnotationDict - ) -> SelfSupportsCloneAnnotations: + self: SelfSupportsAnnotations, values: _AnnotationDict + ) -> SelfSupportsAnnotations: """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -132,8 +129,8 @@ class SupportsCloneAnnotations(SupportsAnnotations): return new def _with_annotations( - self: SelfSupportsCloneAnnotations, values: _AnnotationDict - ) -> SelfSupportsCloneAnnotations: + self: SelfSupportsAnnotations, values: _AnnotationDict + ) -> SelfSupportsAnnotations: """return a copy of this ClauseElement with annotations replaced by the given dictionary. @@ -184,11 +181,6 @@ class SupportsCloneAnnotations(SupportsAnnotations): return self -SelfSupportsWrappingAnnotations = TypeVar( - "SelfSupportsWrappingAnnotations", bound="SupportsWrappingAnnotations" -) - - class SupportsWrappingAnnotations(SupportsAnnotations): __slots__ = () @@ -200,19 +192,23 @@ class SupportsWrappingAnnotations(SupportsAnnotations): def entity_namespace(self) -> _EntityNamespace: ... - def _annotate(self, values: _AnnotationDict) -> Annotated: + def _annotate( + self: SelfSupportsAnnotations, values: _AnnotationDict + ) -> SelfSupportsAnnotations: """return a copy of this ClauseElement with annotations updated by the given dictionary. """ - return Annotated._as_annotated_instance(self, values) + return Annotated._as_annotated_instance(self, values) # type: ignore - def _with_annotations(self, values: _AnnotationDict) -> Annotated: + def _with_annotations( + self: SelfSupportsAnnotations, values: _AnnotationDict + ) -> SelfSupportsAnnotations: """return a copy of this ClauseElement with annotations replaced by the given dictionary. """ - return Annotated._as_annotated_instance(self, values) + return Annotated._as_annotated_instance(self, values) # type: ignore @overload def _deannotate( @@ -306,16 +302,17 @@ class Annotated(SupportsAnnotations): self: SelfAnnotated, values: _AnnotationDict ) -> SelfAnnotated: _values = self._annotations.union(values) - return self._with_annotations(_values) + new: SelfAnnotated = self._with_annotations(_values) # type: ignore + return new def _with_annotations( - self: SelfAnnotated, values: util.immutabledict[str, Any] - ) -> SelfAnnotated: + self: SelfAnnotated, values: _AnnotationDict + ) -> SupportsAnnotations: clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() clone.__dict__.pop("_annotations_cache_key", None) clone.__dict__.pop("_generate_cache_key", None) - clone._annotations = values + clone._annotations = util.immutabledict(values) return clone @overload diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 19e4c13d22..6b25d8fcd5 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -63,12 +63,14 @@ if TYPE_CHECKING: from . import elements from . import type_api from ._typing import _ColumnsClauseArgument - from ._typing import _SelectIterable from .elements import BindParameter from .elements import ColumnClause from .elements import ColumnElement from .elements import NamedColumn from .elements import SQLCoreOperations + from .elements import TextClause + from .selectable import _JoinTargetElement + from .selectable import _SelectIterable from .selectable import FromClause from ..engine import Connection from ..engine import Result @@ -167,7 +169,11 @@ class SingletonConstant(Immutable): cls._singleton = obj -def _from_objects(*elements: ColumnElement[Any]) -> Iterator[FromClause]: +def _from_objects( + *elements: Union[ + ColumnElement[Any], FromClause, TextClause, _JoinTargetElement + ] +) -> Iterator[FromClause]: return itertools.chain.from_iterable( [element._from_objects for element in elements] ) @@ -255,6 +261,11 @@ def _expand_cloned(elements): predecessors. """ + # TODO: cython candidate + # and/or change approach: in + # https://gerrit.sqlalchemy.org/c/sqlalchemy/sqlalchemy/+/3712 we propose + # getting rid of _cloned_set. + # turning this into chain.from_iterable adds all kinds of callcount return itertools.chain(*[x._cloned_set for x in elements]) @@ -1559,6 +1570,11 @@ class ColumnCollection(Generic[_COLKEY, _COL]): was moved onto the :class:`_expression.ColumnCollection` itself. """ + # TODO: cython candidate + + # don't dig around if the column is locally present + if column in self._colset: + return column def embedded(expanded_proxy_set, target_set): for t in target_set.difference(expanded_proxy_set): @@ -1568,9 +1584,6 @@ class ColumnCollection(Generic[_COLKEY, _COL]): return False return True - # don't dig around if the column is locally present - if column in self._colset: - return column col, intersect = None, None target_set = column.proxy_set cols = [c for (k, c) in self._collection] diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index 19a232c563..1f8b9c19e8 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -54,6 +54,11 @@ class CacheConst(enum.Enum): NO_CACHE = CacheConst.NO_CACHE +_CacheKeyTraversalType = Union[ + "_TraverseInternalsType", Literal[CacheConst.NO_CACHE], Literal[None] +] + + class CacheTraverseTarget(enum.Enum): CACHE_IN_PLACE = 0 CALL_GEN_CACHE_KEY = 1 @@ -89,9 +94,7 @@ class HasCacheKey: __slots__ = () - _cache_key_traversal: Union[ - _TraverseInternalsType, Literal[CacheConst.NO_CACHE], Literal[None] - ] = NO_CACHE + _cache_key_traversal: _CacheKeyTraversalType = NO_CACHE _is_has_cache_key = True diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index ccc8fba8d2..4c71ca38b1 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -12,7 +12,6 @@ import numbers import re import typing from typing import Any -from typing import Any as TODO_Any from typing import Callable from typing import Dict from typing import List @@ -20,7 +19,9 @@ from typing import NoReturn from typing import Optional from typing import overload from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import operators from . import roles @@ -32,6 +33,7 @@ from .visitors import Visitable from .. import exc from .. import inspection from .. import util +from ..util.typing import Literal if not typing.TYPE_CHECKING: elements = None @@ -46,12 +48,26 @@ if typing.TYPE_CHECKING: from . import schema from . import selectable from . import traversals + from ._typing import _ColumnExpressionArgument from ._typing import _ColumnsClauseArgument + from ._typing import _DMLTableArgument + from ._typing import _FromClauseArgument + from .dml import _DMLTableElement from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement + from .elements import DQLDMLClauseElement from .elements import SQLCoreOperations - + from .schema import Column + from .selectable import _ColumnsClauseElement + from .selectable import _JoinTargetElement + from .selectable import _JoinTargetProtocol + from .selectable import _OnClauseElement + from .selectable import FromClause + from .selectable import HasCTE + from .selectable import SelectBase + from .selectable import Subquery + from .visitors import _TraverseCallableType _SR = TypeVar("_SR", bound=roles.SQLRole) _F = TypeVar("_F", bound=Callable[..., Any]) @@ -143,10 +159,6 @@ def _expression_collection_was_a_list(attrname, fnname, args): def expect( role: Type[roles.TruncatedLabelRole], element: Any, - *, - apply_propagate_attrs: Optional[ClauseElement] = None, - argname: Optional[str] = None, - post_inspect: bool = False, **kw: Any, ) -> str: ... @@ -154,12 +166,30 @@ def expect( @overload def expect( - role: Type[roles.ExpressionElementRole[_T]], + role: Type[roles.StatementOptionRole], element: Any, - *, - apply_propagate_attrs: Optional[ClauseElement] = None, - argname: Optional[str] = None, - post_inspect: bool = False, + **kw: Any, +) -> DQLDMLClauseElement: + ... + + +@overload +def expect( + role: Type[roles.DDLReferredColumnRole], + element: Any, + **kw: Any, +) -> Column[Any]: + ... + + +@overload +def expect( + role: Union[ + Type[roles.ExpressionElementRole[Any]], + Type[roles.LimitOffsetRole], + Type[roles.WhereHavingRole], + ], + element: _ColumnExpressionArgument[_T], **kw: Any, ) -> ColumnElement[_T]: ... @@ -167,40 +197,89 @@ def expect( @overload def expect( - role: Type[roles.DMLTableRole], + role: Union[ + Type[roles.ExpressionElementRole[Any]], + Type[roles.LimitOffsetRole], + Type[roles.WhereHavingRole], + ], element: Any, + **kw: Any, +) -> ColumnElement[Any]: + ... + + +@overload +def expect( + role: Type[roles.DMLTableRole], + element: _DMLTableArgument, + **kw: Any, +) -> _DMLTableElement: + ... + + +@overload +def expect( + role: Type[roles.HasCTERole], + element: HasCTE, + **kw: Any, +) -> HasCTE: + ... + + +@overload +def expect( + role: Type[roles.SelectStatementRole], + element: SelectBase, + **kw: Any, +) -> SelectBase: + ... + + +@overload +def expect( + role: Type[roles.FromClauseRole], + element: _FromClauseArgument, + **kw: Any, +) -> FromClause: + ... + + +@overload +def expect( + role: Type[roles.FromClauseRole], + element: SelectBase, *, - apply_propagate_attrs: Optional[ClauseElement] = None, - argname: Optional[str] = None, - post_inspect: bool = False, + explicit_subquery: Literal[True] = ..., **kw: Any, -) -> roles.DMLTableRole: +) -> Subquery: ... @overload def expect( role: Type[roles.ColumnsClauseRole], - element: Any, - *, - apply_propagate_attrs: Optional[ClauseElement] = None, - argname: Optional[str] = None, - post_inspect: bool = False, + element: _ColumnsClauseArgument, + **kw: Any, +) -> _ColumnsClauseElement: + ... + + +@overload +def expect( + role: Union[Type[roles.JoinTargetRole], Type[roles.OnClauseRole]], + element: _JoinTargetProtocol, **kw: Any, -) -> roles.ColumnsClauseRole: +) -> _JoinTargetProtocol: ... +# catchall for not-yet-implemented overloads @overload def expect( role: Type[_SR], element: Any, - *, - apply_propagate_attrs: Optional[ClauseElement] = None, - argname: Optional[str] = None, - post_inspect: bool = False, **kw: Any, -) -> TODO_Any: +) -> Any: ... @@ -212,7 +291,7 @@ def expect( argname: Optional[str] = None, post_inspect: bool = False, **kw: Any, -) -> TODO_Any: +) -> Any: if ( role.allows_lambda # note callable() will not invoke a __getattr__() method, whereas @@ -329,7 +408,8 @@ def expect_col_expression_collection(role, expressions): strname = resolved = expr else: cols: List[ColumnClause[Any]] = [] - visitors.traverse(resolved, {}, {"column": cols.append}) + col_append: _TraverseCallableType[ColumnClause[Any]] = cols.append + visitors.traverse(resolved, {}, {"column": col_append}) if cols: column = cols[0] add_element = column if column is not None else strname @@ -432,7 +512,7 @@ class _ColumnCoercions(RoleImpl): original_element = element if not getattr(resolved, "is_clause_element", False): self._raise_for_expected(original_element, argname, resolved) - elif resolved._is_select_statement: + elif resolved._is_select_base: self._warn_for_scalar_subquery_coercion() return resolved.scalar_subquery() elif resolved._is_from_clause and isinstance( @@ -670,7 +750,7 @@ class InElementImpl(RoleImpl): if resolved._is_from_clause: if ( isinstance(resolved, selectable.Alias) - and resolved.element._is_select_statement + and resolved.element._is_select_base ): self._warn_for_implicit_coercion(resolved) return self._post_coercion(resolved.element, **kw) @@ -722,7 +802,7 @@ class InElementImpl(RoleImpl): self._raise_for_expected(element, **kw) def _post_coercion(self, element, expr, operator, **kw): - if element._is_select_statement: + if element._is_select_base: # for IN, we are doing scalar_subquery() coercion without # a warning return element.scalar_subquery() @@ -1085,7 +1165,7 @@ class JoinTargetImpl(RoleImpl): # #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match # were set to False. return element - elif legacy and resolved._is_select_statement: + elif legacy and resolved._is_select_base: util.warn_deprecated( "Implicit coercion of SELECT and textual SELECT " "constructs into FROM clauses is deprecated; please call " @@ -1114,7 +1194,7 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): allow_select: bool = True, **kw: Any, ) -> Any: - if resolved._is_select_statement: + if resolved._is_select_base: if explicit_subquery: return resolved.subquery() elif allow_select: @@ -1150,7 +1230,7 @@ class StrictFromClauseImpl(FromClauseImpl): allow_select: bool = False, **kw: Any, ) -> Any: - if resolved._is_select_statement and allow_select: + if resolved._is_select_base and allow_select: util.warn_deprecated( "Implicit coercion of SELECT and textual SELECT constructs " "into FROM clauses is deprecated; please call .subquery() " @@ -1195,7 +1275,7 @@ class DMLSelectImpl(_NoTextCoercion, RoleImpl): if resolved._is_from_clause: if ( isinstance(resolved, selectable.Alias) - and resolved.element._is_select_statement + and resolved.element._is_select_base ): return resolved.element else: @@ -1235,3 +1315,9 @@ for name in dir(roles): if name in globals(): impl = globals()[name](cls) _impl_lookup[cls] = impl + +if not TYPE_CHECKING: + ee_impl = _impl_lookup[roles.ExpressionElementRole] + + for py_type in (int, bool, str, float): + _impl_lookup[roles.ExpressionElementRole[py_type]] = ee_impl diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b7f6d11f64..6ecfbf9866 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -313,12 +313,12 @@ EXTRACT_MAP = { } COMPOUND_KEYWORDS = { - selectable.CompoundSelect.UNION: "UNION", - selectable.CompoundSelect.UNION_ALL: "UNION ALL", - selectable.CompoundSelect.EXCEPT: "EXCEPT", - selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL", - selectable.CompoundSelect.INTERSECT: "INTERSECT", - selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL", + selectable._CompoundSelectKeyword.UNION: "UNION", + selectable._CompoundSelectKeyword.UNION_ALL: "UNION ALL", + selectable._CompoundSelectKeyword.EXCEPT: "EXCEPT", + selectable._CompoundSelectKeyword.EXCEPT_ALL: "EXCEPT ALL", + selectable._CompoundSelectKeyword.INTERSECT: "INTERSECT", + selectable._CompoundSelectKeyword.INTERSECT_ALL: "INTERSECT ALL", } @@ -1468,6 +1468,10 @@ class SQLCompiler(Compiled): self.post_compile_params = frozenset() for key in expanded_state.parameter_expansion: bind = self.binds.pop(key) + + if TYPE_CHECKING: + assert bind.value is not None + self.bind_names.pop(bind) for value, expanded_key in zip( bind.value, expanded_state.parameter_expansion[key] @@ -3089,12 +3093,7 @@ class SQLCompiler(Compiled): self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) if cte.recursive: - if isinstance(cte.element, selectable.Select): - col_source = cte.element - elif isinstance(cte.element, selectable.CompoundSelect): - col_source = cte.element.selects[0] - else: - assert False, "cte should only be against SelectBase" + col_source = cte.element # TODO: can we get at the .columns_plus_names collection # that is already (or will be?) generated for the SELECT @@ -3315,7 +3314,9 @@ class SQLCompiler(Compiled): for elem in chunk ) - if isinstance(element.name, elements._truncated_label): + if element._unnamed: + name = None + elif isinstance(element.name, elements._truncated_label): name = self._truncated_identifier("values", element.name) else: name = element.name @@ -3980,7 +3981,7 @@ class SQLCompiler(Compiled): clause = " ".join( prefix._compiler_dispatch(self, **kw) for prefix, dialect_name in prefixes - if dialect_name is None or dialect_name == self.dialect.name + if dialect_name in (None, "*") or dialect_name == self.dialect.name ) if clause: clause += " " diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 0c9056aeeb..8a3a1b38f0 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -45,11 +45,14 @@ from .base import Executable from .base import HasCompileState from .elements import BooleanClauseList from .elements import ClauseElement +from .elements import ColumnClause from .elements import ColumnElement from .elements import Null +from .selectable import Alias from .selectable import FromClause from .selectable import HasCTE from .selectable import HasPrefixes +from .selectable import Join from .selectable import ReturnsRows from .selectable import TableClause from .sqltypes import NullType @@ -59,15 +62,15 @@ from .. import util from ..util.typing import TypeGuard if TYPE_CHECKING: - + from ._typing import _ColumnExpressionArgument from ._typing import _ColumnsClauseArgument from ._typing import _DMLColumnArgument + from ._typing import _DMLTableArgument from ._typing import _FromClauseArgument - from ._typing import _HasClauseElement - from ._typing import _SelectIterable from .base import ReadOnlyColumnCollection from .compiler import SQLCompiler - from .elements import ColumnClause + from .selectable import _ColumnsClauseElement + from .selectable import _SelectIterable from .selectable import Select def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]: @@ -85,7 +88,8 @@ else: isinsert = operator.attrgetter("isinsert") -_DMLColumnElement = Union[str, "ColumnClause[Any]"] +_DMLColumnElement = Union[str, ColumnClause[Any]] +_DMLTableElement = Union[TableClause, Alias, Join] class DMLState(CompileState): @@ -132,7 +136,7 @@ class DMLState(CompileState): ] @property - def dml_table(self) -> roles.DMLTableRole: + def dml_table(self) -> _DMLTableElement: return self.statement.table if TYPE_CHECKING: @@ -322,17 +326,17 @@ class UpdateBase( __visit_name__ = "update_base" _hints: util.immutabledict[ - Tuple[roles.DMLTableRole, str], str + Tuple[_DMLTableElement, str], str ] = util.EMPTY_DICT named_with_column = False - table: roles.DMLTableRole + table: _DMLTableElement _return_defaults = False _return_defaults_columns: Optional[ - Tuple[roles.ColumnsClauseRole, ...] + Tuple[_ColumnsClauseElement, ...] ] = None - _returning: Tuple[roles.ColumnsClauseRole, ...] = () + _returning: Tuple[_ColumnsClauseElement, ...] = () is_dml = True @@ -483,7 +487,7 @@ class UpdateBase( def with_hint( self: SelfUpdateBase, text: str, - selectable: Optional[roles.DMLTableRole] = None, + selectable: Optional[_DMLTableArgument] = None, dialect_name: str = "*", ) -> SelfUpdateBase: """Add a table hint for a single table to this @@ -517,7 +521,8 @@ class UpdateBase( """ if selectable is None: selectable = self.table - + else: + selectable = coercions.expect(roles.DMLTableRole, selectable) self._hints = self._hints.union({(selectable, dialect_name): text}) return self @@ -636,9 +641,9 @@ class ValuesBase(UpdateBase): _select_names: Optional[List[str]] = None _inline: bool = False - _returning: Tuple[roles.ColumnsClauseRole, ...] = () + _returning: Tuple[_ColumnsClauseElement, ...] = () - def __init__(self, table: _FromClauseArgument): + def __init__(self, table: _DMLTableArgument): self.table = coercions.expect( roles.DMLTableRole, table, apply_propagate_attrs=self ) @@ -970,7 +975,7 @@ class Insert(ValuesBase): + HasCTE._has_ctes_traverse_internals ) - def __init__(self, table: roles.FromClauseRole): + def __init__(self, table: _DMLTableArgument): super(Insert, self).__init__(table) @_generative @@ -1066,12 +1071,12 @@ SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase") class DMLWhereBase: - table: roles.DMLTableRole + table: _DMLTableElement _where_criteria: Tuple[ColumnElement[Any], ...] = () @_generative def where( - self: SelfDMLWhereBase, *whereclause: roles.ExpressionElementRole[Any] + self: SelfDMLWhereBase, *whereclause: _ColumnExpressionArgument[bool] ) -> SelfDMLWhereBase: """Return a new construct with the given expression(s) added to its WHERE clause, joined to the existing clause via AND, if any. @@ -1104,7 +1109,9 @@ class DMLWhereBase: """ for criterion in whereclause: - where_criteria = coercions.expect(roles.WhereHavingRole, criterion) + where_criteria: ColumnElement[Any] = coercions.expect( + roles.WhereHavingRole, criterion + ) self._where_criteria += (where_criteria,) return self @@ -1119,7 +1126,7 @@ class DMLWhereBase: return self.where(*criteria) - def _filter_by_zero(self) -> roles.DMLTableRole: + def _filter_by_zero(self) -> _DMLTableElement: return self.table def filter_by(self: SelfDMLWhereBase, **kwargs: Any) -> SelfDMLWhereBase: @@ -1189,7 +1196,7 @@ class Update(DMLWhereBase, ValuesBase): + HasCTE._has_ctes_traverse_internals ) - def __init__(self, table: roles.FromClauseRole): + def __init__(self, table: _DMLTableArgument): super(Update, self).__init__(table) @_generative @@ -1279,7 +1286,7 @@ class Delete(DMLWhereBase, UpdateBase): + HasCTE._has_ctes_traverse_internals ) - def __init__(self, table: roles.FromClauseRole): + def __init__(self, table: _DMLTableArgument): self.table = coercions.expect( roles.DMLTableRole, table, apply_propagate_attrs=self ) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index c735085f83..aec29d1b2e 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -26,6 +26,7 @@ from typing import Dict from typing import FrozenSet from typing import Generic from typing import Iterable +from typing import Iterator from typing import List from typing import Mapping from typing import Optional @@ -77,8 +78,8 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _PropagateAttrsType - from ._typing import _SelectIterable from ._typing import _TypeEngineArgument + from .cache_key import _CacheKeyTraversalType from .cache_key import CacheKey from .compiler import Compiled from .compiler import SQLCompiler @@ -88,6 +89,7 @@ if typing.TYPE_CHECKING: from .schema import DefaultGenerator from .schema import FetchedValue from .schema import ForeignKey + from .selectable import _SelectIterable from .selectable import FromClause from .selectable import NamedFromClause from .selectable import ReturnsRows @@ -96,6 +98,7 @@ if typing.TYPE_CHECKING: from .sqltypes import Boolean from .sqltypes import TupleType from .type_api import TypeEngine + from .visitors import _CloneCallableType from .visitors import _TraverseInternalsType from ..engine import Connection from ..engine import Dialect @@ -310,6 +313,7 @@ class ClauseElement( _is_text_clause = False _is_from_container = False _is_select_container = False + _is_select_base = False _is_select_statement = False _is_bind_parameter = False _is_clause_list = False @@ -321,7 +325,7 @@ class ClauseElement( def _order_by_label_element(self) -> Optional[Label[Any]]: return None - _cache_key_traversal = None + _cache_key_traversal: _CacheKeyTraversalType = None negation_clause: ColumnElement[bool] @@ -528,7 +532,7 @@ class ClauseElement( """ return traversals.compare(self, other, **kw) - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> Any: """Apply a 'grouping' to this :class:`_expression.ClauseElement`. This method is overridden by subclasses to return a "grouping" @@ -637,9 +641,9 @@ class ClauseElement( return self._negate() def _negate(self) -> ClauseElement: - return UnaryExpression( - self.self_group(against=operators.inv), operator=operators.inv - ) + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression(grouped, operator=operators.inv) def __bool__(self): raise TypeError("Boolean value of this clause is not defined") @@ -1288,12 +1292,6 @@ class ColumnElement( ) -> ColumnElement[_T]: ... - @overload - def self_group( - self: ColumnElement[bool], against: Optional[OperatorType] = None - ) -> ColumnElement[bool]: - ... - @overload def self_group( self: ColumnElement[Any], against: Optional[OperatorType] = None @@ -1764,6 +1762,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): key: str type: TypeEngine[_T] + value: Optional[_T] _is_crud = False _is_bind_parameter = True @@ -1883,7 +1882,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): return cloned @property - def effective_value(self): + def effective_value(self) -> Optional[_T]: """Return the value of this bound parameter, taking into account if the ``callable`` parameter was set. @@ -1893,11 +1892,12 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): """ if self.callable: - return self.callable() + # TODO: set up protocol for bind parameter callable + return self.callable() # type: ignore else: return self.value - def render_literal_execute(self): + def render_literal_execute(self) -> BindParameter[_T]: """Produce a copy of this bound parameter that will enable the :paramref:`_sql.BindParameter.literal_execute` flag. @@ -2513,8 +2513,10 @@ class ClauseList( self.operator = operator self.group = group self.group_contents = group_contents + clauses_iterator: Iterable[_ColumnExpressionArgument[Any]] = clauses if _flatten_sub_clauses: - clauses = util.flatten_iterator(clauses) + clauses_iterator = util.flatten_iterator(clauses_iterator) + self._text_converter_role: Type[roles.SQLRole] = _literal_as_text_role text_converter_role: Type[roles.SQLRole] = _literal_as_text_role @@ -2523,31 +2525,35 @@ class ClauseList( coercions.expect( text_converter_role, clause, apply_propagate_attrs=self ).self_group(against=self.operator) - for clause in clauses + for clause in clauses_iterator ] else: self.clauses = [ coercions.expect( text_converter_role, clause, apply_propagate_attrs=self ) - for clause in clauses + for clause in clauses_iterator ] self._is_implicitly_boolean = operators.is_boolean(self.operator) @classmethod - def _construct_raw(cls, operator, clauses=None): + def _construct_raw( + cls, + operator: OperatorType, + clauses: Optional[Sequence[ColumnElement[Any]]] = None, + ) -> ClauseList: self = cls.__new__(cls) - self.clauses = clauses if clauses else [] + self.clauses = list(clauses) if clauses else [] self.group = True self.operator = operator self.group_contents = True self._is_implicitly_boolean = False return self - def __iter__(self): + def __iter__(self) -> Iterator[ColumnElement[Any]]: return iter(self.clauses) - def __len__(self): + def __len__(self) -> int: return len(self.clauses) @property @@ -2708,10 +2714,10 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): def _construct_raw( cls, operator: OperatorType, - clauses: Optional[List[ColumnElement[Any]]] = None, + clauses: Optional[Sequence[ColumnElement[Any]]] = None, ) -> BooleanClauseList: self = cls.__new__(cls) - self.clauses = clauses if clauses else [] + self.clauses = list(clauses) if clauses else [] self.group = True self.operator = operator self.group_contents = True @@ -2781,7 +2787,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): sqltypes = util.preloaded.sql_sqltypes if types is None: - init_clauses = [ + init_clauses: List[ColumnElement[Any]] = [ coercions.expect(roles.ExpressionElementRole, c) for c in clauses ] @@ -2908,7 +2914,7 @@ class Case(ColumnElement[_T]): ] if whenlist: - type_ = list(whenlist[-1])[-1].type + type_ = whenlist[-1][-1].type else: type_ = None @@ -3098,6 +3104,8 @@ class _label_reference(ColumnElement[_T]): ("element", InternalTraversal.dp_clauseelement) ] + element: ColumnElement[_T] + def __init__(self, element: ColumnElement[_T]): self.element = element @@ -3212,7 +3220,9 @@ class UnaryExpression(ColumnElement[_T]): cls, expr: _ColumnExpressionArgument[_T], ) -> UnaryExpression[_T]: - col_expr = coercions.expect(roles.ExpressionElementRole, expr) + col_expr: ColumnElement[_T] = coercions.expect( + roles.ExpressionElementRole, expr + ) return UnaryExpression( col_expr, operator=operators.distinct_op, @@ -3265,7 +3275,7 @@ class CollectionAggregate(UnaryExpression[_T]): def _create_any( cls, expr: _ColumnExpressionArgument[_T] ) -> CollectionAggregate[bool]: - col_expr = coercions.expect( + col_expr: ColumnElement[_T] = coercions.expect( roles.ExpressionElementRole, expr, ) @@ -3281,7 +3291,7 @@ class CollectionAggregate(UnaryExpression[_T]): def _create_all( cls, expr: _ColumnExpressionArgument[_T] ) -> CollectionAggregate[bool]: - col_expr = coercions.expect( + col_expr: ColumnElement[_T] = coercions.expect( roles.ExpressionElementRole, expr, ) @@ -3374,6 +3384,9 @@ class BinaryExpression(ColumnElement[_T]): modifiers: Optional[Mapping[str, Any]] + left: ColumnElement[Any] + right: Union[ColumnElement[Any], ClauseList] + def __init__( self, left: ColumnElement[Any], @@ -4147,7 +4160,13 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]): def foreign_keys(self): return self.element.foreign_keys - def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw): + def _copy_internals( + self, + *, + clone: _CloneCallableType = _clone, + anonymize_labels: bool = False, + **kw: Any, + ) -> None: self._reset_memoizations() self._element = clone(self._element, **kw) if anonymize_labels: @@ -4447,7 +4466,9 @@ class TableValuedColumn(NamedColumn[_T]): self.key = self.name = scalar_alias.name self.type = type_ - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = _clone, **kw: Any + ) -> None: self.scalar_alias = clone(self.scalar_alias, **kw) self.key = self.name = self.scalar_alias.name @@ -4467,7 +4488,7 @@ class CollationClause(ColumnElement[str]): def _create_collation_expression( cls, expression: _ColumnExpressionArgument[str], collation: str ) -> BinaryExpression[str]: - expr = coercions.expect(roles.ExpressionElementRole, expression) + expr = coercions.expect(roles.ExpressionElementRole[str], expression) return BinaryExpression( expr, CollationClause(collation), diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 3bca8b502f..db4bb58373 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -11,10 +11,15 @@ from __future__ import annotations +import datetime from typing import Any +from typing import cast +from typing import Dict +from typing import Mapping from typing import Optional from typing import overload -from typing import Sequence +from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -24,7 +29,9 @@ from . import operators from . import roles from . import schema from . import sqltypes +from . import type_api from . import util as sqlutil +from ._typing import is_table_value_type from .base import _entity_namespace from .base import ColumnCollection from .base import Executable @@ -46,16 +53,21 @@ from .elements import WithinGroup from .selectable import FromClause from .selectable import Select from .selectable import TableValuedAlias +from .sqltypes import _N +from .sqltypes import TableValueType from .type_api import TypeEngine from .visitors import InternalTraversal from .. import util + if TYPE_CHECKING: from ._typing import _TypeEngineArgument _T = TypeVar("_T", bound=Any) -_registry = util.defaultdict(dict) +_registry: util.defaultdict[ + str, Dict[str, Type[Function[Any]]] +] = util.defaultdict(dict) def register_function(identifier, fn, package="_default"): @@ -103,11 +115,18 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): ("_table_value_type", InternalTraversal.dp_has_cache_key), ] - packagenames = () + packagenames: Tuple[str, ...] = () _has_args = False _with_ordinality = False - _table_value_type = None + _table_value_type: Optional[TableValueType] = None + + # some attributes that are defined between both ColumnElement and + # FromClause are set to Any here to avoid typing errors + primary_key: Any + _is_clone_of: Any + + clause_expr: Grouping[Any] def __init__(self, *clauses: Any): r"""Construct a :class:`.FunctionElement`. @@ -135,9 +154,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): for c in clauses ] self._has_args = self._has_args or bool(args) - self.clause_expr = ClauseList( - operator=operators.comma_op, group_contents=True, *args - ).self_group() + self.clause_expr = Grouping( + ClauseList(operator=operators.comma_op, group_contents=True, *args) + ) _non_anon_label = None @@ -263,9 +282,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): expr += (with_ordinality,) new_func._with_ordinality = True - new_func.type = new_func._table_value_type = sqltypes.TableValueType( - *expr - ) + new_func.type = new_func._table_value_type = TableValueType(*expr) return new_func.alias(name=name, joins_implicitly=joins_implicitly) @@ -332,7 +349,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): @property def _all_selected_columns(self): - if self.type._is_table_value: + if is_table_value_type(self.type): cols = self.type._elements else: cols = [self.label(None)] @@ -344,12 +361,12 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return self.columns @HasMemoized.memoized_attribute - def clauses(self): + def clauses(self) -> ClauseList: """Return the underlying :class:`.ClauseList` which contains the arguments for this :class:`.FunctionElement`. """ - return self.clause_expr.element + return cast(ClauseList, self.clause_expr.element) def over(self, partition_by=None, order_by=None, rows=None, range_=None): """Produce an OVER clause against this function. @@ -647,7 +664,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return _entity_namespace(self.clause_expr) -class FunctionAsBinary(BinaryExpression): +class FunctionAsBinary(BinaryExpression[Any]): _traverse_internals = [ ("sql_function", InternalTraversal.dp_clauseelement), ("left_index", InternalTraversal.dp_plain_obj), @@ -655,10 +672,16 @@ class FunctionAsBinary(BinaryExpression): ("modifiers", InternalTraversal.dp_plain_dict), ] + sql_function: FunctionElement[Any] + left_index: int + right_index: int + def _gen_cache_key(self, anon_map, bindparams): return ColumnElement._gen_cache_key(self, anon_map, bindparams) - def __init__(self, fn, left_index, right_index): + def __init__( + self, fn: FunctionElement[Any], left_index: int, right_index: int + ): self.sql_function = fn self.left_index = left_index self.right_index = right_index @@ -670,23 +693,30 @@ class FunctionAsBinary(BinaryExpression): self.modifiers = {} @property - def left(self): + def left_expr(self) -> ColumnElement[Any]: return self.sql_function.clauses.clauses[self.left_index - 1] - @left.setter - def left(self, value): + @left_expr.setter + def left_expr(self, value: ColumnElement[Any]) -> None: self.sql_function.clauses.clauses[self.left_index - 1] = value @property - def right(self): + def right_expr(self) -> ColumnElement[Any]: return self.sql_function.clauses.clauses[self.right_index - 1] - @right.setter - def right(self, value): + @right_expr.setter + def right_expr(self, value: ColumnElement[Any]) -> None: self.sql_function.clauses.clauses[self.right_index - 1] = value + if not TYPE_CHECKING: + # mypy can't accommodate @property to replace an instance + # variable + + left = left_expr + right = right_expr + -class ScalarFunctionColumn(NamedColumn): +class ScalarFunctionColumn(NamedColumn[_T]): __visit_name__ = "scalar_function_column" _traverse_internals = [ @@ -698,10 +728,18 @@ class ScalarFunctionColumn(NamedColumn): is_literal = False table = None - def __init__(self, fn, name, type_=None): + def __init__( + self, + fn: FunctionElement[_T], + name: str, + type_: Optional[_TypeEngineArgument[_T]] = None, + ): self.fn = fn self.name = name - self.type = sqltypes.to_instance(type_) + + # if type is None, we get NULLTYPE, which is our _T. But I don't + # know how to get the overloads to express that correctly + self.type = type_api.to_instance(type_) # type: ignore class _FunctionGenerator: @@ -789,7 +827,7 @@ class _FunctionGenerator: # passthru __ attributes; fixes pydoc if name.startswith("__"): try: - return self.__dict__[name] + return self.__dict__[name] # type: ignore except KeyError: raise AttributeError(name) @@ -883,8 +921,6 @@ class Function(FunctionElement[_T]): identifier: str - packagenames: Sequence[str] - type: TypeEngine[_T] """A :class:`_types.TypeEngine` object which refers to the SQL return type represented by this SQL function. @@ -907,7 +943,7 @@ class Function(FunctionElement[_T]): name: str, *clauses: Any, type_: Optional[_TypeEngineArgument[_T]] = None, - packagenames: Optional[Sequence[str]] = None, + packagenames: Optional[Tuple[str, ...]] = None, ): """Construct a :class:`.Function`. @@ -918,7 +954,9 @@ class Function(FunctionElement[_T]): self.packagenames = packagenames or () self.name = name - self.type = sqltypes.to_instance(type_) + # if type is None, we get NULLTYPE, which is our _T. But I don't + # know how to get the overloads to express that correctly + self.type = type_api.to_instance(type_) # type: ignore FunctionElement.__init__(self, *clauses) @@ -934,7 +972,7 @@ class Function(FunctionElement[_T]): ) -class GenericFunction(Function): +class GenericFunction(Function[_T]): """Define a 'generic' function. A generic function is a pre-established :class:`.Function` @@ -957,7 +995,7 @@ class GenericFunction(Function): from sqlalchemy.types import DateTime class as_utc(GenericFunction): - type = DateTime + type = DateTime() inherit_cache = True print(select(func.as_utc())) @@ -971,7 +1009,7 @@ class GenericFunction(Function): "time":: class as_utc(GenericFunction): - type = DateTime + type = DateTime() package = "time" inherit_cache = True @@ -987,7 +1025,7 @@ class GenericFunction(Function): the usage of ``name`` as the rendered name:: class GeoBuffer(GenericFunction): - type = Geometry + type = Geometry() package = "geo" name = "ST_Buffer" identifier = "buffer" @@ -1006,7 +1044,7 @@ class GenericFunction(Function): from sqlalchemy.sql import quoted_name class GeoBuffer(GenericFunction): - type = Geometry + type = Geometry() package = "geo" name = quoted_name("ST_Buffer", True) identifier = "buffer" @@ -1028,6 +1066,8 @@ class GenericFunction(Function): coerce_arguments = True inherit_cache = True + _register: bool + name = "GenericFunction" def __init_subclass__(cls) -> None: @@ -1036,7 +1076,9 @@ class GenericFunction(Function): super().__init_subclass__() @classmethod - def _register_generic_function(cls, clsname, clsdict): + def _register_generic_function( + cls, clsname: str, clsdict: Mapping[str, Any] + ) -> None: cls.name = name = clsdict.get("name", clsname) cls.identifier = identifier = clsdict.get("identifier", name) package = clsdict.get("package", "_default") @@ -1068,11 +1110,14 @@ class GenericFunction(Function): ] self._has_args = self._has_args or bool(parsed_args) self.packagenames = () - self.clause_expr = ClauseList( - operator=operators.comma_op, group_contents=True, *parsed_args - ).self_group() - self.type = sqltypes.to_instance( + self.clause_expr = Grouping( + ClauseList( + operator=operators.comma_op, group_contents=True, *parsed_args + ) + ) + + self.type = type_api.to_instance( # type: ignore kwargs.pop("type_", None) or getattr(self, "type", None) ) @@ -1081,7 +1126,7 @@ register_function("cast", Cast) register_function("extract", Extract) -class next_value(GenericFunction): +class next_value(GenericFunction[int]): """Represent the 'next value', given a :class:`.Sequence` as its single argument. @@ -1103,7 +1148,7 @@ class next_value(GenericFunction): seq, schema.Sequence ), "next_value() accepts a Sequence object as input." self.sequence = seq - self.type = sqltypes.to_instance( + self.type = sqltypes.to_instance( # type: ignore seq.data_type or getattr(self, "type", None) ) @@ -1118,7 +1163,7 @@ class next_value(GenericFunction): return [] -class AnsiFunction(GenericFunction): +class AnsiFunction(GenericFunction[_T]): """Define a function in "ansi" format, which doesn't render parenthesis.""" inherit_cache = True @@ -1127,13 +1172,13 @@ class AnsiFunction(GenericFunction): GenericFunction.__init__(self, *args, **kwargs) -class ReturnTypeFromArgs(GenericFunction): +class ReturnTypeFromArgs(GenericFunction[_T]): """Define a function whose return type is the same as its arguments.""" inherit_cache = True def __init__(self, *args, **kwargs): - args = [ + fn_args = [ coercions.expect( roles.ExpressionElementRole, c, @@ -1142,35 +1187,35 @@ class ReturnTypeFromArgs(GenericFunction): ) for c in args ] - kwargs.setdefault("type_", _type_from_args(args)) - kwargs["_parsed_args"] = args - super(ReturnTypeFromArgs, self).__init__(*args, **kwargs) + kwargs.setdefault("type_", _type_from_args(fn_args)) + kwargs["_parsed_args"] = fn_args + super(ReturnTypeFromArgs, self).__init__(*fn_args, **kwargs) -class coalesce(ReturnTypeFromArgs): +class coalesce(ReturnTypeFromArgs[_T]): _has_args = True inherit_cache = True -class max(ReturnTypeFromArgs): # noqa A001 +class max(ReturnTypeFromArgs[_T]): # noqa A001 """The SQL MAX() aggregate function.""" inherit_cache = True -class min(ReturnTypeFromArgs): # noqa A001 +class min(ReturnTypeFromArgs[_T]): # noqa A001 """The SQL MIN() aggregate function.""" inherit_cache = True -class sum(ReturnTypeFromArgs): # noqa A001 +class sum(ReturnTypeFromArgs[_T]): # noqa A001 """The SQL SUM() aggregate function.""" inherit_cache = True -class now(GenericFunction): +class now(GenericFunction[datetime.datetime]): """The SQL now() datetime function. SQLAlchemy dialects will usually render this particular function @@ -1178,11 +1223,11 @@ class now(GenericFunction): """ - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class concat(GenericFunction): +class concat(GenericFunction[str]): """The SQL CONCAT() function, which concatenates strings. E.g.:: @@ -1200,28 +1245,30 @@ class concat(GenericFunction): """ - type = sqltypes.String + type = sqltypes.String() inherit_cache = True -class char_length(GenericFunction): +class char_length(GenericFunction[int]): """The CHAR_LENGTH() SQL function.""" - type = sqltypes.Integer + type = sqltypes.Integer() inherit_cache = True - def __init__(self, arg, **kwargs): - GenericFunction.__init__(self, arg, **kwargs) + def __init__(self, arg, **kw): + # slight hack to limit to just one positional argument + # not sure why this one function has this special treatment + super().__init__(arg, **kw) -class random(GenericFunction): +class random(GenericFunction[float]): """The RANDOM() SQL function.""" _has_args = True inherit_cache = True -class count(GenericFunction): +class count(GenericFunction[int]): r"""The ANSI COUNT aggregate function. With no arguments, emits COUNT \*. @@ -1242,7 +1289,7 @@ class count(GenericFunction): """ - type = sqltypes.Integer + type = sqltypes.Integer() inherit_cache = True def __init__(self, expression=None, **kwargs): @@ -1251,70 +1298,70 @@ class count(GenericFunction): super(count, self).__init__(expression, **kwargs) -class current_date(AnsiFunction): +class current_date(AnsiFunction[datetime.date]): """The CURRENT_DATE() SQL function.""" - type = sqltypes.Date + type = sqltypes.Date() inherit_cache = True -class current_time(AnsiFunction): +class current_time(AnsiFunction[datetime.time]): """The CURRENT_TIME() SQL function.""" - type = sqltypes.Time + type = sqltypes.Time() inherit_cache = True -class current_timestamp(AnsiFunction): +class current_timestamp(AnsiFunction[datetime.datetime]): """The CURRENT_TIMESTAMP() SQL function.""" - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class current_user(AnsiFunction): +class current_user(AnsiFunction[str]): """The CURRENT_USER() SQL function.""" - type = sqltypes.String + type = sqltypes.String() inherit_cache = True -class localtime(AnsiFunction): +class localtime(AnsiFunction[datetime.datetime]): """The localtime() SQL function.""" - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class localtimestamp(AnsiFunction): +class localtimestamp(AnsiFunction[datetime.datetime]): """The localtimestamp() SQL function.""" - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class session_user(AnsiFunction): +class session_user(AnsiFunction[str]): """The SESSION_USER() SQL function.""" - type = sqltypes.String + type = sqltypes.String() inherit_cache = True -class sysdate(AnsiFunction): +class sysdate(AnsiFunction[datetime.datetime]): """The SYSDATE() SQL function.""" - type = sqltypes.DateTime + type = sqltypes.DateTime() inherit_cache = True -class user(AnsiFunction): +class user(AnsiFunction[str]): """The USER() SQL function.""" - type = sqltypes.String + type = sqltypes.String() inherit_cache = True -class array_agg(GenericFunction): +class array_agg(GenericFunction[_T]): """Support for the ARRAY_AGG function. The ``func.array_agg(expr)`` construct returns an expression of @@ -1334,11 +1381,10 @@ class array_agg(GenericFunction): """ - type = sqltypes.ARRAY inherit_cache = True def __init__(self, *args, **kwargs): - args = [ + fn_args = [ coercions.expect( roles.ExpressionElementRole, c, apply_propagate_attrs=self ) @@ -1348,16 +1394,16 @@ class array_agg(GenericFunction): default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY) if "type_" not in kwargs: - type_from_args = _type_from_args(args) + type_from_args = _type_from_args(fn_args) if isinstance(type_from_args, sqltypes.ARRAY): kwargs["type_"] = type_from_args else: kwargs["type_"] = default_array_type(type_from_args) - kwargs["_parsed_args"] = args - super(array_agg, self).__init__(*args, **kwargs) + kwargs["_parsed_args"] = fn_args + super(array_agg, self).__init__(*fn_args, **kwargs) -class OrderedSetAgg(GenericFunction): +class OrderedSetAgg(GenericFunction[_T]): """Define a function where the return type is based on the sort expression type as defined by the expression passed to the :meth:`.FunctionElement.within_group` method.""" @@ -1366,7 +1412,7 @@ class OrderedSetAgg(GenericFunction): inherit_cache = True def within_group_type(self, within_group): - func_clauses = self.clause_expr.element + func_clauses = cast(ClauseList, self.clause_expr.element) order_by = sqlutil.unwrap_order_by(within_group.order_by) if self.array_for_multi_clause and len(func_clauses.clauses) > 1: return sqltypes.ARRAY(order_by[0].type) @@ -1374,7 +1420,7 @@ class OrderedSetAgg(GenericFunction): return order_by[0].type -class mode(OrderedSetAgg): +class mode(OrderedSetAgg[_T]): """Implement the ``mode`` ordered-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1389,7 +1435,7 @@ class mode(OrderedSetAgg): inherit_cache = True -class percentile_cont(OrderedSetAgg): +class percentile_cont(OrderedSetAgg[_T]): """Implement the ``percentile_cont`` ordered-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1407,7 +1453,7 @@ class percentile_cont(OrderedSetAgg): inherit_cache = True -class percentile_disc(OrderedSetAgg): +class percentile_disc(OrderedSetAgg[_T]): """Implement the ``percentile_disc`` ordered-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1425,7 +1471,7 @@ class percentile_disc(OrderedSetAgg): inherit_cache = True -class rank(GenericFunction): +class rank(GenericFunction[int]): """Implement the ``rank`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1441,7 +1487,7 @@ class rank(GenericFunction): inherit_cache = True -class dense_rank(GenericFunction): +class dense_rank(GenericFunction[int]): """Implement the ``dense_rank`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1457,7 +1503,7 @@ class dense_rank(GenericFunction): inherit_cache = True -class percent_rank(GenericFunction): +class percent_rank(GenericFunction[_N]): """Implement the ``percent_rank`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1469,11 +1515,11 @@ class percent_rank(GenericFunction): """ - type = sqltypes.Numeric() + type: sqltypes.Numeric[_N] = sqltypes.Numeric() inherit_cache = True -class cume_dist(GenericFunction): +class cume_dist(GenericFunction[_N]): """Implement the ``cume_dist`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1485,11 +1531,11 @@ class cume_dist(GenericFunction): """ - type = sqltypes.Numeric() + type: sqltypes.Numeric[_N] = sqltypes.Numeric() inherit_cache = True -class cube(GenericFunction): +class cube(GenericFunction[_T]): r"""Implement the ``CUBE`` grouping operation. This function is used as part of the GROUP BY of a statement, @@ -1506,7 +1552,7 @@ class cube(GenericFunction): inherit_cache = True -class rollup(GenericFunction): +class rollup(GenericFunction[_T]): r"""Implement the ``ROLLUP`` grouping operation. This function is used as part of the GROUP BY of a statement, @@ -1523,7 +1569,7 @@ class rollup(GenericFunction): inherit_cache = True -class grouping_sets(GenericFunction): +class grouping_sets(GenericFunction[_T]): r"""Implement the ``GROUPING SETS`` grouping operation. This function is used as part of the GROUP BY of a statement, diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 9d011ef539..da15c305fc 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -12,6 +12,18 @@ import inspect import itertools import operator import types +from types import CodeType +from typing import Any +from typing import Callable +from typing import cast +from typing import Iterable +from typing import List +from typing import MutableMapping +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union import weakref from . import cache_key as _cache_key @@ -19,37 +31,62 @@ from . import coercions from . import elements from . import roles from . import schema -from . import traversals from . import type_api from . import visitors from .base import _clone +from .base import Executable from .base import Options +from .cache_key import CacheConst from .operators import ColumnOperators from .. import exc from .. import inspection from .. import util +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import Self -_closure_per_cache_key = util.LRUCache(1000) +if TYPE_CHECKING: + from .cache_key import CacheConst + from .cache_key import NO_CACHE + from .elements import BindParameter + from .elements import ClauseElement + from .roles import SQLRole + from .visitors import _CloneCallableType + +_LambdaCacheType = MutableMapping[ + Tuple[Any, ...], Union["NonAnalyzedFunction", "AnalyzedFunction"] +] +_BoundParameterGetter = Callable[..., Any] + +_closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000) + + +class _LambdaType(Protocol): + __code__: CodeType + __closure__: Iterable[Tuple[Any, Any]] + + def __call__(self, *arg: Any, **kw: Any) -> ClauseElement: + ... class LambdaOptions(Options): enable_tracking = True track_closure_variables = True - track_on = None + track_on: Optional[object] = None global_track_bound_values = True track_bound_values = True - lambda_cache = None + lambda_cache: Optional[_LambdaCacheType] = None def lambda_stmt( - lmb, - enable_tracking=True, - track_closure_variables=True, - track_on=None, - global_track_bound_values=True, - track_bound_values=True, - lambda_cache=None, -): + lmb: _LambdaType, + enable_tracking: bool = True, + track_closure_variables: bool = True, + track_on: Optional[object] = None, + global_track_bound_values: bool = True, + track_bound_values: bool = True, + lambda_cache: Optional[_LambdaCacheType] = None, +) -> StatementLambdaElement: """Produce a SQL statement that is cached as a lambda. The Python code object within the lambda is scanned for both Python @@ -142,15 +179,28 @@ class LambdaElement(elements.ClauseElement): ("_resolved", visitors.InternalTraversal.dp_clauseelement) ] - _transforms = () + _transforms: Tuple[_CloneCallableType, ...] = () - parent_lambda = None + _resolved_bindparams: List[BindParameter[Any]] + parent_lambda: Optional[StatementLambdaElement] = None + closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]] + role: Type[SQLRole] + _rec: Union[AnalyzedFunction, NonAnalyzedFunction] + fn: _LambdaType + tracker_key: Tuple[CodeType, ...] def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self.fn.__code__) + return "%s(%r)" % ( + self.__class__.__name__, + self.fn.__code__, + ) def __init__( - self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None + self, + fn: _LambdaType, + role: Type[SQLRole], + opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, + apply_propagate_attrs: Optional[ClauseElement] = None, ): self.fn = fn self.role = role @@ -182,6 +232,7 @@ class LambdaElement(elements.ClauseElement): opts, ) + bindparams: List[BindParameter[Any]] self._resolved_bindparams = bindparams = [] if self.parent_lambda is not None: @@ -189,8 +240,10 @@ class LambdaElement(elements.ClauseElement): else: parent_closure_cache_key = () + cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]] + if parent_closure_cache_key is not _cache_key.NO_CACHE: - anon_map = traversals.anon_map() + anon_map = visitors.anon_map() cache_key = tuple( [ getter(closure, opts, anon_map, bindparams) @@ -241,7 +294,7 @@ class LambdaElement(elements.ClauseElement): if self.parent_lambda is not None: bindparams[:0] = self.parent_lambda._resolved_bindparams - lambda_element = self + lambda_element: Optional[LambdaElement] = self while lambda_element is not None: rec = lambda_element._rec if rec.bindparam_trackers: @@ -289,17 +342,21 @@ class LambdaElement(elements.ClauseElement): def _setup_binds_for_tracked_expr(self, expr): bindparam_lookup = {b.key: b for b in self._resolved_bindparams} - def replace(thing): - if isinstance(thing, elements.BindParameter): + def replace( + element: Optional[visitors.ExternallyTraversible], **kw: Any + ) -> Optional[visitors.ExternallyTraversible]: + if isinstance(element, elements.BindParameter): - if thing.key in bindparam_lookup: - bind = bindparam_lookup[thing.key] - if thing.expanding: + if element.key in bindparam_lookup: + bind = bindparam_lookup[element.key] + if element.expanding: bind.expanding = True - bind.expand_op = thing.expand_op - bind.type = thing.type + bind.expand_op = element.expand_op + bind.type = element.type return bind + return None + if self._rec.is_sequence: expr = [ visitors.replacement_traverse(sub_expr, {}, replace) @@ -311,8 +368,11 @@ class LambdaElement(elements.ClauseElement): return expr def _copy_internals( - self, clone=_clone, deferred_copy_internals=None, **kw - ): + self: Self, + clone: _CloneCallableType = _clone, + deferred_copy_internals: Optional[_CloneCallableType] = None, + **kw: Any, + ) -> None: # TODO: this needs A LOT of tests self._resolved = clone( self._resolved, @@ -340,9 +400,15 @@ class LambdaElement(elements.ClauseElement): ) + self.closure_cache_key parent = self.parent_lambda + while parent is not None: + assert parent.closure_cache_key is not CacheConst.NO_CACHE + parent_closure_cache_key: Tuple[ + Any, ... + ] = parent.closure_cache_key + cache_key = ( - (parent.fn.__code__,) + parent.closure_cache_key + cache_key + (parent.fn.__code__,) + parent_closure_cache_key + cache_key ) parent = parent.parent_lambda @@ -351,7 +417,7 @@ class LambdaElement(elements.ClauseElement): bindparams.extend(self._resolved_bindparams) return cache_key - def _invoke_user_fn(self, fn, *arg): + def _invoke_user_fn(self, fn: _LambdaType, *arg: Any) -> ClauseElement: return fn() @@ -365,7 +431,13 @@ class DeferredLambdaElement(LambdaElement): """ - def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()): + def __init__( + self, + fn: _LambdaType, + role: Type[roles.SQLRole], + opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, + lambda_args: Tuple[Any, ...] = (), + ): self.lambda_args = lambda_args super(DeferredLambdaElement, self).__init__(fn, role, opts) @@ -373,6 +445,7 @@ class DeferredLambdaElement(LambdaElement): return fn(*self.lambda_args) def _resolve_with_args(self, *lambda_args): + assert isinstance(self._rec, AnalyzedFunction) tracker_fn = self._rec.tracker_instrumented_fn expr = tracker_fn(*lambda_args) @@ -506,6 +579,8 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): def _execute_on_connection( self, connection, distilled_params, execution_options ): + if TYPE_CHECKING: + assert isinstance(self._rec.expected_expr, ClauseElement) if self._rec.expected_expr.supports_execution: return connection._execute_clauseelement( self, distilled_params, execution_options @@ -515,14 +590,20 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): @property def _with_options(self): + if TYPE_CHECKING: + assert isinstance(self._rec.expected_expr, Executable) return self._rec.expected_expr._with_options @property def _effective_plugin_target(self): + if TYPE_CHECKING: + assert isinstance(self._rec.expected_expr, Executable) return self._rec.expected_expr._effective_plugin_target @property def _execution_options(self): + if TYPE_CHECKING: + assert isinstance(self._rec.expected_expr, Executable) return self._rec.expected_expr._execution_options def spoil(self): @@ -583,9 +664,14 @@ class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement): class LinkedLambdaElement(StatementLambdaElement): """Represent subsequent links of a :class:`.StatementLambdaElement`.""" - role = None + parent_lambda: StatementLambdaElement - def __init__(self, fn, parent_lambda, opts): + def __init__( + self, + fn: _LambdaType, + parent_lambda: StatementLambdaElement, + opts: Union[Type[LambdaOptions], LambdaOptions], + ): self.opts = opts self.fn = fn self.parent_lambda = parent_lambda @@ -606,7 +692,9 @@ class AnalyzedCode: "closure_trackers", "build_py_wrappers", ) - _fns = weakref.WeakKeyDictionary() + _fns: weakref.WeakKeyDictionary[ + CodeType, AnalyzedCode + ] = weakref.WeakKeyDictionary() @classmethod def get(cls, fn, lambda_element, lambda_kw, **kw): @@ -615,6 +703,8 @@ class AnalyzedCode: return cls._fns[fn.__code__] except KeyError: pass + + analyzed: AnalyzedCode cls._fns[fn.__code__] = analyzed = AnalyzedCode( fn, lambda_element, lambda_kw, **kw ) @@ -947,14 +1037,18 @@ class AnalyzedCode: class NonAnalyzedFunction: __slots__ = ("expr",) - closure_bindparams = None - bindparam_trackers = None + closure_bindparams: Optional[List[BindParameter[Any]]] = None + bindparam_trackers: Optional[List[_BoundParameterGetter]] = None + + is_sequence = False + + expr: ClauseElement - def __init__(self, expr): + def __init__(self, expr: ClauseElement): self.expr = expr @property - def expected_expr(self): + def expected_expr(self) -> ClauseElement: return self.expr @@ -972,6 +1066,10 @@ class AnalyzedFunction: "closure_bindparams", ) + closure_bindparams: Optional[List[BindParameter[Any]]] + expected_expr: Union[ClauseElement, List[ClauseElement]] + bindparam_trackers: Optional[List[_BoundParameterGetter]] + def __init__( self, analyzed_code, @@ -1071,19 +1169,25 @@ class AnalyzedFunction: if parent_lambda is None: if isinstance(expr, collections_abc.Sequence): self.expected_expr = [ - coercions.expect( - lambda_element.role, - sub_expr, - apply_propagate_attrs=apply_propagate_attrs, + cast( + "ClauseElement", + coercions.expect( + lambda_element.role, + sub_expr, + apply_propagate_attrs=apply_propagate_attrs, + ), ) for sub_expr in expr ] self.is_sequence = True else: - self.expected_expr = coercions.expect( - lambda_element.role, - expr, - apply_propagate_attrs=apply_propagate_attrs, + self.expected_expr = cast( + "ClauseElement", + coercions.expect( + lambda_element.role, + expr, + apply_propagate_attrs=apply_propagate_attrs, + ), ) self.is_sequence = False else: diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 86725f86f5..577d868fdb 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -19,7 +19,6 @@ from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _PropagateAttrsType - from ._typing import _SelectIterable from .base import _EntityNamespace from .base import ColumnCollection from .base import ReadOnlyColumnCollection @@ -28,6 +27,7 @@ if TYPE_CHECKING: from .elements import ColumnElement from .elements import Label from .elements import NamedColumn + from .selectable import _SelectIterable from .selectable import FromClause from .selectable import Subquery @@ -164,6 +164,12 @@ class WhereHavingRole(OnClauseRole): class ExpressionElementRole(Generic[_T], SQLRole): + # note when using generics for ExpressionElementRole, + # the generic type needs to be in + # sqlalchemy.sql.coercions._impl_lookup mapping also. + # these are set up for basic types like int, bool, str, float + # right now + __slots__ = () _role_name = "SQL expression element" diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index cbd0c77f45..883439ca53 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -36,6 +36,7 @@ import operator import typing from typing import Any from typing import Callable +from typing import cast from typing import Dict from typing import Iterator from typing import List @@ -68,6 +69,7 @@ from .elements import SQLCoreOperations from .elements import TextClause from .selectable import TableClause from .type_api import to_instance +from .visitors import ExternallyTraversible from .visitors import InternalTraversal from .. import event from .. import exc @@ -131,21 +133,33 @@ def _get_table_key(name: str, schema: Optional[str]) -> str: # this should really be in sql/util.py but we'd have to # break an import cycle -def _copy_expression(expression, source_table, target_table): +def _copy_expression( + expression: ColumnElement[Any], + source_table: Optional[Table], + target_table: Optional[Table], +) -> ColumnElement[Any]: if source_table is None or target_table is None: return expression - def replace(col): + fixed_source_table = source_table + fixed_target_table = target_table + + def replace( + element: ExternallyTraversible, **kw: Any + ) -> Optional[ExternallyTraversible]: if ( - isinstance(col, Column) - and col.table is source_table - and col.key in source_table.c + isinstance(element, Column) + and element.table is fixed_source_table + and element.key in fixed_source_table.c ): - return target_table.c[col.key] + return fixed_target_table.c[element.key] else: return None - return visitors.replacement_traverse(expression, {}, replace) + return cast( + ColumnElement[Any], + visitors.replacement_traverse(expression, {}, replace), + ) @inspection._self_inspects @@ -911,8 +925,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): def _reset_exported(self): pass - @property - def _autoincrement_column(self): + @util.ro_non_memoized_property + def _autoincrement_column(self) -> Optional[Column[Any]]: return self.primary_key._autoincrement_column @property @@ -2308,6 +2322,8 @@ class ForeignKey(DialectKWArgs, SchemaItem): parent: Column[Any] + _table_column: Optional[Column[Any]] + def __init__( self, column: Union[str, Column[Any], SQLCoreOperations[Any]], @@ -4290,11 +4306,11 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): self._columns.extend(columns) - PrimaryKeyConstraint._autoincrement_column._reset(self) + PrimaryKeyConstraint._autoincrement_column._reset(self) # type: ignore self._set_parent_with_dispatch(self.table) def _replace(self, col): - PrimaryKeyConstraint._autoincrement_column._reset(self) + PrimaryKeyConstraint._autoincrement_column._reset(self) # type: ignore self._columns.replace(col) self.dispatch._sa_event_column_added_to_pk_constraint(self, col) @@ -4308,8 +4324,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): else: return list(self._columns) - @util.memoized_property - def _autoincrement_column(self): + @util.ro_memoized_property + def _autoincrement_column(self) -> Optional[Column[Any]]: def _validate_autoinc(col, autoinc_true): if col.type._type_affinity is None or not issubclass( col.type._type_affinity, type_api.INTEGERTYPE._type_affinity @@ -4350,6 +4366,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): "ignore_fk", ) and _validate_autoinc(col, False): return col + else: + return None else: autoinc = None diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 4f6e3795e1..6504449f1d 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -17,16 +17,26 @@ import collections from enum import Enum import itertools import typing +from typing import AbstractSet from typing import Any as TODO_Any from typing import Any +from typing import Callable +from typing import cast +from typing import Dict from typing import Iterable +from typing import Iterator from typing import List from typing import NamedTuple +from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence +from typing import Set from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import cache_key from . import coercions @@ -37,6 +47,9 @@ from . import type_api from . import visitors from ._typing import _ColumnsClauseArgument from ._typing import is_column_element +from ._typing import is_select_statement +from ._typing import is_subquery +from ._typing import is_table from .annotation import Annotated from .annotation import SupportsCloneAnnotations from .base import _clone @@ -68,32 +81,80 @@ from .elements import ColumnClause from .elements import ColumnElement from .elements import DQLDMLClauseElement from .elements import GroupedElement -from .elements import Grouping from .elements import literal_column from .elements import TableValuedColumn from .elements import UnaryExpression +from .operators import OperatorType +from .visitors import _TraverseInternalsType from .visitors import InternalTraversal from .visitors import prefix_anon_map from .. import exc from .. import util +from ..util import HasMemoized_ro_memoized_attribute +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import Self and_ = BooleanClauseList.and_ _T = TypeVar("_T", bound=Any) if TYPE_CHECKING: - from ._typing import _SelectIterable + from ._typing import _ColumnExpressionArgument + from ._typing import _FromClauseArgument + from ._typing import _JoinTargetArgument + from ._typing import _OnClauseArgument + from ._typing import _SelectStatementForCompoundArgument + from ._typing import _TextCoercedExpressionArgument + from ._typing import _TypeEngineArgument + from .base import _AmbiguousTableNameMap + from .base import ExecutableOption from .base import ReadOnlyColumnCollection + from .cache_key import _CacheKeyTraversalType + from .compiler import SQLCompiler + from .dml import Delete + from .dml import Insert + from .dml import Update from .elements import NamedColumn + from .elements import TextClause + from .functions import Function + from .schema import Column from .schema import ForeignKey - from .schema import PrimaryKeyConstraint + from .schema import ForeignKeyConstraint + from .type_api import TypeEngine + from .util import ClauseAdapter + from .visitors import _CloneCallableType -class _OffsetLimitParam(BindParameter): +_ColumnsClauseElement = Union["FromClause", ColumnElement[Any], "TextClause"] + + +class _JoinTargetProtocol(Protocol): + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + ... + + +_JoinTargetElement = Union["FromClause", _JoinTargetProtocol] +_OnClauseElement = Union["ColumnElement[bool]", _JoinTargetProtocol] + + +_SetupJoinsElement = Tuple[ + _JoinTargetElement, + Optional[_OnClauseElement], + Optional["FromClause"], + Dict[str, Any], +] + + +_SelectIterable = Iterable[Union["ColumnElement[Any]", "TextClause"]] + + +class _OffsetLimitParam(BindParameter[int]): inherit_cache = True @property - def _limit_offset_value(self): + def _limit_offset_value(self) -> Optional[int]: return self.effective_value @@ -114,11 +175,12 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): # sub-elements of returns_rows _is_from_clause = False + _is_select_base = False _is_select_statement = False _is_lateral = False @property - def selectable(self): + def selectable(self) -> ReturnsRows: return self @util.non_memoized_property @@ -133,8 +195,28 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): """ raise NotImplementedError() + def is_derived_from(self, fromclause: FromClause) -> bool: + """Return ``True`` if this :class:`.ReturnsRows` is + 'derived' from the given :class:`.FromClause`. + + An example would be an Alias of a Table is derived from that Table. + + """ + raise NotImplementedError() + + def _generate_fromclause_column_proxies( + self, fromclause: FromClause + ) -> None: + """Populate columns into an :class:`.AliasedReturnsRows` object.""" + + raise NotImplementedError() + + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: + """reset internal collections for an incoming column being added.""" + raise NotImplementedError() + @property - def exported_columns(self): + def exported_columns(self) -> ReadOnlyColumnCollection[Any, Any]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.ReturnsRows`. @@ -160,6 +242,9 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): raise NotImplementedError() +SelfSelectable = TypeVar("SelfSelectable", bound="Selectable") + + class Selectable(ReturnsRows): """Mark a class as being selectable.""" @@ -167,10 +252,10 @@ class Selectable(ReturnsRows): is_selectable = True - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: raise NotImplementedError() - def lateral(self, name=None): + def lateral(self, name: Optional[str] = None) -> LateralFromClause: """Return a LATERAL alias of this :class:`_expression.Selectable`. The return value is the :class:`_expression.Lateral` construct also @@ -192,15 +277,21 @@ class Selectable(ReturnsRows): "functionality is available via the sqlalchemy.sql.visitors module.", ) @util.preload_module("sqlalchemy.sql.util") - def replace_selectable(self, old, alias): + def replace_selectable( + self: SelfSelectable, old: FromClause, alias: Alias + ) -> SelfSelectable: """Replace all occurrences of :class:`_expression.FromClause` 'old' with the given :class:`_expression.Alias` object, returning a copy of this :class:`_expression.FromClause`. """ - return util.preloaded.sql_util.ClauseAdapter(alias).traverse(self) + return util.preloaded.sql_util.ClauseAdapter(alias).traverse( # type: ignore # noqa E501 + self + ) - def corresponding_column(self, column, require_embedded=False): + def corresponding_column( + self, column: ColumnElement[Any], require_embedded: bool = False + ) -> Optional[ColumnElement[Any]]: """Given a :class:`_expression.ColumnElement`, return the exported :class:`_expression.ColumnElement` object from the :attr:`_expression.Selectable.exported_columns` @@ -242,19 +333,23 @@ SelfHasPrefixes = typing.TypeVar("SelfHasPrefixes", bound="HasPrefixes") class HasPrefixes: - _prefixes = () + _prefixes: Tuple[Tuple[DQLDMLClauseElement, str], ...] = () - _has_prefixes_traverse_internals = [ + _has_prefixes_traverse_internals: _TraverseInternalsType = [ ("_prefixes", InternalTraversal.dp_prefix_sequence) ] @_generative @_document_text_coercion( - "expr", + "prefixes", ":meth:`_expression.HasPrefixes.prefix_with`", - ":paramref:`.HasPrefixes.prefix_with.*expr`", + ":paramref:`.HasPrefixes.prefix_with.*prefixes`", ) - def prefix_with(self: SelfHasPrefixes, *expr, **kw) -> SelfHasPrefixes: + def prefix_with( + self: SelfHasPrefixes, + *prefixes: _TextCoercedExpressionArgument[Any], + dialect: str = "*", + ) -> SelfHasPrefixes: r"""Add one or more expressions following the statement keyword, i.e. SELECT, INSERT, UPDATE, or DELETE. Generative. @@ -272,49 +367,44 @@ class HasPrefixes: Multiple prefixes can be specified by multiple calls to :meth:`_expression.HasPrefixes.prefix_with`. - :param \*expr: textual or :class:`_expression.ClauseElement` + :param \*prefixes: textual or :class:`_expression.ClauseElement` construct which will be rendered following the INSERT, UPDATE, or DELETE keyword. - :param \**kw: A single keyword 'dialect' is accepted. This is an - optional string dialect name which will + :param dialect: optional string dialect name which will limit rendering of this prefix to only that dialect. """ - dialect = kw.pop("dialect", None) - if kw: - raise exc.ArgumentError( - "Unsupported argument(s): %s" % ",".join(kw) - ) - self._setup_prefixes(expr, dialect) - return self - - def _setup_prefixes(self, prefixes, dialect=None): self._prefixes = self._prefixes + tuple( [ (coercions.expect(roles.StatementOptionRole, p), dialect) for p in prefixes ] ) + return self SelfHasSuffixes = typing.TypeVar("SelfHasSuffixes", bound="HasSuffixes") class HasSuffixes: - _suffixes = () + _suffixes: Tuple[Tuple[DQLDMLClauseElement, str], ...] = () - _has_suffixes_traverse_internals = [ + _has_suffixes_traverse_internals: _TraverseInternalsType = [ ("_suffixes", InternalTraversal.dp_prefix_sequence) ] @_generative @_document_text_coercion( - "expr", + "suffixes", ":meth:`_expression.HasSuffixes.suffix_with`", - ":paramref:`.HasSuffixes.suffix_with.*expr`", + ":paramref:`.HasSuffixes.suffix_with.*suffixes`", ) - def suffix_with(self: SelfHasSuffixes, *expr, **kw) -> SelfHasSuffixes: + def suffix_with( + self: SelfHasSuffixes, + *suffixes: _TextCoercedExpressionArgument[Any], + dialect: str = "*", + ) -> SelfHasSuffixes: r"""Add one or more expressions following the statement as a whole. This is used to support backend-specific suffix keywords on @@ -328,44 +418,39 @@ class HasSuffixes: Multiple suffixes can be specified by multiple calls to :meth:`_expression.HasSuffixes.suffix_with`. - :param \*expr: textual or :class:`_expression.ClauseElement` + :param \*suffixes: textual or :class:`_expression.ClauseElement` construct which will be rendered following the target clause. - :param \**kw: A single keyword 'dialect' is accepted. This is an - optional string dialect name which will + :param dialect: Optional string dialect name which will limit rendering of this suffix to only that dialect. """ - dialect = kw.pop("dialect", None) - if kw: - raise exc.ArgumentError( - "Unsupported argument(s): %s" % ",".join(kw) - ) - self._setup_suffixes(expr, dialect) - return self - - def _setup_suffixes(self, suffixes, dialect=None): self._suffixes = self._suffixes + tuple( [ (coercions.expect(roles.StatementOptionRole, p), dialect) for p in suffixes ] ) + return self SelfHasHints = typing.TypeVar("SelfHasHints", bound="HasHints") class HasHints: - _hints = util.immutabledict() - _statement_hints = () + _hints: util.immutabledict[ + Tuple[FromClause, str], str + ] = util.immutabledict() + _statement_hints: Tuple[Tuple[str, str], ...] = () - _has_hints_traverse_internals = [ + _has_hints_traverse_internals: _TraverseInternalsType = [ ("_statement_hints", InternalTraversal.dp_statement_hint_list), ("_hints", InternalTraversal.dp_table_hint_list), ] - def with_statement_hint(self, text, dialect_name="*"): + def with_statement_hint( + self: SelfHasHints, text: str, dialect_name: str = "*" + ) -> SelfHasHints: """Add a statement hint to this :class:`_expression.Select` or other selectable object. @@ -389,11 +474,14 @@ class HasHints: MySQL optimizer hints """ - return self.with_hint(None, text, dialect_name) + return self._with_hint(None, text, dialect_name) @_generative def with_hint( - self: SelfHasHints, selectable, text, dialect_name="*" + self: SelfHasHints, + selectable: _FromClauseArgument, + text: str, + dialect_name: str = "*", ) -> SelfHasHints: r"""Add an indexing or other executional context hint for the given selectable to this :class:`_expression.Select` or other selectable @@ -429,6 +517,15 @@ class HasHints: :meth:`_expression.Select.with_statement_hint` """ + + return self._with_hint(selectable, text, dialect_name) + + def _with_hint( + self: SelfHasHints, + selectable: Optional[_FromClauseArgument], + text: str, + dialect_name: str, + ) -> SelfHasHints: if selectable is None: self._statement_hints += ((dialect_name, text),) else: @@ -443,6 +540,9 @@ class HasHints: return self +SelfFromClause = TypeVar("SelfFromClause", bound="FromClause") + + class FromClause(roles.AnonymizedFromClauseRole, Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -473,6 +573,8 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _is_clone_of: Optional[FromClause] + _columns: ColumnCollection[Any, Any] + schema: Optional[str] = None """Define the 'schema' attribute for this :class:`_expression.FromClause`. @@ -488,7 +590,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _use_schema_map = False - def select(self) -> "Select": + def select(self) -> Select: r"""Return a SELECT of this :class:`_expression.FromClause`. @@ -504,7 +606,13 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return Select(self) - def join(self, right, onclause=None, isouter=False, full=False): + def join( + self, + right: _FromClauseArgument, + onclause: Optional[_ColumnExpressionArgument[bool]] = None, + isouter: bool = False, + full: bool = False, + ) -> Join: """Return a :class:`_expression.Join` from this :class:`_expression.FromClause` to another :class:`FromClause`. @@ -550,7 +658,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return Join(self, right, onclause, isouter, full) - def outerjoin(self, right, onclause=None, full=False): + def outerjoin( + self, + right: _FromClauseArgument, + onclause: Optional[_ColumnExpressionArgument[bool]] = None, + full: bool = False, + ) -> Join: """Return a :class:`_expression.Join` from this :class:`_expression.FromClause` to another :class:`FromClause`, with the "isouter" flag set to @@ -596,7 +709,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return Join(self, right, onclause, True, full) - def alias(self, name=None, flat=False): + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> NamedFromClause: """Return an alias of this :class:`_expression.FromClause`. E.g.:: @@ -617,35 +732,12 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return Alias._construct(self, name) - @util.preload_module("sqlalchemy.sql.sqltypes") - def table_valued(self): - """Return a :class:`_sql.TableValuedColumn` object for this - :class:`_expression.FromClause`. - - A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that - represents a complete row in a table. Support for this construct is - backend dependent, and is supported in various forms by backends - such as PostgreSQL, Oracle and SQL Server. - - E.g.:: - - >>> from sqlalchemy import select, column, func, table - >>> a = table("a", column("id"), column("x"), column("y")) - >>> stmt = select(func.row_to_json(a.table_valued())) - >>> print(stmt) - SELECT row_to_json(a) AS row_to_json_1 - FROM a - - .. versionadded:: 1.4.0b2 - - .. seealso:: - - :ref:`tutorial_functions` - in the :ref:`unified_tutorial` - - """ - return TableValuedColumn(self, type_api.TABLEVALUE) - - def tablesample(self, sampling, name=None, seed=None): + def tablesample( + self, + sampling: Union[float, Function[Any]], + name: Optional[str] = None, + seed: Optional[roles.ExpressionElementRole[Any]] = None, + ) -> TableSample: """Return a TABLESAMPLE alias of this :class:`_expression.FromClause`. The return value is the :class:`_expression.TableSample` @@ -661,7 +753,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return TableSample._construct(self, sampling, name, seed) - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: """Return ``True`` if this :class:`_expression.FromClause` is 'derived' from the given ``FromClause``. @@ -673,7 +765,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): # contained elements. return fromclause in self._cloned_set - def _is_lexical_equivalent(self, other): + def _is_lexical_equivalent(self, other: FromClause) -> bool: """Return ``True`` if this :class:`_expression.FromClause` and the other represent the same lexical identity. @@ -681,9 +773,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): if they are the same via annotation identity. """ - return self._cloned_set.intersection(other._cloned_set) + return bool(self._cloned_set.intersection(other._cloned_set)) - @util.non_memoized_property + @util.ro_non_memoized_property def description(self) -> str: """A brief description of this :class:`_expression.FromClause`. @@ -692,13 +784,15 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ return getattr(self, "name", self.__class__.__name__ + " object") - def _generate_fromclause_column_proxies(self, fromclause): + def _generate_fromclause_column_proxies( + self, fromclause: FromClause + ) -> None: fromclause._columns._populate_separate_keys( col._make_proxy(fromclause) for col in self.c ) @property - def exported_columns(self): + def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`. @@ -796,7 +890,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self._populate_column_collection() return self.foreign_keys - def _reset_column_collection(self): + def _reset_column_collection(self) -> None: """Reset the attributes linked to the ``FromClause.c`` attribute. This collection is separate from all the other memoized things @@ -817,7 +911,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): def _select_iterable(self) -> _SelectIterable: return self.c - def _init_collections(self): + def _init_collections(self) -> None: assert "_columns" not in self.__dict__ assert "primary_key" not in self.__dict__ assert "foreign_keys" not in self.__dict__ @@ -827,10 +921,10 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self.foreign_keys = set() # type: ignore @property - def _cols_populated(self): + def _cols_populated(self) -> bool: return "_columns" in self.__dict__ - def _populate_column_collection(self): + def _populate_column_collection(self) -> None: """Called on subclasses to establish the .c collection. Each implementation has a different way of establishing @@ -838,7 +932,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: """Given a column added to the .c collection of an underlying selectable, produce the local version of that column, assuming this selectable ultimately should proxy this column. @@ -865,15 +959,60 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ self._reset_column_collection() - def _anonymous_fromclause(self, name=None, flat=False): + def _anonymous_fromclause( + self, name: Optional[str] = None, flat: bool = False + ) -> NamedFromClause: return self.alias(name=name) + if TYPE_CHECKING: + + def self_group( + self: Self, against: Optional[OperatorType] = None + ) -> Union[FromGrouping, Self]: + ... + class NamedFromClause(FromClause): + """A :class:`.FromClause` that has a name. + + Examples include tables, subqueries, CTEs, aliased tables. + + .. versionadded:: 2.0 + + """ + named_with_column = True name: str + @util.preload_module("sqlalchemy.sql.sqltypes") + def table_valued(self) -> TableValuedColumn[Any]: + """Return a :class:`_sql.TableValuedColumn` object for this + :class:`_expression.FromClause`. + + A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that + represents a complete row in a table. Support for this construct is + backend dependent, and is supported in various forms by backends + such as PostgreSQL, Oracle and SQL Server. + + E.g.:: + + >>> from sqlalchemy import select, column, func, table + >>> a = table("a", column("id"), column("x"), column("y")) + >>> stmt = select(func.row_to_json(a.table_valued())) + >>> print(stmt) + SELECT row_to_json(a) AS row_to_json_1 + FROM a + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :ref:`tutorial_functions` - in the :ref:`unified_tutorial` + + """ + return TableValuedColumn(self, type_api.TABLEVALUE) + class SelectLabelStyle(Enum): """Label style constants that may be passed to @@ -992,7 +1131,7 @@ class Join(roles.DMLTableRole, FromClause): __visit_name__ = "join" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("left", InternalTraversal.dp_clauseelement), ("right", InternalTraversal.dp_clauseelement), ("onclause", InternalTraversal.dp_clauseelement), @@ -1002,7 +1141,20 @@ class Join(roles.DMLTableRole, FromClause): _is_join = True - def __init__(self, left, right, onclause=None, isouter=False, full=False): + left: FromClause + right: FromClause + onclause: Optional[ColumnElement[bool]] + isouter: bool + full: bool + + def __init__( + self, + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + ): """Construct a new :class:`_expression.Join`. The usual entrypoint here is the :func:`_expression.join` @@ -1010,11 +1162,23 @@ class Join(roles.DMLTableRole, FromClause): :class:`_expression.FromClause` object. """ + + # when deannotate was removed here, callcounts went up for ORM + # compilation of eager joins, since there were more comparisons of + # annotated objects. test_orm.py -> test_fetch_results + # was therefore changed to show a more real-world use case, where the + # compilation is cached; there's no change in post-cache callcounts. + # callcounts for a single compilation in that particular test + # that includes about eight joins about 1100 extra fn calls, from + # 29200 -> 30373 + self.left = coercions.expect( - roles.FromClauseRole, left, deannotate=True + roles.FromClauseRole, + left, ) self.right = coercions.expect( - roles.FromClauseRole, right, deannotate=True + roles.FromClauseRole, + right, ).self_group() if onclause is None: @@ -1029,7 +1193,7 @@ class Join(roles.DMLTableRole, FromClause): self.isouter = isouter self.full = full - @property + @util.ro_non_memoized_property def description(self) -> str: return "Join object on %s(%d) and %s(%d)" % ( self.left.description, @@ -1038,7 +1202,7 @@ class Join(roles.DMLTableRole, FromClause): id(self.right), ) - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: return ( # use hash() to ensure direct comparison to annotated works # as well @@ -1047,7 +1211,10 @@ class Join(roles.DMLTableRole, FromClause): or self.right.is_derived_from(fromclause) ) - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> FromGrouping: + ... return FromGrouping(self) @util.preload_module("sqlalchemy.sql.util") @@ -1055,7 +1222,7 @@ class Join(roles.DMLTableRole, FromClause): sqlutil = util.preloaded.sql_util columns = [c for c in self.left.c] + [c for c in self.right.c] - self.primary_key.extend( + self.primary_key.extend( # type: ignore sqlutil.reduce_columns( (c for c in columns if c.primary_key), self.onclause ) @@ -1063,11 +1230,13 @@ class Join(roles.DMLTableRole, FromClause): self._columns._populate_separate_keys( (col._tq_key_label, col) for col in columns ) - self.foreign_keys.update( + self.foreign_keys.update( # type: ignore itertools.chain(*[col.foreign_keys for col in columns]) ) - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = _clone, **kw: Any + ) -> None: # see Select._copy_internals() for similar concept # here we pre-clone "left" and "right" so that we can @@ -1100,12 +1269,14 @@ class Join(roles.DMLTableRole, FromClause): self._reset_memoizations() - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: super(Join, self)._refresh_for_new_column(column) self.left._refresh_for_new_column(column) self.right._refresh_for_new_column(column) - def _match_primaries(self, left, right): + def _match_primaries( + self, left: FromClause, right: FromClause + ) -> ColumnElement[bool]: if isinstance(left, Join): left_right = left.right else: @@ -1114,8 +1285,15 @@ class Join(roles.DMLTableRole, FromClause): @classmethod def _join_condition( - cls, a, b, a_subset=None, consider_as_foreign_keys=None - ): + cls, + a: FromClause, + b: FromClause, + *, + a_subset: Optional[FromClause] = None, + consider_as_foreign_keys: Optional[ + AbstractSet[ColumnClause[Any]] + ] = None, + ) -> ColumnElement[bool]: """Create a join condition between two tables or selectables. See sqlalchemy.sql.util.join_condition() for full docs. @@ -1151,7 +1329,15 @@ class Join(roles.DMLTableRole, FromClause): return and_(*crit) @classmethod - def _can_join(cls, left, right, consider_as_foreign_keys=None): + def _can_join( + cls, + left: FromClause, + right: FromClause, + *, + consider_as_foreign_keys: Optional[ + AbstractSet[ColumnClause[Any]] + ] = None, + ) -> bool: if isinstance(left, Join): left_right = left.right else: @@ -1169,20 +1355,31 @@ class Join(roles.DMLTableRole, FromClause): @classmethod @util.preload_module("sqlalchemy.sql.util") def _joincond_scan_left_right( - cls, a, a_subset, b, consider_as_foreign_keys - ): + cls, + a: FromClause, + a_subset: Optional[FromClause], + b: FromClause, + consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]], + ) -> collections.defaultdict[ + Optional[ForeignKeyConstraint], + List[Tuple[ColumnClause[Any], ColumnClause[Any]]], + ]: sql_util = util.preloaded.sql_util a = coercions.expect(roles.FromClauseRole, a) b = coercions.expect(roles.FromClauseRole, b) - constraints = collections.defaultdict(list) + constraints: collections.defaultdict[ + Optional[ForeignKeyConstraint], + List[Tuple[ColumnClause[Any], ColumnClause[Any]]], + ] = collections.defaultdict(list) for left in (a_subset, a): if left is None: continue for fk in sorted( - b.foreign_keys, key=lambda fk: fk.parent._creation_order + b.foreign_keys, + key=lambda fk: fk.parent._creation_order, # type: ignore ): if ( consider_as_foreign_keys is not None @@ -1202,7 +1399,8 @@ class Join(roles.DMLTableRole, FromClause): constraints[fk.constraint].append((col, fk.parent)) if left is not b: for fk in sorted( - left.foreign_keys, key=lambda fk: fk.parent._creation_order + left.foreign_keys, + key=lambda fk: fk.parent._creation_order, # type: ignore ): if ( consider_as_foreign_keys is not None @@ -1309,7 +1507,8 @@ class Join(roles.DMLTableRole, FromClause): @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: - return [self] + self.left._from_objects + self.right._from_objects + self_list: List[FromClause] = [self] + return self_list + self.left._from_objects + self.right._from_objects class NoInit: @@ -1327,6 +1526,14 @@ class NoInit: ) +class LateralFromClause(NamedFromClause): + """mark a FROM clause as being able to render directly as LATERAL""" + + +_SelfAliasedReturnsRows = TypeVar( + "_SelfAliasedReturnsRows", bound="AliasedReturnsRows" +) + # FromClause -> # AliasedReturnsRows # -> Alias only for FromClause @@ -1335,6 +1542,8 @@ class NoInit: # -> Lateral -> FromClause, but we accept SelectBase # w/ non-deprecated coercion # -> TableSample -> only for FromClause + + class AliasedReturnsRows(NoInit, NamedFromClause): """Base class of aliases against tables, subqueries, and other selectables.""" @@ -1343,24 +1552,21 @@ class AliasedReturnsRows(NoInit, NamedFromClause): _supports_derived_columns = False - element: ClauseElement + element: ReturnsRows - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("name", InternalTraversal.dp_anon_name), ] @classmethod - def _construct(cls, *arg, **kw): + def _construct( + cls: Type[_SelfAliasedReturnsRows], *arg: Any, **kw: Any + ) -> _SelfAliasedReturnsRows: obj = cls.__new__(cls) obj._init(*arg, **kw) return obj - @classmethod - def _factory(cls, returnsrows, name=None): - """Base factory method. Subclasses need to provide this.""" - raise NotImplementedError() - def _init(self, selectable, name=None): self.element = coercions.expect( roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self @@ -1378,11 +1584,14 @@ class AliasedReturnsRows(NoInit, NamedFromClause): name = _anonymous_label.safe_construct(id(self), name or "anon") self.name = name - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: super(AliasedReturnsRows, self)._refresh_for_new_column(column) self.element._refresh_for_new_column(column) - @property + def _populate_column_collection(self): + self.element._generate_fromclause_column_proxies(self) + + @util.ro_non_memoized_property def description(self) -> str: name = self.name if isinstance(name, _anonymous_label): @@ -1395,15 +1604,14 @@ class AliasedReturnsRows(NoInit, NamedFromClause): """Legacy for dialects that are referring to Alias.original.""" return self.element - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: if fromclause in self._cloned_set: return True return self.element.is_derived_from(fromclause) - def _populate_column_collection(self): - self.element._generate_fromclause_column_proxies(self) - - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = _clone, **kw: Any + ) -> None: existing_element = self.element super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw) @@ -1420,7 +1628,11 @@ class AliasedReturnsRows(NoInit, NamedFromClause): return [self] -class Alias(roles.DMLTableRole, AliasedReturnsRows): +class FromClauseAlias(AliasedReturnsRows): + element: FromClause + + +class Alias(roles.DMLTableRole, FromClauseAlias): """Represents an table or selectable alias (AS). Represents an alias, as typically applied to any table or @@ -1445,13 +1657,18 @@ class Alias(roles.DMLTableRole, AliasedReturnsRows): element: FromClause @classmethod - def _factory(cls, selectable, name=None, flat=False): + def _factory( + cls, + selectable: FromClause, + name: Optional[str] = None, + flat: bool = False, + ) -> NamedFromClause: return coercions.expect( roles.FromClauseRole, selectable, allow_select=True ).alias(name=name, flat=flat) -class TableValuedAlias(Alias): +class TableValuedAlias(LateralFromClause, Alias): """An alias against a "table valued" SQL function. This construct provides for a SQL function that returns columns @@ -1480,7 +1697,7 @@ class TableValuedAlias(Alias): _render_derived_w_types = False joins_implicitly = False - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("name", InternalTraversal.dp_anon_name), ("_tableval_type", InternalTraversal.dp_type), @@ -1526,7 +1743,9 @@ class TableValuedAlias(Alias): return TableValuedColumn(self, self._tableval_type) - def alias(self, name=None): + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> TableValuedAlias: """Return a new alias of this :class:`_sql.TableValuedAlias`. This creates a distinct FROM object that will be distinguished @@ -1547,7 +1766,7 @@ class TableValuedAlias(Alias): return tva - def lateral(self, name=None): + def lateral(self, name: Optional[str] = None) -> LateralFromClause: """Return a new :class:`_sql.TableValuedAlias` with the lateral flag set, so that it renders as LATERAL. @@ -1619,7 +1838,7 @@ class TableValuedAlias(Alias): return new_alias -class Lateral(AliasedReturnsRows): +class Lateral(FromClauseAlias, LateralFromClause): """Represent a LATERAL subquery. This object is constructed from the :func:`_expression.lateral` module @@ -1644,13 +1863,17 @@ class Lateral(AliasedReturnsRows): inherit_cache = True @classmethod - def _factory(cls, selectable, name=None): + def _factory( + cls, + selectable: Union[SelectBase, _FromClauseArgument], + name: Optional[str] = None, + ) -> LateralFromClause: return coercions.expect( roles.FromClauseRole, selectable, explicit_subquery=True ).lateral(name=name) -class TableSample(AliasedReturnsRows): +class TableSample(FromClauseAlias): """Represent a TABLESAMPLE clause. This object is constructed from the :func:`_expression.tablesample` module @@ -1668,13 +1891,22 @@ class TableSample(AliasedReturnsRows): __visit_name__ = "tablesample" - _traverse_internals = AliasedReturnsRows._traverse_internals + [ - ("sampling", InternalTraversal.dp_clauseelement), - ("seed", InternalTraversal.dp_clauseelement), - ] + _traverse_internals: _TraverseInternalsType = ( + AliasedReturnsRows._traverse_internals + + [ + ("sampling", InternalTraversal.dp_clauseelement), + ("seed", InternalTraversal.dp_clauseelement), + ] + ) @classmethod - def _factory(cls, selectable, sampling, name=None, seed=None): + def _factory( + cls, + selectable: _FromClauseArgument, + sampling: Union[float, Function[Any]], + name: Optional[str] = None, + seed: Optional[roles.ExpressionElementRole[Any]] = None, + ) -> TableSample: return coercions.expect(roles.FromClauseRole, selectable).tablesample( sampling, name=name, seed=seed ) @@ -1721,7 +1953,7 @@ class CTE( __visit_name__ = "cte" - _traverse_internals = ( + _traverse_internals: _TraverseInternalsType = ( AliasedReturnsRows._traverse_internals + [ ("_cte_alias", InternalTraversal.dp_clauseelement), @@ -1736,7 +1968,12 @@ class CTE( element: HasCTE @classmethod - def _factory(cls, selectable, name=None, recursive=False): + def _factory( + cls, + selectable: HasCTE, + name: Optional[str] = None, + recursive: bool = False, + ) -> CTE: r"""Return a new :class:`_expression.CTE`, or Common Table Expression instance. @@ -1775,7 +2012,9 @@ class CTE( else: self.element._generate_fromclause_column_proxies(self) - def alias(self, name=None, flat=False): + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> NamedFromClause: """Return an :class:`_expression.Alias` of this :class:`_expression.CTE`. @@ -1814,6 +2053,10 @@ class CTE( :meth:`_sql.HasCTE.cte` - examples of calling styles """ + assert is_select_statement( + self.element + ), f"CTE element f{self.element} does not support union()" + return CTE._construct( self.element.union(*other), name=self.name, @@ -1839,6 +2082,11 @@ class CTE( :meth:`_sql.HasCTE.cte` - examples of calling styles """ + + assert is_select_statement( + self.element + ), f"CTE element f{self.element} does not support union_all()" + return CTE._construct( self.element.union_all(*other), name=self.name, @@ -1865,67 +2113,273 @@ class _CTEOpts(NamedTuple): nesting: bool -class HasCTE(roles.HasCTERole, ClauseElement): - """Mixin that declares a class to include CTE support. +class _ColumnsPlusNames(NamedTuple): + required_label_name: Optional[str] + """ + string label name, if non-None, must be rendered as a + label, i.e. "AS " + """ - .. versionadded:: 1.1 + proxy_key: Optional[str] + """ + proxy_key that is to be part of the result map for this + col. this is also the key in a fromclause.c or + select.selected_columns collection + """ + fallback_label_name: Optional[str] + """ + name that can be used to render an "AS " when + we have to render a label even though + required_label_name was not given """ - _has_ctes_traverse_internals = [ - ("_independent_ctes", InternalTraversal.dp_clauseelement_list), - ("_independent_ctes_opts", InternalTraversal.dp_plain_obj), - ] + column: Union[ColumnElement[Any], TextClause] + """ + the ColumnElement itself + """ - _independent_ctes = () - _independent_ctes_opts = () + repeated: bool + """ + True if this is a duplicate of a previous column + in the list of columns + """ - @_generative - def add_cte(self: SelfHasCTE, *ctes, nest_here=False) -> SelfHasCTE: - r"""Add one or more :class:`_sql.CTE` constructs to this statement. - This method will associate the given :class:`_sql.CTE` constructs with - the parent statement such that they will each be unconditionally - rendered in the WITH clause of the final statement, even if not - referenced elsewhere within the statement or any sub-selects. +class SelectsRows(ReturnsRows): + """Sub-base of ReturnsRows for elements that deliver rows + directly, namely SELECT and INSERT/UPDATE/DELETE..RETURNING""" - The optional :paramref:`.HasCTE.add_cte.nest_here` parameter when set - to True will have the effect that each given :class:`_sql.CTE` will - render in a WITH clause rendered directly along with this statement, - rather than being moved to the top of the ultimate rendered statement, - even if this statement is rendered as a subquery within a larger - statement. + _label_style: SelectLabelStyle = LABEL_STYLE_NONE - This method has two general uses. One is to embed CTE statements that - serve some purpose without being referenced explicitly, such as the use - case of embedding a DML statement such as an INSERT or UPDATE as a CTE - inline with a primary statement that may draw from its results - indirectly. The other is to provide control over the exact placement - of a particular series of CTE constructs that should remain rendered - directly in terms of a particular statement that may be nested in a - larger statement. + def _generate_columns_plus_names( + self, anon_for_dupe_key: bool + ) -> List[_ColumnsPlusNames]: + """Generate column names as rendered in a SELECT statement by + the compiler. - E.g.:: + This is distinct from the _column_naming_convention generator that's + intended for population of .c collections and similar, which has + different rules. the collection returned here calls upon the + _column_naming_convention as well. - from sqlalchemy import table, column, select - t = table('t', column('c1'), column('c2')) + """ + cols = self._all_selected_columns - ins = t.insert().values({"c1": "x", "c2": "y"}).cte() + key_naming_convention = SelectState._column_naming_convention( + self._label_style + ) - stmt = select(t).add_cte(ins) + names = {} - Would render:: + result: List[_ColumnsPlusNames] = [] + result_append = result.append - WITH anon_1 AS - (INSERT INTO t (c1, c2) VALUES (:param_1, :param_2)) - SELECT t.c1, t.c2 - FROM t + table_qualified = self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL + label_style_none = self._label_style is LABEL_STYLE_NONE - Above, the "anon_1" CTE is not referred towards in the SELECT - statement, however still accomplishes the task of running an INSERT - statement. + # a counter used for "dedupe" labels, which have double underscores + # in them and are never referred by name; they only act + # as positional placeholders. they need only be unique within + # the single columns clause they're rendered within (required by + # some dbs such as mysql). So their anon identity is tracked against + # a fixed counter rather than hash() identity. + dedupe_hash = 1 - Similarly in a DML-related context, using the PostgreSQL + for c in cols: + repeated = False + + if not c._render_label_in_columns_clause: + effective_name = ( + required_label_name + ) = fallback_label_name = None + elif label_style_none: + if TYPE_CHECKING: + assert is_column_element(c) + + effective_name = required_label_name = None + fallback_label_name = c._non_anon_label or c._anon_name_label + else: + if TYPE_CHECKING: + assert is_column_element(c) + + if table_qualified: + required_label_name = ( + effective_name + ) = fallback_label_name = c._tq_label + else: + effective_name = fallback_label_name = c._non_anon_label + required_label_name = None + + if effective_name is None: + # it seems like this could be _proxy_key and we would + # not need _expression_label but it isn't + # giving us a clue when to use anon_label instead + expr_label = c._expression_label + if expr_label is None: + repeated = c._anon_name_label in names + names[c._anon_name_label] = c + effective_name = required_label_name = None + + if repeated: + # here, "required_label_name" is sent as + # "None" and "fallback_label_name" is sent. + if table_qualified: + fallback_label_name = ( + c._dedupe_anon_tq_label_idx(dedupe_hash) + ) + dedupe_hash += 1 + else: + fallback_label_name = c._dedupe_anon_label_idx( + dedupe_hash + ) + dedupe_hash += 1 + else: + fallback_label_name = c._anon_name_label + else: + required_label_name = ( + effective_name + ) = fallback_label_name = expr_label + + if effective_name is not None: + if TYPE_CHECKING: + assert is_column_element(c) + + if effective_name in names: + # when looking to see if names[name] is the same column as + # c, use hash(), so that an annotated version of the column + # is seen as the same as the non-annotated + if hash(names[effective_name]) != hash(c): + + # different column under the same name. apply + # disambiguating label + if table_qualified: + required_label_name = ( + fallback_label_name + ) = c._anon_tq_label + else: + required_label_name = ( + fallback_label_name + ) = c._anon_name_label + + if anon_for_dupe_key and required_label_name in names: + # here, c._anon_tq_label is definitely unique to + # that column identity (or annotated version), so + # this should always be true. + # this is also an infrequent codepath because + # you need two levels of duplication to be here + assert hash(names[required_label_name]) == hash(c) + + # the column under the disambiguating label is + # already present. apply the "dedupe" label to + # subsequent occurrences of the column so that the + # original stays non-ambiguous + if table_qualified: + required_label_name = ( + fallback_label_name + ) = c._dedupe_anon_tq_label_idx(dedupe_hash) + dedupe_hash += 1 + else: + required_label_name = ( + fallback_label_name + ) = c._dedupe_anon_label_idx(dedupe_hash) + dedupe_hash += 1 + repeated = True + else: + names[required_label_name] = c + elif anon_for_dupe_key: + # same column under the same name. apply the "dedupe" + # label so that the original stays non-ambiguous + if table_qualified: + required_label_name = ( + fallback_label_name + ) = c._dedupe_anon_tq_label_idx(dedupe_hash) + dedupe_hash += 1 + else: + required_label_name = ( + fallback_label_name + ) = c._dedupe_anon_label_idx(dedupe_hash) + dedupe_hash += 1 + repeated = True + else: + names[effective_name] = c + + result_append( + _ColumnsPlusNames( + required_label_name, + key_naming_convention(c), + fallback_label_name, + c, + repeated, + ) + ) + + return result + + +class HasCTE(roles.HasCTERole, SelectsRows): + """Mixin that declares a class to include CTE support. + + .. versionadded:: 1.1 + + """ + + _has_ctes_traverse_internals: _TraverseInternalsType = [ + ("_independent_ctes", InternalTraversal.dp_clauseelement_list), + ("_independent_ctes_opts", InternalTraversal.dp_plain_obj), + ] + + _independent_ctes: Tuple[CTE, ...] = () + _independent_ctes_opts: Tuple[_CTEOpts, ...] = () + + @_generative + def add_cte( + self: SelfHasCTE, *ctes: CTE, nest_here: bool = False + ) -> SelfHasCTE: + r"""Add one or more :class:`_sql.CTE` constructs to this statement. + + This method will associate the given :class:`_sql.CTE` constructs with + the parent statement such that they will each be unconditionally + rendered in the WITH clause of the final statement, even if not + referenced elsewhere within the statement or any sub-selects. + + The optional :paramref:`.HasCTE.add_cte.nest_here` parameter when set + to True will have the effect that each given :class:`_sql.CTE` will + render in a WITH clause rendered directly along with this statement, + rather than being moved to the top of the ultimate rendered statement, + even if this statement is rendered as a subquery within a larger + statement. + + This method has two general uses. One is to embed CTE statements that + serve some purpose without being referenced explicitly, such as the use + case of embedding a DML statement such as an INSERT or UPDATE as a CTE + inline with a primary statement that may draw from its results + indirectly. The other is to provide control over the exact placement + of a particular series of CTE constructs that should remain rendered + directly in terms of a particular statement that may be nested in a + larger statement. + + E.g.:: + + from sqlalchemy import table, column, select + t = table('t', column('c1'), column('c2')) + + ins = t.insert().values({"c1": "x", "c2": "y"}).cte() + + stmt = select(t).add_cte(ins) + + Would render:: + + WITH anon_1 AS + (INSERT INTO t (c1, c2) VALUES (:param_1, :param_2)) + SELECT t.c1, t.c2 + FROM t + + Above, the "anon_1" CTE is not referred towards in the SELECT + statement, however still accomplishes the task of running an INSERT + statement. + + Similarly in a DML-related context, using the PostgreSQL :class:`_postgresql.Insert` construct to generate an "upsert":: from sqlalchemy import table, column @@ -1985,7 +2439,12 @@ class HasCTE(roles.HasCTERole, ClauseElement): self._independent_ctes_opts += (opt,) return self - def cte(self, name=None, recursive=False, nesting=False): + def cte( + self, + name: Optional[str] = None, + recursive: bool = False, + nesting: bool = False, + ) -> CTE: r"""Return a new :class:`_expression.CTE`, or Common Table Expression instance. @@ -2293,10 +2752,12 @@ class Subquery(AliasedReturnsRows): inherit_cache = True - element: Select + element: SelectBase @classmethod - def _factory(cls, selectable, name=None): + def _factory( + cls, selectable: SelectBase, name: Optional[str] = None + ) -> Subquery: """Return a :class:`.Subquery` object.""" return coercions.expect( roles.SelectStatementRole, selectable @@ -2335,11 +2796,13 @@ class Subquery(AliasedReturnsRows): class FromGrouping(GroupedElement, FromClause): """Represent a grouping of a FROM clause""" - _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_clauseelement) + ] element: FromClause - def __init__(self, element): + def __init__(self, element: FromClause): self.element = coercions.expect(roles.FromClauseRole, element) def _init_collections(self): @@ -2361,11 +2824,13 @@ class FromGrouping(GroupedElement, FromClause): def foreign_keys(self): return self.element.foreign_keys - def is_derived_from(self, element): - return self.element.is_derived_from(element) + def is_derived_from(self, fromclause: FromClause) -> bool: + return self.element.is_derived_from(fromclause) - def alias(self, **kw): - return FromGrouping(self.element.alias(**kw)) + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> NamedFromGrouping: + return NamedFromGrouping(self.element.alias(name=name, flat=flat)) def _anonymous_fromclause(self, **kw): return FromGrouping(self.element._anonymous_fromclause(**kw)) @@ -2385,6 +2850,16 @@ class FromGrouping(GroupedElement, FromClause): self.element = state["element"] +class NamedFromGrouping(FromGrouping, NamedFromClause): + """represent a grouping of a named FROM clause + + .. versionadded:: 2.0 + + """ + + inherit_cache = True + + class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): """Represents a minimal "table" construct. @@ -2417,7 +2892,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): __visit_name__ = "table" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ( "columns", InternalTraversal.dp_fromclause_canonical_column_collection, @@ -2434,15 +2909,17 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): doesn't support having a primary key or column -level defaults, so implicit returning doesn't apply.""" - _autoincrement_column = None - """No PK or default support so no autoincrement column.""" + @util.ro_memoized_property + def _autoincrement_column(self) -> Optional[ColumnClause[Any]]: + """No PK or default support so no autoincrement column.""" + return None - def __init__(self, name, *columns, **kw): + def __init__(self, name: str, *columns: ColumnClause[Any], **kw: Any): super(TableClause, self).__init__() self.name = name self._columns = DedupeColumnCollection() - self.primary_key = ColumnSet() - self.foreign_keys = set() + self.primary_key = ColumnSet() # type: ignore + self.foreign_keys = set() # type: ignore for c in columns: self.append_column(c) @@ -2466,23 +2943,23 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: ... - def __str__(self): + def __str__(self) -> str: if self.schema is not None: return self.schema + "." + self.name else: return self.name - def _refresh_for_new_column(self, column): + def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: pass - def _init_collections(self): + def _init_collections(self) -> None: pass - @util.memoized_property + @util.ro_memoized_property def description(self) -> str: return self.name - def append_column(self, c, **kw): + def append_column(self, c: ColumnClause[Any]) -> None: existing = c.table if existing is not None and existing is not self: raise exc.ArgumentError( @@ -2494,7 +2971,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): c.table = self @util.preload_module("sqlalchemy.sql.dml") - def insert(self): + def insert(self) -> Insert: """Generate an :func:`_expression.insert` construct against this :class:`_expression.TableClause`. @@ -2505,10 +2982,11 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): See :func:`_expression.insert` for argument and usage information. """ + return util.preloaded.sql_dml.Insert(self) @util.preload_module("sqlalchemy.sql.dml") - def update(self): + def update(self) -> Update: """Generate an :func:`_expression.update` construct against this :class:`_expression.TableClause`. @@ -2524,7 +3002,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): ) @util.preload_module("sqlalchemy.sql.dml") - def delete(self): + def delete(self) -> Delete: """Generate a :func:`_expression.delete` construct against this :class:`_expression.TableClause`. @@ -2543,13 +3021,18 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): class ForUpdateArg(ClauseElement): - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("of", InternalTraversal.dp_clauseelement_list), ("nowait", InternalTraversal.dp_boolean), ("read", InternalTraversal.dp_boolean), ("skip_locked", InternalTraversal.dp_boolean), ] + of: Optional[Sequence[ClauseElement]] + nowait: bool + read: bool + skip_locked: bool + @classmethod def _from_argument(cls, with_for_update): if isinstance(with_for_update, ForUpdateArg): @@ -2606,7 +3089,7 @@ class ForUpdateArg(ClauseElement): SelfValues = typing.TypeVar("SelfValues", bound="Values") -class Values(Generative, NamedFromClause): +class Values(Generative, LateralFromClause): """Represent a ``VALUES`` construct that can be used as a FROM element in a statement. @@ -2619,28 +3102,42 @@ class Values(Generative, NamedFromClause): __visit_name__ = "values" - _data = () + _data: Tuple[List[Tuple[Any, ...]], ...] = () - _traverse_internals = [ + _unnamed: bool + _traverse_internals: _TraverseInternalsType = [ ("_column_args", InternalTraversal.dp_clauseelement_list), ("_data", InternalTraversal.dp_dml_multi_values), ("name", InternalTraversal.dp_string), ("literal_binds", InternalTraversal.dp_boolean), ] - def __init__(self, *columns, name=None, literal_binds=False): + def __init__( + self, + *columns: ColumnClause[Any], + name: Optional[str] = None, + literal_binds: bool = False, + ): super(Values, self).__init__() self._column_args = columns - self.name = name + if name is None: + self._unnamed = True + self.name = _anonymous_label.safe_construct(id(self), "anon") + else: + self._unnamed = False + self.name = name self.literal_binds = literal_binds - self.named_with_column = self.name is not None + self.named_with_column = not self._unnamed @property def _column_types(self): return [col.type for col in self._column_args] @_generative - def alias(self: SelfValues, name, **kw) -> SelfValues: + def alias( + self: SelfValues, name: Optional[str] = None, flat: bool = False + ) -> SelfValues: + """Return a new :class:`_expression.Values` construct that is a copy of this one with the given name. @@ -2655,12 +3152,20 @@ class Values(Generative, NamedFromClause): :func:`_expression.alias` """ - self.name = name - self.named_with_column = self.name is not None + non_none_name: str + + if name is None: + non_none_name = _anonymous_label.safe_construct(id(self), "anon") + else: + non_none_name = name + + self.name = non_none_name + self.named_with_column = True + self._unnamed = False return self @_generative - def lateral(self: SelfValues, name=None) -> SelfValues: + def lateral(self, name: Optional[str] = None) -> LateralFromClause: """Return a new :class:`_expression.Values` with the lateral flag set, so that it renders as LATERAL. @@ -2670,13 +3175,20 @@ class Values(Generative, NamedFromClause): :func:`_expression.lateral` """ + non_none_name: str + + if name is None: + non_none_name = self.name + else: + non_none_name = name + self._is_lateral = True - if name is not None: - self.name = name + self.name = non_none_name + self._unnamed = False return self @_generative - def data(self: SelfValues, values) -> SelfValues: + def data(self: SelfValues, values: List[Tuple[Any, ...]]) -> SelfValues: """Return a new :class:`_expression.Values` construct, adding the given data to the data list. @@ -2694,7 +3206,7 @@ class Values(Generative, NamedFromClause): self._data += (values,) return self - def _populate_column_collection(self): + def _populate_column_collection(self) -> None: for c in self._column_args: self._columns.add(c) c.table = self @@ -2727,32 +3239,16 @@ class SelectBase( """ - _is_select_statement = True + _is_select_base = True is_select = True - def _generate_fromclause_column_proxies( - self, fromclause: FromClause - ) -> None: - raise NotImplementedError() + _label_style: SelectLabelStyle = LABEL_STYLE_NONE def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: self._reset_memoizations() - def _generate_columns_plus_names( - self, anon_for_dupe_key: bool - ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: - raise NotImplementedError() - - def set_label_style( - self: SelfSelectBase, label_style: SelectLabelStyle - ) -> SelfSelectBase: - raise NotImplementedError() - - def get_label_style(self) -> SelectLabelStyle: - raise NotImplementedError() - - @property - def selected_columns(self): + @util.ro_non_memoized_property + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set. @@ -2797,7 +3293,7 @@ class SelectBase( raise NotImplementedError() @property - def exported_columns(self): + def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`, not including @@ -2819,7 +3315,7 @@ class SelectBase( """ - return self.selected_columns + return self.selected_columns.as_readonly() @util.deprecated_property( "1.4", @@ -2841,6 +3337,26 @@ class SelectBase( def columns(self): return self.c + def get_label_style(self) -> SelectLabelStyle: + """ + Retrieve the current label style. + + Implemented by subclasses. + + """ + raise NotImplementedError() + + def set_label_style( + self: SelfSelectBase, style: SelectLabelStyle + ) -> SelfSelectBase: + """Return a new selectable with the specified label style. + + Implemented by subclasses. + + """ + + raise NotImplementedError() + @util.deprecated( "1.4", "The :meth:`_expression.SelectBase.select` method is deprecated " @@ -2857,6 +3373,9 @@ class SelectBase( def _implicit_subquery(self): return self.subquery() + def _scalar_type(self) -> TypeEngine[Any]: + raise NotImplementedError() + @util.deprecated( "1.4", "The :meth:`_expression.SelectBase.as_scalar` " @@ -2926,7 +3445,7 @@ class SelectBase( """ return self.scalar_subquery().label(name) - def lateral(self, name=None): + def lateral(self, name: Optional[str] = None) -> LateralFromClause: """Return a LATERAL alias of this :class:`_expression.Selectable`. The return value is the :class:`_expression.Lateral` construct also @@ -2941,11 +3460,7 @@ class SelectBase( """ return Lateral._factory(self, name) - @util.ro_non_memoized_property - def _from_objects(self) -> List[FromClause]: - return [self] - - def subquery(self, name=None): + def subquery(self, name: Optional[str] = None) -> Subquery: """Return a subquery of this :class:`_expression.SelectBase`. A subquery is from a SQL perspective a parenthesized, named @@ -2995,7 +3510,9 @@ class SelectBase( raise NotImplementedError() - def alias(self, name=None, flat=False): + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> Subquery: """Return a named subquery against this :class:`_expression.SelectBase`. @@ -3023,7 +3540,9 @@ class SelectStatementGrouping(GroupedElement, SelectBase): """ __visit_name__ = "select_statement_grouping" - _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_clauseelement) + ] _is_select_container = True @@ -3053,13 +3572,14 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def select_statement(self): return self.element - def self_group(self, against=None): + def self_group(self: Self, against: Optional[OperatorType] = None) -> Self: + ... return self - def _generate_columns_plus_names( - self, anon_for_dupe_key: bool - ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: - return self.element._generate_columns_plus_names(anon_for_dupe_key) + # def _generate_columns_plus_names( + # self, anon_for_dupe_key: bool + # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: + # return self.element._generate_columns_plus_names(anon_for_dupe_key) def _generate_fromclause_column_proxies( self, subquery: FromClause @@ -3070,8 +3590,8 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def _all_selected_columns(self) -> _SelectIterable: return self.element._all_selected_columns - @property - def selected_columns(self): + @util.ro_non_memoized_property + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that the embedded SELECT statement returns in its result set, not including @@ -3112,25 +3632,30 @@ class GenerativeSelect(SelectBase): """ - _order_by_clauses = () - _group_by_clauses = () - _limit_clause = None - _offset_clause = None - _fetch_clause = None - _fetch_clause_options = None - _for_update_arg = None + _order_by_clauses: Tuple[ColumnElement[Any], ...] = () + _group_by_clauses: Tuple[ColumnElement[Any], ...] = () + _limit_clause: Optional[ColumnElement[Any]] = None + _offset_clause: Optional[ColumnElement[Any]] = None + _fetch_clause: Optional[ColumnElement[Any]] = None + _fetch_clause_options: Optional[Dict[str, bool]] = None + _for_update_arg: Optional[ForUpdateArg] = None - def __init__(self, _label_style=LABEL_STYLE_DEFAULT): + def __init__(self, _label_style: SelectLabelStyle = LABEL_STYLE_DEFAULT): self._label_style = _label_style @_generative def with_for_update( self: SelfGenerativeSelect, - nowait=False, - read=False, - of=None, - skip_locked=False, - key_share=False, + nowait: bool = False, + read: bool = False, + of: Optional[ + Union[ + _ColumnExpressionArgument[Any], + Sequence[_ColumnExpressionArgument[Any]], + ] + ] = None, + skip_locked: bool = False, + key_share: bool = False, ) -> SelfGenerativeSelect: """Specify a ``FOR UPDATE`` clause for this :class:`_expression.GenerativeSelect`. @@ -3241,20 +3766,25 @@ class GenerativeSelect(SelectBase): return self @property - def _group_by_clause(self): + def _group_by_clause(self) -> ClauseList: """ClauseList access to group_by_clauses for legacy dialects""" return ClauseList._construct_raw( operators.comma_op, self._group_by_clauses ) @property - def _order_by_clause(self): + def _order_by_clause(self) -> ClauseList: """ClauseList access to order_by_clauses for legacy dialects""" return ClauseList._construct_raw( operators.comma_op, self._order_by_clauses ) - def _offset_or_limit_clause(self, element, name=None, type_=None): + def _offset_or_limit_clause( + self, + element: Union[int, _ColumnExpressionArgument[Any]], + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[int]] = None, + ) -> ColumnElement[Any]: """Convert the given value to an "offset or limit" clause. This handles incoming integers and converts to an expression; if @@ -3265,7 +3795,21 @@ class GenerativeSelect(SelectBase): roles.LimitOffsetRole, element, name=name, type_=type_ ) - def _offset_or_limit_clause_asint(self, clause, attrname): + @overload + def _offset_or_limit_clause_asint( + self, clause: ColumnElement[Any], attrname: str + ) -> NoReturn: + ... + + @overload + def _offset_or_limit_clause_asint( + self, clause: Optional[_OffsetLimitParam], attrname: str + ) -> Optional[int]: + ... + + def _offset_or_limit_clause_asint( + self, clause: Optional[ColumnElement[Any]], attrname: str + ) -> Union[NoReturn, Optional[int]]: """Convert the "offset or limit" clause of a select construct to an integer. @@ -3286,7 +3830,7 @@ class GenerativeSelect(SelectBase): return util.asint(value) @property - def _limit(self): + def _limit(self) -> Optional[int]: """Get an integer value for the limit. This should only be used by code that cannot support a limit as a BindParameter or other custom clause as it will throw an exception if the limit @@ -3295,14 +3839,14 @@ class GenerativeSelect(SelectBase): """ return self._offset_or_limit_clause_asint(self._limit_clause, "limit") - def _simple_int_clause(self, clause): + def _simple_int_clause(self, clause: ClauseElement) -> bool: """True if the clause is a simple integer, False if it is not present or is a SQL expression. """ return isinstance(clause, _OffsetLimitParam) @property - def _offset(self): + def _offset(self) -> Optional[int]: """Get an integer value for the offset. This should only be used by code that cannot support an offset as a BindParameter or other custom clause as it will throw an exception if the @@ -3314,7 +3858,7 @@ class GenerativeSelect(SelectBase): ) @property - def _has_row_limiting_clause(self): + def _has_row_limiting_clause(self) -> bool: return ( self._limit_clause is not None or self._offset_clause is not None @@ -3322,7 +3866,10 @@ class GenerativeSelect(SelectBase): ) @_generative - def limit(self: SelfGenerativeSelect, limit) -> SelfGenerativeSelect: + def limit( + self: SelfGenerativeSelect, + limit: Union[int, _ColumnExpressionArgument[int]], + ) -> SelfGenerativeSelect: """Return a new selectable with the given LIMIT criterion applied. @@ -3356,7 +3903,10 @@ class GenerativeSelect(SelectBase): @_generative def fetch( - self: SelfGenerativeSelect, count, with_ties=False, percent=False + self: SelfGenerativeSelect, + count: Union[int, _ColumnExpressionArgument[int]], + with_ties: bool = False, + percent: bool = False, ) -> SelfGenerativeSelect: """Return a new selectable with the given FETCH FIRST criterion applied. @@ -3408,7 +3958,10 @@ class GenerativeSelect(SelectBase): return self @_generative - def offset(self: SelfGenerativeSelect, offset) -> SelfGenerativeSelect: + def offset( + self: SelfGenerativeSelect, + offset: Union[int, _ColumnExpressionArgument[int]], + ) -> SelfGenerativeSelect: """Return a new selectable with the given OFFSET criterion applied. @@ -3438,7 +3991,11 @@ class GenerativeSelect(SelectBase): @_generative @util.preload_module("sqlalchemy.sql.util") - def slice(self: SelfGenerativeSelect, start, stop) -> SelfGenerativeSelect: + def slice( + self: SelfGenerativeSelect, + start: int, + stop: int, + ) -> SelfGenerativeSelect: """Apply LIMIT / OFFSET to this statement based on a slice. The start and stop indices behave like the argument to Python's @@ -3485,7 +4042,9 @@ class GenerativeSelect(SelectBase): return self @_generative - def order_by(self: SelfGenerativeSelect, *clauses) -> SelfGenerativeSelect: + def order_by( + self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any] + ) -> SelfGenerativeSelect: r"""Return a new selectable with the given list of ORDER BY criteria applied. @@ -3522,7 +4081,9 @@ class GenerativeSelect(SelectBase): return self @_generative - def group_by(self: SelfGenerativeSelect, *clauses) -> SelfGenerativeSelect: + def group_by( + self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any] + ) -> SelfGenerativeSelect: r"""Return a new selectable with the given list of GROUP BY criterion applied. @@ -3567,6 +4128,15 @@ class CompoundSelectState(CompileState): return d, d, d +class _CompoundSelectKeyword(Enum): + UNION = "UNION" + UNION_ALL = "UNION ALL" + EXCEPT = "EXCEPT" + EXCEPT_ALL = "EXCEPT ALL" + INTERSECT = "INTERSECT" + INTERSECT_ALL = "INTERSECT ALL" + + class CompoundSelect(HasCompileState, GenerativeSelect): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations. @@ -3590,7 +4160,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect): __visit_name__ = "compound_select" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("selects", InternalTraversal.dp_clauseelement_list), ("_limit_clause", InternalTraversal.dp_clauseelement), ("_offset_clause", InternalTraversal.dp_clauseelement), @@ -3602,17 +4172,16 @@ class CompoundSelect(HasCompileState, GenerativeSelect): ("keyword", InternalTraversal.dp_string), ] + SupportsCloneAnnotations._clone_annotations_traverse_internals - UNION = util.symbol("UNION") - UNION_ALL = util.symbol("UNION ALL") - EXCEPT = util.symbol("EXCEPT") - EXCEPT_ALL = util.symbol("EXCEPT ALL") - INTERSECT = util.symbol("INTERSECT") - INTERSECT_ALL = util.symbol("INTERSECT ALL") + selects: List[SelectBase] _is_from_container = True _auto_correlate = False - def __init__(self, keyword, *selects): + def __init__( + self, + keyword: _CompoundSelectKeyword, + *selects: _SelectStatementForCompoundArgument, + ): self.keyword = keyword self.selects = [ coercions.expect(roles.CompoundElementRole, s).self_group( @@ -3624,36 +4193,50 @@ class CompoundSelect(HasCompileState, GenerativeSelect): GenerativeSelect.__init__(self) @classmethod - def _create_union(cls, *selects, **kwargs): - return CompoundSelect(CompoundSelect.UNION, *selects, **kwargs) + def _create_union( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.UNION, *selects) @classmethod - def _create_union_all(cls, *selects): - return CompoundSelect(CompoundSelect.UNION_ALL, *selects) + def _create_union_all( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.UNION_ALL, *selects) @classmethod - def _create_except(cls, *selects): - return CompoundSelect(CompoundSelect.EXCEPT, *selects) + def _create_except( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.EXCEPT, *selects) @classmethod - def _create_except_all(cls, *selects): - return CompoundSelect(CompoundSelect.EXCEPT_ALL, *selects) + def _create_except_all( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.EXCEPT_ALL, *selects) @classmethod - def _create_intersect(cls, *selects): - return CompoundSelect(CompoundSelect.INTERSECT, *selects) + def _create_intersect( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.INTERSECT, *selects) @classmethod - def _create_intersect_all(cls, *selects): - return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects) + def _create_intersect_all( + cls, *selects: _SelectStatementForCompoundArgument + ) -> CompoundSelect: + return CompoundSelect(_CompoundSelectKeyword.INTERSECT_ALL, *selects) - def _scalar_type(self): + def _scalar_type(self) -> TypeEngine[Any]: return self.selects[0]._scalar_type() - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> GroupedElement: return SelectStatementGrouping(self) - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: for s in self.selects: if s.is_derived_from(fromclause): return True @@ -3675,7 +4258,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return self - def _generate_fromclause_column_proxies(self, subquery): + def _generate_fromclause_column_proxies( + self, subquery: FromClause + ) -> None: # this is a slightly hacky thing - the union exports a # column that resembles just that of the *first* selectable. @@ -3716,8 +4301,8 @@ class CompoundSelect(HasCompileState, GenerativeSelect): def _all_selected_columns(self) -> _SelectIterable: return self.selects[0]._all_selected_columns - @property - def selected_columns(self): + @util.ro_non_memoized_property + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -3739,6 +4324,11 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return self.selects[0].selected_columns +# backwards compat +for elem in _CompoundSelectKeyword: + setattr(CompoundSelect, elem.name, elem) + + @CompileState.plugin_for("default", "select") class SelectState(util.MemoizedSlots, CompileState): __slots__ = ( @@ -3758,10 +4348,12 @@ class SelectState(util.MemoizedSlots, CompileState): if TYPE_CHECKING: @classmethod - def get_plugin_class(cls, statement: Select) -> SelectState: + def get_plugin_class(cls, statement: Executable) -> Type[SelectState]: ... - def __init__(self, statement, compiler, **kw): + def __init__( + self, statement: Select, compiler: Optional[SQLCompiler], **kw: Any + ): self.statement = statement self.from_clauses = statement._from_obj @@ -3778,14 +4370,16 @@ class SelectState(util.MemoizedSlots, CompileState): self.columns_plus_names = statement._generate_columns_plus_names(True) @classmethod - def _plugin_not_implemented(cls): + def _plugin_not_implemented(cls) -> NoReturn: raise NotImplementedError( "The default SELECT construct without plugins does not " "implement this method." ) @classmethod - def get_column_descriptions(cls, statement): + def get_column_descriptions( + cls, statement: Select + ) -> List[Dict[str, Any]]: return [ { "name": name, @@ -3798,11 +4392,13 @@ class SelectState(util.MemoizedSlots, CompileState): ] @classmethod - def from_statement(cls, statement, from_statement): + def from_statement( + cls, statement: Select, from_statement: ReturnsRows + ) -> Any: cls._plugin_not_implemented() @classmethod - def get_columns_clause_froms(cls, statement): + def get_columns_clause_froms(cls, statement: Select) -> List[FromClause]: return cls._normalize_froms( itertools.chain.from_iterable( element._from_objects for element in statement._raw_columns @@ -3810,7 +4406,9 @@ class SelectState(util.MemoizedSlots, CompileState): ) @classmethod - def _column_naming_convention(cls, label_style): + def _column_naming_convention( + cls, label_style: SelectLabelStyle + ) -> Callable[[Union[ColumnElement[Any], TextClause]], Optional[str]]: table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL dedupe = label_style is not LABEL_STYLE_NONE @@ -3850,7 +4448,8 @@ class SelectState(util.MemoizedSlots, CompileState): return go - def _get_froms(self, statement): + def _get_froms(self, statement: Select) -> List[FromClause]: + ambiguous_table_name_map: _AmbiguousTableNameMap self._ambiguous_table_name_map = ambiguous_table_name_map = {} return self._normalize_froms( @@ -3876,10 +4475,10 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def _normalize_froms( cls, - iterable_of_froms, - check_statement=None, - ambiguous_table_name_map=None, - ): + iterable_of_froms: Iterable[FromClause], + check_statement: Optional[Select] = None, + ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, + ) -> List[FromClause]: """given an iterable of things to select FROM, reduce them to what would actually render in the FROM clause of a SELECT. @@ -3888,12 +4487,12 @@ class SelectState(util.MemoizedSlots, CompileState): etc. """ - seen = set() - froms = [] + seen: Set[FromClause] = set() + froms: List[FromClause] = [] for item in iterable_of_froms: - if item._is_subquery and item.element is check_statement: + if is_subquery(item) and item.element is check_statement: raise exc.InvalidRequestError( "select() construct refers to itself as a FROM" ) @@ -3923,7 +4522,7 @@ class SelectState(util.MemoizedSlots, CompileState): ) for item in froms for fr in item._from_objects - if fr._is_table + if is_table(fr) and fr.schema and fr.name not in ambiguous_table_name_map ) @@ -3931,8 +4530,10 @@ class SelectState(util.MemoizedSlots, CompileState): return froms def _get_display_froms( - self, explicit_correlate_froms=None, implicit_correlate_froms=None - ): + self, + explicit_correlate_froms: Optional[Sequence[FromClause]] = None, + implicit_correlate_froms: Optional[Sequence[FromClause]] = None, + ) -> List[FromClause]: """Return the full list of 'from' clauses to be displayed. Takes into account a set of existing froms which may be @@ -3998,25 +4599,33 @@ class SelectState(util.MemoizedSlots, CompileState): return froms - def _memoized_attr__label_resolve_dict(self): - with_cols = dict( - (c._tq_label or c.key, c) + def _memoized_attr__label_resolve_dict( + self, + ) -> Tuple[ + Dict[str, ColumnElement[Any]], + Dict[str, ColumnElement[Any]], + Dict[str, ColumnElement[Any]], + ]: + with_cols: Dict[str, ColumnElement[Any]] = dict( + (c._tq_label or c.key, c) # type: ignore for c in self.statement._all_selected_columns if c._allow_label_resolve ) - only_froms = dict( - (c.key, c) + only_froms: Dict[str, ColumnElement[Any]] = dict( + (c.key, c) # type: ignore for c in _select_iterables(self.froms) if c._allow_label_resolve ) - only_cols = with_cols.copy() + only_cols: Dict[str, ColumnElement[Any]] = with_cols.copy() for key, value in only_froms.items(): with_cols.setdefault(key, value) return with_cols, only_froms, only_cols @classmethod - def determine_last_joined_entity(cls, stmt): + def determine_last_joined_entity( + cls, stmt: Select + ) -> Optional[_JoinTargetElement]: if stmt._setup_joins: return stmt._setup_joins[-1][0] else: @@ -4026,8 +4635,16 @@ class SelectState(util.MemoizedSlots, CompileState): def all_selected_columns(cls, statement: Select) -> _SelectIterable: return [c for c in _select_iterables(statement._raw_columns)] - def _setup_joins(self, args, raw_columns): + def _setup_joins( + self, + args: Tuple[_SetupJoinsElement, ...], + raw_columns: List[_ColumnsClauseElement], + ) -> None: for (right, onclause, left, flags) in args: + if TYPE_CHECKING: + if onclause is not None: + assert isinstance(onclause, ColumnElement) + isouter = flags["isouter"] full = flags["full"] @@ -4043,6 +4660,16 @@ class SelectState(util.MemoizedSlots, CompileState): left ) + # these assertions can be made here, as if the right/onclause + # contained ORM elements, the select() statement would have been + # upgraded to an ORM select, and this method would not be called; + # orm.context.ORMSelectCompileState._join() would be + # used instead. + if TYPE_CHECKING: + assert isinstance(right, FromClause) + if onclause is not None: + assert isinstance(onclause, ColumnElement) + if replace_from_obj_index is not None: # splice into an existing element in the # self._from_obj list @@ -4062,15 +4689,19 @@ class SelectState(util.MemoizedSlots, CompileState): + self.from_clauses[replace_from_obj_index + 1 :] ) else: - + assert left is not None self.from_clauses = self.from_clauses + ( Join(left, right, onclause, isouter=isouter, full=full), ) @util.preload_module("sqlalchemy.sql.util") def _join_determine_implicit_left_side( - self, raw_columns, left, right, onclause - ): + self, + raw_columns: List[_ColumnsClauseElement], + left: Optional[FromClause], + right: _JoinTargetElement, + onclause: Optional[ColumnElement[Any]], + ) -> Tuple[Optional[FromClause], Optional[int]]: """When join conditions don't express the left side explicitly, determine if an existing FROM or entity in this query can serve as the left hand side. @@ -4079,13 +4710,13 @@ class SelectState(util.MemoizedSlots, CompileState): sql_util = util.preloaded.sql_util - replace_from_obj_index = None + replace_from_obj_index: Optional[int] = None from_clauses = self.from_clauses if from_clauses: - indexes = sql_util.find_left_clause_to_join_from( + indexes: List[int] = sql_util.find_left_clause_to_join_from( from_clauses, right, onclause ) @@ -4138,15 +4769,17 @@ class SelectState(util.MemoizedSlots, CompileState): return left, replace_from_obj_index @util.preload_module("sqlalchemy.sql.util") - def _join_place_explicit_left_side(self, left): - replace_from_obj_index = None + def _join_place_explicit_left_side( + self, left: FromClause + ) -> Optional[int]: + replace_from_obj_index: Optional[int] = None sql_util = util.preloaded.sql_util from_clauses = list(self.statement._iterate_from_elements()) if from_clauses: - indexes = sql_util.find_left_clause_that_matches_given( + indexes: List[int] = sql_util.find_left_clause_that_matches_given( self.from_clauses, left ) else: @@ -4171,7 +4804,13 @@ class SelectState(util.MemoizedSlots, CompileState): class _SelectFromElements: - def _iterate_from_elements(self): + __slots__ = () + + _raw_columns: List[_ColumnsClauseElement] + _where_criteria: Tuple[ColumnElement[Any], ...] + _from_obj: Tuple[FromClause, ...] + + def _iterate_from_elements(self) -> Iterator[FromClause]: # note this does not include elements # in _setup_joins @@ -4195,28 +4834,58 @@ class _SelectFromElements: yield element +Self_MemoizedSelectEntities = TypeVar("Self_MemoizedSelectEntities", bound=Any) + + class _MemoizedSelectEntities( cache_key.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible ): + """represents partial state from a Select object, for the case + where Select.columns() has redefined the set of columns/entities the + statement will be SELECTing from. This object represents + the entities from the SELECT before that transformation was applied, + so that transformations that were made in terms of the SELECT at that + time, such as join() as well as options(), can access the correct context. + + In previous SQLAlchemy versions, this wasn't needed because these + constructs calculated everything up front, like when you called join() + or options(), it did everything to figure out how that would translate + into specific SQL constructs that would be ready to send directly to the + SQL compiler when needed. But as of + 1.4, all of that stuff is done in the compilation phase, during the + "compile state" portion of the process, so that the work can all be + cached. So it needs to be able to resolve joins/options2 based on what + the list of entities was when those methods were called. + + + """ + __visit_name__ = "memoized_select_entities" - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), ("_setup_joins", InternalTraversal.dp_setup_join_tuple), ("_with_options", InternalTraversal.dp_executable_options), ] + _is_clone_of: Optional[ClauseElement] + _raw_columns: List[_ColumnsClauseElement] + _setup_joins: Tuple[_SetupJoinsElement, ...] + _with_options: Tuple[ExecutableOption, ...] + _annotations = util.EMPTY_DICT - def _clone(self, **kw): + def _clone( + self: Self_MemoizedSelectEntities, **kw: Any + ) -> Self_MemoizedSelectEntities: c = self.__class__.__new__(self.__class__) c.__dict__ = {k: v for k, v in self.__dict__.items()} c._is_clone_of = self.__dict__.get("_is_clone_of", self) - return c + return c # type: ignore @classmethod - def _generate_for_statement(cls, select_stmt): + def _generate_for_statement(cls, select_stmt: Select) -> None: if select_stmt._setup_joins or select_stmt._with_options: self = _MemoizedSelectEntities() self._raw_columns = select_stmt._raw_columns @@ -4224,12 +4893,10 @@ class _MemoizedSelectEntities( self._with_options = select_stmt._with_options select_stmt._memoized_select_entities += (self,) - select_stmt._raw_columns = ( - select_stmt._setup_joins - ) = select_stmt._with_options = () + select_stmt._raw_columns = [] + select_stmt._setup_joins = select_stmt._with_options = () -# TODO: use pep-673 when feasible SelfSelect = typing.TypeVar("SelfSelect", bound="Select") @@ -4258,9 +4925,11 @@ class Select( __visit_name__ = "select" - _setup_joins: Tuple[TODO_Any, ...] = () + _setup_joins: Tuple[_SetupJoinsElement, ...] = () _memoized_select_entities: Tuple[TODO_Any, ...] = () + _raw_columns: List[_ColumnsClauseElement] + _distinct = False _distinct_on: Tuple[ColumnElement[Any], ...] = () _correlate: Tuple[FromClause, ...] = () @@ -4269,12 +4938,12 @@ class Select( _having_criteria: Tuple[ColumnElement[Any], ...] = () _from_obj: Tuple[FromClause, ...] = () _auto_correlate = True - + _is_select_statement = True _compile_options: CacheableOptions = ( SelectState.default_select_compile_options ) - _traverse_internals = ( + _traverse_internals: _TraverseInternalsType = ( [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), ( @@ -4306,12 +4975,14 @@ class Select( + Executable._executable_traverse_internals ) - _cache_key_traversal = _traverse_internals + [ + _cache_key_traversal: _CacheKeyTraversalType = _traverse_internals + [ ("_compile_options", InternalTraversal.dp_has_cache_key) ] + _compile_state_factory: Type[SelectState] + @classmethod - def _create_raw_select(cls, **kw) -> "Select": + def _create_raw_select(cls, **kw: Any) -> Select: """Create a :class:`.Select` using raw ``__new__`` with no coercions. Used internally to build up :class:`.Select` constructs with @@ -4330,6 +5001,12 @@ class Select( :func:`_sql.select` function. """ + things = [ + coercions.expect( + roles.ColumnsClauseRole, ent, apply_propagate_attrs=self + ) + for ent in entities + ] self._raw_columns = [ coercions.expect( @@ -4340,7 +5017,7 @@ class Select( GenerativeSelect.__init__(self) - def _scalar_type(self): + def _scalar_type(self) -> TypeEngine[Any]: elem = self._raw_columns[0] cols = list(elem._select_iterable) return cols[0].type @@ -4446,7 +5123,12 @@ class Select( @_generative def join( - self: SelfSelect, target, onclause=None, *, isouter=False, full=False + self: SelfSelect, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + isouter: bool = False, + full: bool = False, ) -> SelfSelect: r"""Create a SQL JOIN against this :class:`_expression.Select` object's criterion @@ -4505,17 +5187,32 @@ class Select( :meth:`_expression.Select.outerjoin` """ # noqa: E501 - target = coercions.expect( + join_target = coercions.expect( roles.JoinTargetRole, target, apply_propagate_attrs=self ) if onclause is not None: - onclause = coercions.expect(roles.OnClauseRole, onclause) + onclause_element = coercions.expect(roles.OnClauseRole, onclause) + else: + onclause_element = None + self._setup_joins += ( - (target, onclause, None, {"isouter": isouter, "full": full}), + ( + join_target, + onclause_element, + None, + {"isouter": isouter, "full": full}, + ), ) return self - def outerjoin_from(self, from_, target, onclause=None, *, full=False): + def outerjoin_from( + self: SelfSelect, + from_: _FromClauseArgument, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + full: bool = False, + ) -> SelfSelect: r"""Create a SQL LEFT OUTER JOIN against this :class:`_expression.Select` object's criterion and apply generatively, returning the newly resulting @@ -4531,12 +5228,12 @@ class Select( @_generative def join_from( self: SelfSelect, - from_, - target, - onclause=None, + from_: _FromClauseArgument, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, *, - isouter=False, - full=False, + isouter: bool = False, + full: bool = False, ) -> SelfSelect: r"""Create a SQL JOIN against this :class:`_expression.Select` object's criterion @@ -4586,18 +5283,31 @@ class Select( from_ = coercions.expect( roles.FromClauseRole, from_, apply_propagate_attrs=self ) - target = coercions.expect( + join_target = coercions.expect( roles.JoinTargetRole, target, apply_propagate_attrs=self ) if onclause is not None: - onclause = coercions.expect(roles.OnClauseRole, onclause) + onclause_element = coercions.expect(roles.OnClauseRole, onclause) + else: + onclause_element = None self._setup_joins += ( - (target, onclause, from_, {"isouter": isouter, "full": full}), + ( + join_target, + onclause_element, + from_, + {"isouter": isouter, "full": full}, + ), ) return self - def outerjoin(self, target, onclause=None, *, full=False): + def outerjoin( + self: SelfSelect, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + full: bool = False, + ) -> SelfSelect: """Create a left outer join. Parameters are the same as that of :meth:`_expression.Select.join`. @@ -4634,7 +5344,7 @@ class Select( """ return self.join(target, onclause=onclause, isouter=True, full=full) - def get_final_froms(self): + def get_final_froms(self) -> Sequence[FromClause]: """Compute the final displayed list of :class:`_expression.FromClause` elements. @@ -4671,6 +5381,7 @@ class Select( :attr:`_sql.Select.columns_clause_froms` """ + return self._compile_state_factory(self, None)._get_display_froms() @util.deprecated_property( @@ -4678,7 +5389,7 @@ class Select( "The :attr:`_expression.Select.froms` attribute is moved to " "the :meth:`_expression.Select.get_final_froms` method.", ) - def froms(self): + def froms(self) -> Sequence[FromClause]: """Return the displayed list of :class:`_expression.FromClause` elements. @@ -4687,7 +5398,7 @@ class Select( return self.get_final_froms() @property - def columns_clause_froms(self): + def columns_clause_froms(self) -> List[FromClause]: """Return the set of :class:`_expression.FromClause` objects implied by the columns clause of this SELECT statement. @@ -4720,7 +5431,7 @@ class Select( return iter(self._all_selected_columns) - def is_derived_from(self, fromclause): + def is_derived_from(self, fromclause: FromClause) -> bool: if self in fromclause._cloned_set: return True @@ -4729,7 +5440,9 @@ class Select( return True return False - def _copy_internals(self, clone=_clone, **kw): + def _copy_internals( + self, clone: _CloneCallableType = _clone, **kw: Any + ) -> None: # Select() object has been cloned and probably adapted by the # given clone function. Apply the cloning function to internal # objects @@ -4786,13 +5499,15 @@ class Select( def get_children(self, **kwargs): return itertools.chain( super(Select, self).get_children( - omit_attrs=["_from_obj", "_correlate", "_correlate_except"] + omit_attrs=("_from_obj", "_correlate", "_correlate_except") ), self._iterate_from_elements(), ) @_generative - def add_columns(self: SelfSelect, *columns) -> SelfSelect: + def add_columns( + self: SelfSelect, *columns: _ColumnsClauseArgument + ) -> SelfSelect: """Return a new :func:`_expression.select` construct with the given column expressions added to its columns clause. @@ -4816,7 +5531,9 @@ class Select( ] return self - def _set_entities(self, entities): + def _set_entities( + self, entities: Iterable[_ColumnsClauseArgument] + ) -> None: self._raw_columns = [ coercions.expect( roles.ColumnsClauseRole, ent, apply_propagate_attrs=self @@ -4830,7 +5547,7 @@ class Select( "be removed in a future release. Please use " ":meth:`_expression.Select.add_columns`", ) - def column(self, column): + def column(self: SelfSelect, column: _ColumnsClauseArgument) -> SelfSelect: """Return a new :func:`_expression.select` construct with the given column expression added to its columns clause. @@ -4847,7 +5564,9 @@ class Select( return self.add_columns(column) @util.preload_module("sqlalchemy.sql.util") - def reduce_columns(self, only_synonyms=True): + def reduce_columns( + self: SelfSelect, only_synonyms: bool = True + ) -> SelfSelect: """Return a new :func:`_expression.select` construct with redundantly named, equivalently-valued columns removed from the columns clause. @@ -4880,7 +5599,9 @@ class Select( @_generative def with_only_columns( - self: SelfSelect, *columns, maintain_column_froms=False + self: SelfSelect, + *columns: _ColumnsClauseArgument, + maintain_column_froms: bool = False, ) -> SelfSelect: r"""Return a new :func:`_expression.select` construct with its columns clause replaced with the given columns. @@ -4941,7 +5662,9 @@ class Select( self._assert_no_memoizations() if maintain_column_froms: - self.select_from.non_generative(self, *self.columns_clause_froms) + self.select_from.non_generative( # type: ignore + self, *self.columns_clause_froms + ) # then memoize the FROMs etc. _MemoizedSelectEntities._generate_for_statement(self) @@ -4974,7 +5697,9 @@ class Select( _whereclause = whereclause @_generative - def where(self: SelfSelect, *whereclause) -> SelfSelect: + def where( + self: SelfSelect, *whereclause: _ColumnExpressionArgument[bool] + ) -> SelfSelect: """Return a new :func:`_expression.select` construct with the given expression added to its WHERE clause, joined to the existing clause via AND, if any. @@ -4984,24 +5709,33 @@ class Select( assert isinstance(self._where_criteria, tuple) for criterion in whereclause: - where_criteria = coercions.expect(roles.WhereHavingRole, criterion) + where_criteria: ColumnElement[Any] = coercions.expect( + roles.WhereHavingRole, criterion + ) self._where_criteria += (where_criteria,) return self @_generative - def having(self: SelfSelect, having) -> SelfSelect: + def having( + self: SelfSelect, *having: _ColumnExpressionArgument[bool] + ) -> SelfSelect: """Return a new :func:`_expression.select` construct with the given expression added to its HAVING clause, joined to the existing clause via AND, if any. """ - self._having_criteria += ( - coercions.expect(roles.WhereHavingRole, having), - ) + + for criterion in having: + having_criteria = coercions.expect( + roles.WhereHavingRole, criterion + ) + self._having_criteria += (having_criteria,) return self @_generative - def distinct(self: SelfSelect, *expr) -> SelfSelect: + def distinct( + self: SelfSelect, *expr: _ColumnExpressionArgument[Any] + ) -> SelfSelect: r"""Return a new :func:`_expression.select` construct which will apply DISTINCT to its columns clause. @@ -5023,7 +5757,9 @@ class Select( return self @_generative - def select_from(self: SelfSelect, *froms) -> SelfSelect: + def select_from( + self: SelfSelect, *froms: _FromClauseArgument + ) -> SelfSelect: r"""Return a new :func:`_expression.select` construct with the given FROM expression(s) merged into its list of FROM objects. @@ -5067,7 +5803,10 @@ class Select( return self @_generative - def correlate(self: SelfSelect, *fromclauses) -> SelfSelect: + def correlate( + self: SelfSelect, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfSelect: r"""Return a new :class:`_expression.Select` which will correlate the given FROM clauses to that of an enclosing :class:`_expression.Select`. @@ -5106,10 +5845,10 @@ class Select( none of its FROM entries, and all will render unconditionally in the local FROM clause. - :param \*fromclauses: a list of one or more - :class:`_expression.FromClause` - constructs, or other compatible constructs (i.e. ORM-mapped - classes) to become part of the correlate collection. + :param \*fromclauses: one or more :class:`.FromClause` or other + FROM-compatible construct such as an ORM mapped entity to become part + of the correlate collection; alternatively pass a single value + ``None`` to remove all existing correlations. .. seealso:: @@ -5119,8 +5858,16 @@ class Select( """ + # tests failing when we try to change how these + # arguments are passed + self._auto_correlate = False - if fromclauses and fromclauses[0] in {None, False}: + if not fromclauses or fromclauses[0] in {None, False}: + if len(fromclauses) > 1: + raise exc.ArgumentError( + "additional FROM objects not accepted when " + "passing None/False to correlate()" + ) self._correlate = () else: self._correlate = self._correlate + tuple( @@ -5129,7 +5876,10 @@ class Select( return self @_generative - def correlate_except(self: SelfSelect, *fromclauses) -> SelfSelect: + def correlate_except( + self: SelfSelect, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfSelect: r"""Return a new :class:`_expression.Select` which will omit the given FROM clauses from the auto-correlation process. @@ -5141,9 +5891,9 @@ class Select( all other FROM elements remain subject to normal auto-correlation behaviors. - If ``None`` is passed, the :class:`_expression.Select` - object will correlate - all of its FROM entries. + If ``None`` is passed, or no arguments are passed, + the :class:`_expression.Select` object will correlate all of its + FROM entries. :param \*fromclauses: a list of one or more :class:`_expression.FromClause` @@ -5159,16 +5909,22 @@ class Select( """ self._auto_correlate = False - if fromclauses and fromclauses[0] in {None, False}: + if not fromclauses or fromclauses[0] in {None, False}: + if len(fromclauses) > 1: + raise exc.ArgumentError( + "additional FROM objects not accepted when " + "passing None/False to correlate_except()" + ) self._correlate_except = () else: self._correlate_except = (self._correlate_except or ()) + tuple( coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) + return self - @HasMemoized.memoized_attribute - def selected_columns(self): + @HasMemoized_ro_memoized_attribute + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -5214,18 +5970,22 @@ class Select( # generates the actual names used in the SELECT string. that # method is more complex because it also renders columns that are # fully ambiguous, e.g. same column more than once. - conv = SelectState._column_naming_convention(self._label_style) + conv = cast( + "Callable[[Any], str]", + SelectState._column_naming_convention(self._label_style), + ) - return ColumnCollection( + cc: ColumnCollection[str, ColumnElement[Any]] = ColumnCollection( [ (conv(c), c) for c in self._all_selected_columns if is_column_element(c) ] - ).as_readonly() + ) + return cc.as_readonly() @HasMemoized.memoized_attribute - def _all_selected_columns(self) -> Sequence[ColumnElement[Any]]: + def _all_selected_columns(self) -> _SelectIterable: meth = SelectState.get_plugin_class(self).all_selected_columns return list(meth(self)) @@ -5234,173 +5994,9 @@ class Select( self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY) return self - def _generate_columns_plus_names( - self, anon_for_dupe_key: bool - ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: - """Generate column names as rendered in a SELECT statement by - the compiler. - - This is distinct from the _column_naming_convention generator that's - intended for population of .c collections and similar, which has - different rules. the collection returned here calls upon the - _column_naming_convention as well. - - """ - cols = self._all_selected_columns - - key_naming_convention = SelectState._column_naming_convention( - self._label_style - ) - - names = {} - - result = [] - result_append = result.append - - table_qualified = self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL - label_style_none = self._label_style is LABEL_STYLE_NONE - - # a counter used for "dedupe" labels, which have double underscores - # in them and are never referred by name; they only act - # as positional placeholders. they need only be unique within - # the single columns clause they're rendered within (required by - # some dbs such as mysql). So their anon identity is tracked against - # a fixed counter rather than hash() identity. - dedupe_hash = 1 - - for c in cols: - repeated = False - - if not c._render_label_in_columns_clause: - effective_name = ( - required_label_name - ) = fallback_label_name = None - elif label_style_none: - effective_name = required_label_name = None - fallback_label_name = c._non_anon_label or c._anon_name_label - else: - if table_qualified: - required_label_name = ( - effective_name - ) = fallback_label_name = c._tq_label - else: - effective_name = fallback_label_name = c._non_anon_label - required_label_name = None - - if effective_name is None: - # it seems like this could be _proxy_key and we would - # not need _expression_label but it isn't - # giving us a clue when to use anon_label instead - expr_label = c._expression_label - if expr_label is None: - repeated = c._anon_name_label in names - names[c._anon_name_label] = c - effective_name = required_label_name = None - - if repeated: - # here, "required_label_name" is sent as - # "None" and "fallback_label_name" is sent. - if table_qualified: - fallback_label_name = ( - c._dedupe_anon_tq_label_idx(dedupe_hash) - ) - dedupe_hash += 1 - else: - fallback_label_name = c._dedupe_anon_label_idx( - dedupe_hash - ) - dedupe_hash += 1 - else: - fallback_label_name = c._anon_name_label - else: - required_label_name = ( - effective_name - ) = fallback_label_name = expr_label - - if effective_name is not None: - if effective_name in names: - # when looking to see if names[name] is the same column as - # c, use hash(), so that an annotated version of the column - # is seen as the same as the non-annotated - if hash(names[effective_name]) != hash(c): - - # different column under the same name. apply - # disambiguating label - if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._anon_tq_label - else: - required_label_name = ( - fallback_label_name - ) = c._anon_name_label - - if anon_for_dupe_key and required_label_name in names: - # here, c._anon_tq_label is definitely unique to - # that column identity (or annotated version), so - # this should always be true. - # this is also an infrequent codepath because - # you need two levels of duplication to be here - assert hash(names[required_label_name]) == hash(c) - - # the column under the disambiguating label is - # already present. apply the "dedupe" label to - # subsequent occurrences of the column so that the - # original stays non-ambiguous - if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_tq_label_idx(dedupe_hash) - dedupe_hash += 1 - else: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_label_idx(dedupe_hash) - dedupe_hash += 1 - repeated = True - else: - names[required_label_name] = c - elif anon_for_dupe_key: - # same column under the same name. apply the "dedupe" - # label so that the original stays non-ambiguous - if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_tq_label_idx(dedupe_hash) - dedupe_hash += 1 - else: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_label_idx(dedupe_hash) - dedupe_hash += 1 - repeated = True - else: - names[effective_name] = c - - result_append( - ( - # string label name, if non-None, must be rendered as a - # label, i.e. "AS " - required_label_name, - # proxy_key that is to be part of the result map for this - # col. this is also the key in a fromclause.c or - # select.selected_columns collection - key_naming_convention(c), - # name that can be used to render an "AS " when - # we have to render a label even though - # required_label_name was not given - fallback_label_name, - # the ColumnElement itself - c, - # True if this is a duplicate of a previous column - # in the list of columns - repeated, - ) - ) - - return result - - def _generate_fromclause_column_proxies(self, subquery): + def _generate_fromclause_column_proxies( + self, subquery: FromClause + ) -> None: """Generate column proxies to place in the exported ``.c`` collection of a subquery.""" @@ -5418,7 +6014,7 @@ class Select( c, repeated, ) in (self._generate_columns_plus_names(False)) - if not c._is_text_clause + if is_column_element(c) ] subquery._columns._populate_separate_keys(prox) @@ -5428,7 +6024,10 @@ class Select( self._order_by_clause.clauses ) - def self_group(self, against=None): + def self_group( + self: Self, against: Optional[OperatorType] = None + ) -> Union[SelectStatementGrouping, Self]: + ... """Return a 'grouping' construct as per the :class:`_expression.ClauseElement` specification. @@ -5445,7 +6044,9 @@ class Select( else: return SelectStatementGrouping(self) - def union(self, *other, **kwargs): + def union( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``UNION`` of this select() construct against the given selectables provided as positional arguments. @@ -5460,9 +6061,11 @@ class Select( for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_union(self, *other, **kwargs) + return CompoundSelect._create_union(self, *other) - def union_all(self, *other, **kwargs): + def union_all( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``UNION ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -5477,9 +6080,11 @@ class Select( for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_union_all(self, *other, **kwargs) + return CompoundSelect._create_union_all(self, *other) - def except_(self, *other, **kwargs): + def except_( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``EXCEPT`` of this select() construct against the given selectable provided as positional arguments. @@ -5490,13 +6095,12 @@ class Select( multiple elements are now accepted. - :param \**kwargs: keyword arguments are forwarded to the constructor - for the newly created :class:`_sql.CompoundSelect` object. - """ - return CompoundSelect._create_except(self, *other, **kwargs) + return CompoundSelect._create_except(self, *other) - def except_all(self, *other, **kwargs): + def except_all( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``EXCEPT ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -5507,13 +6111,12 @@ class Select( multiple elements are now accepted. - :param \**kwargs: keyword arguments are forwarded to the constructor - for the newly created :class:`_sql.CompoundSelect` object. - """ - return CompoundSelect._create_except_all(self, *other, **kwargs) + return CompoundSelect._create_except_all(self, *other) - def intersect(self, *other, **kwargs): + def intersect( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``INTERSECT`` of this select() construct against the given selectables provided as positional arguments. @@ -5528,9 +6131,11 @@ class Select( for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_intersect(self, *other, **kwargs) + return CompoundSelect._create_intersect(self, *other) - def intersect_all(self, *other, **kwargs): + def intersect_all( + self, *other: _SelectStatementForCompoundArgument + ) -> CompoundSelect: r"""Return a SQL ``INTERSECT ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -5545,13 +6150,17 @@ class Select( for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_intersect_all(self, *other, **kwargs) + return CompoundSelect._create_intersect_all(self, *other) -SelfScalarSelect = typing.TypeVar("SelfScalarSelect", bound="ScalarSelect") +SelfScalarSelect = typing.TypeVar( + "SelfScalarSelect", bound="ScalarSelect[Any]" +) -class ScalarSelect(roles.InElementRole, Generative, Grouping): +class ScalarSelect( + roles.InElementRole, Generative, GroupedElement, ColumnElement[_T] +): """Represent a scalar subquery. @@ -5570,15 +6179,33 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): """ - _from_objects = [] + _traverse_internals: _TraverseInternalsType = [ + ("element", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ] + + _from_objects: List[FromClause] = [] _is_from_container = True - _is_implicitly_boolean = False + if not TYPE_CHECKING: + _is_implicitly_boolean = False inherit_cache = True + element: SelectBase + def __init__(self, element): self.element = element self.type = element._scalar_type() + def __getattr__(self, attr): + return getattr(self.element, attr) + + def __getstate__(self): + return {"element": self.element, "type": self.type} + + def __setstate__(self, state): + self.element = state["element"] + self.type = state["type"] + @property def columns(self): raise exc.InvalidRequestError( @@ -5590,19 +6217,39 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): c = columns @_generative - def where(self: SelfScalarSelect, crit) -> SelfScalarSelect: + def where( + self: SelfScalarSelect, crit: _ColumnExpressionArgument[bool] + ) -> SelfScalarSelect: """Apply a WHERE clause to the SELECT statement referred to by this :class:`_expression.ScalarSelect`. """ - self.element = self.element.where(crit) + self.element = cast(Select, self.element).where(crit) return self - def self_group(self, **kwargs): + @overload + def self_group( + self: ScalarSelect[Any], against: Optional[OperatorType] = None + ) -> ScalarSelect[Any]: + ... + + @overload + def self_group( + self: ColumnElement[Any], against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: + ... + + def self_group( + self, against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: + return self @_generative - def correlate(self: SelfScalarSelect, *fromclauses) -> SelfScalarSelect: + def correlate( + self: SelfScalarSelect, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfScalarSelect: r"""Return a new :class:`_expression.ScalarSelect` which will correlate the given FROM clauses to that of an enclosing :class:`_expression.Select`. @@ -5631,12 +6278,13 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): """ - self.element = self.element.correlate(*fromclauses) + self.element = cast(Select, self.element).correlate(*fromclauses) return self @_generative def correlate_except( - self: SelfScalarSelect, *fromclauses + self: SelfScalarSelect, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], ) -> SelfScalarSelect: r"""Return a new :class:`_expression.ScalarSelect` which will omit the given FROM @@ -5668,11 +6316,16 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping): """ - self.element = self.element.correlate_except(*fromclauses) + self.element = cast(Select, self.element).correlate_except( + *fromclauses + ) return self -class Exists(UnaryExpression[_T]): +SelfExists = TypeVar("SelfExists", bound="Exists") + + +class Exists(UnaryExpression[bool]): """Represent an ``EXISTS`` clause. See :func:`_sql.exists` for a description of usage. @@ -5682,10 +6335,14 @@ class Exists(UnaryExpression[_T]): """ - _from_objects = () inherit_cache = True - def __init__(self, __argument=None): + def __init__( + self, + __argument: Optional[ + Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + ] = None, + ): if __argument is None: s = Select(literal_column("*")).scalar_subquery() elif isinstance(__argument, (SelectBase, ScalarSelect)): @@ -5701,12 +6358,16 @@ class Exists(UnaryExpression[_T]): wraps_column_expression=True, ) + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + return [] + def _regroup(self, fn): element = self.element._ungroup() element = fn(element) return element.self_group(against=operators.exists) - def select(self) -> "Select": + def select(self) -> Select: r"""Return a SELECT of this :class:`_expression.Exists`. e.g.:: @@ -5726,7 +6387,10 @@ class Exists(UnaryExpression[_T]): return Select(self) - def correlate(self, *fromclause): + def correlate( + self: SelfExists, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfExists: """Apply correlation to the subquery noted by this :class:`_sql.Exists`. .. seealso:: @@ -5736,11 +6400,14 @@ class Exists(UnaryExpression[_T]): """ e = self._clone() e.element = self._regroup( - lambda element: element.correlate(*fromclause) + lambda element: element.correlate(*fromclauses) ) return e - def correlate_except(self, *fromclause): + def correlate_except( + self: SelfExists, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfExists: """Apply correlation to the subquery noted by this :class:`_sql.Exists`. .. seealso:: @@ -5751,11 +6418,11 @@ class Exists(UnaryExpression[_T]): e = self._clone() e.element = self._regroup( - lambda element: element.correlate_except(*fromclause) + lambda element: element.correlate_except(*fromclauses) ) return e - def select_from(self, *froms): + def select_from(self: SelfExists, *froms: FromClause) -> SelfExists: """Return a new :class:`_expression.Exists` construct, applying the given expression to the :meth:`_expression.Select.select_from` @@ -5772,7 +6439,9 @@ class Exists(UnaryExpression[_T]): e.element = self._regroup(lambda element: element.select_from(*froms)) return e - def where(self, *clause): + def where( + self: SelfExists, *clause: _ColumnExpressionArgument[bool] + ) -> SelfExists: """Return a new :func:`_expression.exists` construct with the given expression added to its WHERE clause, joined to the existing clause via AND, if any. @@ -5824,7 +6493,7 @@ class TextualSelect(SelectBase): _label_style = LABEL_STYLE_NONE - _traverse_internals = [ + _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement), ("column_args", InternalTraversal.dp_clauseelement_list), ] + SupportsCloneAnnotations._clone_annotations_traverse_internals @@ -5842,8 +6511,8 @@ class TextualSelect(SelectBase): ] self.positional = positional - @HasMemoized.memoized_attribute - def selected_columns(self): + @HasMemoized_ro_memoized_attribute + def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -5868,6 +6537,13 @@ class TextualSelect(SelectBase): (c.key, c) for c in self.column_args ).as_readonly() + # def _generate_columns_plus_names( + # self, anon_for_dupe_key: bool + # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: + # return Select._generate_columns_plus_names( + # self, anon_for_dupe_key=anon_for_dupe_key + # ) + @util.non_memoized_property def _all_selected_columns(self) -> _SelectIterable: return self.column_args @@ -5880,7 +6556,9 @@ class TextualSelect(SelectBase): @_generative def bindparams( - self: SelfTextualSelect, *binds, **bind_as_values + self: SelfTextualSelect, + *binds: BindParameter[Any], + **bind_as_values: Any, ) -> SelfTextualSelect: self.element = self.element.bindparams(*binds, **bind_as_values) return self diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 1f3d508769..c3653c2647 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -228,7 +228,7 @@ class HasCopyInternals(HasTraverseInternals): raise NotImplementedError() def _copy_internals( - self, omit_attrs: Iterable[str] = (), **kw: Any + self, *, omit_attrs: Iterable[str] = (), **kw: Any ) -> None: """Reassign internal elements to be clones of themselves. diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 9a934a50bc..82adf4a4f9 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1008,10 +1008,7 @@ class TypeEngine(Visitable, Generic[_T]): @util.preload_module("sqlalchemy.engine.default") def _default_dialect(self) -> Dialect: - if TYPE_CHECKING: - from ..engine import default - else: - default = util.preloaded.engine_default + default = util.preloaded.engine_default # dmypy / mypy seems to sporadically keep thinking this line is # returning Any, which seems to be caused by the @deprecated_params diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index cdce49f7bc..80711c4b57 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -13,10 +13,20 @@ from __future__ import annotations from collections import deque from itertools import chain import typing +from typing import AbstractSet from typing import Any +from typing import Callable from typing import cast +from typing import Dict from typing import Iterator +from typing import List from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union from . import coercions from . import operators @@ -49,11 +59,22 @@ from .selectable import Join from .selectable import ScalarSelect from .selectable import SelectBase from .selectable import TableClause +from .visitors import _ET from .. import exc from .. import util +from ..util.typing import Literal +from ..util.typing import Protocol if typing.TYPE_CHECKING: + from ._typing import _ColumnExpressionArgument + from ._typing import _TypeEngineArgument from .roles import FromClauseRole + from .selectable import _JoinTargetElement + from .selectable import _OnClauseElement + from .selectable import Selectable + from .visitors import _TraverseCallableType + from .visitors import ExternallyTraversible + from .visitors import ExternalTraversal from ..engine.interfaces import _AnyExecuteParams from ..engine.interfaces import _AnyMultiExecuteParams from ..engine.interfaces import _AnySingleExecuteParams @@ -160,7 +181,11 @@ def find_left_clause_that_matches_given(clauses, join_from): return liberal_idx -def find_left_clause_to_join_from(clauses, join_to, onclause): +def find_left_clause_to_join_from( + clauses: Sequence[FromClause], + join_to: _JoinTargetElement, + onclause: Optional[ColumnElement[Any]], +) -> List[int]: """Given a list of FROM clauses, a selectable, and optional ON clause, return a list of integer indexes from the clauses list indicating the clauses that can be joined from. @@ -189,6 +214,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause): for i, f in enumerate(clauses): for s in selectables.difference([f]): if resolve_ambiguity: + assert cols_in_onclause is not None if set(f.c).union(s.c).issuperset(cols_in_onclause): idx.append(i) break @@ -207,7 +233,7 @@ def find_left_clause_to_join_from(clauses, join_to, onclause): # onclause was given and none of them resolved, so assume # all indexes can match if not idx and onclause is not None: - return range(len(clauses)) + return list(range(len(clauses))) else: return idx @@ -247,7 +273,7 @@ def visit_binary_product(fn, expr): a binary comparison is passed as pairs. """ - stack = [] + stack: List[ClauseElement] = [] def visit(element): if isinstance(element, ScalarSelect): @@ -272,21 +298,22 @@ def visit_binary_product(fn, expr): yield e list(visit(expr)) - visit = None # remove gc cycles + visit = None # type: ignore # remove gc cycles def find_tables( - clause, - check_columns=False, - include_aliases=False, - include_joins=False, - include_selects=False, - include_crud=False, -): + clause: ClauseElement, + *, + check_columns: bool = False, + include_aliases: bool = False, + include_joins: bool = False, + include_selects: bool = False, + include_crud: bool = False, +) -> List[TableClause]: """locate Table objects within the given expression.""" - tables = [] - _visitors = {} + tables: List[TableClause] = [] + _visitors: Dict[str, _TraverseCallableType[Any]] = {} if include_selects: _visitors["select"] = _visitors["compound_select"] = tables.append @@ -335,7 +362,7 @@ def unwrap_order_by(clause): t = stack.popleft() if isinstance(t, ColumnElement) and ( not isinstance(t, UnaryExpression) - or not operators.is_ordering_modifier(t.modifier) + or not operators.is_ordering_modifier(t.modifier) # type: ignore ): if isinstance(t, Label) and not isinstance( t.element, ScalarSelect @@ -365,9 +392,14 @@ def unwrap_order_by(clause): def unwrap_label_reference(element): - def replace(elem): - if isinstance(elem, (_label_reference, _textual_label_reference)): - return elem.element + def replace( + element: ExternallyTraversible, **kw: Any + ) -> Optional[ExternallyTraversible]: + if isinstance(element, _label_reference): + return element.element + elif isinstance(element, _textual_label_reference): + assert False, "can't unwrap a textual label reference" + return None return visitors.replacement_traverse(element, {}, replace) @@ -407,7 +439,7 @@ def clause_is_present(clause, search): return False -def tables_from_leftmost(clause: FromClauseRole) -> Iterator[FromClause]: +def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]: if isinstance(clause, Join): for t in tables_from_leftmost(clause.left): yield t @@ -509,6 +541,8 @@ class _repr_base: __slots__ = ("max_chars",) + max_chars: int + def trunc(self, value: Any) -> str: rep = repr(value) lenrep = len(rep) @@ -612,7 +646,7 @@ class _repr_params(_repr_base): def _repr_multi( self, multi_params: _AnyMultiExecuteParams, - typ, + typ: int, ) -> str: if multi_params: if isinstance(multi_params[0], list): @@ -639,7 +673,7 @@ class _repr_params(_repr_base): def _repr_params( self, - params: Optional[_AnySingleExecuteParams], + params: _AnySingleExecuteParams, typ: int, ) -> str: trunc = self.trunc @@ -653,9 +687,10 @@ class _repr_params(_repr_base): ) ) elif typ is self._TUPLE: + seq_params = cast("Sequence[Any]", params) return "(%s%s)" % ( - ", ".join(trunc(value) for value in params), - "," if len(params) == 1 else "", + ", ".join(trunc(value) for value in seq_params), + "," if len(seq_params) == 1 else "", ) else: return "[%s]" % (", ".join(trunc(value) for value in params)) @@ -688,11 +723,15 @@ def adapt_criterion_to_null(crit, nulls): return visitors.cloned_traverse(crit, {}, {"binary": visit_binary}) -def splice_joins(left, right, stop_on=None): +def splice_joins( + left: Optional[FromClause], + right: Optional[FromClause], + stop_on: Optional[FromClause] = None, +) -> Optional[FromClause]: if left is None: return right - stack = [(right, None)] + stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)] adapter = ClauseAdapter(left) ret = None @@ -705,6 +744,7 @@ def splice_joins(left, right, stop_on=None): else: right = adapter.traverse(right) if prevright is not None: + assert right is not None prevright.left = right if ret is None: ret = right @@ -845,11 +885,14 @@ def criterion_as_pairs( elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) - pairs = [] + pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = [] visitors.traverse(expression, {}, {"binary": visit_binary}) return pairs +_CE = TypeVar("_CE", bound="ClauseElement") + + class ClauseAdapter(visitors.ReplacingExternalTraversal): """Clones and modifies clauses based on column correspondence. @@ -879,13 +922,15 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): def __init__( self, - selectable, - equivalents=None, - include_fn=None, - exclude_fn=None, - adapt_on_names=False, - anonymize_labels=False, - adapt_from_selectables=None, + selectable: Selectable, + equivalents: Optional[ + Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] + ] = None, + include_fn: Optional[Callable[[ClauseElement], bool]] = None, + exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, + adapt_on_names: bool = False, + anonymize_labels: bool = False, + adapt_from_selectables: Optional[AbstractSet[FromClause]] = None, ): self.__traverse_options__ = { "stop_on": [selectable], @@ -898,6 +943,29 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): self.adapt_on_names = adapt_on_names self.adapt_from_selectables = adapt_from_selectables + if TYPE_CHECKING: + + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + # note this specializes the ReplacingExternalTraversal.traverse() + # method to state + # that we will return the same kind of ExternalTraversal object as + # we were given. This is probably not 100% true, such as it's + # possible for us to swap out Alias for Table at the top level. + # Ideally there could be overloads specific to ColumnElement and + # FromClause but Mypy is not accepting those as compatible with + # the base ReplacingExternalTraversal + @overload + def traverse(self, obj: _ET) -> _ET: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: + ... + def _corresponding_column( self, col, require_embedded, _seen=util.EMPTY_SET ): @@ -919,9 +987,13 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): return newcol @util.preload_module("sqlalchemy.sql.functions") - def replace(self, col, _include_singleton_constants=False): + def replace( + self, col: _ET, _include_singleton_constants: bool = False + ) -> Optional[_ET]: functions = util.preloaded.sql_functions + # TODO: cython candidate + if isinstance(col, FromClause) and not isinstance( col, functions.FunctionElement ): @@ -933,7 +1005,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): break else: return None - return self.selectable + return self.selectable # type: ignore elif isinstance(col, Alias) and isinstance( col.element, TableClause ): @@ -944,7 +1016,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): # we are an alias of a table and we are not derived from an # alias of a table (which nonetheless may be the same table # as ours) so, same thing - return col + return col # type: ignore else: # other cases where we are a selectable and the element # is another join or selectable that contains a table which our @@ -972,12 +1044,22 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): else: return None + if TYPE_CHECKING: + assert isinstance(col, ColumnElement) + if self.include_fn and not self.include_fn(col): return None elif self.exclude_fn and self.exclude_fn(col): return None else: - return self._corresponding_column(col, True) + return self._corresponding_column(col, True) # type: ignore + + +class _ColumnLookup(Protocol): + def __getitem__( + self, key: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: + ... class ColumnAdapter(ClauseAdapter): @@ -1011,17 +1093,21 @@ class ColumnAdapter(ClauseAdapter): """ + columns: _ColumnLookup + def __init__( self, - selectable, - equivalents=None, - adapt_required=False, - include_fn=None, - exclude_fn=None, - adapt_on_names=False, - allow_label_resolve=True, - anonymize_labels=False, - adapt_from_selectables=None, + selectable: Selectable, + equivalents: Optional[ + Dict[ColumnElement[Any], AbstractSet[ColumnElement[Any]]] + ] = None, + adapt_required: bool = False, + include_fn: Optional[Callable[[ClauseElement], bool]] = None, + exclude_fn: Optional[Callable[[ClauseElement], bool]] = None, + adapt_on_names: bool = False, + allow_label_resolve: bool = True, + anonymize_labels: bool = False, + adapt_from_selectables: Optional[AbstractSet[FromClause]] = None, ): ClauseAdapter.__init__( self, @@ -1034,7 +1120,7 @@ class ColumnAdapter(ClauseAdapter): adapt_from_selectables=adapt_from_selectables, ) - self.columns = util.WeakPopulateDict(self._locate_col) + self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore if self.include_fn or self.exclude_fn: self.columns = self._IncludeExcludeMapping(self, self.columns) self.adapt_required = adapt_required @@ -1060,7 +1146,7 @@ class ColumnAdapter(ClauseAdapter): ac = self.__class__.__new__(self.__class__) ac.__dict__.update(self.__dict__) ac._wrap = adapter - ac.columns = util.WeakPopulateDict(ac._locate_col) + ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore if ac.include_fn or ac.exclude_fn: ac.columns = self._IncludeExcludeMapping(ac, ac.columns) @@ -1069,6 +1155,17 @@ class ColumnAdapter(ClauseAdapter): def traverse(self, obj): return self.columns[obj] + def chain(self, visitor: ExternalTraversal) -> ColumnAdapter: + assert isinstance(visitor, ColumnAdapter) + + return super().chain(visitor) + + if TYPE_CHECKING: + + @property + def visitor_iterator(self) -> Iterator[ColumnAdapter]: + ... + adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process @@ -1080,7 +1177,9 @@ class ColumnAdapter(ClauseAdapter): return newcol - def _locate_col(self, col): + def _locate_col( + self, col: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: # both replace and traverse() are overly complicated for what # we are doing here and we would do better to have an inlined # version that doesn't build up as much overhead. the issue is that @@ -1120,10 +1219,14 @@ class ColumnAdapter(ClauseAdapter): def __setstate__(self, state): self.__dict__.update(state) - self.columns = util.WeakPopulateDict(self._locate_col) + self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore -def _offset_or_limit_clause(element, name=None, type_=None): +def _offset_or_limit_clause( + element: Union[int, _ColumnExpressionArgument[int]], + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[int]] = None, +) -> ColumnElement[int]: """Convert the given value to an "offset or limit" clause. This handles incoming integers and converts to an expression; if @@ -1135,7 +1238,9 @@ def _offset_or_limit_clause(element, name=None, type_=None): ) -def _offset_or_limit_clause_asint_if_possible(clause): +def _offset_or_limit_clause_asint_if_possible( + clause: Optional[Union[int, _ColumnExpressionArgument[int]]] +) -> Optional[Union[int, _ColumnExpressionArgument[int]]]: """Return the offset or limit clause as a simple integer if possible, else return the clause. @@ -1143,18 +1248,27 @@ def _offset_or_limit_clause_asint_if_possible(clause): if clause is None: return None if hasattr(clause, "_limit_offset_value"): - value = clause._limit_offset_value + value = clause._limit_offset_value # type: ignore return util.asint(value) else: return clause -def _make_slice(limit_clause, offset_clause, start, stop): +def _make_slice( + limit_clause: Optional[Union[int, _ColumnExpressionArgument[int]]], + offset_clause: Optional[Union[int, _ColumnExpressionArgument[int]]], + start: int, + stop: int, +) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]: """Compute LIMIT/OFFSET in terms of slice start/end""" # for calculated limit/offset, try to do the addition of # values to offset in Python, however if a SQL clause is present # then the addition has to be on the SQL side. + + # TODO: typing is finding a few gaps in here, see if they can be + # closed up + if start is not None and stop is not None: offset_clause = _offset_or_limit_clause_asint_if_possible( offset_clause @@ -1163,11 +1277,12 @@ def _make_slice(limit_clause, offset_clause, start, stop): offset_clause = 0 if start != 0: - offset_clause = offset_clause + start + offset_clause = offset_clause + start # type: ignore if offset_clause == 0: offset_clause = None else: + assert offset_clause is not None offset_clause = _offset_or_limit_clause(offset_clause) limit_clause = _offset_or_limit_clause(stop - start) @@ -1182,11 +1297,13 @@ def _make_slice(limit_clause, offset_clause, start, stop): offset_clause = 0 if start != 0: - offset_clause = offset_clause + start + offset_clause = offset_clause + start # type: ignore if offset_clause == 0: offset_clause = None else: - offset_clause = _offset_or_limit_clause(offset_clause) + offset_clause = _offset_or_limit_clause( + offset_clause # type: ignore + ) - return limit_clause, offset_clause + return limit_clause, offset_clause # type: ignore diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 903aae6483..081faf1e9f 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -28,6 +28,7 @@ from typing import Iterator from typing import List from typing import Mapping from typing import Optional +from typing import overload from typing import Tuple from typing import Type from typing import TypeVar @@ -37,6 +38,7 @@ from .. import exc from .. import util from ..util import langhelpers from ..util._has_cy import HAS_CYEXTENSION +from ..util.typing import Literal from ..util.typing import Protocol from ..util.typing import Self @@ -599,8 +601,8 @@ class ExternallyTraversible(HasTraverseInternals, Visitable): raise NotImplementedError() def _copy_internals( - self: Self, omit_attrs: Tuple[str, ...] = (), **kw: Any - ) -> Self: + self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any + ) -> None: """Reassign internal elements to be clones of themselves. Called during a copy-and-traverse operation on newly @@ -615,10 +617,24 @@ class ExternallyTraversible(HasTraverseInternals, Visitable): _ET = TypeVar("_ET", bound=ExternallyTraversible) + + _TraverseCallableType = Callable[[_ET], None] -_TraverseTransformCallableType = Callable[ - [ExternallyTraversible], Optional[ExternallyTraversible] -] + + +class _CloneCallableType(Protocol): + def __call__(self, element: _ET, **kw: Any) -> _ET: + ... + + +class _TraverseTransformCallableType(Protocol): + def __call__( + self, element: ExternallyTraversible, **kw: Any + ) -> Optional[ExternallyTraversible]: + ... + + +_ExtT = TypeVar("_ExtT", bound="ExternalTraversal") class ExternalTraversal: @@ -640,7 +656,7 @@ class ExternalTraversal: return meth(obj, **kw) def iterate( - self, obj: ExternallyTraversible + self, obj: Optional[ExternallyTraversible] ) -> Iterator[ExternallyTraversible]: """Traverse the given expression structure, returning an iterator of all elements. @@ -648,7 +664,17 @@ class ExternalTraversal: """ return iterate(obj, self.__traverse_options__) + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + @overload def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: """Traverse and visit the given expression structure.""" return traverse(obj, self.__traverse_options__, self._visitor_dict) @@ -671,7 +697,7 @@ class ExternalTraversal: yield v v = getattr(v, "_next", None) - def chain(self, visitor: ExternalTraversal) -> ExternalTraversal: + def chain(self: _ExtT, visitor: ExternalTraversal) -> _ExtT: """'Chain' an additional ExternalTraversal onto this ExternalTraversal The chained visitor will receive all visit events after this one. @@ -701,7 +727,17 @@ class CloningExternalTraversal(ExternalTraversal): """ return [self.traverse(x) for x in list_] + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + @overload def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: """Traverse and visit the given expression structure.""" return cloned_traverse( @@ -729,14 +765,25 @@ class ReplacingExternalTraversal(CloningExternalTraversal): """ return None + @overload + def traverse(self, obj: Literal[None]) -> None: + ... + + @overload def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: + ... + + def traverse( + self, obj: Optional[ExternallyTraversible] + ) -> Optional[ExternallyTraversible]: """Traverse and visit the given expression structure.""" def replace( - elem: ExternallyTraversible, + element: ExternallyTraversible, + **kw: Any, ) -> Optional[ExternallyTraversible]: for v in self.visitor_iterator: - e = cast(ReplacingExternalTraversal, v).replace(elem) + e = cast(ReplacingExternalTraversal, v).replace(element) if e is not None: return e @@ -754,7 +801,8 @@ ReplacingCloningVisitor = ReplacingExternalTraversal def iterate( - obj: ExternallyTraversible, opts: Mapping[str, Any] = util.EMPTY_DICT + obj: Optional[ExternallyTraversible], + opts: Mapping[str, Any] = util.EMPTY_DICT, ) -> Iterator[ExternallyTraversible]: r"""Traverse the given expression structure, returning an iterator. @@ -776,6 +824,9 @@ def iterate( empty in modern usage. """ + if obj is None: + return + yield obj children = obj.get_children(**opts) @@ -790,11 +841,29 @@ def iterate( stack.append(t.get_children(**opts)) +@overload +def traverse_using( + iterator: Iterable[ExternallyTraversible], + obj: Literal[None], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> None: + ... + + +@overload def traverse_using( iterator: Iterable[ExternallyTraversible], obj: ExternallyTraversible, visitors: Mapping[str, _TraverseCallableType[Any]], ) -> ExternallyTraversible: + ... + + +def traverse_using( + iterator: Iterable[ExternallyTraversible], + obj: Optional[ExternallyTraversible], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> Optional[ExternallyTraversible]: """Visit the given expression structure using the given iterator of objects. @@ -826,11 +895,29 @@ def traverse_using( return obj +@overload +def traverse( + obj: Literal[None], + opts: Mapping[str, Any], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> None: + ... + + +@overload def traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], ) -> ExternallyTraversible: + ... + + +def traverse( + obj: Optional[ExternallyTraversible], + opts: Mapping[str, Any], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> Optional[ExternallyTraversible]: """Traverse and visit the given expression structure using the default iterator. @@ -863,11 +950,29 @@ def traverse( return traverse_using(iterate(obj, opts), obj, visitors) +@overload +def cloned_traverse( + obj: Literal[None], + opts: Mapping[str, Any], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> None: + ... + + +@overload def cloned_traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], ) -> ExternallyTraversible: + ... + + +def cloned_traverse( + obj: Optional[ExternallyTraversible], + opts: Mapping[str, Any], + visitors: Mapping[str, _TraverseCallableType[Any]], +) -> Optional[ExternallyTraversible]: """Clone the given expression structure, allowing modifications by visitors. @@ -931,11 +1036,29 @@ def cloned_traverse( return obj +@overload +def replacement_traverse( + obj: Literal[None], + opts: Mapping[str, Any], + replace: _TraverseTransformCallableType, +) -> None: + ... + + +@overload def replacement_traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], replace: _TraverseTransformCallableType, ) -> ExternallyTraversible: + ... + + +def replacement_traverse( + obj: Optional[ExternallyTraversible], + opts: Mapping[str, Any], + replace: _TraverseTransformCallableType, +) -> Optional[ExternallyTraversible]: """Clone the given expression structure, allowing element replacement by a given replacement function. diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 4496b8dede..f1bf5c0c4e 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -306,6 +306,7 @@ def testing_engine( options=None, asyncio=False, transfer_staticpool=False, + share_pool=False, _sqlite_savepoint=False, ): if asyncio: @@ -356,6 +357,8 @@ def testing_engine( if config.db is not None and isinstance(config.db.pool, StaticPool): use_reaper = False engine.pool._transfer_from(config.db.pool) + elif share_pool: + engine.pool = config.db.pool if scope == "global": if asyncio: diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 6c6b21fcec..8c0120bcc4 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -121,6 +121,7 @@ class TestBase: future=None, asyncio=False, transfer_staticpool=False, + share_pool=False, ): if options is None: options = {} @@ -130,6 +131,7 @@ class TestBase: options=options, asyncio=asyncio, transfer_staticpool=transfer_staticpool, + share_pool=share_pool, ) yield gen_testing_engine diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 406c8af248..c0c2e7dfb7 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -9,7 +9,9 @@ from collections import defaultdict as defaultdict from functools import partial as partial from functools import update_wrapper as update_wrapper +from typing import TYPE_CHECKING +from . import preloaded as preloaded from ._collections import coerce_generator_arg as coerce_generator_arg from ._collections import coerce_to_immutabledict as coerce_to_immutabledict from ._collections import column_dict as column_dict @@ -44,8 +46,6 @@ from ._collections import UniqueAppender as UniqueAppender from ._collections import update_copy as update_copy from ._collections import WeakPopulateDict as WeakPopulateDict from ._collections import WeakSequence as WeakSequence -from ._preloaded import preload_module as preload_module -from ._preloaded import preloaded as preloaded from .compat import arm as arm from .compat import b as b from .compat import b64decode as b64decode @@ -148,3 +148,4 @@ from .langhelpers import warn as warn from .langhelpers import warn_exception as warn_exception from .langhelpers import warn_limited as warn_limited from .langhelpers import wrap_callable as wrap_callable +from .preloaded import preload_module as preload_module diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index bd73bf7140..bcb2ad4230 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -463,11 +463,12 @@ def update_copy(d, _new=None, **kw): return d -def flatten_iterator(x): +def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]: """Given an iterator of which further sub-elements may also be iterators, flatten the sub-elements into a single iterator. """ + elem: _T for elem in x: if not isinstance(elem, str) and hasattr(elem, "__iter__"): for y in flatten_iterator(elem): diff --git a/lib/sqlalchemy/util/_preloaded.py b/lib/sqlalchemy/util/preloaded.py similarity index 85% rename from lib/sqlalchemy/util/_preloaded.py rename to lib/sqlalchemy/util/preloaded.py index 511b93351d..c861c83b3f 100644 --- a/lib/sqlalchemy/util/_preloaded.py +++ b/lib/sqlalchemy/util/preloaded.py @@ -16,10 +16,16 @@ from types import ModuleType import typing from typing import Any from typing import Callable +from typing import TYPE_CHECKING from typing import TypeVar _FN = TypeVar("_FN", bound=Callable[..., Any]) +if TYPE_CHECKING: + from sqlalchemy.engine import default as engine_default + from sqlalchemy.sql import dml as sql_dml + from sqlalchemy.sql import util as sql_util + class _ModuleRegistry: """Registry of modules to load in a package init file. @@ -67,7 +73,7 @@ class _ModuleRegistry: not path or module.startswith(path) ) and key not in self.__dict__: __import__(module, globals(), locals()) - self.__dict__[key] = sys.modules[module] + self.__dict__[key] = globals()[key] = sys.modules[module] if typing.TYPE_CHECKING: @@ -75,5 +81,11 @@ class _ModuleRegistry: ... -preloaded = _ModuleRegistry() -preload_module = preloaded.preload_module +_reg = _ModuleRegistry() +preload_module = _reg.preload_module +import_prefix = _reg.import_prefix + +if TYPE_CHECKING: + + def __getattr__(key: str) -> ModuleType: + ... diff --git a/pyproject.toml b/pyproject.toml index cc79e86469..012f1bffa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,9 +68,6 @@ module = [ "sqlalchemy.ext.mutable", "sqlalchemy.ext.horizontal_shard", - "sqlalchemy.sql._selectable_constructors", - "sqlalchemy.sql._dml_constructors", - # TODO for non-strict: "sqlalchemy.ext.baked", "sqlalchemy.ext.instrumentation", @@ -78,11 +75,6 @@ module = [ "sqlalchemy.ext.orderinglist", "sqlalchemy.ext.serializer", - "sqlalchemy.sql.selectable", # would be nice as strict - "sqlalchemy.sql.functions", # would be nice as strict - "sqlalchemy.sql.lambdas", - "sqlalchemy.sql.util", - # not yet classified: "sqlalchemy.orm.*", "sqlalchemy.dialects.*", @@ -132,10 +124,14 @@ module = [ "sqlalchemy.sql.crud", "sqlalchemy.sql.ddl", # would be nice as strict "sqlalchemy.sql.elements", # would be nice as strict + "sqlalchemy.sql.functions", # would be nice as strict, requires sqltypes + "sqlalchemy.sql.lambdas", "sqlalchemy.sql.naming", + "sqlalchemy.sql.selectable", # would be nice as strict "sqlalchemy.sql.schema", # would be nice as strict "sqlalchemy.sql.sqltypes", # would be nice as strict "sqlalchemy.sql.traversals", + "sqlalchemy.sql.util", "sqlalchemy.util.*", ] diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 06be9fbd8e..e03a8415d0 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -835,27 +835,14 @@ class JoinedEagerLoadTest(NoCache, fixtures.MappedTest): ) s.commit() - def test_build_query(self): + def test_fetch_results_integrated(self, testing_engine): A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") - sess = fixture_session() + # this test has been reworked to use the compiled cache again, + # as a real-world scenario. - @profiling.function_call_count() - def go(): - for i in range(100): - q = sess.query(A).options( - joinedload(A.bs).joinedload(B.cs).joinedload(C.ds), - joinedload(A.es).joinedload(E.fs), - defaultload(A.es).joinedload(E.gs), - ) - q._compile_context() - - go() - - def test_fetch_results(self): - A, B, C, D, E, F, G = self.classes("A", "B", "C", "D", "E", "F", "G") - - sess = Session(testing.db) + eng = testing_engine(share_pool=True) + sess = Session(eng) q = sess.query(A).options( joinedload(A.bs).joinedload(B.cs).joinedload(C.ds), @@ -863,47 +850,27 @@ class JoinedEagerLoadTest(NoCache, fixtures.MappedTest): defaultload(A.es).joinedload(E.gs), ) - compile_state = q._compile_state() + @profiling.function_call_count() + def initial_run(): + list(q.all()) - from sqlalchemy.orm.context import ORMCompileState + initial_run() + sess.close() - @profiling.function_call_count(warmup=1) - def go(): - for i in range(100): - # NOTE: this test was broken in - # 77f1b7d236dba6b1c859bb428ef32d118ec372e6 because we started - # clearing out the attributes after the first iteration. make - # sure the attributes are there every time. - assert compile_state.attributes - exec_opts = {} - bind_arguments = {} - ORMCompileState.orm_pre_session_exec( - sess, - compile_state.select_statement, - {}, - exec_opts, - bind_arguments, - is_reentrant_invoke=False, - ) + @profiling.function_call_count() + def subsequent_run(): + list(q.all()) - r = sess.connection().execute( - compile_state.statement, - execution_options=exec_opts, - ) + subsequent_run() + sess.close() - r.context.compiled.compile_state = compile_state - obj = ORMCompileState.orm_setup_cursor_result( - sess, - compile_state.statement, - {}, - exec_opts, - {}, - r, - ) - list(obj.unique()) - sess.close() + @profiling.function_call_count() + def more_runs(): + for i in range(100): + list(q.all()) - go() + more_runs() + sess.close() class JoinConditionTest(NoCache, fixtures.DeclarativeMappedTest): diff --git a/test/base/test_utils.py b/test/base/test_utils.py index fc61e39b65..e22340da68 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -25,11 +25,11 @@ from sqlalchemy.testing import mock from sqlalchemy.testing import ne_ from sqlalchemy.testing.util import gc_collect from sqlalchemy.testing.util import picklers -from sqlalchemy.util import _preloaded from sqlalchemy.util import classproperty from sqlalchemy.util import compat from sqlalchemy.util import get_callable_argspec from sqlalchemy.util import langhelpers +from sqlalchemy.util import preloaded from sqlalchemy.util import WeakSequence from sqlalchemy.util._collections import merge_lists_w_ordering @@ -3187,7 +3187,7 @@ class TestModuleRegistry(fixtures.TestBase): for m in ("xml.dom", "wsgiref.simple_server"): to_restore.append((m, sys.modules.pop(m, None))) try: - mr = _preloaded._ModuleRegistry() + mr = preloaded._ModuleRegistry() ret = mr.preload_module( "xml.dom", "wsgiref.simple_server", "sqlalchemy.sql.util" diff --git a/test/profiles.txt b/test/profiles.txt index 31f72bd166..7b4f377349 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -196,16 +196,16 @@ test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3 test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 96844 test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 102344 -# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query - -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 520615 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 522475 # TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 440705 test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 458805 +# TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated + +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 30373,1014,96450 + # TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 22984 diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index dd073d2a59..8d6dc75534 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -208,6 +208,8 @@ class CoreFixtures: column("q") == column("x"), column("q") == column("y"), column("z") == column("x"), + (column("z") == column("x")).self_group(), + (column("q") == column("x")).self_group(), column("z") + column("x"), column("z") - column("x"), column("x") - column("z"), diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index d20037e921..6ca06dc0e1 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -82,7 +82,6 @@ from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import CompilerColumnElement from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.sql.expression import ClauseList -from sqlalchemy.sql.expression import HasPrefixes from sqlalchemy.sql.selectable import LABEL_STYLE_NONE from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.testing import assert_raises @@ -270,18 +269,6 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "columns", ) - def test_prefix_constructor(self): - class Pref(HasPrefixes): - def _generate(self): - return self - - assert_raises( - exc.ArgumentError, - Pref().prefix_with, - "some prefix", - not_a_dialect=True, - ) - def test_table_select(self): self.assert_compile( table1.select(), diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 686d4928d6..d1d01a5c74 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1186,6 +1186,37 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): dialect="postgresql", ) + def test_recursive_dml_syntax(self): + orders = table( + "orders", + column("region"), + column("amount"), + column("product"), + column("quantity"), + ) + + upsert = ( + orders.update() + .where(orders.c.region == "Region1") + .values(amount=1.0, product="Product1", quantity=1) + .returning(*(orders.c._all_columns)) + .cte("upsert", recursive=True) + ) + stmt = select(upsert) + + # This statement probably makes no sense, just want to see that the + # column generation aspect needed by RECURSIVE works (new in 2.0) + self.assert_compile( + stmt, + "WITH RECURSIVE upsert(region, amount, product, quantity) " + "AS (UPDATE orders SET amount=:param_1, product=:param_2, " + "quantity=:param_3 WHERE orders.region = :region_1 " + "RETURNING orders.region, orders.amount, orders.product, " + "orders.quantity) " + "SELECT upsert.region, upsert.amount, upsert.product, " + "upsert.quantity FROM upsert", + ) + def test_upsert_from_select(self): orders = table( "orders", diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index ca5f43bb6e..9fdc519389 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -432,6 +432,24 @@ class SelectableTest( ): select(stmt.subquery()).compile() + def test_correlate_none_arg_error(self): + stmt = select(table1) + with expect_raises_message( + exc.ArgumentError, + "additional FROM objects not accepted when passing " + "None/False to correlate", + ): + stmt.correlate(None, table2) + + def test_correlate_except_none_arg_error(self): + stmt = select(table1) + with expect_raises_message( + exc.ArgumentError, + "additional FROM objects not accepted when passing " + "None/False to correlate_except", + ): + stmt.correlate_except(None, table2) + def test_select_label_grouped_still_corresponds(self): label = select(table1.c.col1).label("foo") label2 = label.self_group() diff --git a/test/sql/test_text.py b/test/sql/test_text.py index 81b20f86f0..0f645a2d23 100644 --- a/test/sql/test_text.py +++ b/test/sql/test_text.py @@ -688,6 +688,19 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): mapping = self._mapping(s) assert x not in mapping + def test_subquery_accessors(self): + t = self._xy_table_fixture() + + s = text("SELECT x from t").columns(t.c.x) + + self.assert_compile( + select(s.scalar_subquery()), "SELECT (SELECT x from t) AS anon_1" + ) + self.assert_compile( + select(s.subquery()), + "SELECT anon_1.x FROM (SELECT x from t) AS anon_1", + ) + def test_select_label_alt_name_table_alias_column(self): t = self._xy_table_fixture() x = t.c.x @@ -716,6 +729,36 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): "FROM mytable, t WHERE mytable.myid = t.id", ) + def test_cte_recursive(self): + t = ( + text("select id, name from user") + .columns(id=Integer, name=String) + .cte("t", recursive=True) + ) + + s = select(table1).where(table1.c.myid == t.c.id) + self.assert_compile( + s, + "WITH RECURSIVE t(id, name) AS (select id, name from user) " + "SELECT mytable.myid, mytable.name, mytable.description " + "FROM mytable, t WHERE mytable.myid = t.id", + ) + + def test_unions(self): + s1 = text("select id, name from user where id > 5").columns( + id=Integer, name=String + ) + s2 = text("select id, name from user where id < 15").columns( + id=Integer, name=String + ) + stmt = union(s1, s2) + eq_(stmt.selected_columns.keys(), ["id", "name"]) + self.assert_compile( + stmt, + "select id, name from user where id > 5 UNION " + "select id, name from user where id < 15", + ) + def test_subquery(self): t = ( text("select id, name from user") diff --git a/test/sql/test_values.py b/test/sql/test_values.py index f5ae9ea53d..d14de9aeed 100644 --- a/test/sql/test_values.py +++ b/test/sql/test_values.py @@ -294,6 +294,31 @@ class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL): checkparams={}, ) + def test_anon_alias(self): + people = self.tables.people + values = ( + Values( + column("bookcase_id", Integer), + column("bookcase_owner_id", Integer), + ) + .data([(1, 1), (2, 1), (3, 2), (3, 3)]) + .alias() + ) + stmt = select(people, values).select_from( + people.join( + values, values.c.bookcase_owner_id == people.c.people_id + ) + ) + self.assert_compile( + stmt, + "SELECT people.people_id, people.age, people.name, " + "anon_1.bookcase_id, anon_1.bookcase_owner_id FROM people " + "JOIN (VALUES (:param_1, :param_2), (:param_3, :param_4), " + "(:param_5, :param_6), (:param_7, :param_8)) AS anon_1 " + "(bookcase_id, bookcase_owner_id) " + "ON people.people_id = anon_1.bookcase_owner_id", + ) + def test_with_join_unnamed(self): people = self.tables.people values = Values(