From ad11c482e2233f44e8747d4d5a2b17a995fff1fa Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 19 Apr 2022 21:06:41 -0400 Subject: [PATCH] pep484 ORM / SQL result support after some experimentation it seems mypy is more amenable to the generic types being fully integrated rather than having separate spin-off types. so key structures like Result, Row, Select become generic. For DML Insert, Update, Delete, these are spun into type-specific subclasses ReturningInsert, ReturningUpdate, ReturningDelete, which is fine since the "row-ness" of these constructs doesn't happen until returning() is called in any case. a Tuple based model is then integrated so that these objects can carry along information about their return types. Overloads at the .execute() level carry through the Tuple from the invoked object to the result. To suit the issue of AliasedClass generating attributes that are dynamic, experimented with a custom subclass AsAliased, but then just settled on having aliased() lie to the type checker and return `Type[_O]`, essentially. will need some type-related accessors for with_polymorphic() also. Additionally, identified an issue in Update when used "mysql style" against a join(), it basically doesn't work if asked to UPDATE two tables on the same column name. added an error message to the specific condition where it happens with a very non-specific error message that we hit a thing we can't do right now, suggest multi-table update as a possible cause. Change-Id: I5eff7eefe1d6166ee74160b2785c5e6a81fa8b95 --- lib/sqlalchemy/engine/__init__.py | 1 + lib/sqlalchemy/engine/base.py | 89 ++- lib/sqlalchemy/engine/cursor.py | 24 +- lib/sqlalchemy/engine/default.py | 14 +- lib/sqlalchemy/engine/events.py | 3 +- lib/sqlalchemy/engine/interfaces.py | 2 +- lib/sqlalchemy/engine/result.py | 422 ++++++++++--- lib/sqlalchemy/engine/row.py | 51 +- lib/sqlalchemy/ext/asyncio/engine.py | 132 ++++- lib/sqlalchemy/ext/asyncio/result.py | 339 ++++++++++- lib/sqlalchemy/ext/asyncio/scoping.py | 135 ++++- lib/sqlalchemy/ext/asyncio/session.py | 136 ++++- lib/sqlalchemy/ext/instrumentation.py | 3 + lib/sqlalchemy/orm/_orm_constructors.py | 30 +- lib/sqlalchemy/orm/attributes.py | 5 +- lib/sqlalchemy/orm/base.py | 16 +- lib/sqlalchemy/orm/context.py | 7 +- lib/sqlalchemy/orm/interfaces.py | 13 +- lib/sqlalchemy/orm/mapper.py | 13 +- lib/sqlalchemy/orm/properties.py | 5 +- lib/sqlalchemy/orm/query.py | 296 +++++++++- lib/sqlalchemy/orm/scoping.py | 219 ++++++- lib/sqlalchemy/orm/session.py | 211 ++++++- lib/sqlalchemy/orm/state.py | 6 +- lib/sqlalchemy/orm/util.py | 53 +- lib/sqlalchemy/sql/__init__.py | 1 - .../sql/_selectable_constructors.py | 166 +++++- lib/sqlalchemy/sql/_typing.py | 78 ++- lib/sqlalchemy/sql/base.py | 15 +- lib/sqlalchemy/sql/coercions.py | 25 +- lib/sqlalchemy/sql/compiler.py | 40 +- lib/sqlalchemy/sql/crud.py | 32 +- lib/sqlalchemy/sql/dml.py | 376 +++++++++++- lib/sqlalchemy/sql/elements.py | 36 +- lib/sqlalchemy/sql/functions.py | 8 +- lib/sqlalchemy/sql/roles.py | 57 +- lib/sqlalchemy/sql/schema.py | 9 +- lib/sqlalchemy/sql/selectable.py | 287 +++++++-- lib/sqlalchemy/sql/util.py | 9 +- lib/sqlalchemy/sql/visitors.py | 4 +- lib/sqlalchemy/util/langhelpers.py | 44 ++ lib/sqlalchemy/util/typing.py | 2 +- pyproject.toml | 1 - test/base/test_result.py | 7 + .../mypy/plain_files/association_proxy_one.py | 4 +- .../ext/mypy/plain_files/engine_inspection.py | 6 +- .../plain_files/experimental_relationship.py | 12 +- test/ext/mypy/plain_files/hybrid_one.py | 10 +- test/ext/mypy/plain_files/hybrid_two.py | 18 +- test/ext/mypy/plain_files/session.py | 22 +- test/ext/mypy/plain_files/sql_operations.py | 20 +- .../plain_files/trad_relationship_uselist.py | 26 +- .../plain_files/traditional_relationship.py | 18 +- test/ext/mypy/plain_files/typed_queries.py | 433 ++++++++++++++ test/ext/mypy/plain_files/typed_results.py | 554 ++++++++++++++++++ .../plugin_files/dataclasses_workaround.py | 4 +- test/ext/mypy/test_mypy_plugin_py3k.py | 37 +- test/orm/test_froms.py | 2 +- test/orm/test_query.py | 16 + test/sql/test_metadata.py | 15 + test/sql/test_resultset.py | 44 ++ test/sql/test_select.py | 2 +- test/sql/test_update.py | 41 ++ tools/generate_proxy_methods.py | 93 +-- tools/generate_tuple_map_overloads.py | 174 ++++++ 65 files changed, 4392 insertions(+), 581 deletions(-) create mode 100644 test/ext/mypy/plain_files/typed_queries.py create mode 100644 test/ext/mypy/plain_files/typed_results.py create mode 100644 tools/generate_tuple_map_overloads.py diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 29dd6aff90..afba170759 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -46,6 +46,7 @@ from .result import MergedResult as MergedResult from .result import Result as Result from .result import result_tuple as result_tuple from .result import ScalarResult as ScalarResult +from .result import TupleResult as TupleResult from .row import BaseRow as BaseRow from .row import Row as Row from .row import RowMapping as RowMapping diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index a325da929b..fe3bfa1adf 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -18,8 +18,10 @@ from typing import Mapping from typing import MutableMapping from typing import NoReturn from typing import Optional +from typing import overload from typing import Tuple from typing import Type +from typing import TypeVar from typing import Union from .interfaces import _IsolationLevel @@ -45,12 +47,10 @@ if typing.TYPE_CHECKING: from . import ScalarResult from .interfaces import _AnyExecuteParams from .interfaces import _AnyMultiExecuteParams - from .interfaces import _AnySingleExecuteParams from .interfaces import _CoreAnyExecuteParams from .interfaces import _CoreMultiExecuteParams from .interfaces import _CoreSingleExecuteParams from .interfaces import _DBAPIAnyExecuteParams - from .interfaces import _DBAPIMultiExecuteParams from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions from .interfaces import _ExecuteOptionsParameter @@ -65,21 +65,21 @@ if typing.TYPE_CHECKING: from ..pool import PoolProxiedConnection from ..sql import Executable from ..sql._typing import _InfoType - from ..sql.base import SchemaVisitor from ..sql.compiler import Compiled from ..sql.ddl import ExecutableDDLElement from ..sql.ddl import SchemaDropper from ..sql.ddl import SchemaGenerator from ..sql.functions import FunctionElement - from ..sql.schema import ColumnDefault from ..sql.schema import DefaultGenerator from ..sql.schema import HasSchemaAttr from ..sql.schema import SchemaItem + from ..sql.selectable import TypedReturnsRows """Defines :class:`_engine.Connection` and :class:`_engine.Engine`. """ +_T = TypeVar("_T", bound=Any) _EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.EMPTY_DICT NO_OPTIONS: Mapping[str, Any] = util.EMPTY_DICT @@ -1142,10 +1142,31 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self._dbapi_connection = None self.__can_reconnect = False + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Optional[_T]: + ... + + @overload + def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Any: + ... + def scalar( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, ) -> Any: r"""Executes a SQL statement construct and returns a scalar object. @@ -1170,10 +1191,31 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): execution_options or NO_OPTIONS, ) + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> ScalarResult[_T]: + ... + + @overload def scalars( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: + ... + + def scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, ) -> ScalarResult[Any]: """Executes and returns a scalar result set, which yields scalar values @@ -1190,14 +1232,37 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ - return self.execute(statement, parameters, execution_options).scalars() + return self.execute( + statement, parameters, execution_options=execution_options + ).scalars() + + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> CursorResult[_T]: + ... + + @overload + def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + ... def execute( self, statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, - ) -> CursorResult: + ) -> CursorResult[Any]: r"""Executes a SQL statement construct and returns a :class:`_engine.CursorResult`. @@ -1246,7 +1311,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): func: FunctionElement[Any], distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: """Execute a sql.FunctionElement object.""" return self._execute_clauseelement( @@ -1317,7 +1382,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ddl: ExecutableDDLElement, distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: """Execute a schema.DDL object.""" execution_options = ddl._execution_options.merge_with( @@ -1414,7 +1479,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): elem: Executable, distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: """Execute a sql.ClauseElement object.""" execution_options = elem._execution_options.merge_with( @@ -1487,7 +1552,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): compiled: Compiled, distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS, - ) -> CursorResult: + ) -> CursorResult[Any]: """Execute a sql.Compiled object. TODO: why do we have this? likely deprecate or remove @@ -1537,7 +1602,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): statement: str, parameters: Optional[_DBAPIAnyExecuteParams] = None, execution_options: Optional[_ExecuteOptionsParameter] = None, - ) -> CursorResult: + ) -> CursorResult[Any]: r"""Executes a SQL statement construct and returns a :class:`_engine.CursorResult`. @@ -1614,7 +1679,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): execution_options: _ExecuteOptions, *args: Any, **kw: Any, - ) -> CursorResult: + ) -> CursorResult[Any]: """Create an :class:`.ExecutionContext` and execute, returning a :class:`_engine.CursorResult`.""" diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index ccf5736756..ff69666b71 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -24,6 +24,7 @@ from typing import Optional from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from .result import MergedResult @@ -55,11 +56,12 @@ if typing.TYPE_CHECKING: from .interfaces import ExecutionContext from .result import _KeyIndexType from .result import _KeyMapRecType - from .result import _KeyMapType from .result import _KeyType from .result import _ProcessorsType from ..sql.type_api import _ResultProcessorType +_T = TypeVar("_T", bound=Any) + # metadata entry tuple indexes. # using raw tuple is faster than namedtuple. MD_INDEX: Literal[0] = 0 # integer index in cursor.description @@ -214,7 +216,9 @@ class CursorResultMetaData(ResultMetaData): return md def __init__( - self, parent: CursorResult, cursor_description: _DBAPICursorDescription + self, + parent: CursorResult[Any], + cursor_description: _DBAPICursorDescription, ): context = parent.context self._tuplefilter = None @@ -1158,7 +1162,7 @@ class _NoResultMetaData(ResultMetaData): _NO_RESULT_METADATA = _NoResultMetaData() -class CursorResult(Result): +class CursorResult(Result[_T]): """A Result that is representing state from a DBAPI cursor. .. versionchanged:: 1.4 The :class:`.CursorResult`` @@ -1179,6 +1183,15 @@ class CursorResult(Result): """ + __slots__ = ( + "context", + "dialect", + "cursor", + "cursor_strategy", + "_echo", + "connection", + ) + _metadata: Union[CursorResultMetaData, _NoResultMetaData] _no_result_metadata = _NO_RESULT_METADATA _soft_closed: bool = False @@ -1231,7 +1244,6 @@ class CursorResult(Result): make_row = _make_row_2 else: make_row = _make_row - self._set_memoized_attribute("_row_getter", make_row) else: @@ -1726,12 +1738,12 @@ class CursorResult(Result): def _raw_row_iterator(self): return self._fetchiter_impl() - def merge(self, *others: Result) -> MergedResult: + def merge(self, *others: Result[Any]) -> MergedResult[Any]: merged_result = super().merge(*others) setup_rowcounts = not self._metadata.returns_rows if setup_rowcounts: merged_result.rowcount = sum( - cast(CursorResult, result).rowcount + cast("CursorResult[Any]", result).rowcount for result in (self,) + others ) return merged_result diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index c6571f68bb..9c6ff758fc 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -62,13 +62,9 @@ if typing.TYPE_CHECKING: from .base import Connection from .base import Engine - from .characteristics import ConnectionCharacteristic - from .interfaces import _AnyMultiExecuteParams from .interfaces import _CoreMultiExecuteParams from .interfaces import _CoreSingleExecuteParams - from .interfaces import _DBAPIAnyExecuteParams from .interfaces import _DBAPIMultiExecuteParams - from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions from .interfaces import _IsolationLevel from .interfaces import _MutableCoreSingleExecuteParams @@ -83,15 +79,11 @@ if typing.TYPE_CHECKING: from ..sql.compiler import Compiled from ..sql.compiler import Linting from ..sql.compiler import ResultColumnsEntry - from ..sql.compiler import TypeCompiler from ..sql.dml import DMLState from ..sql.dml import UpdateBase from ..sql.elements import BindParameter - from ..sql.roles import ColumnsClauseRole from ..sql.schema import Column - from ..sql.schema import ColumnDefault from ..sql.type_api import _BindProcessorType - from ..sql.type_api import _ResultProcessorType from ..sql.type_api import TypeEngine # When we're handed literal SQL, ensure it's a SELECT query @@ -781,7 +773,7 @@ class DefaultExecutionContext(ExecutionContext): result_column_struct: Optional[ Tuple[List[ResultColumnsEntry], bool, bool, bool] ] = None - returned_default_rows: Optional[List[Row]] = None + returned_default_rows: Optional[Sequence[Row[Any]]] = None execution_options: _ExecuteOptions = util.EMPTY_DICT @@ -1385,7 +1377,9 @@ class DefaultExecutionContext(ExecutionContext): if cursor_description is None: strategy = _cursor._NO_CURSOR_DML - result = _cursor.CursorResult(self, strategy, cursor_description) + result: _cursor.CursorResult[Any] = _cursor.CursorResult( + self, strategy, cursor_description + ) if self.isinsert: if self._is_implicit_returning: diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index ef10946a86..4093d3e0e7 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -28,7 +28,6 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from .base import Connection - from .interfaces import _CoreAnyExecuteParams from .interfaces import _CoreMultiExecuteParams from .interfaces import _CoreSingleExecuteParams from .interfaces import _DBAPIAnyExecuteParams @@ -273,7 +272,7 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]): multiparams: _CoreMultiExecuteParams, params: _CoreSingleExecuteParams, execution_options: _ExecuteOptions, - result: Result, + result: Result[Any], ) -> None: """Intercept high level execute() events after execute. diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 54fe21d747..6410246039 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -2422,7 +2422,7 @@ class ExecutionContext: def _get_cache_stats(self) -> str: raise NotImplementedError() - def _setup_result_proxy(self) -> CursorResult: + def _setup_result_proxy(self) -> CursorResult[Any]: raise NotImplementedError() def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int: diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 71320a583d..55d36a1d5b 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -28,6 +28,7 @@ from typing import overload from typing import Sequence from typing import Set from typing import Tuple +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -70,6 +71,8 @@ _RawRowType = Tuple[Any, ...] """represents the kind of row we get from a DBAPI cursor""" _R = TypeVar("_R", bound=_RowData) +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) _InterimRowType = Union[_R, _RawRowType] """a catchall "anything" kind of return type that can be applied @@ -141,7 +144,7 @@ class ResultMetaData: def _getter( self, key: Any, raiseerr: bool = True - ) -> Optional[Callable[[Row], Any]]: + ) -> Optional[Callable[[Row[Any]], Any]]: index = self._index_for_key(key, raiseerr) @@ -270,7 +273,7 @@ class SimpleResultMetaData(ResultMetaData): _tuplefilter=_tuplefilter, ) - def _contains(self, value: Any, row: Row) -> bool: + def _contains(self, value: Any, row: Row[Any]) -> bool: return value in row._data def _index_for_key(self, key: Any, raiseerr: bool = True) -> int: @@ -335,7 +338,7 @@ class SimpleResultMetaData(ResultMetaData): def result_tuple( fields: Sequence[str], extra: Optional[Any] = None -) -> Callable[[Iterable[Any]], Row]: +) -> Callable[[Iterable[Any]], Row[Any]]: parent = SimpleResultMetaData(fields, extra) return functools.partial( Row, parent, parent._processors, parent._keymap, Row._default_key_style @@ -355,7 +358,9 @@ SelfResultInternal = TypeVar("SelfResultInternal", bound="ResultInternal[Any]") class ResultInternal(InPlaceGenerative, Generic[_R]): - _real_result: Optional[Result] = None + __slots__ = () + + _real_result: Optional[Result[Any]] = None _generate_rows: bool = True _row_logging_fn: Optional[Callable[[Any], Any]] @@ -367,20 +372,20 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): _source_supports_scalars: bool - def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row]]: + def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]: raise NotImplementedError() def _fetchone_impl( self, hard_close: bool = False - ) -> Optional[_InterimRowType[Row]]: + ) -> Optional[_InterimRowType[Row[Any]]]: raise NotImplementedError() def _fetchmany_impl( self, size: Optional[int] = None - ) -> List[_InterimRowType[Row]]: + ) -> List[_InterimRowType[Row[Any]]]: raise NotImplementedError() - def _fetchall_impl(self) -> List[_InterimRowType[Row]]: + def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]: raise NotImplementedError() def _soft_close(self, hard: bool = False) -> None: @@ -388,8 +393,10 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): @HasMemoized_ro_memoized_attribute def _row_getter(self) -> Optional[Callable[..., _R]]: - real_result: Result = ( - self._real_result if self._real_result else cast(Result, self) + real_result: Result[Any] = ( + self._real_result + if self._real_result + else cast("Result[Any]", self) ) if real_result._source_supports_scalars: @@ -404,7 +411,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): keymap: _KeyMapType, key_style: Any, scalar_obj: Any, - ) -> Row: + ) -> Row[Any]: return _proc( metadata, processors, keymap, key_style, (scalar_obj,) ) @@ -429,7 +436,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): fixed_tf = tf - def make_row(row: _InterimRowType[Row]) -> _R: + def make_row(row: _InterimRowType[Row[Any]]) -> _R: return _make_row_orig(fixed_tf(row)) else: @@ -447,7 +454,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): if fns: _make_row = make_row - def make_row(row: _InterimRowType[Row]) -> _R: + def make_row(row: _InterimRowType[Row[Any]]) -> _R: interim_row = _make_row(row) for fn in fns: interim_row = fn(interim_row) @@ -465,7 +472,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): if self._unique_filter_state: uniques, strategy = self._unique_strategy - def iterrows(self: Result) -> Iterator[_R]: + def iterrows(self: Result[Any]) -> Iterator[_R]: for raw_row in self._fetchiter_impl(): obj: _InterimRowType[Any] = ( make_row(raw_row) if make_row else raw_row @@ -480,7 +487,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): else: - def iterrows(self: Result) -> Iterator[_R]: + def iterrows(self: Result[Any]) -> Iterator[_R]: for raw_row in self._fetchiter_impl(): row: _InterimRowType[Any] = ( make_row(raw_row) if make_row else raw_row @@ -546,7 +553,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): if self._unique_filter_state: uniques, strategy = self._unique_strategy - def onerow(self: Result) -> Union[_NoRow, _R]: + def onerow(self: Result[Any]) -> Union[_NoRow, _R]: _onerow = self._fetchone_impl while True: row = _onerow() @@ -567,7 +574,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): else: - def onerow(self: Result) -> Union[_NoRow, _R]: + def onerow(self: Result[Any]) -> Union[_NoRow, _R]: row = self._fetchone_impl() if row is None: return _NO_ROW @@ -627,7 +634,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): real_result = ( self._real_result if self._real_result - else cast(Result, self) + else cast("Result[Any]", self) ) if real_result._yield_per: num_required = num = real_result._yield_per @@ -667,7 +674,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): real_result = ( self._real_result if self._real_result - else cast(Result, self) + else cast("Result[Any]", self) ) num = real_result._yield_per @@ -799,7 +806,9 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): self: SelfResultInternal, indexes: Sequence[_KeyIndexType] ) -> SelfResultInternal: real_result = ( - self._real_result if self._real_result else cast(Result, self) + self._real_result + if self._real_result + else cast("Result[Any]", self) ) if not real_result._source_supports_scalars or len(indexes) != 1: @@ -817,7 +826,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): real_result = ( self._real_result if self._real_result is not None - else cast(Result, self) + else cast("Result[Any]", self) ) if not strategy and self._metadata._unique_filters: @@ -836,6 +845,8 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): class _WithKeys: + __slots__ = () + _metadata: ResultMetaData # used mainly to share documentation on the keys method. @@ -859,10 +870,10 @@ class _WithKeys: return self._metadata.keys -SelfResult = TypeVar("SelfResult", bound="Result") +SelfResult = TypeVar("SelfResult", bound="Result[Any]") -class Result(_WithKeys, ResultInternal[Row]): +class Result(_WithKeys, ResultInternal[Row[_TP]]): """Represent a set of database results. .. versionadded:: 1.4 The :class:`.Result` object provides a completely @@ -887,7 +898,9 @@ class Result(_WithKeys, ResultInternal[Row]): """ - _row_logging_fn: Optional[Callable[[Row], Row]] = None + __slots__ = ("_metadata", "__dict__") + + _row_logging_fn: Optional[Callable[[Row[Any]], Row[Any]]] = None _source_supports_scalars: bool = False @@ -1011,6 +1024,15 @@ class Result(_WithKeys, ResultInternal[Row]): appropriate :class:`.ColumnElement` objects which correspond to a given statement construct. + .. versionchanged:: 2.0 Due to a bug in 1.4, the + :meth:`.Result.columns` method had an incorrect behavior where + calling upon the method with just one index would cause the + :class:`.Result` object to yield scalar values rather than + :class:`.Row` objects. In version 2.0, this behavior has been + corrected such that calling upon :meth:`.Result.columns` with + a single index will produce a :class:`.Result` object that continues + to yield :class:`.Row` objects, which include only a single column. + E.g.:: statement = select(table.c.x, table.c.y, table.c.z) @@ -1040,6 +1062,20 @@ class Result(_WithKeys, ResultInternal[Row]): """ return self._column_slices(col_expressions) + @overload + def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]: + ... + + @overload + def scalars( + self: Result[Tuple[_T]], index: Literal[0] + ) -> ScalarResult[_T]: + ... + + @overload + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: + ... + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: """Return a :class:`_result.ScalarResult` filtering object which will return single elements rather than :class:`_row.Row` objects. @@ -1067,7 +1103,7 @@ class Result(_WithKeys, ResultInternal[Row]): def _getter( self, key: _KeyIndexType, raiseerr: bool = True - ) -> Optional[Callable[[Row], Any]]: + ) -> Optional[Callable[[Row[Any]], Any]]: """return a callable that will retrieve the given key from a :class:`.Row`. @@ -1105,6 +1141,43 @@ class Result(_WithKeys, ResultInternal[Row]): return MappingResult(self) + @property + def t(self) -> TupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + The :attr:`.Result.t` attribute is a synonym for calling the + :meth:`.Result.tuples` method. + + .. versionadded:: 2.0 + + """ + return self # type: ignore + + def tuples(self) -> TupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + This method returns the same :class:`.Result` object at runtime, + however annotates as returning a :class:`.TupleResult` object + that will indicate to :pep:`484` typing tools that plain typed + ``Tuple`` instances are returned rather than rows. This allows + tuple unpacking and ``__getitem__`` access of :class:`.Row` objects + to by typed, for those cases where the statement invoked itself + included typing information. + + .. versionadded:: 2.0 + + :return: the :class:`_result.TupleResult` type at typing time. + + .. seealso:: + + :attr:`.Result.t` - shorter synonym + + :attr:`.Row.t` - :class:`.Row` version + + """ + + return self # type: ignore + def _raw_row_iterator(self) -> Iterator[_RowData]: """Return a safe iterator that yields raw row data. @@ -1114,13 +1187,15 @@ class Result(_WithKeys, ResultInternal[Row]): """ raise NotImplementedError() - def __iter__(self) -> Iterator[Row]: + def __iter__(self) -> Iterator[Row[_TP]]: return self._iter_impl() - def __next__(self) -> Row: + def __next__(self) -> Row[_TP]: return self._next_impl() - def partitions(self, size: Optional[int] = None) -> Iterator[List[Row]]: + def partitions( + self, size: Optional[int] = None + ) -> Iterator[Sequence[Row[_TP]]]: """Iterate through sub-lists of rows of the size given. Each list will be of the size given, excluding the last list to @@ -1158,12 +1233,12 @@ class Result(_WithKeys, ResultInternal[Row]): else: break - def fetchall(self) -> List[Row]: + def fetchall(self) -> Sequence[Row[_TP]]: """A synonym for the :meth:`_engine.Result.all` method.""" return self._allrows() - def fetchone(self) -> Optional[Row]: + def fetchone(self) -> Optional[Row[_TP]]: """Fetch one row. When all rows are exhausted, returns None. @@ -1185,7 +1260,7 @@ class Result(_WithKeys, ResultInternal[Row]): else: return row - def fetchmany(self, size: Optional[int] = None) -> List[Row]: + def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]: """Fetch many rows. When all rows are exhausted, returns an empty list. @@ -1202,7 +1277,7 @@ class Result(_WithKeys, ResultInternal[Row]): return self._manyrow_getter(self, size) - def all(self) -> List[Row]: + def all(self) -> Sequence[Row[_TP]]: """Return all rows in a list. Closes the result set after invocation. Subsequent invocations @@ -1216,7 +1291,7 @@ class Result(_WithKeys, ResultInternal[Row]): return self._allrows() - def first(self) -> Optional[Row]: + def first(self) -> Optional[Row[_TP]]: """Fetch the first row or None if no row is present. Closes the result set and discards remaining rows. @@ -1252,7 +1327,7 @@ class Result(_WithKeys, ResultInternal[Row]): raise_for_second_row=False, raise_for_none=False, scalar=False ) - def one_or_none(self) -> Optional[Row]: + def one_or_none(self) -> Optional[Row[_TP]]: """Return at most one result or raise an exception. Returns ``None`` if the result has no rows. @@ -1276,6 +1351,14 @@ class Result(_WithKeys, ResultInternal[Row]): raise_for_second_row=True, raise_for_none=False, scalar=False ) + @overload + def scalar_one(self: Result[Tuple[_T]]) -> _T: + ... + + @overload + def scalar_one(self) -> Any: + ... + def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. @@ -1293,6 +1376,14 @@ class Result(_WithKeys, ResultInternal[Row]): raise_for_second_row=True, raise_for_none=True, scalar=True ) + @overload + def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + def scalar_one_or_none(self) -> Optional[Any]: + ... + def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. @@ -1310,7 +1401,7 @@ class Result(_WithKeys, ResultInternal[Row]): raise_for_second_row=True, raise_for_none=False, scalar=True ) - def one(self) -> Row: + def one(self) -> Row[_TP]: """Return exactly one row or raise an exception. Raises :class:`.NoResultFound` if the result returns no @@ -1341,6 +1432,14 @@ class Result(_WithKeys, ResultInternal[Row]): raise_for_second_row=True, raise_for_none=True, scalar=False ) + @overload + def scalar(self: Result[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + def scalar(self) -> Any: + ... + def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. @@ -1359,7 +1458,7 @@ class Result(_WithKeys, ResultInternal[Row]): raise_for_second_row=False, raise_for_none=False, scalar=True ) - def freeze(self) -> FrozenResult: + def freeze(self) -> FrozenResult[_TP]: """Return a callable object that will produce copies of this :class:`.Result` when invoked. @@ -1382,7 +1481,7 @@ class Result(_WithKeys, ResultInternal[Row]): return FrozenResult(self) - def merge(self, *others: Result) -> MergedResult: + def merge(self, *others: Result[Any]) -> MergedResult[_TP]: """Merge this :class:`.Result` with other compatible result objects. @@ -1405,9 +1504,17 @@ class FilterResult(ResultInternal[_R]): """ - _post_creational_filter: Optional[Callable[[Any], Any]] = None + __slots__ = ( + "_real_result", + "_post_creational_filter", + "_metadata", + "_unique_filter_state", + "__dict__", + ) + + _post_creational_filter: Optional[Callable[[Any], Any]] - _real_result: Result + _real_result: Result[Any] def _soft_close(self, hard: bool = False) -> None: self._real_result._soft_close(hard=hard) @@ -1416,20 +1523,20 @@ class FilterResult(ResultInternal[_R]): def _attributes(self) -> Dict[Any, Any]: return self._real_result._attributes - def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row]]: + def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]: return self._real_result._fetchiter_impl() def _fetchone_impl( self, hard_close: bool = False - ) -> Optional[_InterimRowType[Row]]: + ) -> Optional[_InterimRowType[Row[Any]]]: return self._real_result._fetchone_impl(hard_close=hard_close) - def _fetchall_impl(self) -> List[_InterimRowType[Row]]: + def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]: return self._real_result._fetchall_impl() def _fetchmany_impl( self, size: Optional[int] = None - ) -> List[_InterimRowType[Row]]: + ) -> List[_InterimRowType[Row[Any]]]: return self._real_result._fetchmany_impl(size=size) @@ -1452,11 +1559,13 @@ class ScalarResult(FilterResult[_R]): """ + __slots__ = () + _generate_rows = False _post_creational_filter: Optional[Callable[[Any], Any]] - def __init__(self, real_result: Result, index: _KeyIndexType): + def __init__(self, real_result: Result[Any], index: _KeyIndexType): self._real_result = real_result if real_result._source_supports_scalars: @@ -1480,7 +1589,7 @@ class ScalarResult(FilterResult[_R]): self._unique_filter_state = (set(), strategy) return self - def partitions(self, size: Optional[int] = None) -> Iterator[List[_R]]: + def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_R]]: """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_result.Result.partitions` except that @@ -1498,12 +1607,12 @@ class ScalarResult(FilterResult[_R]): else: break - def fetchall(self) -> List[_R]: + def fetchall(self) -> Sequence[_R]: """A synonym for the :meth:`_engine.ScalarResult.all` method.""" return self._allrows() - def fetchmany(self, size: Optional[int] = None) -> List[_R]: + def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: """Fetch many objects. Equivalent to :meth:`_result.Result.fetchmany` except that @@ -1513,7 +1622,7 @@ class ScalarResult(FilterResult[_R]): """ return self._manyrow_getter(self, size) - def all(self) -> List[_R]: + def all(self) -> Sequence[_R]: """Return all scalar values in a list. Equivalent to :meth:`_result.Result.all` except that @@ -1567,6 +1676,177 @@ class ScalarResult(FilterResult[_R]): ) +SelfTupleResult = TypeVar("SelfTupleResult", bound="TupleResult[Any]") + + +class TupleResult(FilterResult[_R], util.TypingOnly): + """a :class:`.Result` that's typed as returning plain Python tuples + instead of rows. + + Since :class:`.Row` acts like a tuple in every way already, + this class is a typing only class, regular :class:`.Result` is still + used at runtime. + + """ + + __slots__ = () + + if TYPE_CHECKING: + + def partitions( + self, size: Optional[int] = None + ) -> Iterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_result.Result.partitions` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + def fetchone(self) -> Optional[_R]: + """Fetch one tuple. + + Equivalent to :meth:`_result.Result.fetchone` except that + tuple values, rather than :class:`_result.Row` + objects, are returned. + + """ + ... + + def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_engine.ScalarResult.all` method.""" + ... + + def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_result.Result.fetchmany` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + def all(self) -> Sequence[_R]: # noqa: A001 + """Return all scalar values in a list. + + Equivalent to :meth:`_result.Result.all` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + def __iter__(self) -> Iterator[_R]: + ... + + def __next__(self) -> _R: + ... + + def first(self) -> Optional[_R]: + """Fetch the first object or None if no object is present. + + Equivalent to :meth:`_result.Result.first` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + + """ + ... + + def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_result.Result.one_or_none` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_result.Result.one` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + @overload + def scalar_one(self: TupleResult[Tuple[_T]]) -> _T: + ... + + @overload + def scalar_one(self) -> Any: + ... + + def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`.Result.scalars` and then + :meth:`.Result.one`. + + .. seealso:: + + :meth:`.Result.one` + + :meth:`.Result.scalars` + + """ + ... + + @overload + def scalar_one_or_none(self: TupleResult[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + def scalar_one_or_none(self) -> Optional[Any]: + ... + + def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one or no scalar result. + + This is equivalent to calling :meth:`.Result.scalars` and then + :meth:`.Result.one_or_none`. + + .. seealso:: + + :meth:`.Result.one_or_none` + + :meth:`.Result.scalars` + + """ + ... + + @overload + def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + def scalar(self) -> Any: + ... + + def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result set. + + Returns None if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value , or None if no rows remain. + + """ + ... + + SelfMappingResult = TypeVar("SelfMappingResult", bound="MappingResult") @@ -1579,11 +1859,13 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): """ + __slots__ = () + _generate_rows = True _post_creational_filter = operator.attrgetter("_mapping") - def __init__(self, result: Result): + def __init__(self, result: Result[Any]): self._real_result = result self._unique_filter_state = result._unique_filter_state self._metadata = result._metadata @@ -1610,7 +1892,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): def partitions( self, size: Optional[int] = None - ) -> Iterator[List[RowMapping]]: + ) -> Iterator[Sequence[RowMapping]]: """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_result.Result.partitions` except that @@ -1628,7 +1910,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): else: break - def fetchall(self) -> List[RowMapping]: + def fetchall(self) -> Sequence[RowMapping]: """A synonym for the :meth:`_engine.MappingResult.all` method.""" return self._allrows() @@ -1648,7 +1930,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): else: return row - def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]: + def fetchmany(self, size: Optional[int] = None) -> Sequence[RowMapping]: """Fetch many objects. Equivalent to :meth:`_result.Result.fetchmany` except that @@ -1659,7 +1941,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): return self._manyrow_getter(self, size) - def all(self) -> List[RowMapping]: + def all(self) -> Sequence[RowMapping]: """Return all scalar values in a list. Equivalent to :meth:`_result.Result.all` except that @@ -1714,7 +1996,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): ) -class FrozenResult: +class FrozenResult(Generic[_TP]): """Represents a :class:`.Result` object in a "frozen" state suitable for caching. @@ -1755,7 +2037,7 @@ class FrozenResult: data: Sequence[Any] - def __init__(self, result: Result): + def __init__(self, result: Result[_TP]): self.metadata = result._metadata._for_freeze() self._source_supports_scalars = result._source_supports_scalars self._attributes = result._attributes @@ -1771,7 +2053,9 @@ class FrozenResult: else: return [list(row) for row in self.data] - def with_new_rows(self, tuple_data: Sequence[Row]) -> FrozenResult: + def with_new_rows( + self, tuple_data: Sequence[Row[_TP]] + ) -> FrozenResult[_TP]: fr = FrozenResult.__new__(FrozenResult) fr.metadata = self.metadata fr._attributes = self._attributes @@ -1783,14 +2067,16 @@ class FrozenResult: fr.data = tuple_data return fr - def __call__(self) -> Result: - result = IteratorResult(self.metadata, iter(self.data)) + def __call__(self) -> Result[_TP]: + result: IteratorResult[_TP] = IteratorResult( + self.metadata, iter(self.data) + ) result._attributes = self._attributes result._source_supports_scalars = self._source_supports_scalars return result -class IteratorResult(Result): +class IteratorResult(Result[_TP]): """A :class:`.Result` that gets data from a Python iterator of :class:`.Row` objects or similar row-like data. @@ -1833,7 +2119,7 @@ class IteratorResult(Result): def _fetchone_impl( self, hard_close: bool = False - ) -> Optional[_InterimRowType[Row]]: + ) -> Optional[_InterimRowType[Row[Any]]]: if self._hard_closed: self._raise_hard_closed() @@ -1844,7 +2130,7 @@ class IteratorResult(Result): else: return row - def _fetchall_impl(self) -> List[_InterimRowType[Row]]: + def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]: if self._hard_closed: self._raise_hard_closed() try: @@ -1854,23 +2140,23 @@ class IteratorResult(Result): def _fetchmany_impl( self, size: Optional[int] = None - ) -> List[_InterimRowType[Row]]: + ) -> List[_InterimRowType[Row[Any]]]: if self._hard_closed: self._raise_hard_closed() return list(itertools.islice(self.iterator, 0, size)) -def null_result() -> IteratorResult: +def null_result() -> IteratorResult[Any]: return IteratorResult(SimpleResultMetaData([]), iter([])) SelfChunkedIteratorResult = TypeVar( - "SelfChunkedIteratorResult", bound="ChunkedIteratorResult" + "SelfChunkedIteratorResult", bound="ChunkedIteratorResult[Any]" ) -class ChunkedIteratorResult(IteratorResult): +class ChunkedIteratorResult(IteratorResult[_TP]): """An :class:`.IteratorResult` that works from an iterator-producing callable. The given ``chunks`` argument is a function that is given a number of rows @@ -1922,13 +2208,13 @@ class ChunkedIteratorResult(IteratorResult): def _fetchmany_impl( self, size: Optional[int] = None - ) -> List[_InterimRowType[Row]]: + ) -> List[_InterimRowType[Row[Any]]]: if self.dynamic_yield_per: self.iterator = itertools.chain.from_iterable(self.chunks(size)) return super()._fetchmany_impl(size=size) -class MergedResult(IteratorResult): +class MergedResult(IteratorResult[_TP]): """A :class:`_engine.Result` that is merged from any number of :class:`_engine.Result` objects. @@ -1942,7 +2228,7 @@ class MergedResult(IteratorResult): rowcount: Optional[int] def __init__( - self, cursor_metadata: ResultMetaData, results: Sequence[Result] + self, cursor_metadata: ResultMetaData, results: Sequence[Result[_TP]] ): self._results = results super(MergedResult, self).__init__( diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 4ba39b55d6..7c9eacb78c 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -16,6 +16,7 @@ import typing from typing import Any from typing import Callable from typing import Dict +from typing import Generic from typing import Iterator from typing import List from typing import Mapping @@ -24,12 +25,14 @@ from typing import Optional from typing import overload from typing import Sequence from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from ..sql import util as sql_util from ..util._has_cy import HAS_CYEXTENSION -if typing.TYPE_CHECKING or not HAS_CYEXTENSION: +if TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_row import BaseRow as BaseRow from ._py_row import KEY_INTEGER_ONLY from ._py_row import KEY_OBJECTS_ONLY @@ -38,13 +41,16 @@ else: from sqlalchemy.cyextension.resultproxy import KEY_INTEGER_ONLY from sqlalchemy.cyextension.resultproxy import KEY_OBJECTS_ONLY -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from .result import _KeyType from .result import RMKeyView from ..sql.type_api import _ResultProcessorType +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) -class Row(BaseRow, typing.Sequence[Any]): + +class Row(BaseRow, Sequence[Any], Generic[_TP]): """Represent a single result row. The :class:`.Row` object represents a row of a database result. It is @@ -82,6 +88,37 @@ class Row(BaseRow, typing.Sequence[Any]): def __delattr__(self, name: str) -> NoReturn: raise AttributeError("can't delete attribute") + def tuple(self) -> _TP: + """Return a 'tuple' form of this :class:`.Row`. + + At runtime, this method returns "self"; the :class:`.Row` object is + already a named tuple. However, at the typing level, if this + :class:`.Row` is typed, the "tuple" return type will be a :pep:`484` + ``Tuple`` datatype that contains typing information about individual + elements, supporting typed unpacking and attribute access. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.Result.tuples` + + """ + return self # type: ignore + + @property + def t(self) -> _TP: + """a synonym for :attr:`.Row.tuple` + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.Result.t` + + """ + return self # type: ignore + @property def _mapping(self) -> RowMapping: """Return a :class:`.RowMapping` for this :class:`.Row`. @@ -107,7 +144,7 @@ class Row(BaseRow, typing.Sequence[Any]): def _filter_on_values( self, filters: Optional[Sequence[Optional[_ResultProcessorType[Any]]]] - ) -> Row: + ) -> Row[Any]: return Row( self._parent, filters, @@ -116,7 +153,7 @@ class Row(BaseRow, typing.Sequence[Any]): self._data, ) - if not typing.TYPE_CHECKING: + if not TYPE_CHECKING: def _special_name_accessor(name: str) -> Any: """Handle ambiguous names such as "count" and "index" """ @@ -151,7 +188,7 @@ class Row(BaseRow, typing.Sequence[Any]): __hash__ = BaseRow.__hash__ - if typing.TYPE_CHECKING: + if TYPE_CHECKING: @overload def __getitem__(self, index: int) -> Any: @@ -299,7 +336,7 @@ class RowMapping(BaseRow, typing.Mapping[str, Any]): _default_key_style = KEY_OBJECTS_ONLY - if typing.TYPE_CHECKING: + if TYPE_CHECKING: def __getitem__(self, key: _KeyType) -> Any: ... diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index fb05f512e4..95549ada69 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -12,8 +12,10 @@ from typing import Generator from typing import NoReturn from typing import Optional from typing import overload +from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import exc as async_exc @@ -50,6 +52,9 @@ if TYPE_CHECKING: from ...pool import PoolProxiedConnection from ...sql._typing import _InfoType from ...sql.base import Executable + from ...sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) class _SyncConnectionCallable(Protocol): @@ -407,7 +412,7 @@ class AsyncConnection( statement: str, parameters: Optional[_DBAPIAnyExecuteParams] = None, execution_options: Optional[_ExecuteOptionsParameter] = None, - ) -> CursorResult: + ) -> CursorResult[Any]: r"""Executes a driver-level SQL string and return buffered :class:`_engine.Result`. @@ -423,12 +428,33 @@ class AsyncConnection( return await _ensure_sync_result(result, self.exec_driver_sql) + @overload + async def stream( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> AsyncResult[_T]: + ... + + @overload async def stream( self, statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, - ) -> AsyncResult: + ) -> AsyncResult[Any]: + ... + + async def stream( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> AsyncResult[Any]: """Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object.""" @@ -436,7 +462,7 @@ class AsyncConnection( self._proxied.execute, statement, parameters, - util.EMPTY_DICT.merge_with( + execution_options=util.EMPTY_DICT.merge_with( execution_options, {"stream_results": True} ), _require_await=True, @@ -446,12 +472,33 @@ class AsyncConnection( assert False, "server side result expected" return AsyncResult(result) + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> CursorResult[_T]: + ... + + @overload async def execute( self, statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, - ) -> CursorResult: + ) -> CursorResult[Any]: + ... + + async def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: r"""Executes a SQL statement construct and return a buffered :class:`_engine.Result`. @@ -487,15 +534,36 @@ class AsyncConnection( self._proxied.execute, statement, parameters, - execution_options, + execution_options=execution_options, _require_await=True, ) return await _ensure_sync_result(result, self.execute) + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Optional[_T]: + ... + + @overload async def scalar( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Any: + ... + + async def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, ) -> Any: r"""Executes a SQL statement construct and returns a scalar object. @@ -508,13 +576,36 @@ class AsyncConnection( first row returned. """ - result = await self.execute(statement, parameters, execution_options) + result = await self.execute( + statement, parameters, execution_options=execution_options + ) return result.scalar() + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> ScalarResult[_T]: + ... + + @overload + async def scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: + ... + async def scalars( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, ) -> ScalarResult[Any]: r"""Executes a SQL statement construct and returns a scalar objects. @@ -528,13 +619,36 @@ class AsyncConnection( .. versionadded:: 1.4.24 """ - result = await self.execute(statement, parameters, execution_options) + result = await self.execute( + statement, parameters, execution_options=execution_options + ) return result.scalars() + @overload + async def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> AsyncScalarResult[_T]: + ... + + @overload async def stream_scalars( self, statement: Executable, parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> AsyncScalarResult[Any]: + ... + + async def stream_scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: Optional[_ExecuteOptionsParameter] = None, ) -> AsyncScalarResult[Any]: r"""Executes a SQL statement and returns a streaming scalar result @@ -549,7 +663,9 @@ class AsyncConnection( .. versionadded:: 1.4.24 """ - result = await self.stream(statement, parameters, execution_options) + result = await self.stream( + statement, parameters, execution_options=execution_options + ) return result.scalars() async def run_sync( diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index d0337554cf..ff3dcf4174 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -9,12 +9,15 @@ from __future__ import annotations import operator from typing import Any from typing import AsyncIterator -from typing import List from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar from . import exc as async_exc +from ... import util from ...engine.result import _NO_ROW from ...engine.result import _R from ...engine.result import FilterResult @@ -24,6 +27,7 @@ from ...engine.result import ResultMetaData from ...engine.row import Row from ...engine.row import RowMapping from ...util.concurrency import greenlet_spawn +from ...util.typing import Literal if TYPE_CHECKING: from ...engine import CursorResult @@ -32,9 +36,14 @@ if TYPE_CHECKING: from ...engine.result import _UniqueFilterType from ...engine.result import RMKeyView +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + class AsyncCommon(FilterResult[_R]): - _real_result: Result + __slots__ = () + + _real_result: Result[Any] _metadata: ResultMetaData async def close(self) -> None: @@ -43,10 +52,10 @@ class AsyncCommon(FilterResult[_R]): await greenlet_spawn(self._real_result.close) -SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult") +SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult[Any]") -class AsyncResult(AsyncCommon[Row]): +class AsyncResult(AsyncCommon[Row[_TP]]): """An asyncio wrapper around a :class:`_result.Result` object. The :class:`_asyncio.AsyncResult` only applies to statement executions that @@ -67,11 +76,16 @@ class AsyncResult(AsyncCommon[Row]): """ - def __init__(self, real_result: Result): + __slots__ = () + + _real_result: Result[_TP] + + def __init__(self, real_result: Result[_TP]): self._real_result = real_result self._metadata = real_result._metadata self._unique_filter_state = real_result._unique_filter_state + self._post_creational_filter = None # BaseCursorResult pre-generates the "_row_getter". Use that # if available rather than building a second one @@ -80,6 +94,43 @@ class AsyncResult(AsyncCommon[Row]): "_row_getter", real_result.__dict__["_row_getter"] ) + @property + def t(self) -> AsyncTupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + The :attr:`.AsyncResult.t` attribute is a synonym for calling the + :meth:`.AsyncResult.tuples` method. + + .. versionadded:: 2.0 + + """ + return self # type: ignore + + def tuples(self) -> AsyncTupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + This method returns the same :class:`.AsyncResult` object at runtime, + however annotates as returning a :class:`.AsyncTupleResult` object + that will indicate to :pep:`484` typing tools that plain typed + ``Tuple`` instances are returned rather than rows. This allows + tuple unpacking and ``__getitem__`` access of :class:`.Row` objects + to by typed, for those cases where the statement invoked itself + included typing information. + + .. versionadded:: 2.0 + + :return: the :class:`_result.AsyncTupleResult` type at typing time. + + .. seealso:: + + :attr:`.AsyncResult.t` - shorter synonym + + :attr:`.Row.t` - :class:`.Row` version + + """ + + return self # type: ignore + def keys(self) -> RMKeyView: """Return the :meth:`_engine.Result.keys` collection from the underlying :class:`_engine.Result`. @@ -115,7 +166,7 @@ class AsyncResult(AsyncCommon[Row]): async def partitions( self, size: Optional[int] = None - ) -> AsyncIterator[List[Row]]: + ) -> AsyncIterator[Sequence[Row[_TP]]]: """Iterate through sub-lists of rows of the size given. An async iterator is returned:: @@ -141,7 +192,16 @@ class AsyncResult(AsyncCommon[Row]): else: break - async def fetchone(self) -> Optional[Row]: + async def fetchall(self) -> Sequence[Row[_TP]]: + """A synonym for the :meth:`.AsyncResult.all` method. + + .. versionadded:: 2.0 + + """ + + return await greenlet_spawn(self._allrows) + + async def fetchone(self) -> Optional[Row[_TP]]: """Fetch one row. When all rows are exhausted, returns None. @@ -163,7 +223,9 @@ class AsyncResult(AsyncCommon[Row]): else: return row - async def fetchmany(self, size: Optional[int] = None) -> List[Row]: + async def fetchmany( + self, size: Optional[int] = None + ) -> Sequence[Row[_TP]]: """Fetch many rows. When all rows are exhausted, returns an empty list. @@ -184,7 +246,7 @@ class AsyncResult(AsyncCommon[Row]): return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self) -> List[Row]: + async def all(self) -> Sequence[Row[_TP]]: """Return all rows in a list. Closes the result set after invocation. Subsequent invocations @@ -196,17 +258,17 @@ class AsyncResult(AsyncCommon[Row]): return await greenlet_spawn(self._allrows) - def __aiter__(self) -> AsyncResult: + def __aiter__(self) -> AsyncResult[_TP]: return self - async def __anext__(self) -> Row: + async def __anext__(self) -> Row[_TP]: row = await greenlet_spawn(self._onerow_getter, self) if row is _NO_ROW: raise StopAsyncIteration() else: return row - async def first(self) -> Optional[Row]: + async def first(self) -> Optional[Row[_TP]]: """Fetch the first row or None if no row is present. Closes the result set and discards remaining rows. @@ -229,7 +291,7 @@ class AsyncResult(AsyncCommon[Row]): """ return await greenlet_spawn(self._only_one_row, False, False, False) - async def one_or_none(self) -> Optional[Row]: + async def one_or_none(self) -> Optional[Row[_TP]]: """Return at most one result or raise an exception. Returns ``None`` if the result has no rows. @@ -251,6 +313,14 @@ class AsyncResult(AsyncCommon[Row]): """ return await greenlet_spawn(self._only_one_row, True, False, False) + @overload + async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: + ... + + @overload + async def scalar_one(self) -> Any: + ... + async def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. @@ -266,6 +336,16 @@ class AsyncResult(AsyncCommon[Row]): """ return await greenlet_spawn(self._only_one_row, True, True, True) + @overload + async def scalar_one_or_none( + self: AsyncResult[Tuple[_T]], + ) -> Optional[_T]: + ... + + @overload + async def scalar_one_or_none(self) -> Optional[Any]: + ... + async def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. @@ -281,7 +361,7 @@ class AsyncResult(AsyncCommon[Row]): """ return await greenlet_spawn(self._only_one_row, True, False, True) - async def one(self) -> Row: + async def one(self) -> Row[_TP]: """Return exactly one row or raise an exception. Raises :class:`.NoResultFound` if the result returns no @@ -312,6 +392,14 @@ class AsyncResult(AsyncCommon[Row]): """ return await greenlet_spawn(self._only_one_row, True, True, False) + @overload + async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + async def scalar(self) -> Any: + ... + async def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. @@ -328,7 +416,7 @@ class AsyncResult(AsyncCommon[Row]): """ return await greenlet_spawn(self._only_one_row, False, False, True) - async def freeze(self) -> FrozenResult: + async def freeze(self) -> FrozenResult[_TP]: """Return a callable object that will produce copies of this :class:`_asyncio.AsyncResult` when invoked. @@ -351,7 +439,7 @@ class AsyncResult(AsyncCommon[Row]): return await greenlet_spawn(FrozenResult, self) - def merge(self, *others: AsyncResult) -> MergedResult: + def merge(self, *others: AsyncResult[_TP]) -> MergedResult[_TP]: """Merge this :class:`_asyncio.AsyncResult` with other compatible result objects. @@ -370,6 +458,20 @@ class AsyncResult(AsyncCommon[Row]): (self._real_result,) + tuple(o._real_result for o in others), ) + @overload + def scalars( + self: AsyncResult[Tuple[_T]], index: Literal[0] + ) -> AsyncScalarResult[_T]: + ... + + @overload + def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: + ... + + @overload + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: + ... + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: """Return an :class:`_asyncio.AsyncScalarResult` filtering object which will return single elements rather than :class:`_row.Row` objects. @@ -423,9 +525,11 @@ class AsyncScalarResult(AsyncCommon[_R]): """ + __slots__ = () + _generate_rows = False - def __init__(self, real_result: Result, index: _KeyIndexType): + def __init__(self, real_result: Result[Any], index: _KeyIndexType): self._real_result = real_result if real_result._source_supports_scalars: @@ -452,7 +556,7 @@ class AsyncScalarResult(AsyncCommon[_R]): async def partitions( self, size: Optional[int] = None - ) -> AsyncIterator[List[_R]]: + ) -> AsyncIterator[Sequence[_R]]: """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that @@ -470,12 +574,12 @@ class AsyncScalarResult(AsyncCommon[_R]): else: break - async def fetchall(self) -> List[_R]: + async def fetchall(self) -> Sequence[_R]: """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method.""" return await greenlet_spawn(self._allrows) - async def fetchmany(self, size: Optional[int] = None) -> List[_R]: + async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: """Fetch many objects. Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that @@ -485,7 +589,7 @@ class AsyncScalarResult(AsyncCommon[_R]): """ return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self) -> List[_R]: + async def all(self) -> Sequence[_R]: """Return all scalar values in a list. Equivalent to :meth:`_asyncio.AsyncResult.all` except that @@ -555,11 +659,13 @@ class AsyncMappingResult(AsyncCommon[RowMapping]): """ + __slots__ = () + _generate_rows = True _post_creational_filter = operator.attrgetter("_mapping") - def __init__(self, result: Result): + def __init__(self, result: Result[Any]): self._real_result = result self._unique_filter_state = result._unique_filter_state self._metadata = result._metadata @@ -602,7 +708,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]): async def partitions( self, size: Optional[int] = None - ) -> AsyncIterator[List[RowMapping]]: + ) -> AsyncIterator[Sequence[RowMapping]]: """Iterate through sub-lists of elements of the size given. @@ -621,7 +727,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]): else: break - async def fetchall(self) -> List[RowMapping]: + async def fetchall(self) -> Sequence[RowMapping]: """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method.""" return await greenlet_spawn(self._allrows) @@ -641,7 +747,9 @@ class AsyncMappingResult(AsyncCommon[RowMapping]): else: return row - async def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]: + async def fetchmany( + self, size: Optional[int] = None + ) -> Sequence[RowMapping]: """Fetch many rows. Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that @@ -652,7 +760,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]): return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self) -> List[RowMapping]: + async def all(self) -> Sequence[RowMapping]: """Return all rows in a list. Equivalent to :meth:`_asyncio.AsyncResult.all` except that @@ -705,11 +813,186 @@ class AsyncMappingResult(AsyncCommon[RowMapping]): return await greenlet_spawn(self._only_one_row, True, True, False) -_RT = TypeVar("_RT", bound="Result") +SelfAsyncTupleResult = TypeVar( + "SelfAsyncTupleResult", bound="AsyncTupleResult[Any]" +) + + +class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): + """a :class:`.AsyncResult` that's typed as returning plain Python tuples + instead of rows. + + Since :class:`.Row` acts like a tuple in every way already, + this class is a typing only class, regular :class:`.AsyncResult` is + still used at runtime. + + """ + + __slots__ = () + + if TYPE_CHECKING: + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_result.Result.partitions` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + async def fetchone(self) -> Optional[_R]: + """Fetch one tuple. + + Equivalent to :meth:`_result.Result.fetchone` except that + tuple values, rather than :class:`_result.Row` + objects, are returned. + + """ + ... + + async def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_engine.ScalarResult.all` method.""" + ... + + async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_result.Result.fetchmany` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + async def all(self) -> Sequence[_R]: # noqa: A001 + """Return all scalar values in a list. + + Equivalent to :meth:`_result.Result.all` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + async def __aiter__(self) -> AsyncIterator[_R]: + ... + + async def __anext__(self) -> _R: + ... + + async def first(self) -> Optional[_R]: + """Fetch the first object or None if no object is present. + + Equivalent to :meth:`_result.Result.first` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + + """ + ... + + async def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_result.Result.one_or_none` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + async def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_result.Result.one` except that + tuple values, rather than :class:`_result.Row` objects, + are returned. + + """ + ... + + @overload + async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: + ... + + @overload + async def scalar_one(self) -> Any: + ... + + async def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`.Result.scalars` and then + :meth:`.Result.one`. + + .. seealso:: + + :meth:`.Result.one` + + :meth:`.Result.scalars` + + """ + ... + + @overload + async def scalar_one_or_none( + self: AsyncTupleResult[Tuple[_T]], + ) -> Optional[_T]: + ... + + @overload + async def scalar_one_or_none(self) -> Optional[Any]: + ... + + async def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one or no scalar result. + + This is equivalent to calling :meth:`.Result.scalars` and then + :meth:`.Result.one_or_none`. + + .. seealso:: + + :meth:`.Result.one_or_none` + + :meth:`.Result.scalars` + + """ + ... + + @overload + async def scalar(self: AsyncTupleResult[Tuple[_T]]) -> Optional[_T]: + ... + + @overload + async def scalar(self) -> Any: + ... + + async def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result set. + + Returns None if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value , or None if no rows remain. + + """ + ... + + +_RT = TypeVar("_RT", bound="Result[Any]") async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT: - cursor_result: CursorResult + cursor_result: CursorResult[Any] try: is_cursor = result._is_cursor diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index c7a6e2ca01..22a060a0d4 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -12,10 +12,12 @@ from typing import Callable from typing import Iterable from typing import Iterator from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from .session import async_sessionmaker @@ -37,9 +39,9 @@ if TYPE_CHECKING: from ...engine import Engine from ...engine import Result from ...engine import Row + from ...engine import RowMapping from ...engine.interfaces import _CoreAnyExecuteParams from ...engine.interfaces import _CoreSingleExecuteParams - from ...engine.interfaces import _ExecuteOptions from ...engine.interfaces import _ExecuteOptionsParameter from ...engine.result import ScalarResult from ...orm._typing import _IdentityKeyType @@ -52,6 +54,9 @@ if TYPE_CHECKING: from ...sql.base import Executable from ...sql.elements import ClauseElement from ...sql.selectable import ForUpdateArg + from ...sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) @create_proxy_methods( @@ -480,6 +485,32 @@ class async_scoped_session: return await self._proxied.delete(instance) + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + async def execute( self, statement: Executable, @@ -488,7 +519,7 @@ class async_scoped_session: execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Result: + ) -> Result[Any]: r"""Execute a statement and return a buffered :class:`_engine.Result` object. @@ -916,6 +947,30 @@ class async_scoped_session: return await self._proxied.rollback() + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + async def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + async def scalar( self, statement: Executable, @@ -947,6 +1002,30 @@ class async_scoped_session: **kw, ) + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + async def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + async def scalars( self, statement: Executable, @@ -984,6 +1063,19 @@ class async_scoped_session: **kw, ) + @overload + async def stream( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[_T]: + ... + + @overload async def stream( self, statement: Executable, @@ -992,7 +1084,18 @@ class async_scoped_session: execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult: + ) -> AsyncResult[Any]: + ... + + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: r"""Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object. @@ -1012,6 +1115,30 @@ class async_scoped_session: **kw, ) + @overload + async def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[_T]: + ... + + @overload + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: + ... + async def stream_scalars( self, statement: Executable, @@ -1323,7 +1450,7 @@ class async_scoped_session: ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Row] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: r"""Return an identity key. diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 1422f99a39..f2a69e9cd9 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -12,10 +12,12 @@ from typing import Iterable from typing import Iterator from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import engine @@ -39,11 +41,10 @@ if TYPE_CHECKING: from ...engine import Engine from ...engine import Result from ...engine import Row + from ...engine import RowMapping from ...engine import ScalarResult - from ...engine import Transaction from ...engine.interfaces import _CoreAnyExecuteParams from ...engine.interfaces import _CoreSingleExecuteParams - from ...engine.interfaces import _ExecuteOptions from ...engine.interfaces import _ExecuteOptionsParameter from ...event import dispatcher from ...orm._typing import _IdentityKeyType @@ -59,9 +60,12 @@ if TYPE_CHECKING: from ...sql.base import Executable from ...sql.elements import ClauseElement from ...sql.selectable import ForUpdateArg + from ...sql.selectable import TypedReturnsRows _AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] +_T = TypeVar("_T", bound=Any) + class _SyncSessionCallable(Protocol): def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any: @@ -257,6 +261,32 @@ class AsyncSession(ReversibleProxy[Session]): return await greenlet_spawn(fn, self.sync_session, *arg, **kw) + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + async def execute( self, statement: Executable, @@ -265,7 +295,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Result: + ) -> Result[Any]: """Execute a statement and return a buffered :class:`_engine.Result` object. @@ -292,6 +322,30 @@ class AsyncSession(ReversibleProxy[Session]): ) return await _ensure_sync_result(result, self.execute) + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + async def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + async def scalar( self, statement: Executable, @@ -326,6 +380,30 @@ class AsyncSession(ReversibleProxy[Session]): ) return result + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + async def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + async def scalars( self, statement: Executable, @@ -391,6 +469,30 @@ class AsyncSession(ReversibleProxy[Session]): ) return result_obj + @overload + async def stream( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[_T]: + ... + + @overload + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: + ... + async def stream( self, statement: Executable, @@ -399,7 +501,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult: + ) -> AsyncResult[Any]: """Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object. @@ -423,6 +525,30 @@ class AsyncSession(ReversibleProxy[Session]): ) return AsyncResult(result) + @overload + async def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[_T]: + ... + + @overload + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: + ... + async def stream_scalars( self, statement: Executable, @@ -1215,7 +1341,7 @@ class AsyncSession(ReversibleProxy[Session]): ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Row] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: r"""Return an identity key. diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py index b1138a4ad8..c14b466ebd 100644 --- a/lib/sqlalchemy/ext/instrumentation.py +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -23,6 +23,7 @@ from ..orm import base as orm_base from ..orm import collections from ..orm import exc as orm_exc from ..orm import instrumentation as orm_instrumentation +from ..orm import util as orm_util from ..orm.instrumentation import _default_dict_getter from ..orm.instrumentation import _default_manager_getter from ..orm.instrumentation import _default_opt_manager_getter @@ -437,5 +438,7 @@ def _install_lookups(lookups): attributes.manager_of_class ) = orm_instrumentation.manager_of_class = manager_of_class orm_base.opt_manager_of_class = ( + orm_util.opt_manager_of_class + ) = ( attributes.opt_manager_of_class ) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 457ad5c5a6..48615b174b 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -38,6 +38,7 @@ from ..exc import InvalidRequestError from ..sql.base import SchemaEventTarget from ..sql.schema import SchemaConst from ..sql.selectable import FromClause +from ..util.typing import Annotated from ..util.typing import Literal if TYPE_CHECKING: @@ -45,6 +46,7 @@ if TYPE_CHECKING: from ._typing import _ORMColumnExprArgument from .descriptor_props import _CompositeAttrType from .interfaces import PropComparator + from .mapper import Mapper from .query import Query from .relationships import _LazyLoadArgumentType from .relationships import _ORMBackrefArgument @@ -1849,9 +1851,27 @@ def clear_mappers(): mapperlib._dispose_registries(mapperlib._all_registries(), False) +# I would really like a way to get the Type[] here that shows up +# in a different way in typing tools, however there is no current method +# that is accepted by mypy (subclass of Type[_O] works in pylance, rejected +# by mypy). +AliasedType = Annotated[Type[_O], "aliased"] + + +@overload +def aliased( + element: Type[_O], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> AliasedType[_O]: + ... + + @overload def aliased( - element: _EntityType[_O], + element: Union[AliasedClass[_O], Mapper[_O], AliasedInsp[_O]], alias: Optional[Union[Alias, Subquery]] = None, name: Optional[str] = None, flat: bool = False, @@ -1877,7 +1897,7 @@ def aliased( name: Optional[str] = None, flat: bool = False, adapt_on_names: bool = False, -) -> Union[AliasedClass[_O], FromClause]: +) -> Union[AliasedClass[_O], FromClause, AliasedType[_O]]: """Produce an alias of the given element, usually an :class:`.AliasedClass` instance. @@ -1885,7 +1905,8 @@ def aliased( my_alias = aliased(MyClass) - session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id) + stmt = select(MyClass, my_alias).filter(MyClass.id > my_alias.id) + result = session.execute(stmt) The :func:`.aliased` function is used to create an ad-hoc mapping of a mapped class to a new selectable. By default, a selectable is generated @@ -1911,6 +1932,9 @@ def aliased( .. seealso:: + :class:`.AsAliased` - a :pep:`484` typed version of + :func:`_orm.aliased` + :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial` :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel` diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 41d944c57d..619af65104 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -70,6 +70,7 @@ from .. import exc from .. import inspection from .. import util from ..sql import base as sql_base +from ..sql import cache_key from ..sql import roles from ..sql import traversals from ..sql import visitors @@ -99,10 +100,8 @@ class QueryableAttribute( traversals.HasCopyInternals, roles.JoinTargetRole, roles.OnClauseRole, - roles.ColumnsClauseRole, - roles.ExpressionElementRole[_T], sql_base.Immutable, - sql_base.MemoizedHasCacheKey, + cache_key.MemoizedHasCacheKey, ): """Base class for :term:`descriptor` objects that intercept attribute events on behalf of a :class:`.MapperProperty` diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 054d52d83b..367a5332de 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -30,6 +30,7 @@ from ._typing import insp_is_mapper from .. import exc as sa_exc from .. import inspection from .. import util +from ..sql import roles from ..sql.elements import SQLCoreOperations from ..util import FastIntFlag from ..util.langhelpers import TypingOnly @@ -483,19 +484,6 @@ def _inspect_mapped_class( return mapper -@inspection._inspects(type) -def _inspect_mc(class_: Type[_O]) -> Optional[Mapper[_O]]: - try: - class_manager = opt_manager_of_class(class_) - if class_manager is None or not class_manager.is_mapped: - return None - mapper = class_manager.mapper - except exc.NO_STATE: - return None - else: - return mapper - - def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]: insp = inspection.inspect(arg, raiseerr=False) if insp_is_mapper(insp): @@ -691,7 +679,7 @@ class ORMDescriptor(Generic[_T], TypingOnly): ... -class Mapped(ORMDescriptor[_T], TypingOnly): +class Mapped(ORMDescriptor[_T], roles.TypedColumnsClauseRole[_T], TypingOnly): """Represent an ORM mapped attribute on a mapped class. This class represents the complete descriptor interface for any class diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 4fee2d383d..05287cbcfd 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -17,6 +17,7 @@ from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import attributes @@ -48,14 +49,15 @@ from ..sql.base import _select_iterables from ..sql.base import CacheableOptions from ..sql.base import CompileState from ..sql.base import Executable +from ..sql.base import Generative from ..sql.base import Options from ..sql.dml import UpdateBase from ..sql.elements import GroupedElement from ..sql.elements import TextClause +from ..sql.selectable import ExecutableReturnsRows from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL -from ..sql.selectable import ReturnsRows from ..sql.selectable import Select from ..sql.selectable import SelectLabelStyle from ..sql.selectable import SelectState @@ -72,6 +74,7 @@ if TYPE_CHECKING: from ..sql.selectable import SelectBase from ..sql.type_api import TypeEngine +_T = TypeVar("_T", bound=Any) _path_registry = PathRegistry.root _EMPTY_DICT = util.immutabledict() @@ -574,7 +577,7 @@ class ORMFromStatementCompileState(ORMCompileState): return None -class FromStatement(GroupedElement, ReturnsRows, Executable): +class FromStatement(GroupedElement, Generative, ExecutableReturnsRows): """Core construct that represents a load of ORM objects from various :class:`.ReturnsRows` and other classes including: diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 0ca62b7e35..6a5690be24 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -61,7 +61,6 @@ from ..sql.schema import Column from ..sql.type_api import TypeEngine from ..util.typing import TypedDict - if typing.TYPE_CHECKING: from ._typing import _EntityType from ._typing import _IdentityKeyType @@ -106,12 +105,12 @@ class ORMStatementRole(roles.StatementRole): ) -class ORMColumnsClauseRole(roles.ColumnsClauseRole): +class ORMColumnsClauseRole(roles.TypedColumnsClauseRole[_T]): __slots__ = () _role_name = "ORM mapped entity, aliased entity, or Column expression" -class ORMEntityColumnsClauseRole(ORMColumnsClauseRole): +class ORMEntityColumnsClauseRole(ORMColumnsClauseRole[_T]): __slots__ = () _role_name = "ORM mapped or aliased entity" @@ -127,8 +126,8 @@ class ORMColumnDescription(TypedDict): # into "type" is a bad idea type: Union[Type[Any], TypeEngine[Any]] aliased: bool - expr: _ColumnsClauseArgument - entity: Optional[_ColumnsClauseArgument] + expr: _ColumnsClauseArgument[Any] + entity: Optional[_ColumnsClauseArgument[Any]] class _IntrospectsAnnotations: @@ -282,7 +281,7 @@ class MapperProperty( query_entity: _MapperEntity, path: PathRegistry, mapper: Mapper[Any], - result: Result, + result: Result[Any], adapter: Optional[ColumnAdapter], populators: _PopulatorDict, ) -> None: @@ -1170,7 +1169,7 @@ class LoaderStrategy: path: AbstractEntityRegistry, loadopt: Optional[_LoadElement], mapper: Mapper[Any], - result: Result, + result: Result[Any], adapter: Optional[ORMAdapter], populators: _PopulatorDict, ) -> None: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index b37c080eaf..0830350936 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -96,7 +96,6 @@ if TYPE_CHECKING: from .descriptor_props import Synonym from .events import MapperEvents from .instrumentation import ClassManager - from .path_registry import AbstractEntityRegistry from .path_registry import CachingEntityRegistry from .properties import ColumnProperty from .relationships import Relationship @@ -108,10 +107,10 @@ if TYPE_CHECKING: from ..sql.base import ReadOnlyColumnCollection from ..sql.elements import ColumnClause from ..sql.elements import ColumnElement + from ..sql.elements import KeyedColumnElement from ..sql.schema import Column from ..sql.schema import Table from ..sql.selectable import FromClause - from ..sql.selectable import TableClause from ..sql.util import ColumnAdapter from ..util import OrderedSet @@ -161,7 +160,7 @@ _CONFIGURE_MUTEX = threading.RLock() @log.class_logger class Mapper( ORMFromClauseRole, - ORMEntityColumnsClauseRole, + ORMEntityColumnsClauseRole[_O], MemoizedHasCacheKey, InspectionAttr, log.Identified, @@ -1006,7 +1005,7 @@ class Mapper( """ - polymorphic_on: Optional[ColumnElement[Any]] + polymorphic_on: Optional[KeyedColumnElement[Any]] """The :class:`_schema.Column` or SQL expression specified as the ``polymorphic_on`` argument for this :class:`_orm.Mapper`, within an inheritance scenario. @@ -1699,10 +1698,10 @@ class Mapper( instrument = True key = getattr(col, "key", None) if key: - if self._should_exclude(col.key, col.key, False, col): + if self._should_exclude(key, key, False, col): raise sa_exc.InvalidRequestError( "Cannot exclude or override the " - "discriminator column %r" % col.key + "discriminator column %r" % key ) else: self.polymorphic_on = col = col.label("_sa_polymorphic_on") @@ -2948,7 +2947,7 @@ class Mapper( def identity_key_from_row( self, - row: Optional[Union[Row, RowMapping]], + row: Optional[Union[Row[Any], RowMapping]], identity_token: Optional[Any] = None, adapter: Optional[ColumnAdapter] = None, ) -> _IdentityKeyType[_O]: diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 9f37e84571..0ca0559b45 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -54,7 +54,7 @@ from ..util.typing import NoneType if TYPE_CHECKING: from ._typing import _ORMColumnExprArgument from ..sql._typing import _InfoType - from ..sql.elements import ColumnElement + from ..sql.elements import KeyedColumnElement _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) @@ -85,7 +85,8 @@ class ColumnProperty( inherit_cache = True _links_to_entity = False - columns: List[ColumnElement[Any]] + columns: List[KeyedColumnElement[Any]] + _orig_columns: List[KeyedColumnElement[Any]] _is_polymorphic_discriminator: bool diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 395d01a1ea..5bd302b21d 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -27,6 +27,8 @@ from typing import Generic from typing import Iterable from typing import List from typing import Optional +from typing import overload +from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar @@ -36,6 +38,7 @@ from . import exc as orm_exc from . import interfaces from . import loading from . import util as orm_util +from ._typing import _O from .base import _assertions from .context import _column_descriptions from .context import _determine_last_joined_entity @@ -56,6 +59,7 @@ from .. import log from .. import sql from .. import util from ..engine import Result +from ..engine import Row from ..sql import coercions from ..sql import expression from ..sql import roles @@ -63,10 +67,12 @@ from ..sql import Select from ..sql import util as sql_util from ..sql import visitors from ..sql._typing import _FromClauseArgument +from ..sql._typing import _TP from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _entity_namespace_key from ..sql.base import _generative from ..sql.base import Executable +from ..sql.base import Generative from ..sql.expression import Exists from ..sql.selectable import _MemoizedSelectEntities from ..sql.selectable import _SelectFromElements @@ -75,10 +81,33 @@ from ..sql.selectable import HasHints from ..sql.selectable import HasPrefixes from ..sql.selectable import HasSuffixes from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util.typing import Literal if TYPE_CHECKING: + from ._typing import _EntityType + from .session import Session + from ..engine.result import ScalarResult + from ..engine.row import Row + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _MAYBE_ENTITY + from ..sql._typing import _no_kw + from ..sql._typing import _NOT_ENTITY + from ..sql._typing import _PropagateAttrsType + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA + from ..sql.roles import TypedColumnsClauseRole from ..sql.selectable import _SetupJoinsElement from ..sql.selectable import Alias + from ..sql.selectable import ExecutableReturnsRows + from ..sql.selectable import ScalarSelect from ..sql.selectable import Subquery __all__ = ["Query", "QueryContext"] @@ -97,6 +126,7 @@ class Query( HasSuffixes, HasHints, log.Identified, + Generative, Executable, Generic[_T], ): @@ -159,9 +189,15 @@ class Query( # mirrors that of ClauseElement, used to propagate the "orm" # plugin as well as the "subject" of the plugin, e.g. the mapper # we are querying against. - _propagate_attrs = util.immutabledict() + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + return util.EMPTY_DICT - def __init__(self, entities, session=None): + def __init__( + self, + entities: Sequence[_ColumnsClauseArgument[Any]], + session: Optional[Session] = None, + ): """Construct a :class:`_query.Query` directly. E.g.:: @@ -207,6 +243,36 @@ class Query( for ent in util.to_list(entities) ] + @overload + def tuples(self: Query[Row[_TP]]) -> Query[_TP]: + ... + + @overload + def tuples(self: Query[_O]) -> Query[Tuple[_O]]: + ... + + def tuples(self) -> Query[Any]: + """return a tuple-typed form of this :class:`.Query`. + + This method invokes the :meth:`.Query.only_return_tuples` + method with a value of ``True``, which by itself ensures that this + :class:`.Query` will always return :class:`.Row` objects, even + if the query is made against a single entity. It then also + at the typing level will return a "typed" query, if possible, + that will type result rows as ``Tuple`` objects with typed + elements. + + This method can be compared to the :meth:`.Result.tuples` method, + which returns "self", but from a typing perspective returns an object + that will yield typed ``Tuple`` objects for results. Typing + takes effect only if this :class:`.Query` object is a typed + query object already. + + .. versionadded:: 2.0 + + """ + return self.only_return_tuples(True) + def _entity_from_pre_ent_zero(self): if not self._raw_columns: return None @@ -582,20 +648,52 @@ class Query( return self.enable_eagerloads(False).statement.label(name) + @overload + def as_scalar( + self: Query[Tuple[_MAYBE_ENTITY]], + ) -> ScalarSelect[_MAYBE_ENTITY]: + ... + + @overload + def as_scalar( + self: Query[Tuple[_NOT_ENTITY]], + ) -> ScalarSelect[_NOT_ENTITY]: + ... + + @overload + def as_scalar(self) -> ScalarSelect[Any]: + ... + @util.deprecated( "1.4", "The :meth:`_query.Query.as_scalar` method is deprecated and will be " "removed in a future release. Please refer to " ":meth:`_query.Query.scalar_subquery`.", ) - def as_scalar(self): + def as_scalar(self) -> ScalarSelect[Any]: """Return the full SELECT statement represented by this :class:`_query.Query`, converted to a scalar subquery. """ return self.scalar_subquery() - def scalar_subquery(self): + @overload + def scalar_subquery( + self: Query[Tuple[_MAYBE_ENTITY]], + ) -> ScalarSelect[Any]: + ... + + @overload + def scalar_subquery( + self: Query[Tuple[_NOT_ENTITY]], + ) -> ScalarSelect[_NOT_ENTITY]: + ... + + @overload + def scalar_subquery(self) -> ScalarSelect[Any]: + ... + + def scalar_subquery(self) -> ScalarSelect[Any]: """Return the full SELECT statement represented by this :class:`_query.Query`, converted to a scalar subquery. @@ -630,16 +728,31 @@ class Query( .statement ) - @_generative - def only_return_tuples(self: SelfQuery, value) -> SelfQuery: - """When set to True, the query results will always be a tuple. + @overload + def only_return_tuples( + self: Query[_O], value: Literal[True] + ) -> RowReturningQuery[Tuple[_O]]: + ... - This is specifically for single element queries. The default is False. + @overload + def only_return_tuples( + self: Query[_O], value: Literal[False] + ) -> Query[_O]: + ... - .. versionadded:: 1.2.5 + @_generative + def only_return_tuples(self, value: bool) -> Query[Any]: + """When set to True, the query results will always be a + :class:`.Row` object. + + This can change a query that normally returns a single entity + as a scalar to return a :class:`.Row` result in all cases. .. seealso:: + :meth:`.Query.tuples` - returns tuples, but also at the typing + level will type results as ``Tuple``. + :meth:`_query.Query.is_single_entity` """ @@ -1077,7 +1190,11 @@ class Query( return self.filter(with_parent(instance, property, entity_zero.entity)) @_generative - def add_entity(self: SelfQuery, entity, alias=None) -> SelfQuery: + def add_entity( + self, + entity: _EntityType[Any], + alias: Optional[Union[Alias, Subquery]] = None, + ) -> Query[Any]: """add a mapped entity to the list of result columns to be returned.""" @@ -1209,8 +1326,107 @@ class Query( except StopIteration: return None + @overload + def with_entities( + self, _entity: _EntityType[_O], **kwargs: Any + ) -> ScalarInstanceQuery[_O]: + ... + + @overload + def with_entities( + self, _colexpr: TypedColumnsClauseRole[_T] + ) -> RowReturningQuery[Tuple[_T]]: + ... + + # START OVERLOADED FUNCTIONS self.with_entities RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def with_entities( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... + + @overload + def with_entities( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.with_entities + + @overload + def with_entities( + self: SelfQuery, *entities: _ColumnsClauseArgument[Any] + ) -> SelfQuery: + ... + @_generative - def with_entities(self: SelfQuery, *entities) -> SelfQuery: + def with_entities( + self: SelfQuery, *entities: _ColumnsClauseArgument[Any], **__kw: Any + ) -> SelfQuery: r"""Return a new :class:`_query.Query` replacing the SELECT list with the given entities. @@ -1234,12 +1450,14 @@ class Query( limit(1) """ + if __kw: + raise _no_kw() _MemoizedSelectEntities._generate_for_statement(self) self._set_entities(entities) return self @_generative - def add_columns(self: SelfQuery, *column) -> SelfQuery: + def add_columns(self, *column: _ColumnExpressionArgument) -> Query[Any]: """Add one or more column expressions to the list of result columns to be returned.""" @@ -1262,7 +1480,7 @@ class Query( "is deprecated and will be removed in a " "future release. Please use :meth:`_query.Query.add_columns`", ) - def add_column(self, column): + def add_column(self, column) -> Query[Any]: """Add a column expression to the list of result columns to be returned. @@ -1472,7 +1690,9 @@ class Query( @_generative @_assertions(_no_statement_condition, _no_limit_offset) - def filter(self: SelfQuery, *criterion) -> SelfQuery: + def filter( + self: SelfQuery, *criterion: _ColumnExpressionArgument[bool] + ) -> SelfQuery: r"""Apply the given filtering criterion to a copy of this :class:`_query.Query`, using SQL expressions. @@ -1556,7 +1776,7 @@ class Query( return self._raw_columns[0] - def filter_by(self, **kwargs): + def filter_by(self: SelfQuery, **kwargs: Any) -> SelfQuery: r"""Apply the given filtering criterion to a copy of this :class:`_query.Query`, using keyword expressions. @@ -1597,7 +1817,9 @@ class Query( @_generative @_assertions(_no_statement_condition, _no_limit_offset) - def order_by(self: SelfQuery, *clauses) -> SelfQuery: + def order_by( + self: SelfQuery, *clauses: _ColumnExpressionArgument[Any] + ) -> SelfQuery: """Apply one or more ORDER BY criteria to the query and return the newly resulting :class:`_query.Query`. @@ -1635,7 +1857,9 @@ class Query( @_generative @_assertions(_no_statement_condition, _no_limit_offset) - def group_by(self: SelfQuery, *clauses) -> SelfQuery: + def group_by( + self: SelfQuery, *clauses: _ColumnExpressionArgument[Any] + ) -> SelfQuery: """Apply one or more GROUP BY criterion to the query and return the newly resulting :class:`_query.Query`. @@ -1667,7 +1891,9 @@ class Query( @_generative @_assertions(_no_statement_condition, _no_limit_offset) - def having(self: SelfQuery, criterion) -> SelfQuery: + def having( + self: SelfQuery, *having: _ColumnExpressionArgument[bool] + ) -> SelfQuery: r"""Apply a HAVING criterion to the query and return the newly resulting :class:`_query.Query`. @@ -1684,17 +1910,17 @@ class Query( """ - self._having_criteria += ( - coercions.expect( - roles.WhereHavingRole, criterion, apply_propagate_attrs=self - ), - ) + for criterion in having: + having_criteria = coercions.expect( + roles.WhereHavingRole, criterion + ) + self._having_criteria += (having_criteria,) return self def _set_op(self, expr_fn, *q): return self._from_selectable(expr_fn(*([self] + list(q))).subquery()) - def union(self, *q): + def union(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce a UNION of this Query against one or more queries. e.g.:: @@ -1733,7 +1959,7 @@ class Query( """ return self._set_op(expression.union, *q) - def union_all(self, *q): + def union_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce a UNION ALL of this Query against one or more queries. Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See @@ -1742,7 +1968,7 @@ class Query( """ return self._set_op(expression.union_all, *q) - def intersect(self, *q): + def intersect(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce an INTERSECT of this Query against one or more queries. Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See @@ -1751,7 +1977,7 @@ class Query( """ return self._set_op(expression.intersect, *q) - def intersect_all(self, *q): + def intersect_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce an INTERSECT ALL of this Query against one or more queries. Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See @@ -1760,7 +1986,7 @@ class Query( """ return self._set_op(expression.intersect_all, *q) - def except_(self, *q): + def except_(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce an EXCEPT of this Query against one or more queries. Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See @@ -1769,7 +1995,7 @@ class Query( """ return self._set_op(expression.except_, *q) - def except_all(self, *q): + def except_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce an EXCEPT ALL of this Query against one or more queries. Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See @@ -2194,7 +2420,9 @@ class Query( @_generative @_assertions(_no_clauseelement_condition) - def from_statement(self: SelfQuery, statement) -> SelfQuery: + def from_statement( + self: SelfQuery, statement: ExecutableReturnsRows + ) -> SelfQuery: """Execute the given SELECT statement and return results. This method bypasses all internal statement compilation, and the @@ -2283,7 +2511,7 @@ class Query( :meth:`_query.Query.one_or_none` """ - return self._iter().one() + return self._iter().one() # type: ignore def scalar(self) -> Any: """Return the first element of the first result or None @@ -2316,7 +2544,7 @@ class Query( def __iter__(self) -> Iterable[_T]: return self._iter().__iter__() - def _iter(self): + def _iter(self) -> Union[ScalarResult[_T], Result[_T]]: # new style execution. params = self._params @@ -2837,3 +3065,7 @@ class BulkUpdate(BulkUD): class BulkDelete(BulkUD): """BulkUD which handles DELETEs.""" + + +class RowReturningQuery(Query[Row[_TP]]): + pass diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 93d18b8d79..9220c44c7f 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -13,6 +13,7 @@ from typing import Dict from typing import Iterable from typing import Iterator from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import Type @@ -20,8 +21,6 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from . import exc as orm_exc -from .base import class_mapper from .session import Session from .. import exc as sa_exc from .. import util @@ -33,11 +32,13 @@ from ..util import warn_deprecated from ..util.typing import Protocol if TYPE_CHECKING: + from ._typing import _EntityType from ._typing import _IdentityKeyType from .identity import IdentityMap from .interfaces import ORMOption from .mapper import Mapper from .query import Query + from .query import RowReturningQuery from .session import _BindArguments from .session import _EntityBindKey from .session import _PKIdentityArgument @@ -48,19 +49,33 @@ if TYPE_CHECKING: from ..engine import Engine from ..engine import Result from ..engine import Row + from ..engine import RowMapping from ..engine.interfaces import _CoreAnyExecuteParams from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptions from ..engine.interfaces import _ExecuteOptionsParameter from ..engine.result import ScalarResult from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable from ..sql.elements import ClauseElement + from ..sql.roles import TypedColumnsClauseRole from ..sql.selectable import ForUpdateArg + from ..sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) class _QueryDescriptorType(Protocol): - def __get__(self, instance: Any, owner: Type[Any]) -> Optional[Query[Any]]: + def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]: ... @@ -236,7 +251,7 @@ class scoped_session: self.registry.clear() def query_property( - self, query_cls: Optional[Type[Query[Any]]] = None + self, query_cls: Optional[Type[Query[_T]]] = None ) -> _QueryDescriptorType: """return a class property which produces a :class:`_query.Query` object @@ -264,20 +279,13 @@ class scoped_session: """ class query: - def __get__( - s, instance: Any, owner: Type[Any] - ) -> Optional[Query[Any]]: - try: - mapper = class_mapper(owner) - assert mapper is not None - if query_cls: - # custom query class - return query_cls(mapper, session=self.registry()) - else: - # session's configured query class - return self.registry().query(mapper) - except orm_exc.UnmappedClassError: - return None + def __get__(s, instance: Any, owner: Type[_O]) -> Query[_O]: + if query_cls: + # custom query class + return query_cls(owner, session=self.registry()) # type: ignore # noqa: E501 + else: + # session's configured query class + return self.registry().query(owner) return query() @@ -548,6 +556,32 @@ class scoped_session: return self._proxied.delete(instance) + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + def execute( self, statement: Executable, @@ -557,7 +591,7 @@ class scoped_session: bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result: + ) -> Result[Any]: r"""Execute a SQL expression construct. .. container:: class_bases @@ -1430,8 +1464,103 @@ class scoped_session: return self._proxied.merge(instance, load=load, options=options) + @overload + def query(self, _entity: _EntityType[_O]) -> Query[_O]: + ... + + @overload def query( - self, *entities: _ColumnsClauseArgument, **kwargs: Any + self, _colexpr: TypedColumnsClauseRole[_T] + ) -> RowReturningQuery[Tuple[_T]]: + ... + + # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.query + + @overload + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: + ... + + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any ) -> Query[Any]: r"""Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. @@ -1559,6 +1688,30 @@ class scoped_session: return self._proxied.rollback() + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + def scalar( self, statement: Executable, @@ -1590,6 +1743,30 @@ class scoped_session: **kw, ) + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + def scalars( self, statement: Executable, @@ -1848,7 +2025,7 @@ class scoped_session: ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Row] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: r"""Return an identity key. diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 74035ec0aa..263d561019 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -27,6 +27,7 @@ from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union import weakref @@ -85,12 +86,16 @@ from ..util.typing import Literal from ..util.typing import Protocol if typing.TYPE_CHECKING: + from ._typing import _EntityType from ._typing import _IdentityKeyType from ._typing import _InstanceDict + from ._typing import _O + from .context import FromStatement from .interfaces import ORMOption from .interfaces import UserDefinedOption from .mapper import Mapper from .path_registry import PathRegistry + from .query import RowReturningQuery from ..engine import Result from ..engine import Row from ..engine import RowMapping @@ -104,10 +109,23 @@ if typing.TYPE_CHECKING: from ..event import _InstanceLevelDispatch from ..sql._typing import _ColumnsClauseArgument from ..sql._typing import _InfoType + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable from ..sql.elements import ClauseElement + from ..sql.roles import TypedColumnsClauseRole from ..sql.schema import Table - from ..sql.selectable import TableClause + from ..sql.selectable import Select + from ..sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) __all__ = [ "Session", @@ -189,7 +207,7 @@ class _SessionClassMethods: ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Union[Row, RowMapping]] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: """Return an identity key. @@ -295,7 +313,7 @@ class ORMExecuteState(util.MemoizedSlots): params: Optional[_CoreAnyExecuteParams] = None, execution_options: Optional[_ExecuteOptionsParameter] = None, bind_arguments: Optional[_BindArguments] = None, - ) -> Result: + ) -> Result[Any]: """Execute the statement represented by this :class:`.ORMExecuteState`, without re-invoking events that have already proceeded. @@ -1718,7 +1736,7 @@ class Session(_SessionClassMethods, EventTarget): _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, _scalar_result: bool = ..., - ) -> Result: + ) -> Result[Any]: ... def _execute_internal( @@ -1789,7 +1807,7 @@ class Session(_SessionClassMethods, EventTarget): ) for idx, fn in enumerate(events_todo): orm_exec_state._starting_event_idx = idx - fn_result: Optional[Result] = fn(orm_exec_state) + fn_result: Optional[Result[Any]] = fn(orm_exec_state) if fn_result: if _scalar_result: return fn_result.scalar() @@ -1806,10 +1824,12 @@ class Session(_SessionClassMethods, EventTarget): if _scalar_result and not compile_state_cls: if TYPE_CHECKING: params = cast(_CoreSingleExecuteParams, params) - return conn.scalar(statement, params or {}, execution_options) + return conn.scalar( + statement, params or {}, execution_options=execution_options + ) - result: Result = conn.execute( - statement, params or {}, execution_options + result: Result[Any] = conn.execute( + statement, params or {}, execution_options=execution_options ) if compile_state_cls: @@ -1827,6 +1847,32 @@ class Session(_SessionClassMethods, EventTarget): else: return result + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + def execute( self, statement: Executable, @@ -1836,7 +1882,7 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result: + ) -> Result[Any]: r"""Execute a SQL expression construct. Returns a :class:`_engine.Result` object representing @@ -1897,6 +1943,30 @@ class Session(_SessionClassMethods, EventTarget): _add_event=_add_event, ) + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + def scalar( self, statement: Executable, @@ -1923,6 +1993,30 @@ class Session(_SessionClassMethods, EventTarget): **kw, ) + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + def scalars( self, statement: Executable, @@ -2284,8 +2378,103 @@ class Session(_SessionClassMethods, EventTarget): f'{", ".join(context)} or this Session.' ) + @overload + def query(self, _entity: _EntityType[_O]) -> Query[_O]: + ... + + @overload + def query( + self, _colexpr: TypedColumnsClauseRole[_T] + ) -> RowReturningQuery[Tuple[_T]]: + ... + + # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.query + + @overload + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: + ... + def query( - self, *entities: _ColumnsClauseArgument, **kwargs: Any + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any ) -> Query[Any]: """Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. @@ -2486,7 +2675,7 @@ class Session(_SessionClassMethods, EventTarget): with_for_update = ForUpdateArg._from_argument(with_for_update) - stmt = sql.select(object_mapper(instance)) + stmt: Select[Any] = sql.select(object_mapper(instance)) if ( loading.load_on_ident( self, diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 58f141997e..ab32a3981a 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -656,13 +656,13 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): @classmethod def _instance_level_callable_processor( cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any - ) -> Callable[[InstanceState[_O], _InstanceDict, Row], None]: + ) -> Callable[[InstanceState[_O], _InstanceDict, Row[Any]], None]: impl = manager[key].impl if is_collection_impl(impl): fixed_impl = impl def _set_callable( - state: InstanceState[_O], dict_: _InstanceDict, row: Row + state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] ) -> None: if "callables" not in state.__dict__: state.callables = {} @@ -674,7 +674,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): else: def _set_callable( - state: InstanceState[_O], dict_: _InstanceDict, row: Row + state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] ) -> None: if "callables" not in state.__dict__: state.callables = {} diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 3934de5355..8148793b12 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -28,6 +28,7 @@ from typing import Union import weakref from . import attributes # noqa +from . import exc from ._typing import _O from ._typing import insp_is_aliased_class from ._typing import insp_is_mapper @@ -41,6 +42,7 @@ from .base import InspectionAttr as InspectionAttr from .base import instance_str as instance_str from .base import object_mapper as object_mapper from .base import object_state as object_state +from .base import opt_manager_of_class from .base import state_attribute_str as state_attribute_str from .base import state_class_str as state_class_str from .base import state_str as state_str @@ -68,6 +70,7 @@ from ..sql.base import ColumnCollection from ..sql.cache_key import HasCacheKey from ..sql.cache_key import MemoizedHasCacheKey from ..sql.elements import ColumnElement +from ..sql.elements import KeyedColumnElement from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots from ..util.typing import de_stringify_annotation @@ -95,9 +98,7 @@ if typing.TYPE_CHECKING: from ..sql.selectable import _ColumnsClauseElement from ..sql.selectable import Alias from ..sql.selectable import Subquery - from ..sql.visitors import _ET from ..sql.visitors import anon_map - from ..sql.visitors import ExternallyTraversible _T = TypeVar("_T", bound=Any) @@ -341,7 +342,7 @@ def identity_key( ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[_T] = None, - row: Optional[Union[Row, RowMapping]] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[_T]: r"""Generate "identity key" tuples, as are used as keys in the @@ -468,7 +469,9 @@ class ORMAdapter(sql_util.ColumnAdapter): return not entity or entity.isa(self.mapper) -class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]): +class AliasedClass( + inspection.Inspectable["AliasedInsp[_O]"], ORMColumnsClauseRole[_O] +): r"""Represents an "aliased" form of a mapped class for usage with Query. The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias` @@ -663,7 +666,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]): @inspection._self_inspects class AliasedInsp( - ORMEntityColumnsClauseRole, + ORMEntityColumnsClauseRole[_O], ORMFromClauseRole, HasCacheKey, InspectionAttr, @@ -1276,12 +1279,29 @@ class LoaderCriteriaOption(CriteriaOption): inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) +@inspection._inspects(type) +def _inspect_mc( + class_: Type[_O], +) -> Optional[Mapper[_O]]: + + try: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: + return None + mapper = class_manager.mapper + except exc.NO_STATE: + + return None + else: + return mapper + + @inspection._self_inspects class Bundle( - ORMColumnsClauseRole, + ORMColumnsClauseRole[_T], SupportsCloneAnnotations, MemoizedHasCacheKey, - inspection.Inspectable["Bundle"], + inspection.Inspectable["Bundle[_T]"], InspectionAttr, ): """A grouping of SQL expressions that are returned by a :class:`.Query` @@ -1373,10 +1393,10 @@ class Bundle( @property def entity_namespace( self, - ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]: + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: return self.c - columns: ReadOnlyColumnCollection[str, ColumnElement[Any]] + columns: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]] """A namespace of SQL expressions referred to by this :class:`.Bundle`. @@ -1402,7 +1422,7 @@ class Bundle( """ - c: ReadOnlyColumnCollection[str, ColumnElement[Any]] + c: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]] """An alias for :attr:`.Bundle.columns`.""" def _clone(self): @@ -1908,9 +1928,10 @@ def _extract_mapped_subtype( raw_annotation: Union[type, str], cls: type, key: str, - attr_cls: type, + attr_cls: Type[Any], required: bool, is_dataclass_field: bool, + superclasses: Optional[Tuple[Type[Any], ...]] = None, ) -> Optional[Union[type, str]]: if raw_annotation is None: @@ -1930,9 +1951,13 @@ def _extract_mapped_subtype( if is_dataclass_field: return annotated else: - if ( - not hasattr(annotated, "__origin__") - or not issubclass(annotated.__origin__, attr_cls) # type: ignore + # TODO: there don't seem to be tests for the failure + # conditions here + if not hasattr(annotated, "__origin__") or ( + not issubclass( + annotated.__origin__, # type: ignore + superclasses if superclasses else attr_cls, + ) and not issubclass(attr_cls, annotated.__origin__) # type: ignore ): our_annotated_str = ( diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 84913225d7..c3ebb45960 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -121,7 +121,6 @@ def __go(lcls: Any) -> None: coercions.lambdas = lambdas coercions.schema = schema coercions.selectable = selectable - coercions.traversals = traversals from .annotation import _prepare_annotations from .annotation import Annotated diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 37d44976a2..f89e8f578d 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -9,12 +9,16 @@ from __future__ import annotations from typing import Any from typing import Optional +from typing import overload +from typing import Tuple from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import coercions from . import roles from ._typing import _ColumnsClauseArgument +from ._typing import _no_kw from .elements import ColumnClause from .selectable import Alias from .selectable import CompoundSelect @@ -34,6 +38,17 @@ if TYPE_CHECKING: from ._typing import _FromClauseArgument from ._typing import _OnClauseArgument from ._typing import _SelectStatementForCompoundArgument + from ._typing import _T0 + from ._typing import _T1 + from ._typing import _T2 + from ._typing import _T3 + from ._typing import _T4 + from ._typing import _T5 + from ._typing import _T6 + from ._typing import _T7 + from ._typing import _T8 + from ._typing import _T9 + from ._typing import _TypedColumnClauseArgument as _TCCA from .functions import Function from .selectable import CTE from .selectable import HasCTE @@ -41,6 +56,9 @@ if TYPE_CHECKING: from .selectable import SelectBase +_T = TypeVar("_T", bound=Any) + + def alias( selectable: FromClause, name: Optional[str] = None, flat: bool = False ) -> NamedFromClause: @@ -89,7 +107,9 @@ def cte( ) -def except_(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def except_( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``EXCEPT`` of multiple selectables. The returned object is an instance of @@ -119,7 +139,7 @@ def except_all( def exists( __argument: Optional[ - Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]] ] = None, ) -> Exists: """Construct a new :class:`_expression.Exists` construct. @@ -162,7 +182,9 @@ def exists( return Exists(__argument) -def intersect(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def intersect( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``INTERSECT`` of multiple selectables. The returned object is an instance of @@ -306,7 +328,129 @@ def outerjoin( return Join(left, right, onclause, isouter=True, full=full) -def select(*entities: _ColumnsClauseArgument) -> Select: +# START OVERLOADED FUNCTIONS select Select 1-10 + +# code within this block is **programmatically, +# statically generated** by tools/generate_tuple_map_overloads.py + + +@overload +def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: + ... + + +@overload +def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[Tuple[_T0, _T1]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] +) -> Select[Tuple[_T0, _T1, _T2]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], +) -> Select[Tuple[_T0, _T1, _T2, _T3]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + __ent8: _TCCA[_T8], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]: + ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + __ent8: _TCCA[_T8], + __ent9: _TCCA[_T9], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]: + ... + + +# END OVERLOADED FUNCTIONS select + + +@overload +def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: + ... + + +def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: r"""Construct a new :class:`_expression.Select`. @@ -343,7 +487,11 @@ def select(*entities: _ColumnsClauseArgument) -> Select: given, as well as ORM-mapped classes. """ - + # the keyword args are a necessary element in order for the typing + # to work out w/ the varargs vs. having named "keyword" arguments that + # aren't always present. + if __kw: + raise _no_kw() return Select(*entities) @@ -425,7 +573,9 @@ def tablesample( return TableSample._factory(selectable, sampling, name=name, seed=seed) -def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def union( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return a ``UNION`` of multiple selectables. The returned object is an instance of @@ -445,7 +595,9 @@ def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: return CompoundSelect._create_union(*selects) -def union_all(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect: +def union_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return a ``UNION ALL`` of multiple selectables. The returned object is an instance of diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 53d29b628f..1df530dbd6 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -5,18 +5,27 @@ from typing import Any from typing import Callable from typing import Dict from typing import Set +from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from . import roles +from .. import exc from .. import util from ..inspection import Inspectable from ..util.typing import Literal from ..util.typing import Protocol if TYPE_CHECKING: + from datetime import date + from datetime import datetime + from datetime import time + from datetime import timedelta + from decimal import Decimal + from uuid import UUID + from .base import Executable from .compiler import Compiled from .compiler import DDLCompiler @@ -26,17 +35,15 @@ if TYPE_CHECKING: from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement + from .elements import KeyedColumnElement from .elements import quoted_name - from .elements import SQLCoreOperations from .elements import TextClause from .lambdas import LambdaElement from .roles import ColumnsClauseRole from .roles import FromClauseRole from .schema import Column - from .schema import DefaultGenerator - from .schema import Sequence - from .schema import Table from .selectable import Alias + from .selectable import CTE from .selectable import FromClause from .selectable import Join from .selectable import NamedFromClause @@ -61,6 +68,30 @@ class _HasClauseElement(Protocol): ... +# match column types that are not ORM entities +_NOT_ENTITY = TypeVar( + "_NOT_ENTITY", + int, + str, + "datetime", + "date", + "time", + "timedelta", + "UUID", + float, + "Decimal", +) + +_MAYBE_ENTITY = TypeVar( + "_MAYBE_ENTITY", + roles.ColumnsClauseRole, + Literal["*", 1], + Type[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, +) + + # convention: # XYZArgument - something that the end user is passing to a public API method # XYZElement - the internal representation that we use for the thing. @@ -76,9 +107,10 @@ _TextCoercedExpressionArgument = Union[ ] _ColumnsClauseArgument = Union[ - Literal["*", 1], + roles.TypedColumnsClauseRole[_T], roles.ColumnsClauseRole, - Type[Any], + Literal["*", 1], + Type[_T], Inspectable[_HasClauseElement], _HasClauseElement, ] @@ -92,6 +124,24 @@ sets; select(...), insert().returning(...), etc. """ +_TypedColumnClauseArgument = Union[ + roles.TypedColumnsClauseRole[_T], roles.ExpressionElementRole[_T], Type[_T] +] + +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + +_T0 = TypeVar("_T0", bound=Any) +_T1 = TypeVar("_T1", bound=Any) +_T2 = TypeVar("_T2", bound=Any) +_T3 = TypeVar("_T3", bound=Any) +_T4 = TypeVar("_T4", bound=Any) +_T5 = TypeVar("_T5", bound=Any) +_T6 = TypeVar("_T6", bound=Any) +_T7 = TypeVar("_T7", bound=Any) +_T8 = TypeVar("_T8", bound=Any) +_T9 = TypeVar("_T9", bound=Any) + + _ColumnExpressionArgument = Union[ "ColumnElement[_T]", _HasClauseElement, @@ -169,6 +219,7 @@ _DMLTableArgument = Union[ "TableClause", "Join", "Alias", + "CTE", Type[Any], Inspectable[_HasClauseElement], _HasClauseElement, @@ -194,6 +245,11 @@ if TYPE_CHECKING: def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: ... + def is_keyed_column_element( + c: ClauseElement, + ) -> TypeGuard[KeyedColumnElement[Any]]: + ... + def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: ... @@ -216,7 +272,7 @@ if TYPE_CHECKING: def is_select_statement( t: Union[Executable, ReturnsRows] - ) -> TypeGuard[Select]: + ) -> TypeGuard[Select[Any]]: ... def is_table(t: FromClause) -> TypeGuard[TableClause]: @@ -234,6 +290,7 @@ else: is_ddl_compiler = operator.attrgetter("is_ddl") is_named_from_clause = operator.attrgetter("named_with_column") is_column_element = operator.attrgetter("_is_column_element") + is_keyed_column_element = operator.attrgetter("_is_keyed_column_element") is_text_clause = operator.attrgetter("_is_text_clause") is_from_clause = operator.attrgetter("_is_from_clause") is_tuple_type = operator.attrgetter("_is_tuple_type") @@ -260,3 +317,10 @@ def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]: def is_insert_update(c: ClauseElement) -> TypeGuard[ValuesBase]: return c.is_dml and (c.is_insert or c.is_update) # type: ignore + + +def _no_kw() -> exc.ArgumentError: + return exc.ArgumentError( + "Additional keyword arguments are not accepted by this " + "function/method. The presence of **kw is for pep-484 typing purposes" + ) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f81878d55d..790edefc6e 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -62,10 +62,10 @@ if TYPE_CHECKING: from . import coercions from . import elements from . import type_api - from ._typing import _ColumnsClauseArgument from .elements import BindParameter - from .elements import ColumnClause + from .elements import ColumnClause # noqa from .elements import ColumnElement + from .elements import KeyedColumnElement from .elements import NamedColumn from .elements import SQLCoreOperations from .elements import TextClause @@ -74,7 +74,6 @@ if TYPE_CHECKING: from .selectable import FromClause from ..engine import Connection from ..engine import CursorResult - from ..engine import Result from ..engine.base import _CompiledCacheType from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import _ExecuteOptions @@ -704,8 +703,11 @@ class InPlaceGenerative(HasMemoized): """Provide a method-chaining pattern in conjunction with the @_generative decorator that mutates in place.""" + __slots__ = () + def _generate(self): skip = self._memoized_keys + # note __dict__ needs to be in __slots__ if this is used for k in skip: self.__dict__.pop(k, None) return self @@ -937,7 +939,7 @@ class ExecutableOption(HasCopyInternals): SelfExecutable = TypeVar("SelfExecutable", bound="Executable") -class Executable(roles.StatementRole, Generative): +class Executable(roles.StatementRole): """Mark a :class:`_expression.ClauseElement` as supporting execution. :class:`.Executable` is a superclass for all "statement" types @@ -994,7 +996,7 @@ class Executable(roles.StatementRole, Generative): connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: ... def _execute_on_scalar( @@ -1253,7 +1255,7 @@ class SchemaVisitor(ClauseVisitor): _COLKEY = TypeVar("_COLKEY", Union[None, str], str) _COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True) -_COL = TypeVar("_COL", bound="ColumnElement[Any]") +_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]") class ColumnCollection(Generic[_COLKEY, _COL_co]): @@ -1505,6 +1507,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) + self._collection[:] = cols self._colset.update(c for k, c in self._collection) self._index.update( diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 0659709ab4..9b7231360e 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -29,6 +29,7 @@ from typing import Union from . import operators from . import roles from . import visitors +from ._typing import is_from_clause from .base import ExecutableOption from .base import Options from .cache_key import HasCacheKey @@ -38,25 +39,18 @@ from .. import inspection from .. import util from ..util.typing import Literal -if not typing.TYPE_CHECKING: - elements = None - lambdas = None - schema = None - selectable = None - traversals = None - if typing.TYPE_CHECKING: from . import elements from . import lambdas from . import schema from . import selectable - from . import traversals from ._typing import _ColumnExpressionArgument from ._typing import _ColumnsClauseArgument from ._typing import _DDLColumnArgument from ._typing import _DMLTableArgument from ._typing import _FromClauseArgument from .dml import _DMLTableElement + from .elements import BindParameter from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement @@ -64,9 +58,7 @@ if typing.TYPE_CHECKING: from .elements import SQLCoreOperations from .schema import Column from .selectable import _ColumnsClauseElement - from .selectable import _JoinTargetElement from .selectable import _JoinTargetProtocol - from .selectable import _OnClauseElement from .selectable import FromClause from .selectable import HasCTE from .selectable import SelectBase @@ -168,6 +160,15 @@ def expect( ... +@overload +def expect( + role: Type[roles.LiteralValueRole], + element: Any, + **kw: Any, +) -> BindParameter[Any]: + ... + + @overload def expect( role: Type[roles.DDLReferredColumnRole], @@ -272,7 +273,7 @@ def expect( @overload def expect( role: Type[roles.ColumnsClauseRole], - element: _ColumnsClauseArgument, + element: _ColumnsClauseArgument[Any], **kw: Any, ) -> _ColumnsClauseElement: ... @@ -933,7 +934,7 @@ class GroupByImpl(ByOfImpl, RoleImpl): argname: Optional[str] = None, **kw: Any, ) -> Any: - if isinstance(resolved, roles.StrictFromClauseRole): + if is_from_clause(resolved): return elements.ClauseList(*resolved.c) else: return resolved diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c524a2602c..a1b25b8a6b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -80,7 +80,6 @@ from ..util.typing import Protocol from ..util.typing import TypedDict if typing.TYPE_CHECKING: - from . import roles from .annotation import _AnnotationDict from .base import _AmbiguousTableNameMap from .base import CompileState @@ -95,7 +94,6 @@ if typing.TYPE_CHECKING: from .elements import ColumnElement from .elements import Label from .functions import Function - from .selectable import Alias from .selectable import AliasedReturnsRows from .selectable import CompoundSelectState from .selectable import CTE @@ -386,7 +384,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): need_result_map_for_nested: bool need_result_map_for_compound: bool select_0: ReturnsRows - insert_from_select: Select + insert_from_select: Select[Any] class ExpandedState(NamedTuple): @@ -2834,15 +2832,31 @@ class SQLCompiler(Compiled): "unique bind parameter of the same name" % name ) elif existing._is_crud or bindparam._is_crud: - raise exc.CompileError( - "bindparam() name '%s' is reserved " - "for automatic usage in the VALUES or SET " - "clause of this " - "insert/update statement. Please use a " - "name other than column name when using bindparam() " - "with insert() or update() (for example, 'b_%s')." - % (bindparam.key, bindparam.key) - ) + if existing._is_crud and bindparam._is_crud: + # TODO: this condition is not well understood. + # see tests in test/sql/test_update.py + raise exc.CompileError( + "Encountered unsupported case when compiling an " + "INSERT or UPDATE statement. If this is a " + "multi-table " + "UPDATE statement, please provide string-named " + "arguments to the " + "values() method with distinct names; support for " + "multi-table UPDATE statements that " + "target multiple tables for UPDATE is very " + "limited", + ) + else: + raise exc.CompileError( + f"bindparam() name '{bindparam.key}' is reserved " + "for automatic usage in the VALUES or SET " + "clause of this " + "insert/update statement. Please use a " + "name other than column name when using " + "bindparam() " + "with insert() or update() (for example, " + f"'b_{bindparam.key}')." + ) self.binds[bindparam.key] = self.binds[name] = bindparam @@ -3881,7 +3895,7 @@ class SQLCompiler(Compiled): return text def _setup_select_hints( - self, select: Select + self, select: Select[Any] ) -> Tuple[str, _FromHintsType]: byfrom = dict( [ diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index e4408cd316..29d7b45d7a 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -22,6 +22,7 @@ from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING from typing import Union @@ -30,8 +31,10 @@ from . import coercions from . import dml from . import elements from . import roles +from .elements import ColumnClause from .schema import default_is_clause_element from .schema import default_is_sequence +from .selectable import TableClause from .. import exc from .. import util from ..util.typing import Literal @@ -41,16 +44,9 @@ if TYPE_CHECKING: from .compiler import SQLCompiler from .dml import _DMLColumnElement from .dml import DMLState - from .dml import Insert - from .dml import Update - from .dml import UpdateDMLState from .dml import ValuesBase - from .elements import ClauseElement - from .elements import ColumnClause from .elements import ColumnElement - from .elements import TextClause from .schema import _SQLExprDefault - from .schema import Column from .selectable import TableClause REQUIRED = util.symbol( @@ -68,12 +64,20 @@ values present. ) +def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]: + if not isinstance(c, ColumnClause): + raise exc.CompileError( + f"Can't create DML statement against column expression {c!r}" + ) + return c + + class _CrudParams(NamedTuple): - single_params: List[ - Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]] + single_params: Sequence[ + Tuple[ColumnElement[Any], str, Optional[Union[str, _SQLExprDefault]]] ] all_multi_params: List[ - List[ + Sequence[ Tuple[ ColumnClause[Any], str, @@ -274,7 +278,7 @@ def _get_crud_params( compiler, stmt, compile_state, - cast("List[Tuple[ColumnClause[Any], str, str]]", values), + cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values), cast("Callable[..., str]", _column_as_key), kw, ) @@ -290,7 +294,7 @@ def _get_crud_params( # insert_executemany_returning mode :) values = [ ( - stmt.table.columns[0], + _as_dml_column(stmt.table.columns[0]), compiler.preparer.format_column(stmt.table.columns[0]), "DEFAULT", ) @@ -1135,10 +1139,10 @@ def _extend_values_for_multiparams( compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState, - initial_values: List[Tuple[ColumnClause[Any], str, str]], + initial_values: Sequence[Tuple[ColumnClause[Any], str, str]], _column_as_key: Callable[..., str], kw: Dict[str, Any], -) -> List[List[Tuple[ColumnClause[Any], str, str]]]: +) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]: values_0 = initial_values values = [initial_values] diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 8307f64003..e0f162fc86 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -22,15 +22,19 @@ from typing import List from typing import MutableMapping from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import coercions from . import roles from . import util as sql_util +from ._typing import _no_kw +from ._typing import _TP from ._typing import is_column_element from ._typing import is_named_from_clause from .base import _entity_namespace_key @@ -42,6 +46,7 @@ from .base import ColumnCollection from .base import CompileState from .base import DialectKWArgs from .base import Executable +from .base import Generative from .base import HasCompileState from .elements import BooleanClauseList from .elements import ClauseElement @@ -49,12 +54,13 @@ from .elements import ColumnClause from .elements import ColumnElement from .elements import Null from .selectable import Alias +from .selectable import ExecutableReturnsRows from .selectable import FromClause from .selectable import HasCTE from .selectable import HasPrefixes from .selectable import Join -from .selectable import ReturnsRows from .selectable import TableClause +from .selectable import TypedReturnsRows from .sqltypes import NullType from .visitors import InternalTraversal from .. import exc @@ -66,9 +72,19 @@ if TYPE_CHECKING: from ._typing import _ColumnsClauseArgument from ._typing import _DMLColumnArgument from ._typing import _DMLTableArgument - from ._typing import _FromClauseArgument + from ._typing import _T0 # noqa + from ._typing import _T1 # noqa + from ._typing import _T2 # noqa + from ._typing import _T3 # noqa + from ._typing import _T4 # noqa + from ._typing import _T5 # noqa + from ._typing import _T6 # noqa + from ._typing import _T7 # noqa + from ._typing import _TypedColumnClauseArgument as _TCCA # noqa from .base import ReadOnlyColumnCollection from .compiler import SQLCompiler + from .elements import ColumnElement + from .elements import KeyedColumnElement from .selectable import _ColumnsClauseElement from .selectable import _SelectIterable from .selectable import Select @@ -88,6 +104,8 @@ else: isinsert = operator.attrgetter("isinsert") +_T = TypeVar("_T", bound=Any) + _DMLColumnElement = Union[str, ColumnClause[Any]] _DMLTableElement = Union[TableClause, Alias, Join] @@ -185,6 +203,11 @@ class DMLState(CompileState): "%s construct does not support " "multiple parameter sets." % statement.__visit_name__.upper() ) + else: + assert isinstance(statement, Insert) + + # which implies... + # assert isinstance(statement.table, TableClause) for parameters in statement._multi_values: multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [ @@ -291,7 +314,9 @@ class UpdateDMLState(DMLState): elif statement._multi_values: self._process_multi_values(statement) self._extra_froms = ef = self._make_extra_froms(statement) - self.is_multitable = mt = ef and self._dict_parameters + + self.is_multitable = mt = ef + self.include_table_with_column_exprs = bool( mt and compiler.render_table_with_column_in_update_from ) @@ -317,8 +342,8 @@ class UpdateBase( HasCompileState, DialectKWArgs, HasPrefixes, - ReturnsRows, - Executable, + Generative, + ExecutableReturnsRows, ClauseElement, ): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" @@ -383,8 +408,8 @@ class UpdateBase( @_generative def returning( - self: SelfUpdateBase, *cols: _ColumnsClauseArgument - ) -> SelfUpdateBase: + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> UpdateBase: r"""Add a :term:`RETURNING` or equivalent clause to this statement. e.g.: @@ -454,6 +479,8 @@ class UpdateBase( :ref:`tutorial_insert_returning` - in the :ref:`unified_tutorial` """ # noqa: E501 + if __kw: + raise _no_kw() if self._return_defaults: raise exc.InvalidRequestError( "return_defaults() is already configured on this statement" @@ -464,7 +491,7 @@ class UpdateBase( return self def corresponding_column( - self, column: ColumnElement[Any], require_embedded: bool = False + self, column: KeyedColumnElement[Any], require_embedded: bool = False ) -> Optional[ColumnElement[Any]]: return self.exported_columns.corresponding_column( column, require_embedded=require_embedded @@ -628,7 +655,7 @@ class ValuesBase(UpdateBase): _supports_multi_parameters = False - select: Optional[Select] = None + select: Optional[Select[Any]] = None """SELECT statement for INSERT .. FROM SELECT""" _post_values_clause: Optional[ClauseElement] = None @@ -804,11 +831,15 @@ class ValuesBase(UpdateBase): ) elif isinstance(arg, collections_abc.Sequence): - if arg and isinstance(arg[0], (list, dict, tuple)): self._multi_values += (arg,) return self + if TYPE_CHECKING: + # crud.py raises during compilation if this is not the + # case + assert isinstance(self, Insert) + # tuple values arg = {c.key: value for c, value in zip(self.table.c, arg)} @@ -1010,7 +1041,7 @@ class Insert(ValuesBase): def from_select( self: SelfInsert, names: List[str], - select: Select, + select: Select[Any], include_defaults: bool = True, ) -> SelfInsert: """Return a new :class:`_expression.Insert` construct which represents @@ -1073,6 +1104,114 @@ class Insert(ValuesBase): self.select = coercions.expect(roles.DMLSelectRole, select) return self + if TYPE_CHECKING: + + # START OVERLOADED FUNCTIONS self.returning ReturningInsert 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningInsert[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningInsert[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningInsert[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningInsert[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningInsert[Any]: + ... + + +class ReturningInsert(Insert, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Insert` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Insert.returning` method. + + .. versionadded:: 2.0 + + """ + SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase") @@ -1264,6 +1403,113 @@ class Update(DMLWhereBase, ValuesBase): self._inline = True return self + if TYPE_CHECKING: + # START OVERLOADED FUNCTIONS self.returning ReturningUpdate 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningUpdate[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningUpdate[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningUpdate[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningUpdate[Any]: + ... + + +class ReturningUpdate(Update, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Update` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Update.returning` method. + + .. versionadded:: 2.0 + + """ + SelfDelete = typing.TypeVar("SelfDelete", bound="Delete") @@ -1297,3 +1543,111 @@ class Delete(DMLWhereBase, UpdateBase): self.table = coercions.expect( roles.DMLTableRole, table, apply_propagate_attrs=self ) + + if TYPE_CHECKING: + + # START OVERLOADED FUNCTIONS self.returning ReturningDelete 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning(self, __ent0: _TCCA[_T0]) -> ReturningDelete[Tuple[_T0]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningDelete[Tuple[_T0, _T1]]: + ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningDelete[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningDelete[Any]: + ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningDelete[Any]: + ... + + +class ReturningDelete(Update, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Delete` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Delete.returning` method. + + .. versionadded:: 2.0 + + """ diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 34d5127ab7..a295612918 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -54,6 +54,7 @@ from .base import _clone from .base import _generative from .base import _NoArg from .base import Executable +from .base import Generative from .base import HasMemoized from .base import Immutable from .base import NO_ARG @@ -94,10 +95,7 @@ if typing.TYPE_CHECKING: from .selectable import _SelectIterable from .selectable import FromClause from .selectable import NamedFromClause - from .selectable import ReturnsRows from .selectable import Select - from .selectable import TableClause - from .sqltypes import Boolean from .sqltypes import TupleType from .type_api import TypeEngine from .visitors import _CloneCallableType @@ -122,7 +120,9 @@ _NT = TypeVar("_NT", bound="_NUMERIC") _NMT = TypeVar("_NMT", bound="_NUMBER") -def literal(value, type_=None): +def literal( + value: Any, type_: Optional[_TypeEngineArgument[_T]] = None +) -> BindParameter[_T]: r"""Return a literal clause, bound to a bind parameter. Literal clauses are created automatically when non- @@ -144,7 +144,9 @@ def literal(value, type_=None): return coercions.expect(roles.LiteralValueRole, value, type_=type_) -def literal_column(text, type_=None): +def literal_column( + text: str, type_: Optional[_TypeEngineArgument[_T]] = None +) -> ColumnClause[_T]: r"""Produce a :class:`.ColumnClause` object that has the :paramref:`_expression.column.is_literal` flag set to True. @@ -316,6 +318,7 @@ class ClauseElement( is_selectable = False is_dml = False _is_column_element = False + _is_keyed_column_element = False _is_table = False _is_textual = False _is_from_clause = False @@ -342,7 +345,7 @@ class ClauseElement( if typing.TYPE_CHECKING: def get_children( - self, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any + self, *, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any ) -> Iterable[ClauseElement]: ... @@ -455,7 +458,7 @@ class ClauseElement( connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptions, - ) -> Result: + ) -> Result[Any]: if self.supports_execution: if TYPE_CHECKING: assert isinstance(self, Executable) @@ -833,13 +836,13 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): def in_( self, - other: Union[Sequence[Any], BindParameter[Any], Select], + other: Union[Sequence[Any], BindParameter[Any], Select[Any]], ) -> BinaryExpression[bool]: ... def not_in( self, - other: Union[Sequence[Any], BindParameter[Any], Select], + other: Union[Sequence[Any], BindParameter[Any], Select[Any]], ) -> BinaryExpression[bool]: ... @@ -1699,6 +1702,14 @@ class ColumnElement( return self._anon_label(label, add_hash=idx) +class KeyedColumnElement(ColumnElement[_T]): + """ColumnElement where ``.key`` is non-None.""" + + _is_keyed_column_element = True + + key: str + + class WrapsColumnExpression(ColumnElement[_T]): """Mixin that defines a :class:`_expression.ColumnElement` as a wrapper with special @@ -1760,7 +1771,7 @@ class WrapsColumnExpression(ColumnElement[_T]): SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter[Any]") -class BindParameter(roles.InElementRole, ColumnElement[_T]): +class BindParameter(roles.InElementRole, KeyedColumnElement[_T]): r"""Represent a "bound expression". :class:`.BindParameter` is invoked explicitly using the @@ -2073,6 +2084,7 @@ class TextClause( roles.FromClauseRole, roles.SelectStatementRole, roles.InElementRole, + Generative, Executable, DQLDMLClauseElement, roles.BinaryElementRole[Any], @@ -4160,7 +4172,7 @@ class FunctionFilter(ColumnElement[_T]): ) -class NamedColumn(ColumnElement[_T]): +class NamedColumn(KeyedColumnElement[_T]): is_literal = False table: Optional[FromClause] = None name: str @@ -4502,7 +4514,7 @@ class ColumnClause( self.is_literal = is_literal - def get_children(self, column_tables=False, **kw): + def get_children(self, *, column_tables=False, **kw): # override base get_children() to not return the Table # or selectable that is parent to this column. Traversals # expect the columns of tables and subqueries to be leaf nodes. diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 6481682355..b827df3df8 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -175,7 +175,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> CursorResult: + ) -> CursorResult[Any]: return connection._execute_function( self, distilled_params, execution_options ) @@ -623,7 +623,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): joins_implicitly=joins_implicitly, ) - def select(self) -> "Select": + def select(self) -> Select[Any]: """Produce a :func:`_expression.select` construct against this :class:`.FunctionElement`. @@ -632,7 +632,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): s = select(function_element) """ - s = Select(self) + s: Select[Any] = Select(self) if self._execution_options: s = s.execution_options(**self._execution_options) return s @@ -846,7 +846,7 @@ class _FunctionGenerator: @overload def __call__( - self, *c: Any, type_: TypeEngine[_T], **kwargs: Any + self, *c: Any, type_: _TypeEngineArgument[_T], **kwargs: Any ) -> Function[_T]: ... diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 231c70a5ba..09d4b35ad0 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -8,8 +8,6 @@ from __future__ import annotations from typing import Any from typing import Generic -from typing import Iterable -from typing import List from typing import Optional from typing import TYPE_CHECKING from typing import TypeVar @@ -19,12 +17,7 @@ from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _PropagateAttrsType - from .base import _EntityNamespace - from .base import ColumnCollection - from .base import ReadOnlyColumnCollection - from .elements import ColumnClause from .elements import Label - from .elements import NamedColumn from .selectable import _SelectIterable from .selectable import FromClause from .selectable import Subquery @@ -108,13 +101,21 @@ class TruncatedLabelRole(StringRole, SQLRole): class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole): __slots__ = () - _role_name = "Column expression or FROM clause" + _role_name = ( + "Column expression, FROM clause, or other columns clause element" + ) @property def _select_iterable(self) -> _SelectIterable: raise NotImplementedError() +class TypedColumnsClauseRole(Generic[_T], SQLRole): + """element-typed form of ColumnsClauseRole""" + + __slots__ = () + + class LimitOffsetRole(SQLRole): __slots__ = () _role_name = "LIMIT / OFFSET expression" @@ -161,7 +162,7 @@ class WhereHavingRole(OnClauseRole): _role_name = "SQL expression for WHERE/HAVING role" -class ExpressionElementRole(Generic[_T], SQLRole): +class ExpressionElementRole(TypedColumnsClauseRole[_T]): # note when using generics for ExpressionElementRole, # the generic type needs to be in # sqlalchemy.sql.coercions._impl_lookup mapping also. @@ -212,39 +213,11 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole): named_with_column: bool - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... - - @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... - - @util.ro_non_memoized_property - def entity_namespace(self) -> _EntityNamespace: - ... - - @util.ro_non_memoized_property - def _hide_froms(self) -> Iterable[FromClause]: - ... - - @util.ro_non_memoized_property - def _from_objects(self) -> List[FromClause]: - ... - class StrictFromClauseRole(FromClauseRole): __slots__ = () # does not allow text() or select() objects - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def description(self) -> str: - ... - class AnonymizedFromClauseRole(StrictFromClauseRole): __slots__ = () @@ -317,16 +290,6 @@ class DMLTableRole(FromClauseRole): __slots__ = () _role_name = "subject table for an INSERT, UPDATE or DELETE" - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def primary_key(self) -> Iterable[NamedColumn[Any]]: - ... - - @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... - class DMLColumnRole(SQLRole): __slots__ = () diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 52ba60a62c..27456d2be4 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -86,7 +86,6 @@ if typing.TYPE_CHECKING: from ._typing import _InfoType from ._typing import _TextCoercedExpressionArgument from ._typing import _TypeEngineArgument - from .base import ColumnCollection from .base import DedupeColumnCollection from .base import ReadOnlyColumnCollection from .compiler import DDLCompiler @@ -97,9 +96,7 @@ if typing.TYPE_CHECKING: from .visitors import anon_map from ..engine import Connection from ..engine import Engine - from ..engine.cursor import CursorResult from ..engine.interfaces import _CoreMultiExecuteParams - from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptionsParameter from ..engine.interfaces import ExecutionContext from ..engine.mock import MockConnection @@ -2609,8 +2606,10 @@ class ForeignKey(DialectKWArgs, SchemaItem): :class:`_schema.Table`. """ - - return table.columns.corresponding_column(self.column) + # our column is a Column, and any subquery etc. proxying us + # would be doing so via another Column, so that's what would + # be returned here + return table.columns.corresponding_column(self.column) # type: ignore @util.memoized_property def _column_tokens(self) -> Tuple[Optional[str], str, Optional[str]]: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 9d4d1d6c79..b08f13f993 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -23,6 +23,7 @@ from typing import Any from typing import Callable from typing import cast from typing import Dict +from typing import Generic from typing import Iterable from typing import Iterator from typing import List @@ -46,6 +47,8 @@ from . import traversals from . import type_api from . import visitors from ._typing import _ColumnsClauseArgument +from ._typing import _no_kw +from ._typing import _TP from ._typing import is_column_element from ._typing import is_select_statement from ._typing import is_subquery @@ -103,9 +106,20 @@ if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _FromClauseArgument from ._typing import _JoinTargetArgument + from ._typing import _MAYBE_ENTITY + from ._typing import _NOT_ENTITY from ._typing import _OnClauseArgument from ._typing import _SelectStatementForCompoundArgument + from ._typing import _T0 + from ._typing import _T1 + from ._typing import _T2 + from ._typing import _T3 + from ._typing import _T4 + from ._typing import _T5 + from ._typing import _T6 + from ._typing import _T7 from ._typing import _TextCoercedExpressionArgument + from ._typing import _TypedColumnClauseArgument as _TCCA from ._typing import _TypeEngineArgument from .base import _AmbiguousTableNameMap from .base import ExecutableOption @@ -115,14 +129,13 @@ if TYPE_CHECKING: from .dml import Delete from .dml import Insert from .dml import Update + from .elements import KeyedColumnElement from .elements import NamedColumn from .elements import TextClause from .functions import Function - from .schema import Column from .schema import ForeignKey from .schema import ForeignKeyConstraint from .type_api import TypeEngine - from .util import ClauseAdapter from .visitors import _CloneCallableType @@ -245,6 +258,14 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): raise NotImplementedError() +class ExecutableReturnsRows(Executable, ReturnsRows): + """base for executable statements that return rows.""" + + +class TypedReturnsRows(ExecutableReturnsRows, Generic[_TP]): + """base for executable statements that return rows.""" + + SelfSelectable = TypeVar("SelfSelectable", bound="Selectable") @@ -293,8 +314,8 @@ class Selectable(ReturnsRows): ) def corresponding_column( - self, column: ColumnElement[Any], require_embedded: bool = False - ) -> Optional[ColumnElement[Any]]: + self, column: KeyedColumnElement[Any], require_embedded: bool = False + ) -> Optional[KeyedColumnElement[Any]]: """Given a :class:`_expression.ColumnElement`, return the exported :class:`_expression.ColumnElement` object from the :attr:`_expression.Selectable.exported_columns` @@ -593,7 +614,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _use_schema_map = False - def select(self) -> Select: + def select(self) -> Select[Any]: r"""Return a SELECT of this :class:`_expression.FromClause`. @@ -795,7 +816,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): ) @util.ro_non_memoized_property - def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`. @@ -817,7 +840,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.c @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, Any]: + def columns( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """A named-based collection of :class:`_expression.ColumnElement` objects maintained by this :class:`_expression.FromClause`. @@ -833,7 +858,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.c @util.ro_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, Any]: + def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """ A synonym for :attr:`.FromClause.columns` @@ -1223,7 +1248,7 @@ class Join(roles.DMLTableRole, FromClause): @util.preload_module("sqlalchemy.sql.util") def _populate_column_collection(self): sqlutil = util.preloaded.sql_util - columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [ + columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [ c for c in self.right.c ] @@ -1458,7 +1483,7 @@ class Join(roles.DMLTableRole, FromClause): "join explicitly." % (a.description, b.description) ) - def select(self) -> "Select": + def select(self) -> Select[Any]: r"""Create a :class:`_expression.Select` from this :class:`_expression.Join`. @@ -2764,6 +2789,7 @@ class Subquery(AliasedReturnsRows): cls, selectable: SelectBase, name: Optional[str] = None ) -> Subquery: """Return a :class:`.Subquery` object.""" + return coercions.expect( roles.SelectStatementRole, selectable ).subquery(name=name) @@ -3216,7 +3242,6 @@ class SelectBase( roles.CompoundElementRole, roles.InElementRole, HasCTE, - Executable, SupportsCloneAnnotations, Selectable, ): @@ -3239,7 +3264,9 @@ class SelectBase( self._reset_memoizations() @util.ro_non_memoized_property - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set. @@ -3284,7 +3311,9 @@ class SelectBase( raise NotImplementedError() @property - def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.Selectable`, not including @@ -3377,7 +3406,7 @@ class SelectBase( def as_scalar(self): return self.scalar_subquery() - def exists(self): + def exists(self) -> Exists: """Return an :class:`_sql.Exists` representation of this selectable, which can be used as a column expression. @@ -3394,7 +3423,7 @@ class SelectBase( """ return Exists(self) - def scalar_subquery(self): + def scalar_subquery(self) -> ScalarSelect[Any]: """Return a 'scalar' representation of this selectable, which can be used as a column expression. @@ -3607,7 +3636,7 @@ SelfGenerativeSelect = typing.TypeVar( ) -class GenerativeSelect(SelectBase): +class GenerativeSelect(SelectBase, Generative): """Base class for SELECT statements where additional elements can be added. @@ -4128,7 +4157,7 @@ class _CompoundSelectKeyword(Enum): INTERSECT_ALL = "INTERSECT ALL" -class CompoundSelect(HasCompileState, GenerativeSelect): +class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations. @@ -4293,7 +4322,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect): return self.selects[0]._all_selected_columns @util.ro_non_memoized_property - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -4343,7 +4374,10 @@ class SelectState(util.MemoizedSlots, CompileState): ... def __init__( - self, statement: Select, compiler: Optional[SQLCompiler], **kw: Any + self, + statement: Select[Any], + compiler: Optional[SQLCompiler], + **kw: Any, ): self.statement = statement self.from_clauses = statement._from_obj @@ -4369,7 +4403,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def get_column_descriptions( - cls, statement: Select + cls, statement: Select[Any] ) -> List[Dict[str, Any]]: return [ { @@ -4384,12 +4418,14 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def from_statement( - cls, statement: Select, from_statement: ReturnsRows - ) -> Any: + cls, statement: Select[Any], from_statement: ExecutableReturnsRows + ) -> ExecutableReturnsRows: cls._plugin_not_implemented() @classmethod - def get_columns_clause_froms(cls, statement: Select) -> List[FromClause]: + def get_columns_clause_froms( + cls, statement: Select[Any] + ) -> List[FromClause]: return cls._normalize_froms( itertools.chain.from_iterable( element._from_objects for element in statement._raw_columns @@ -4439,7 +4475,7 @@ class SelectState(util.MemoizedSlots, CompileState): return go - def _get_froms(self, statement: Select) -> List[FromClause]: + def _get_froms(self, statement: Select[Any]) -> List[FromClause]: ambiguous_table_name_map: _AmbiguousTableNameMap self._ambiguous_table_name_map = ambiguous_table_name_map = {} @@ -4467,7 +4503,7 @@ class SelectState(util.MemoizedSlots, CompileState): def _normalize_froms( cls, iterable_of_froms: Iterable[FromClause], - check_statement: Optional[Select] = None, + check_statement: Optional[Select[Any]] = None, ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, ) -> List[FromClause]: """given an iterable of things to select FROM, reduce them to what @@ -4615,7 +4651,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def determine_last_joined_entity( - cls, stmt: Select + cls, stmt: Select[Any] ) -> Optional[_JoinTargetElement]: if stmt._setup_joins: return stmt._setup_joins[-1][0] @@ -4623,7 +4659,7 @@ class SelectState(util.MemoizedSlots, CompileState): return None @classmethod - def all_selected_columns(cls, statement: Select) -> _SelectIterable: + def all_selected_columns(cls, statement: Select[Any]) -> _SelectIterable: return [c for c in _select_iterables(statement._raw_columns)] def _setup_joins( @@ -4876,7 +4912,7 @@ class _MemoizedSelectEntities( return c # type: ignore @classmethod - def _generate_for_statement(cls, select_stmt: Select) -> None: + def _generate_for_statement(cls, select_stmt: Select[Any]) -> None: if select_stmt._setup_joins or select_stmt._with_options: self = _MemoizedSelectEntities() self._raw_columns = select_stmt._raw_columns @@ -4888,7 +4924,7 @@ class _MemoizedSelectEntities( select_stmt._setup_joins = select_stmt._with_options = () -SelfSelect = typing.TypeVar("SelfSelect", bound="Select") +SelfSelect = typing.TypeVar("SelfSelect", bound="Select[Any]") class Select( @@ -4898,6 +4934,7 @@ class Select( HasCompileState, _SelectFromElements, GenerativeSelect, + TypedReturnsRows[_TP], ): """Represents a ``SELECT`` statement. @@ -4973,7 +5010,7 @@ class Select( _compile_state_factory: Type[SelectState] @classmethod - def _create_raw_select(cls, **kw: Any) -> Select: + def _create_raw_select(cls, **kw: Any) -> Select[Any]: """Create a :class:`.Select` using raw ``__new__`` with no coercions. Used internally to build up :class:`.Select` constructs with @@ -4985,7 +5022,7 @@ class Select( stmt.__dict__.update(kw) return stmt - def __init__(self, *entities: _ColumnsClauseArgument): + def __init__(self, *entities: _ColumnsClauseArgument[Any]): r"""Construct a new :class:`_expression.Select`. The public constructor for :class:`_expression.Select` is the @@ -5013,7 +5050,9 @@ class Select( cols = list(elem._select_iterable) return cols[0].type - def filter(self: SelfSelect, *criteria: ColumnElement[Any]) -> SelfSelect: + def filter( + self: SelfSelect, *criteria: _ColumnExpressionArgument[bool] + ) -> SelfSelect: """A synonym for the :meth:`_future.Select.where` method.""" return self.where(*criteria) @@ -5032,7 +5071,28 @@ class Select( return self._raw_columns[0] - def filter_by(self, **kwargs): + if TYPE_CHECKING: + + @overload + def scalar_subquery( + self: Select[Tuple[_MAYBE_ENTITY]], + ) -> ScalarSelect[Any]: + ... + + @overload + def scalar_subquery( + self: Select[Tuple[_NOT_ENTITY]], + ) -> ScalarSelect[_NOT_ENTITY]: + ... + + @overload + def scalar_subquery(self) -> ScalarSelect[Any]: + ... + + def scalar_subquery(self) -> ScalarSelect[Any]: + ... + + def filter_by(self: SelfSelect, **kwargs: Any) -> SelfSelect: r"""apply the given filtering criterion as a WHERE clause to this select. @@ -5046,7 +5106,7 @@ class Select( return self.filter(*clauses) @property - def column_descriptions(self): + def column_descriptions(self) -> Any: """Return a :term:`plugin-enabled` 'column descriptions' structure referring to the columns which are SELECTed by this statement. @@ -5089,7 +5149,9 @@ class Select( meth = SelectState.get_plugin_class(self).get_column_descriptions return meth(self) - def from_statement(self, statement): + def from_statement( + self, statement: ExecutableReturnsRows + ) -> ExecutableReturnsRows: """Apply the columns which this :class:`.Select` would select onto another statement. @@ -5410,7 +5472,7 @@ class Select( ) @property - def inner_columns(self): + def inner_columns(self) -> _SelectIterable: """An iterator of all :class:`_expression.ColumnElement` expressions which would be rendered into the columns clause of the resulting SELECT statement. @@ -5487,18 +5549,19 @@ class Select( self._reset_memoizations() - def get_children(self, **kwargs): + def get_children(self, **kw: Any) -> Iterable[ClauseElement]: return itertools.chain( super(Select, self).get_children( - omit_attrs=("_from_obj", "_correlate", "_correlate_except") + omit_attrs=("_from_obj", "_correlate", "_correlate_except"), + **kw, ), self._iterate_from_elements(), ) @_generative def add_columns( - self: SelfSelect, *columns: _ColumnsClauseArgument - ) -> SelfSelect: + self, *columns: _ColumnsClauseArgument[Any] + ) -> Select[Any]: """Return a new :func:`_expression.select` construct with the given column expressions added to its columns clause. @@ -5523,7 +5586,7 @@ class Select( return self def _set_entities( - self, entities: Iterable[_ColumnsClauseArgument] + self, entities: Iterable[_ColumnsClauseArgument[Any]] ) -> None: self._raw_columns = [ coercions.expect( @@ -5538,7 +5601,7 @@ class Select( "be removed in a future release. Please use " ":meth:`_expression.Select.add_columns`", ) - def column(self: SelfSelect, column: _ColumnsClauseArgument) -> SelfSelect: + def column(self, column: _ColumnsClauseArgument[Any]) -> Select[Any]: """Return a new :func:`_expression.select` construct with the given column expression added to its columns clause. @@ -5555,9 +5618,7 @@ class Select( return self.add_columns(column) @util.preload_module("sqlalchemy.sql.util") - def reduce_columns( - self: SelfSelect, only_synonyms: bool = True - ) -> SelfSelect: + def reduce_columns(self, only_synonyms: bool = True) -> Select[Any]: """Return a new :func:`_expression.select` construct with redundantly named, equivalently-valued columns removed from the columns clause. @@ -5580,20 +5641,115 @@ class Select( all columns that are equivalent to another are removed. """ - return self.with_only_columns( + woc: Select[Any] + woc = self.with_only_columns( *util.preloaded.sql_util.reduce_columns( self._all_selected_columns, only_synonyms=only_synonyms, *(self._where_criteria + self._from_obj), ) ) + return woc + + # START OVERLOADED FUNCTIONS self.with_only_columns Select 8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_sel_v1_overloads.py + + @overload + def with_only_columns(self, __ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: + ... + + @overload + def with_only_columns( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> Select[Tuple[_T0, _T1]]: + ... + + @overload + def with_only_columns( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> Select[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> Select[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def with_only_columns( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.with_only_columns + + @overload + def with_only_columns( + self, + *columns: _ColumnsClauseArgument[Any], + maintain_column_froms: bool = False, + **__kw: Any, + ) -> Select[Any]: + ... @_generative def with_only_columns( - self: SelfSelect, - *columns: _ColumnsClauseArgument, + self, + *columns: _ColumnsClauseArgument[Any], maintain_column_froms: bool = False, - ) -> SelfSelect: + **__kw: Any, + ) -> Select[Any]: r"""Return a new :func:`_expression.select` construct with its columns clause replaced with the given columns. @@ -5647,6 +5803,9 @@ class Select( """ # noqa: E501 + if __kw: + raise _no_kw() + # memoizations should be cleared here as of # I95c560ffcbfa30b26644999412fb6a385125f663 , asserting this # is the case for now. @@ -5915,7 +6074,9 @@ class Select( return self @HasMemoized_ro_memoized_attribute - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, ColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, @@ -6215,7 +6376,7 @@ class ScalarSelect( by this :class:`_expression.ScalarSelect`. """ - self.element = cast(Select, self.element).where(crit) + self.element = cast("Select[Any]", self.element).where(crit) return self @overload @@ -6269,7 +6430,9 @@ class ScalarSelect( """ - self.element = cast(Select, self.element).correlate(*fromclauses) + self.element = cast("Select[Any]", self.element).correlate( + *fromclauses + ) return self @_generative @@ -6307,7 +6470,7 @@ class ScalarSelect( """ - self.element = cast(Select, self.element).correlate_except( + self.element = cast("Select[Any]", self.element).correlate_except( *fromclauses ) return self @@ -6331,12 +6494,18 @@ class Exists(UnaryExpression[bool]): def __init__( self, __argument: Optional[ - Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]] + Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]] ] = None, ): + s: ScalarSelect[Any] + + # TODO: this seems like we should be using coercions for this if __argument is None: s = Select(literal_column("*")).scalar_subquery() - elif isinstance(__argument, (SelectBase, ScalarSelect)): + elif isinstance(__argument, SelectBase): + s = __argument.scalar_subquery() + s._propagate_attrs = __argument._propagate_attrs + elif isinstance(__argument, ScalarSelect): s = __argument else: s = Select(__argument).scalar_subquery() @@ -6358,7 +6527,7 @@ class Exists(UnaryExpression[bool]): element = fn(element) return element.self_group(against=operators.exists) - def select(self) -> Select: + def select(self) -> Select[Any]: r"""Return a SELECT of this :class:`_expression.Exists`. e.g.:: @@ -6452,7 +6621,7 @@ class Exists(UnaryExpression[bool]): SelfTextualSelect = typing.TypeVar("SelfTextualSelect", bound="TextualSelect") -class TextualSelect(SelectBase): +class TextualSelect(SelectBase, Executable, Generative): """Wrap a :class:`_expression.TextClause` construct within a :class:`_expression.SelectBase` interface. @@ -6503,7 +6672,9 @@ class TextualSelect(SelectBase): self.positional = positional @HasMemoized_ro_memoized_attribute - def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]: + def selected_columns( + self, + ) -> ColumnCollection[str, KeyedColumnElement[Any]]: """A :class:`_expression.ColumnCollection` representing the columns that this SELECT statement or similar construct returns in its result set, diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d08fef60a9..8c45ba4101 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -50,6 +50,7 @@ from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement from .elements import Grouping +from .elements import KeyedColumnElement from .elements import Label from .elements import Null from .elements import UnaryExpression @@ -72,9 +73,7 @@ if typing.TYPE_CHECKING: from ._typing import _EquivalentColumnMap from ._typing import _TypeEngineArgument from .elements import TextClause - from .roles import FromClauseRole from .selectable import _JoinTargetElement - from .selectable import _OnClauseElement from .selectable import _SelectIterable from .selectable import Selectable from .visitors import _TraverseCallableType @@ -569,7 +568,7 @@ class _repr_row(_repr_base): __slots__ = ("row",) - def __init__(self, row: "Row", max_chars: int = 300): + def __init__(self, row: "Row[Any]", max_chars: int = 300): self.row = row self.max_chars = max_chars @@ -1068,7 +1067,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): col = col._annotations["adapt_column"] if TYPE_CHECKING: - assert isinstance(col, ColumnElement) + assert isinstance(col, KeyedColumnElement) if self.adapt_from_selectables and col not in self.equivalents: for adp in self.adapt_from_selectables: @@ -1078,7 +1077,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): return None if TYPE_CHECKING: - assert isinstance(col, ColumnElement) + assert isinstance(col, KeyedColumnElement) if self.include_fn and not self.include_fn(col): return None diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index e0a66fbcf4..88586d834a 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -450,7 +450,7 @@ class HasTraverseInternals: @util.preload_module("sqlalchemy.sql.traversals") def get_children( - self, omit_attrs: Tuple[str, ...] = (), **kw: Any + self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any ) -> Iterable[HasTraverseInternals]: r"""Return immediate child :class:`.visitors.HasTraverseInternals` elements of this :class:`.visitors.HasTraverseInternals`. @@ -594,7 +594,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable): if typing.TYPE_CHECKING: def get_children( - self, omit_attrs: Tuple[str, ...] = (), **kw: Any + self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any ) -> Iterable[ExternallyTraversible]: ... diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 49c5d693af..da3fbc718a 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -18,6 +18,7 @@ import hashlib import inspect import itertools import operator +import os import re import sys import textwrap @@ -32,6 +33,7 @@ from typing import Generic from typing import Iterator from typing import List from typing import Mapping +from typing import no_type_check from typing import NoReturn from typing import Optional from typing import overload @@ -2106,3 +2108,45 @@ def has_compiled_ext(raise_=False): ) else: return False + + +@no_type_check +def console_scripts( + path: str, options: dict, ignore_output: bool = False +) -> None: + + import subprocess + import shlex + from pathlib import Path + + is_posix = os.name == "posix" + + entrypoint_name = options["entrypoint"] + + for entry in compat.importlib_metadata_get("console_scripts"): + if entry.name == entrypoint_name: + impl = entry + break + else: + raise Exception( + f"Could not find entrypoint console_scripts.{entrypoint_name}" + ) + cmdline_options_str = options.get("options", "") + cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [ + path + ] + + kw = {} + if ignore_output: + kw["stdout"] = kw["stderr"] = subprocess.DEVNULL + + subprocess.run( + [ + sys.executable, + "-c", + "import %s; %s.%s()" % (impl.module, impl.module, impl.attr), + ] + + cmdline_options_list, + cwd=Path(__file__).parent.parent, + **kw, + ) diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index d192dc06bf..2a215c4f1a 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -14,7 +14,7 @@ from typing import Type from typing import TypeVar from typing import Union -from typing_extensions import NotRequired as NotRequired # noqa +from typing_extensions import NotRequired as NotRequired from . import compat diff --git a/pyproject.toml b/pyproject.toml index d16f03c032..516831bca5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,6 @@ target-version = ['py37'] [tool.zimports] black-line-length = 79 -keep-unused-type-checking = true [tool.slotscheck] exclude-modules = '^sqlalchemy\.testing' diff --git a/test/base/test_result.py b/test/base/test_result.py index bc7bfefa49..90938263f5 100644 --- a/test/base/test_result.py +++ b/test/base/test_result.py @@ -276,6 +276,13 @@ class ResultTest(fixtures.TestBase): eq_(m1.fetchone(), {"a": 1, "b": 1, "c": 1}) eq_(r1.fetchone(), (2, 1, 2)) + def test_tuples_plus_base(self): + r1 = self._fixture() + + t1 = r1.tuples() + eq_(t1.fetchone(), (1, 1, 1)) + eq_(r1.fetchone(), (2, 1, 2)) + def test_scalar_plus_base(self): r1 = self._fixture() diff --git a/test/ext/mypy/plain_files/association_proxy_one.py b/test/ext/mypy/plain_files/association_proxy_one.py index e8b57a0c02..cb9f0b85d7 100644 --- a/test/ext/mypy/plain_files/association_proxy_one.py +++ b/test/ext/mypy/plain_files/association_proxy_one.py @@ -40,8 +40,8 @@ class Address(Base): u1 = User() if typing.TYPE_CHECKING: - # EXPECTED_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*?\[builtins.str\]\] + # EXPECTED_RE_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*?\[builtins.str\]\] reveal_type(User.email_addresses) - # EXPECTED_TYPE: builtins.set\*?\[builtins.str\] + # EXPECTED_RE_TYPE: builtins.set\*?\[builtins.str\] reveal_type(u1.email_addresses) diff --git a/test/ext/mypy/plain_files/engine_inspection.py b/test/ext/mypy/plain_files/engine_inspection.py index 1a1649e4ec..20e252cddc 100644 --- a/test/ext/mypy/plain_files/engine_inspection.py +++ b/test/ext/mypy/plain_files/engine_inspection.py @@ -14,11 +14,11 @@ c1 = cols[0] if typing.TYPE_CHECKING: - # EXPECTED_TYPE: sqlalchemy.engine.base.Engine + # EXPECTED_RE_TYPE: sqlalchemy.engine.base.Engine reveal_type(e) - # EXPECTED_TYPE: sqlalchemy.engine.reflection.Inspector.* + # EXPECTED_RE_TYPE: sqlalchemy.engine.reflection.Inspector.* reveal_type(insp) - # EXPECTED_TYPE: .*list.*TypedDict.*ReflectedColumn.* + # EXPECTED_RE_TYPE: .*list.*TypedDict.*ReflectedColumn.* reveal_type(cols) diff --git a/test/ext/mypy/plain_files/experimental_relationship.py b/test/ext/mypy/plain_files/experimental_relationship.py index fe2742072c..a8d81426e7 100644 --- a/test/ext/mypy/plain_files/experimental_relationship.py +++ b/test/ext/mypy/plain_files/experimental_relationship.py @@ -49,20 +49,20 @@ class Address(Base): if typing.TYPE_CHECKING: - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\] reveal_type(User.extra) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\] reveal_type(User.extra_name) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\] reveal_type(Address.email) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\] reveal_type(Address.email_name) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[experimental_relationship.Address\]\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[experimental_relationship.Address\]\] reveal_type(User.addresses_style_one) - # EXPECTED_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[experimental_relationship.Address\]\] + # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[experimental_relationship.Address\]\] reveal_type(User.addresses_style_two) diff --git a/test/ext/mypy/plain_files/hybrid_one.py b/test/ext/mypy/plain_files/hybrid_one.py index d9f97ebcff..12c7c204c5 100644 --- a/test/ext/mypy/plain_files/hybrid_one.py +++ b/test/ext/mypy/plain_files/hybrid_one.py @@ -47,17 +47,17 @@ expr2 = Interval.contains(7) expr3 = Interval.intersects(i2) if typing.TYPE_CHECKING: - # EXPECTED_TYPE: builtins.int\*? + # EXPECTED_RE_TYPE: builtins.int\*? reveal_type(i1.length) - # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] reveal_type(Interval.length) - # EXPECTED_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] reveal_type(expr1) - # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] reveal_type(expr2) - # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] reveal_type(expr3) diff --git a/test/ext/mypy/plain_files/hybrid_two.py b/test/ext/mypy/plain_files/hybrid_two.py index ab2970656e..430d796c60 100644 --- a/test/ext/mypy/plain_files/hybrid_two.py +++ b/test/ext/mypy/plain_files/hybrid_two.py @@ -47,10 +47,10 @@ class Interval(Base): # while we are here, check some Float[] / div type stuff if typing.TYPE_CHECKING: - # EXPECTED_TYPE: sqlalchemy.*Function\[builtins.float\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*Function\[builtins.float\*?\] reveal_type(f1) - # EXPECTED_TYPE: sqlalchemy.*ColumnElement\[builtins.float\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*ColumnElement\[builtins.float\*?\] reveal_type(expr) return expr @@ -69,23 +69,23 @@ expr3 = Interval.radius.in_([0.5, 5.2]) if typing.TYPE_CHECKING: - # EXPECTED_TYPE: builtins.int\*? + # EXPECTED_RE_TYPE: builtins.int\*? reveal_type(i1.length) - # EXPECTED_TYPE: builtins.float\*? + # EXPECTED_RE_TYPE: builtins.float\*? reveal_type(i2.radius) - # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] reveal_type(Interval.length) - # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\] reveal_type(Interval.radius) - # EXPECTED_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] reveal_type(expr1) - # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\] reveal_type(expr2) - # EXPECTED_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] reveal_type(expr3) diff --git a/test/ext/mypy/plain_files/session.py b/test/ext/mypy/plain_files/session.py index 199d3a804f..0dfa0a7520 100644 --- a/test/ext/mypy/plain_files/session.py +++ b/test/ext/mypy/plain_files/session.py @@ -4,7 +4,6 @@ from typing import List from sqlalchemy import create_engine from sqlalchemy import ForeignKey -from sqlalchemy import select from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -43,7 +42,20 @@ with Session(e) as sess: sess.add_all([Address(user=u1, email="e1"), Address(user=u1, email="e2")]) sess.commit() -with Session(e) as sess: - users: List[User] = sess.scalars( - select(User), execution_options={"stream_results": False} - ).all() + q = sess.query(User).filter_by(id=7) + + # EXPECTED_TYPE: Query[User] + reveal_type(q) + + rows1 = q.all() + + # EXPECTED_RE_TYPE: builtins.[Ll]ist\[.*User\*?\] + reveal_type(rows1) + + q2 = sess.query(User.id).filter_by(id=7) + rows2 = q2.all() + + # EXPECTED_TYPE: List[Row[Tuple[int]]] + reveal_type(rows2) + +# more result tests in typed_results.py diff --git a/test/ext/mypy/plain_files/sql_operations.py b/test/ext/mypy/plain_files/sql_operations.py index f9b9b2ffe5..6b06535bf1 100644 --- a/test/ext/mypy/plain_files/sql_operations.py +++ b/test/ext/mypy/plain_files/sql_operations.py @@ -40,32 +40,32 @@ if typing.TYPE_CHECKING: # as far as if this is ColumnElement, BinaryElement, SQLCoreOperations, # that might change. main thing is it's SomeSQLColThing[bool] and # not 'bool' or 'Any'. - # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[builtins.bool\] + # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.bool\] reveal_type(expr1) - # EXPECTED_TYPE: sqlalchemy..*ColumnClause\[builtins.str.?\] + # EXPECTED_RE_TYPE: sqlalchemy..*ColumnClause\[builtins.str.?\] reveal_type(c1) - # EXPECTED_TYPE: sqlalchemy..*ColumnClause\[builtins.int.?\] + # EXPECTED_RE_TYPE: sqlalchemy..*ColumnClause\[builtins.int.?\] reveal_type(c2) - # EXPECTED_TYPE: sqlalchemy..*BinaryExpression\[builtins.bool\] + # EXPECTED_RE_TYPE: sqlalchemy..*BinaryExpression\[builtins.bool\] reveal_type(expr2) - # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, decimal.Decimal\]\] + # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, decimal.Decimal\]\] reveal_type(expr3) - # EXPECTED_TYPE: sqlalchemy..*UnaryExpression\[builtins.int.?\] + # EXPECTED_RE_TYPE: sqlalchemy..*UnaryExpression\[builtins.int.?\] reveal_type(expr4) - # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[builtins.bool.?\] + # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.bool.?\] reveal_type(expr5) - # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[builtins.bool.?\] + # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.bool.?\] reveal_type(expr6) - # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[builtins.str\] + # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.str\] reveal_type(expr7) - # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[builtins.int.?\] + # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.int.?\] reveal_type(expr8) diff --git a/test/ext/mypy/plain_files/trad_relationship_uselist.py b/test/ext/mypy/plain_files/trad_relationship_uselist.py index af7d292be7..4d17dab78d 100644 --- a/test/ext/mypy/plain_files/trad_relationship_uselist.py +++ b/test/ext/mypy/plain_files/trad_relationship_uselist.py @@ -101,45 +101,45 @@ class Address(Base): if typing.TYPE_CHECKING: - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[trad_relationship_uselist.Address\]\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[trad_relationship_uselist.Address\]\] reveal_type(User.addresses_style_one) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[trad_relationship_uselist.Address\]\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[trad_relationship_uselist.Address\]\] reveal_type(User.addresses_style_two) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(User.addresses_style_three) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(User.addresses_style_three_cast) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(User.addresses_style_four) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_one) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\] reveal_type(Address.user_style_one_typed) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_two) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\] reveal_type(Address.user_style_two_typed) # reveal_type(Address.user_style_six) # reveal_type(Address.user_style_seven) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_eight) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_nine) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_ten) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*?\[builtins.str, trad_relationship_uselist.User\]\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*?\[builtins.str, trad_relationship_uselist.User\]\] reveal_type(Address.user_style_ten_typed) diff --git a/test/ext/mypy/plain_files/traditional_relationship.py b/test/ext/mypy/plain_files/traditional_relationship.py index ce131dd004..0edeb694ad 100644 --- a/test/ext/mypy/plain_files/traditional_relationship.py +++ b/test/ext/mypy/plain_files/traditional_relationship.py @@ -60,29 +60,29 @@ class Address(Base): if typing.TYPE_CHECKING: - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.Address\]\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.Address\]\] reveal_type(User.addresses_style_one) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[traditional_relationship.Address\]\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[traditional_relationship.Address\]\] reveal_type(User.addresses_style_two) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_one) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\] reveal_type(Address.user_style_one_typed) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_two) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\] reveal_type(Address.user_style_two_typed) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\] reveal_type(Address.user_style_three) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\] reveal_type(Address.user_style_four) - # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] reveal_type(Address.user_style_five) diff --git a/test/ext/mypy/plain_files/typed_queries.py b/test/ext/mypy/plain_files/typed_queries.py new file mode 100644 index 0000000000..234c4da1d5 --- /dev/null +++ b/test/ext/mypy/plain_files/typed_queries.py @@ -0,0 +1,433 @@ +from __future__ import annotations + +from typing import Tuple + +from sqlalchemy import column +from sqlalchemy import create_engine +from sqlalchemy import delete +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import Select +from sqlalchemy import select +from sqlalchemy import update +from sqlalchemy.orm import aliased +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import Session + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + data: Mapped[str] + + +session = Session() + +e = create_engine("sqlite://") +connection = e.connect() + + +def t_select_1() -> None: + stmt = select(User.id, User.name).filter(User.id == 5) + + # EXPECTED_TYPE: Select[Tuple[int, str]] + reveal_type(stmt) + + result = session.execute(stmt) + + # EXPECTED_TYPE: Result[Tuple[int, str]] + reveal_type(result) + + +def t_select_2() -> None: + stmt = select(User).filter(User.id == 5) + + # EXPECTED_TYPE: Select[Tuple[User]] + reveal_type(stmt) + + result = session.execute(stmt) + + # EXPECTED_TYPE: Result[Tuple[User]] + reveal_type(result) + + +def t_select_3() -> None: + ua = aliased(User) + + # this will fail at runtime, but as we at the moment see aliased(_T) + # as _T, typing tools see the constructor as fine. + # this line would ideally have a typing error but we'd need the ability + # for aliased() to return some namespace of User that's not User. + # AsAliased superclass type was tested for this but it had its own + # awkwardnesses that aren't really worth it + x = ua(id=1, name="foo") + + # EXPECTED_TYPE: Type[User] + reveal_type(ua) + + stmt = select(ua.id, ua.name).filter(User.id == 5) + + # EXPECTED_TYPE: Select[Tuple[int, str]] + reveal_type(stmt) + + result = session.execute(stmt) + + # EXPECTED_TYPE: Result[Tuple[int, str]] + reveal_type(result) + + +def t_select_4() -> None: + ua = aliased(User) + stmt = select(ua, User).filter(User.id == 5) + + # EXPECTED_TYPE: Select[Tuple[User, User]] + reveal_type(stmt) + + result = session.execute(stmt) + + # EXPECTED_TYPE: Result[Tuple[User, User]] + reveal_type(result) + + +def t_legacy_query_single_entity() -> None: + q1 = session.query(User).filter(User.id == 5) + + # EXPECTED_TYPE: Query[User] + reveal_type(q1) + + # EXPECTED_TYPE: User + reveal_type(q1.one()) + + # EXPECTED_TYPE: List[User] + reveal_type(q1.all()) + + # mypy switches to builtins.list for some reason here + # EXPECTED_RE_TYPE: .*\.[Ll]ist\[.*Row\*?\[Tuple\[.*User\]\]\] + reveal_type(q1.only_return_tuples(True).all()) + + # EXPECTED_TYPE: List[Tuple[User]] + reveal_type(q1.tuples().all()) + + +def t_legacy_query_cols_1() -> None: + q1 = session.query(User.id, User.name).filter(User.id == 5) + + # EXPECTED_TYPE: RowReturningQuery[Tuple[int, str]] + reveal_type(q1) + + # EXPECTED_TYPE: Row[Tuple[int, str]] + reveal_type(q1.one()) + + r1 = q1.one() + + x, y = r1.t + + # EXPECTED_TYPE: int + reveal_type(x) + + # EXPECTED_TYPE: str + reveal_type(y) + + +def t_legacy_query_cols_tupleq_1() -> None: + q1 = session.query(User.id, User.name).filter(User.id == 5) + + # EXPECTED_TYPE: RowReturningQuery[Tuple[int, str]] + reveal_type(q1) + + q2 = q1.tuples() + + # EXPECTED_TYPE: Tuple[int, str] + reveal_type(q2.one()) + + r1 = q2.one() + + x, y = r1 + + # EXPECTED_TYPE: int + reveal_type(x) + + # EXPECTED_TYPE: str + reveal_type(y) + + +def t_legacy_query_cols_1_with_entities() -> None: + q1 = session.query(User).filter(User.id == 5) + + # EXPECTED_TYPE: Query[User] + reveal_type(q1) + + q2 = q1.with_entities(User.id, User.name) + + # EXPECTED_TYPE: RowReturningQuery[Tuple[int, str]] + reveal_type(q2) + + # EXPECTED_TYPE: Row[Tuple[int, str]] + reveal_type(q2.one()) + + r1 = q2.one() + + x, y = r1.t + + # EXPECTED_TYPE: int + reveal_type(x) + + # EXPECTED_TYPE: str + reveal_type(y) + + +def t_select_with_only_cols() -> None: + q1 = select(User).where(User.id == 5) + + # EXPECTED_TYPE: Select[Tuple[User]] + reveal_type(q1) + + q2 = q1.with_only_columns(User.id, User.name) + + # EXPECTED_TYPE: Select[Tuple[int, str]] + reveal_type(q2) + + row = connection.execute(q2).one() + + # EXPECTED_TYPE: Row[Tuple[int, str]] + reveal_type(row) + + x, y = row.t + + # EXPECTED_TYPE: int + reveal_type(x) + + # EXPECTED_TYPE: str + reveal_type(y) + + +def t_legacy_query_cols_2() -> None: + a1 = aliased(User) + q1 = session.query(User, a1, User.name).filter(User.id == 5) + + # EXPECTED_TYPE: RowReturningQuery[Tuple[User, User, str]] + reveal_type(q1) + + # EXPECTED_TYPE: Row[Tuple[User, User, str]] + reveal_type(q1.one()) + + r1 = q1.one() + + x, y, z = r1.t + + # EXPECTED_TYPE: User + reveal_type(x) + + # EXPECTED_TYPE: User + reveal_type(y) + + # EXPECTED_TYPE: str + reveal_type(z) + + +def t_legacy_query_cols_2_with_entities() -> None: + + q1 = session.query(User) + + # EXPECTED_TYPE: Query[User] + reveal_type(q1) + + a1 = aliased(User) + q2 = q1.with_entities(User, a1, User.name).filter(User.id == 5) + + # EXPECTED_TYPE: RowReturningQuery[Tuple[User, User, str]] + reveal_type(q2) + + # EXPECTED_TYPE: Row[Tuple[User, User, str]] + reveal_type(q2.one()) + + r1 = q2.one() + + x, y, z = r1.t + + # EXPECTED_TYPE: User + reveal_type(x) + + # EXPECTED_TYPE: User + reveal_type(y) + + # EXPECTED_TYPE: str + reveal_type(z) + + +def t_select_add_col_loses_type() -> None: + q1 = select(User.id, User.name).filter(User.id == 5) + + q2 = q1.add_columns(User.data) + + # note this should not match Select + # EXPECTED_TYPE: Select[Any] + reveal_type(q2) + + +def t_legacy_query_add_col_loses_type() -> None: + q1 = session.query(User.id, User.name).filter(User.id == 5) + + q2 = q1.add_columns(User.data) + + # this should match only Any + # EXPECTED_TYPE: Query[Any] + reveal_type(q2) + + ua = aliased(User) + q3 = q1.add_entity(ua) + + # EXPECTED_TYPE: Query[Any] + reveal_type(q3) + + +def t_legacy_query_scalar_subquery() -> None: + """scalar subquery should receive the type if first element is a + column only""" + q1 = session.query(User.id) + + q2 = q1.scalar_subquery() + + # this should be int but mypy can't see it due to the + # overload that tries to match an entity. + # EXPECTED_RE_TYPE: .*ScalarSelect\[(?:int|Any)\] + reveal_type(q2) + + q3 = session.query(User) + + q4 = q3.scalar_subquery() + + # EXPECTED_TYPE: ScalarSelect[Any] + reveal_type(q4) + + q5 = session.query(User, User.name) + + q6 = q5.scalar_subquery() + + # EXPECTED_TYPE: ScalarSelect[Any] + reveal_type(q6) + + # try to simulate the problem with select() + q7 = session.query(User).only_return_tuples(True) + q8 = q7.scalar_subquery() + + # EXPECTED_TYPE: ScalarSelect[Any] + reveal_type(q8) + + +def t_select_scalar_subquery() -> None: + """scalar subquery should receive the type if first element is a + column only""" + s1 = select(User.id) + s2 = s1.scalar_subquery() + + # this should be int but mypy can't see it due to the + # overload that tries to match an entity. + # EXPECTED_TYPE: ScalarSelect[Any] + reveal_type(s2) + + s3 = select(User) + s4 = s3.scalar_subquery() + + # it's more important that mypy doesn't get a false positive of + # 'User' here + # EXPECTED_TYPE: ScalarSelect[Any] + reveal_type(s4) + + +def t_select_w_core_selectables() -> None: + """things that come from .c. or are FromClause objects currently are not + typed. Make sure we are still getting Select at least. + + """ + s1 = select(User.id, User.name).subquery() + + # EXPECTED_TYPE: KeyedColumnElement[Any] + reveal_type(s1.c.name) + + s2 = select(User.id, s1.c.name) + + # this one unfortunately is not working in mypy. + # pylance gets the correct type + # EXPECTED_TYPE: Select[Tuple[int, Any]] + # when experimenting with having a separate TypedSelect class for typing, + # mypy would downgrade to Any rather than picking the basemost type. + # with typing integrated into Select etc. we can at least get a Select + # object back. + # EXPECTED_TYPE: Select[Any] + reveal_type(s2) + + # so a fully explicit type may be given + s2_typed: Select[Tuple[int, str]] = select(User.id, s1.c.name) + + # EXPECTED_TYPE: Select[Tuple[int, str]] + reveal_type(s2_typed) + + # plain FromClause etc we at least get Select + s3 = select(s1) + + # EXPECTED_TYPE: Select[Any] + reveal_type(s3) + + t1 = User.__table__ + assert t1 is not None + + # EXPECTED_TYPE: FromClause + reveal_type(t1) + + s4 = select(t1) + + # EXPECTED_TYPE: Select[Any] + reveal_type(s4) + + +def t_dml_insert() -> None: + s1 = insert(User).returning(User.id, User.name) + + r1 = session.execute(s1) + + # EXPECTED_TYPE: Result[Tuple[int, str]] + reveal_type(r1) + + s2 = insert(User).returning(User) + + r2 = session.execute(s2) + + # EXPECTED_TYPE: Result[Tuple[User]] + reveal_type(r2) + + s3 = insert(User).returning(func.foo(), column("q")) + + # EXPECTED_TYPE: ReturningInsert[Any] + reveal_type(s3) + + r3 = session.execute(s3) + + # EXPECTED_TYPE: Result[Any] + reveal_type(r3) + + +def t_dml_update() -> None: + s1 = update(User).returning(User.id, User.name) + + r1 = session.execute(s1) + + # EXPECTED_TYPE: Result[Tuple[int, str]] + reveal_type(r1) + + +def t_dml_delete() -> None: + s1 = delete(User).returning(User.id, User.name) + + r1 = session.execute(s1) + + # EXPECTED_TYPE: Result[Tuple[int, str]] + reveal_type(r1) diff --git a/test/ext/mypy/plain_files/typed_results.py b/test/ext/mypy/plain_files/typed_results.py new file mode 100644 index 0000000000..1eecb4c75f --- /dev/null +++ b/test/ext/mypy/plain_files/typed_results.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import asyncio +from typing import cast + +from sqlalchemy import Column +from sqlalchemy import column +from sqlalchemy import create_engine +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import table +from sqlalchemy.ext.asyncio import AsyncConnection +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import aliased +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import Session + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + +e = create_engine("sqlite://") +ae = create_async_engine("sqlite+aiosqlite://") + + +connection = e.connect() +session = Session(connection) + + +async def async_connect() -> AsyncConnection: + return await ae.connect() + + +# the thing with the \*? seems like it could go away +# as of mypy 0.950 + +async_connection = asyncio.run(async_connect()) + +# EXPECTED_RE_TYPE: sqlalchemy..*AsyncConnection\*? +reveal_type(async_connection) + +async_session = AsyncSession(async_connection) + + +# EXPECTED_RE_TYPE: sqlalchemy..*AsyncSession\*? +reveal_type(async_session) + + +single_stmt = select(User.name).where(User.name == "foo") + +# EXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[Tuple\[builtins.str\*?\]\] +reveal_type(single_stmt) + +multi_stmt = select(User.id, User.name).where(User.name == "foo") + +# EXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[Tuple\[builtins.int\*?, builtins.str\*?\]\] +reveal_type(multi_stmt) + + +def t_entity_varieties() -> None: + + a1 = aliased(User) + + s1 = select(User.id, User, User.name).where(User.name == "foo") + + r1 = session.execute(s1) + + # EXPECTED_RE_TYPE: sqlalchemy..*.Result\[Tuple\[builtins.int\*?, typed_results.User\*?, builtins.str\*?\]\] + reveal_type(r1) + + s2 = select(User, a1).where(User.name == "foo") + + r2 = session.execute(s2) + + # EXPECTED_RE_TYPE: sqlalchemy.*Result\[Tuple\[typed_results.User\*?, typed_results.User\*?\]\] + reveal_type(r2) + + row = r2.t.one() + + # EXPECTED_RE_TYPE: .*typed_results.User\*? + reveal_type(row[0]) + # EXPECTED_RE_TYPE: .*typed_results.User\*? + reveal_type(row[1]) + + # testing that plain Mapped[x] gets picked up as well as + # aliased class + # there is unfortunately no way for attributes on an AliasedClass to be + # automatically typed since they are dynamically generated + a1_id = cast(Mapped[int], a1.id) + s3 = select(User.id, a1_id, a1, User).where(User.name == "foo") + # EXPECTED_RE_TYPE: sqlalchemy.*Select\*?\[Tuple\[builtins.int\*?, builtins.int\*?, typed_results.User\*?, typed_results.User\*?\]\] + reveal_type(s3) + + # testing Mapped[entity] + some_mp = cast(Mapped[User], object()) + s4 = select(some_mp, a1, User).where(User.name == "foo") + + # NOTEXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[Tuple\[typed_results.User\*?, typed_results.User\*?, typed_results.User\*?\]\] + + # sqlalchemy.sql._gen_overloads.Select[Tuple[typed_results.User, typed_results.User, typed_results.User]] + + # EXPECTED_TYPE: Select[Tuple[User, User, User]] + reveal_type(s4) + + # test plain core expressions + x = Column("x", Integer) + y = x + 5 + + s5 = select(x, y, User.name + "hi") + + # EXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[Tuple\[builtins.int\*?, builtins.int\*?\, builtins.str\*?]\] + reveal_type(s5) + + +def t_ambiguous_result_type_one() -> None: + stmt = select(column("q", Integer), table("x", column("y"))) + + # EXPECTED_TYPE: Select[Any] + reveal_type(stmt) + + result = session.execute(stmt) + + # EXPECTED_TYPE: Result[Any] + reveal_type(result) + + +def t_ambiguous_result_type_two() -> None: + + stmt = select(column("q")) + + # EXPECTED_TYPE: Select[Tuple[Any]] + reveal_type(stmt) + result = session.execute(stmt) + + # EXPECTED_TYPE: Result[Any] + reveal_type(result) + + +def t_aliased() -> None: + + a1 = aliased(User) + + s1 = select(a1) + # EXPECTED_TYPE: Select[Tuple[User]] + reveal_type(s1) + + s4 = select(a1.name, a1, a1, User).where(User.name == "foo") + # EXPECTED_TYPE: Select[Tuple[str, User, User, User]] + reveal_type(s4) + + +def t_result_scalar_accessors() -> None: + result = connection.execute(single_stmt) + + r1 = result.scalar() + + # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\] + reveal_type(r1) + + r2 = result.scalar_one() + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(r2) + + r3 = result.scalar_one_or_none() + + # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\] + reveal_type(r3) + + r4 = result.scalars() + + # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\] + reveal_type(r4) + + r5 = result.scalars(0) + + # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\] + reveal_type(r5) + + +async def t_async_result_scalar_accessors() -> None: + result = await async_connection.stream(single_stmt) + + r1 = await result.scalar() + + # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\] + reveal_type(r1) + + r2 = await result.scalar_one() + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(r2) + + r3 = await result.scalar_one_or_none() + + # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\] + reveal_type(r3) + + r4 = result.scalars() + + # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\] + reveal_type(r4) + + r5 = result.scalars(0) + + # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\] + reveal_type(r5) + + +def t_connection_execute_multi_row_t() -> None: + result = connection.execute(multi_stmt) + + # EXPECTED_RE_TYPE: sqlalchemy.*CursorResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\] + reveal_type(result) + row = result.one() + + # EXPECTED_RE_TYPE: sqlalchemy.*Row\[Tuple\[builtins.int\*?, builtins.str\*?\]\] + reveal_type(row) + + x, y = row.t + + # EXPECTED_RE_TYPE: builtins.int\*? + reveal_type(x) + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(y) + + +def t_connection_execute_multi() -> None: + result = connection.execute(multi_stmt).t + + # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\] + reveal_type(result) + row = result.one() + + # EXPECTED_RE_TYPE: Tuple\[builtins.int\*?, builtins.str\*?\] + reveal_type(row) + + x, y = row + + # EXPECTED_RE_TYPE: builtins.int\*? + reveal_type(x) + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(y) + + +def t_connection_execute_single() -> None: + result = connection.execute(single_stmt).t + + # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.str\*?\]\] + reveal_type(result) + row = result.one() + + # EXPECTED_RE_TYPE: Tuple\[builtins.str\*?\] + reveal_type(row) + + (x,) = row + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(x) + + +def t_connection_execute_single_row_scalar() -> None: + result = connection.execute(single_stmt).t + + # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.str\*?\]\] + reveal_type(result) + + x = result.scalar() + + # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\] + reveal_type(x) + + +def t_connection_scalar() -> None: + obj = connection.scalar(single_stmt) + + # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\] + reveal_type(obj) + + +def t_connection_scalars() -> None: + result = connection.scalars(single_stmt) + + # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\[builtins.str\*?\] + reveal_type(result) + data = result.all() + + # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\] + reveal_type(data) + + +def t_session_execute_multi() -> None: + result = session.execute(multi_stmt).t + + # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\] + reveal_type(result) + row = result.one() + + # EXPECTED_RE_TYPE: Tuple\[builtins.int\*?, builtins.str\*?\] + reveal_type(row) + + x, y = row + + # EXPECTED_RE_TYPE: builtins.int\*? + reveal_type(x) + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(y) + + +def t_session_execute_single() -> None: + result = session.execute(single_stmt).t + + # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.str\*?\]\] + reveal_type(result) + row = result.one() + + # EXPECTED_RE_TYPE: Tuple\[builtins.str\*?\] + reveal_type(row) + + (x,) = row + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(x) + + +def t_session_scalar() -> None: + obj = session.scalar(single_stmt) + + # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\] + reveal_type(obj) + + +def t_session_scalars() -> None: + result = session.scalars(single_stmt) + + # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\[builtins.str\*?\] + reveal_type(result) + data = result.all() + + # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\] + reveal_type(data) + + +async def t_async_connection_execute_multi() -> None: + result = (await async_connection.execute(multi_stmt)).t + + # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\] + reveal_type(result) + row = result.one() + + # EXPECTED_RE_TYPE: Tuple\[builtins.int\*?, builtins.str\*?\] + reveal_type(row) + + x, y = row + + # EXPECTED_RE_TYPE: builtins.int\*? + reveal_type(x) + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(y) + + +async def t_async_connection_execute_single() -> None: + result = (await async_connection.execute(single_stmt)).t + + # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.str\*?\]\] + reveal_type(result) + + row = result.one() + + # EXPECTED_RE_TYPE: Tuple\[builtins.str\*?\] + reveal_type(row) + + (x,) = row + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(x) + + +async def t_async_connection_scalar() -> None: + obj = await async_connection.scalar(single_stmt) + + # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\] + reveal_type(obj) + + +async def t_async_connection_scalars() -> None: + result = await async_connection.scalars(single_stmt) + + # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\*?\[builtins.str\*?\] + reveal_type(result) + data = result.all() + + # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\] + reveal_type(data) + + +async def t_async_session_execute_multi() -> None: + result = (await async_session.execute(multi_stmt)).t + + # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\] + reveal_type(result) + row = result.one() + + # EXPECTED_RE_TYPE: Tuple\[builtins.int\*?, builtins.str\*?\] + reveal_type(row) + + x, y = row + + # EXPECTED_RE_TYPE: builtins.int\*? + reveal_type(x) + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(y) + + +async def t_async_session_execute_single() -> None: + result = (await async_session.execute(single_stmt)).t + + # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.str\*?\]\] + reveal_type(result) + row = result.one() + + # EXPECTED_RE_TYPE: Tuple\[builtins.str\*?\] + reveal_type(row) + + (x,) = row + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(x) + + +async def t_async_session_scalar() -> None: + obj = await async_session.scalar(single_stmt) + + # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\] + reveal_type(obj) + + +async def t_async_session_scalars() -> None: + result = await async_session.scalars(single_stmt) + + # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\*?\[builtins.str\*?\] + reveal_type(result) + data = result.all() + + # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\] + reveal_type(data) + + +async def t_async_connection_stream_multi() -> None: + result = (await async_connection.stream(multi_stmt)).t + + # EXPECTED_RE_TYPE: sqlalchemy.*AsyncTupleResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\] + reveal_type(result) + row = await result.one() + + # EXPECTED_RE_TYPE: Tuple\[builtins.int\*?, builtins.str\*?\] + reveal_type(row) + + x, y = row + + # EXPECTED_RE_TYPE: builtins.int\*? + reveal_type(x) + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(y) + + +async def t_async_connection_stream_single() -> None: + result = (await async_connection.stream(single_stmt)).t + + # EXPECTED_RE_TYPE: sqlalchemy.*AsyncTupleResult\[Tuple\[builtins.str\*?\]\] + reveal_type(result) + row = await result.one() + + # EXPECTED_RE_TYPE: Tuple\[builtins.str\*?\] + reveal_type(row) + + (x,) = row + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(x) + + +async def t_async_connection_stream_scalars() -> None: + result = await async_connection.stream_scalars(single_stmt) + + # EXPECTED_RE_TYPE: sqlalchemy.*AsyncScalarResult\*?\[builtins.str\*?\] + reveal_type(result) + data = await result.all() + + # EXPECTED_RE_TYPE: typing.Sequence\*?\[builtins.str\*?\] + reveal_type(data) + + +async def t_async_session_stream_multi() -> None: + result = (await async_session.stream(multi_stmt)).t + + # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\] + reveal_type(result) + row = await result.one() + + # EXPECTED_RE_TYPE: Tuple\[builtins.int\*?, builtins.str\*?\] + reveal_type(row) + + x, y = row + + # EXPECTED_RE_TYPE: builtins.int\*? + reveal_type(x) + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(y) + + +async def t_async_session_stream_single() -> None: + result = (await async_session.stream(single_stmt)).t + + # EXPECTED_RE_TYPE: sqlalchemy.*AsyncTupleResult\[Tuple\[builtins.str\*?\]\] + reveal_type(result) + row = await result.one() + + # EXPECTED_RE_TYPE: Tuple\[builtins.str\*?\] + reveal_type(row) + + (x,) = row + + # EXPECTED_RE_TYPE: builtins.str\*? + reveal_type(x) + + +async def t_async_session_stream_scalars() -> None: + result = await async_session.stream_scalars(single_stmt) + + # EXPECTED_RE_TYPE: sqlalchemy.*AsyncScalarResult\*?\[builtins.str\*?\] + reveal_type(result) + data = await result.all() + + # EXPECTED_RE_TYPE: typing.Sequence\*?\[builtins.str\*?\] + reveal_type(data) diff --git a/test/ext/mypy/plugin_files/dataclasses_workaround.py b/test/ext/mypy/plugin_files/dataclasses_workaround.py index 9928b5a335..8ad69dbd0f 100644 --- a/test/ext/mypy/plugin_files/dataclasses_workaround.py +++ b/test/ext/mypy/plugin_files/dataclasses_workaround.py @@ -66,5 +66,5 @@ class Address: _mypy_mapped_attrs = [id, user_id, email_address] -stmt = select(User.name).where(User.id.in_([1, 2, 3])) -stmt = select(Address).where(Address.email_address.contains(["foo"])) +stmt1 = select(User.name).where(User.id.in_([1, 2, 3])) +stmt2 = select(Address).where(Address.email_address.contains(["foo"])) diff --git a/test/ext/mypy/test_mypy_plugin_py3k.py b/test/ext/mypy/test_mypy_plugin_py3k.py index 3a932021de..1086f187af 100644 --- a/test/ext/mypy/test_mypy_plugin_py3k.py +++ b/test/ext/mypy/test_mypy_plugin_py3k.py @@ -233,6 +233,41 @@ class MypyPluginTest(fixtures.TestBase): expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4)) if is_type: + if not is_re: + # the goal here is that we can cut-and-paste + # from vscode -> pylance into the + # EXPECTED_TYPE: line, then the test suite will + # validate that line against what mypy produces + expected_msg = re.sub( + r"([\[\]])", + lambda m: rf"\{m.group(0)}", + expected_msg, + ) + + # note making sure preceding text matches + # with a dot, so that an expect for "Select" + # does not match "TypedSelect" + expected_msg = re.sub( + r"([\w_]+)", + lambda m: rf"(?:.*\.)?{m.group(1)}\*?", + expected_msg, + ) + + expected_msg = re.sub( + "List", "builtins.list", expected_msg + ) + + expected_msg = re.sub( + r"(int|str|float|bool)", + lambda m: rf"builtins.{m.group(0)}\*?", + expected_msg, + ) + # expected_msg = re.sub( + # r"(Sequence|Tuple|List|Union)", + # lambda m: fr"typing.{m.group(0)}\*?", + # expected_msg, + # ) + is_mypy = is_re = True expected_msg = f'Revealed type is "{expected_msg}"' current_assert_messages.append( @@ -295,7 +330,7 @@ class MypyPluginTest(fixtures.TestBase): del output[idx] if output: - print("messages from mypy that were not consumed:") + print(f"{len(output)} messages from mypy were not consumed:") print("\n".join(msg for _, msg in output)) assert False, "errors and/or notes remain, see stdout" diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 9585da125b..01a8698a4a 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -492,7 +492,7 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL): assert_raises_message( sa_exc.ArgumentError, - "Column expression or FROM clause expected, got " + "Column expression, FROM clause, or other .* expected, got " " object resolved from " " object. To create a FROM clause from " "a object", diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 55414364c5..8374d05d4d 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -186,6 +186,14 @@ class OnlyReturnTuplesTest(QueryTest): assert isinstance(row, collections_abc.Sequence) assert isinstance(row._mapping, collections_abc.Mapping) + def test_single_entity_tuples(self): + User = self.classes.User + query = fixture_session().query(User).tuples() + is_false(query.is_single_entity) + row = query.first() + assert isinstance(row, collections_abc.Sequence) + assert isinstance(row._mapping, collections_abc.Mapping) + def test_multiple_entity_false(self): User = self.classes.User query = ( @@ -204,6 +212,14 @@ class OnlyReturnTuplesTest(QueryTest): assert isinstance(row, collections_abc.Sequence) assert isinstance(row._mapping, collections_abc.Mapping) + def test_multiple_entity_true(self): + User = self.classes.User + query = fixture_session().query(User.id, User).tuples() + is_false(query.is_single_entity) + row = query.first() + assert isinstance(row, collections_abc.Sequence) + assert isinstance(row._mapping, collections_abc.Mapping) + class RowTupleTest(QueryTest, AssertsCompiledSQL): run_setup_mappers = None diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index b175f96633..1a070fcf52 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -22,6 +22,7 @@ from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import schema +from sqlalchemy import select from sqlalchemy import Sequence from sqlalchemy import String from sqlalchemy import Table @@ -1337,6 +1338,20 @@ class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables): self._assert_fk(t2, None, "p.t1.x", referred_schema_fn=ref_fn) + def test_fk_get_referent_is_always_a_column(self): + """test the annotation on ForeignKey.get_referent() in that it does + in fact return Column even if given a labeled expr in a subquery""" + + m = MetaData() + a = Table("a", m, Column("id", Integer, primary_key=True)) + b = Table("b", m, Column("aid", Integer, ForeignKey("a.id"))) + + stmt = select(a.c.id.label("somelabel")).subquery() + + referent = list(b.c.aid.foreign_keys)[0].get_referent(stmt) + is_(referent, stmt.c.somelabel) + assert isinstance(referent, Column) + def test_copy_info(self): m = MetaData() fk = ForeignKey("t2.id") diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index ff70fc184a..cb9f930180 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -156,6 +156,50 @@ class CursorResultTest(fixtures.TablesTest): rows.append(row) eq_(len(rows), 3) + def test_scalars(self, connection): + users = self.tables.users + + connection.execute( + users.insert(), + [ + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + {"user_id": 9, "user_name": "fred"}, + ], + ) + r = connection.scalars(users.select().order_by(users.c.user_id)) + eq_(r.all(), [7, 8, 9]) + + def test_result_tuples(self, connection): + users = self.tables.users + + connection.execute( + users.insert(), + [ + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + {"user_id": 9, "user_name": "fred"}, + ], + ) + r = connection.execute( + users.select().order_by(users.c.user_id) + ).tuples() + eq_(r.all(), [(7, "jack"), (8, "ed"), (9, "fred")]) + + def test_row_tuple(self, connection): + users = self.tables.users + + connection.execute( + users.insert(), + [ + {"user_id": 7, "user_name": "jack"}, + {"user_id": 8, "user_name": "ed"}, + {"user_id": 9, "user_name": "fred"}, + ], + ) + r = connection.execute(users.select().order_by(users.c.user_id)) + eq_([row.t for row in r], [(7, "jack"), (8, "ed"), (9, "fred")]) + def test_row_next(self, connection): users = self.tables.users diff --git a/test/sql/test_select.py b/test/sql/test_select.py index be64e205e4..d91e50e637 100644 --- a/test/sql/test_select.py +++ b/test/sql/test_select.py @@ -61,7 +61,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): def test_old_bracket_style_fail(self): with expect_raises_message( exc.ArgumentError, - r"Column expression or FROM clause expected, " + r"Column expression, FROM clause, or other columns clause .*" r".*Did you mean to say", ): select([table1.c.myid]) diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 619cbd863b..e93900bbda 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -25,6 +25,7 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock from sqlalchemy.testing.schema import Column @@ -1016,6 +1017,46 @@ class UpdateFromCompileTest( dialect="mysql", ) + def test_update_from_join_unsupported_cases(self): + """ + found_during_typing + + It's unclear how to cleanly guard against this case without producing + false positives, particularly due to the support for UPDATE + of a CTE. I'm also not sure of the nature of the failure and why + it happens this way. + + """ + users, addresses = self.tables.users, self.tables.addresses + + j = users.join(addresses) + + with expect_raises_message( + exc.CompileError, + r"Encountered unsupported case when compiling an INSERT or UPDATE " + r"statement. If this is a multi-table " + r"UPDATE statement, please provide string-named arguments to the " + r"values\(\) method with distinct names; support for multi-table " + r"UPDATE statements that " + r"target multiple tables for UPDATE is very limited", + ): + update(j).where(addresses.c.email_address == "e1").values( + {users.c.id: 10, addresses.c.email_address: "asdf"} + ).compile(dialect=mysql.dialect()) + + with expect_raises_message( + exc.CompileError, + r"Encountered unsupported case when compiling an INSERT or UPDATE " + r"statement. If this is a multi-table " + r"UPDATE statement, please provide string-named arguments to the " + r"values\(\) method with distinct names; support for multi-table " + r"UPDATE statements that " + r"target multiple tables for UPDATE is very limited", + ): + update(j).where(addresses.c.email_address == "e1").compile( + dialect=mysql.dialect() + ) + def test_update_from_join_mysql_whereclause(self): users, addresses = self.tables.users, self.tables.addresses diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py index ffc470972f..91a8918824 100644 --- a/tools/generate_proxy_methods.py +++ b/tools/generate_proxy_methods.py @@ -31,6 +31,12 @@ A similar approach is used in Alembic where a dynamic approach towards creating alembic "ops" was enhanced to generate a .pyi stubs file statically for consumption by typing tools. +Note that the usual OO approach of having a common interface class with +concrete subtypes doesn't really solve any problems here; the concrete subtypes +must still list out all methods, arguments, typing annotations, and docstrings, +all of which is copied by this script rather than requiring it all be +typed by hand. + .. versionadded:: 2.0 """ @@ -43,9 +49,7 @@ import inspect import os from pathlib import Path import re -import shlex import shutil -import subprocess import sys from tempfile import NamedTemporaryFile import textwrap @@ -61,6 +65,7 @@ from typing import TypeVar from sqlalchemy import util from sqlalchemy.util import compat from sqlalchemy.util import langhelpers +from sqlalchemy.util.langhelpers import console_scripts from sqlalchemy.util.langhelpers import format_argspec_plus from sqlalchemy.util.langhelpers import inject_docstring_text @@ -122,6 +127,47 @@ def create_proxy_methods( return decorate +def _grab_overloads(fn): + """grab @overload entries for a function, assuming black-formatted + code ;) so that we can do a simple regex + + """ + + # functions that use @util.deprecated and whatnot will have a string + # generated fn. we can look at __wrapped__ but these functions don't + # have any overloads in any case right now so skip + if fn.__code__.co_filename == "": + return [] + + with open(fn.__code__.co_filename) as f: + lines = [l for i, l in zip(range(fn.__code__.co_firstlineno), f)] + + lines.reverse() + + output = [] + + current_ov = [] + for line in lines[1:]: + current_ov.append(line) + outside_block_match = re.match(r"^\w", line) + if outside_block_match: + current_ov[:] = [] + break + + fn_match = re.match(rf"^ (?:async )?def (.*)\($", line) + if fn_match and fn_match.group(1) != fn.__name__: + current_ov[:] = [] + break + + ov_match = re.match(r"^ @overload$", line) + if ov_match: + output.append("".join(reversed(current_ov))) + current_ov[:] = [] + + output.reverse() + return output + + def process_class( buf: TextIO, target_cls: Type[Any], @@ -145,6 +191,12 @@ def process_class( def instrument(buf: TextIO, name: str, clslevel: bool = False) -> None: fn = getattr(target_cls, name) + + overloads = _grab_overloads(fn) + + for overload in overloads: + buf.write(overload) + spec = compat.inspect_getfullargspec(fn) iscoroutine = inspect.iscoroutinefunction(fn) @@ -311,7 +363,7 @@ def process_module(modname: str, filename: str) -> str: "\n # code within this block is " "**programmatically, \n" " # statically generated** by" - " tools/generate_proxy_methods.py\n\n" + f" tools/{os.path.basename(__file__)}\n\n" ) process_class(buf, *args) @@ -323,41 +375,6 @@ def process_module(modname: str, filename: str) -> str: return buf.name -def console_scripts( - path: str, options: dict, ignore_output: bool = False -) -> None: - - entrypoint_name = options["entrypoint"] - - for entry in compat.importlib_metadata_get("console_scripts"): - if entry.name == entrypoint_name: - impl = entry - break - else: - raise Exception( - f"Could not find entrypoint console_scripts.{entrypoint_name}" - ) - cmdline_options_str = options.get("options", "") - cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [ - path - ] - - kw = {} - if ignore_output: - kw["stdout"] = kw["stderr"] = subprocess.DEVNULL - - subprocess.run( - [ - sys.executable, - "-c", - "import %s; %s.%s()" % (impl.module, impl.module, impl.attr), - ] - + cmdline_options_list, - cwd=Path(__file__).parent.parent, - **kw, - ) - - def run_module(modname, stdout): sys.stderr.write(f"importing module {modname}\n") diff --git a/tools/generate_tuple_map_overloads.py b/tools/generate_tuple_map_overloads.py new file mode 100644 index 0000000000..ff8f37840a --- /dev/null +++ b/tools/generate_tuple_map_overloads.py @@ -0,0 +1,174 @@ +r"""Generate tuple mapping overloads. + +the problem solved by this script is that of there's no way in current +pep-484 typing to unpack \*args: _T into Tuple[_T]. pep-646 is the first +pep to provide this, but it doesn't work for the actual Tuple class +and also mypy does not have support for pep-646 as of yet. Better pep-646 +support would allow us to use a TypeVarTuple with Unpack, but TypeVarTuple +does not have support for sequence operations like ``__getitem__`` and +iteration; there's also no way for TypeVarTuple to be translated back to a +Tuple which does have those things without a combinatoric hardcoding approach +to each length of tuple. + +So here, the script creates a map from `*args` to a Tuple directly using a +combinatoric generated code approach. + +.. versionadded:: 2.0 + +""" +from __future__ import annotations + +from argparse import ArgumentParser +import importlib +import os +from pathlib import Path +import re +import shutil +import sys +from tempfile import NamedTemporaryFile +import textwrap + +from sqlalchemy.util.langhelpers import console_scripts + +is_posix = os.name == "posix" + + +sys.path.append(str(Path(__file__).parent.parent)) + + +def process_module(modname: str, filename: str) -> str: + + # use tempfile in same path as the module, or at least in the + # current working directory, so that black / zimports use + # local pyproject.toml + with NamedTemporaryFile( + mode="w", delete=False, suffix=".py", dir=Path(filename).parent + ) as buf, open(filename) as orig_py: + indent = "" + in_block = False + current_fnname = given_fnname = None + for line in orig_py: + m = re.match( + r"^( *)# START OVERLOADED FUNCTIONS ([\.\w_]+) ([\w_]+) (\d+)-(\d+)$", # noqa: E501 + line, + ) + if m: + indent = m.group(1) + given_fnname = current_fnname = m.group(2) + if current_fnname.startswith("self."): + use_self = True + current_fnname = current_fnname.split(".")[1] + else: + use_self = False + return_type = m.group(3) + start_index = int(m.group(4)) + end_index = int(m.group(5)) + + sys.stderr.write( + f"Generating {start_index}-{end_index} overloads " + f"attributes for " + f"class {'self.' if use_self else ''}{current_fnname} " + f"-> {return_type}\n" + ) + in_block = True + buf.write(line) + buf.write( + "\n # code within this block is " + "**programmatically, \n" + " # statically generated** by" + f" tools/{os.path.basename(__file__)}\n\n" + ) + + for num_args in range(start_index, end_index + 1): + combinations = [ + [ + f"__ent{arg}: _TCCA[_T{arg}]" + for arg in range(num_args) + ] + ] + for combination in combinations: + buf.write( + textwrap.indent( + f""" +@overload +def {current_fnname}( + {'self, ' if use_self else ''}{", ".join(combination)} +) -> {return_type}[Tuple[{', '.join(f'_T{i}' for i in range(num_args))}]]: + ... + +""", # noqa: E501 + indent, + ) + ) + + if in_block and line.startswith( + f"{indent}# END OVERLOADED FUNCTIONS {given_fnname}" + ): + in_block = False + + if not in_block: + buf.write(line) + return buf.name + + +def run_module(modname, stdout): + + sys.stderr.write(f"importing module {modname}\n") + mod = importlib.import_module(modname) + filename = destination_path = mod.__file__ + assert filename is not None + + tempfile = process_module(modname, filename) + + ignore_output = stdout + + console_scripts( + str(tempfile), + {"entrypoint": "zimports"}, + ignore_output=ignore_output, + ) + + console_scripts( + str(tempfile), + {"entrypoint": "black"}, + ignore_output=ignore_output, + ) + + if stdout: + with open(tempfile) as tf: + print(tf.read()) + os.unlink(tempfile) + else: + sys.stderr.write(f"Writing {destination_path}...\n") + shutil.move(tempfile, destination_path) + + +def main(args): + for modname in entries: + if args.module in {"all", modname}: + run_module(modname, args.stdout) + + +entries = [ + "sqlalchemy.sql._selectable_constructors", + "sqlalchemy.orm.session", + "sqlalchemy.orm.query", + "sqlalchemy.sql.selectable", + "sqlalchemy.sql.dml", +] + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument( + "--module", + choices=entries + ["all"], + default="all", + help="Which file to generate. Default is to regenerate all files", + ) + parser.add_argument( + "--stdout", + action="store_true", + help="Write to stdout instead of saving to file", + ) + args = parser.parse_args() + main(args) -- 2.47.2