]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pep484 ORM / SQL result support
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Apr 2022 01:06:41 +0000 (21:06 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 27 Apr 2022 18:46:36 +0000 (14:46 -0400)
after some experimentation it seems mypy is more amenable
to the generic types being fully integrated rather than
having separate spin-off types.   so key structures
like Result, Row, Select become generic.  For DML
Insert, Update, Delete, these are spun into type-specific
subclasses ReturningInsert, ReturningUpdate, ReturningDelete,
which is fine since the "row-ness" of these constructs
doesn't happen until returning() is called in any case.

a Tuple based model is then integrated so that these
objects can carry along information about their return
types.  Overloads at the .execute() level carry through
the Tuple from the invoked object to the result.

To suit the issue of AliasedClass generating attributes
that are dynamic, experimented with a custom subclass
AsAliased, but then just settled on having aliased()
lie to the type checker and return `Type[_O]`, essentially.
will need some type-related accessors for with_polymorphic()
also.

Additionally, identified an issue in Update when used
"mysql style" against a join(), it basically doesn't work
if asked to UPDATE two tables on the same column name.
added an error message to the specific condition where
it happens with a very non-specific error message that we
hit a thing we can't do right now, suggest multi-table
update as a possible cause.

Change-Id: I5eff7eefe1d6166ee74160b2785c5e6a81fa8b95

65 files changed:
lib/sqlalchemy/engine/__init__.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/events.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/engine/row.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/result.py
lib/sqlalchemy/ext/asyncio/scoping.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/ext/instrumentation.py
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/__init__.py
lib/sqlalchemy/sql/_selectable_constructors.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/roles.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/typing.py
pyproject.toml
test/base/test_result.py
test/ext/mypy/plain_files/association_proxy_one.py
test/ext/mypy/plain_files/engine_inspection.py
test/ext/mypy/plain_files/experimental_relationship.py
test/ext/mypy/plain_files/hybrid_one.py
test/ext/mypy/plain_files/hybrid_two.py
test/ext/mypy/plain_files/session.py
test/ext/mypy/plain_files/sql_operations.py
test/ext/mypy/plain_files/trad_relationship_uselist.py
test/ext/mypy/plain_files/traditional_relationship.py
test/ext/mypy/plain_files/typed_queries.py [new file with mode: 0644]
test/ext/mypy/plain_files/typed_results.py [new file with mode: 0644]
test/ext/mypy/plugin_files/dataclasses_workaround.py
test/ext/mypy/test_mypy_plugin_py3k.py
test/orm/test_froms.py
test/orm/test_query.py
test/sql/test_metadata.py
test/sql/test_resultset.py
test/sql/test_select.py
test/sql/test_update.py
tools/generate_proxy_methods.py
tools/generate_tuple_map_overloads.py [new file with mode: 0644]

index 29dd6aff90489e1cad37541af1803dd01d81405b..afba1707595f1fe136db5e858ac3a8247910fd4b 100644 (file)
@@ -46,6 +46,7 @@ from .result import MergedResult as MergedResult
 from .result import Result as Result
 from .result import result_tuple as result_tuple
 from .result import ScalarResult as ScalarResult
+from .result import TupleResult as TupleResult
 from .row import BaseRow as BaseRow
 from .row import Row as Row
 from .row import RowMapping as RowMapping
index a325da929b73ef6e163f24c5b04ded22d5b6fea8..fe3bfa1adfcc1c9ffd62b51afa665d7bb1679b46 100644 (file)
@@ -18,8 +18,10 @@ from typing import Mapping
 from typing import MutableMapping
 from typing import NoReturn
 from typing import Optional
+from typing import overload
 from typing import Tuple
 from typing import Type
+from typing import TypeVar
 from typing import Union
 
 from .interfaces import _IsolationLevel
@@ -45,12 +47,10 @@ if typing.TYPE_CHECKING:
     from . import ScalarResult
     from .interfaces import _AnyExecuteParams
     from .interfaces import _AnyMultiExecuteParams
-    from .interfaces import _AnySingleExecuteParams
     from .interfaces import _CoreAnyExecuteParams
     from .interfaces import _CoreMultiExecuteParams
     from .interfaces import _CoreSingleExecuteParams
     from .interfaces import _DBAPIAnyExecuteParams
-    from .interfaces import _DBAPIMultiExecuteParams
     from .interfaces import _DBAPISingleExecuteParams
     from .interfaces import _ExecuteOptions
     from .interfaces import _ExecuteOptionsParameter
@@ -65,21 +65,21 @@ if typing.TYPE_CHECKING:
     from ..pool import PoolProxiedConnection
     from ..sql import Executable
     from ..sql._typing import _InfoType
-    from ..sql.base import SchemaVisitor
     from ..sql.compiler import Compiled
     from ..sql.ddl import ExecutableDDLElement
     from ..sql.ddl import SchemaDropper
     from ..sql.ddl import SchemaGenerator
     from ..sql.functions import FunctionElement
-    from ..sql.schema import ColumnDefault
     from ..sql.schema import DefaultGenerator
     from ..sql.schema import HasSchemaAttr
     from ..sql.schema import SchemaItem
+    from ..sql.selectable import TypedReturnsRows
 
 """Defines :class:`_engine.Connection` and :class:`_engine.Engine`.
 
 """
 
+_T = TypeVar("_T", bound=Any)
 _EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.EMPTY_DICT
 NO_OPTIONS: Mapping[str, Any] = util.EMPTY_DICT
 
@@ -1142,10 +1142,31 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
             self._dbapi_connection = None
         self.__can_reconnect = False
 
+    @overload
+    def scalar(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> Optional[_T]:
+        ...
+
+    @overload
+    def scalar(
+        self,
+        statement: Executable,
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> Any:
+        ...
+
     def scalar(
         self,
         statement: Executable,
         parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
     ) -> Any:
         r"""Executes a SQL statement construct and returns a scalar object.
@@ -1170,10 +1191,31 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
                 execution_options or NO_OPTIONS,
             )
 
+    @overload
+    def scalars(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> ScalarResult[_T]:
+        ...
+
+    @overload
     def scalars(
         self,
         statement: Executable,
         parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> ScalarResult[Any]:
+        ...
+
+    def scalars(
+        self,
+        statement: Executable,
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
     ) -> ScalarResult[Any]:
         """Executes and returns a scalar result set, which yields scalar values
@@ -1190,14 +1232,37 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
 
         """
 
-        return self.execute(statement, parameters, execution_options).scalars()
+        return self.execute(
+            statement, parameters, execution_options=execution_options
+        ).scalars()
+
+    @overload
+    def execute(
+        self,
+        statement: TypedReturnsRows[_T],
+        parameters: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> CursorResult[_T]:
+        ...
+
+    @overload
+    def execute(
+        self,
+        statement: Executable,
+        parameters: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> CursorResult[Any]:
+        ...
 
     def execute(
         self,
         statement: Executable,
         parameters: Optional[_CoreAnyExecuteParams] = None,
+        *,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
-    ) -> CursorResult:
+    ) -> CursorResult[Any]:
         r"""Executes a SQL statement construct and returns a
         :class:`_engine.CursorResult`.
 
@@ -1246,7 +1311,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         func: FunctionElement[Any],
         distilled_parameters: _CoreMultiExecuteParams,
         execution_options: _ExecuteOptionsParameter,
-    ) -> CursorResult:
+    ) -> CursorResult[Any]:
         """Execute a sql.FunctionElement object."""
 
         return self._execute_clauseelement(
@@ -1317,7 +1382,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         ddl: ExecutableDDLElement,
         distilled_parameters: _CoreMultiExecuteParams,
         execution_options: _ExecuteOptionsParameter,
-    ) -> CursorResult:
+    ) -> CursorResult[Any]:
         """Execute a schema.DDL object."""
 
         execution_options = ddl._execution_options.merge_with(
@@ -1414,7 +1479,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         elem: Executable,
         distilled_parameters: _CoreMultiExecuteParams,
         execution_options: _ExecuteOptionsParameter,
-    ) -> CursorResult:
+    ) -> CursorResult[Any]:
         """Execute a sql.ClauseElement object."""
 
         execution_options = elem._execution_options.merge_with(
@@ -1487,7 +1552,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         compiled: Compiled,
         distilled_parameters: _CoreMultiExecuteParams,
         execution_options: _ExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS,
-    ) -> CursorResult:
+    ) -> CursorResult[Any]:
         """Execute a sql.Compiled object.
 
         TODO: why do we have this?   likely deprecate or remove
@@ -1537,7 +1602,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         statement: str,
         parameters: Optional[_DBAPIAnyExecuteParams] = None,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
-    ) -> CursorResult:
+    ) -> CursorResult[Any]:
         r"""Executes a SQL statement construct and returns a
         :class:`_engine.CursorResult`.
 
@@ -1614,7 +1679,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         execution_options: _ExecuteOptions,
         *args: Any,
         **kw: Any,
-    ) -> CursorResult:
+    ) -> CursorResult[Any]:
         """Create an :class:`.ExecutionContext` and execute, returning
         a :class:`_engine.CursorResult`."""
 
index ccf5736756d7c289e065852386fd5cf701d9d60d..ff69666b71066f0ecd32e8a04f308083a2976492 100644 (file)
@@ -24,6 +24,7 @@ from typing import Optional
 from typing import Sequence
 from typing import Tuple
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from .result import MergedResult
@@ -55,11 +56,12 @@ if typing.TYPE_CHECKING:
     from .interfaces import ExecutionContext
     from .result import _KeyIndexType
     from .result import _KeyMapRecType
-    from .result import _KeyMapType
     from .result import _KeyType
     from .result import _ProcessorsType
     from ..sql.type_api import _ResultProcessorType
 
+_T = TypeVar("_T", bound=Any)
+
 # metadata entry tuple indexes.
 # using raw tuple is faster than namedtuple.
 MD_INDEX: Literal[0] = 0  # integer index in cursor.description
@@ -214,7 +216,9 @@ class CursorResultMetaData(ResultMetaData):
         return md
 
     def __init__(
-        self, parent: CursorResult, cursor_description: _DBAPICursorDescription
+        self,
+        parent: CursorResult[Any],
+        cursor_description: _DBAPICursorDescription,
     ):
         context = parent.context
         self._tuplefilter = None
@@ -1158,7 +1162,7 @@ class _NoResultMetaData(ResultMetaData):
 _NO_RESULT_METADATA = _NoResultMetaData()
 
 
-class CursorResult(Result):
+class CursorResult(Result[_T]):
     """A Result that is representing state from a DBAPI cursor.
 
     .. versionchanged:: 1.4  The :class:`.CursorResult``
@@ -1179,6 +1183,15 @@ class CursorResult(Result):
 
     """
 
+    __slots__ = (
+        "context",
+        "dialect",
+        "cursor",
+        "cursor_strategy",
+        "_echo",
+        "connection",
+    )
+
     _metadata: Union[CursorResultMetaData, _NoResultMetaData]
     _no_result_metadata = _NO_RESULT_METADATA
     _soft_closed: bool = False
@@ -1231,7 +1244,6 @@ class CursorResult(Result):
                 make_row = _make_row_2
             else:
                 make_row = _make_row
-
             self._set_memoized_attribute("_row_getter", make_row)
 
         else:
@@ -1726,12 +1738,12 @@ class CursorResult(Result):
     def _raw_row_iterator(self):
         return self._fetchiter_impl()
 
-    def merge(self, *others: Result) -> MergedResult:
+    def merge(self, *others: Result[Any]) -> MergedResult[Any]:
         merged_result = super().merge(*others)
         setup_rowcounts = not self._metadata.returns_rows
         if setup_rowcounts:
             merged_result.rowcount = sum(
-                cast(CursorResult, result).rowcount
+                cast("CursorResult[Any]", result).rowcount
                 for result in (self,) + others
             )
         return merged_result
index c6571f68bb92e1b6608e176119476abb48c7a37e..9c6ff758fc6c2e1879d206fbbbaeb1a62390a2b5 100644 (file)
@@ -62,13 +62,9 @@ if typing.TYPE_CHECKING:
 
     from .base import Connection
     from .base import Engine
-    from .characteristics import ConnectionCharacteristic
-    from .interfaces import _AnyMultiExecuteParams
     from .interfaces import _CoreMultiExecuteParams
     from .interfaces import _CoreSingleExecuteParams
-    from .interfaces import _DBAPIAnyExecuteParams
     from .interfaces import _DBAPIMultiExecuteParams
-    from .interfaces import _DBAPISingleExecuteParams
     from .interfaces import _ExecuteOptions
     from .interfaces import _IsolationLevel
     from .interfaces import _MutableCoreSingleExecuteParams
@@ -83,15 +79,11 @@ if typing.TYPE_CHECKING:
     from ..sql.compiler import Compiled
     from ..sql.compiler import Linting
     from ..sql.compiler import ResultColumnsEntry
-    from ..sql.compiler import TypeCompiler
     from ..sql.dml import DMLState
     from ..sql.dml import UpdateBase
     from ..sql.elements import BindParameter
-    from ..sql.roles import ColumnsClauseRole
     from ..sql.schema import Column
-    from ..sql.schema import ColumnDefault
     from ..sql.type_api import _BindProcessorType
-    from ..sql.type_api import _ResultProcessorType
     from ..sql.type_api import TypeEngine
 
 # When we're handed literal SQL, ensure it's a SELECT query
@@ -781,7 +773,7 @@ class DefaultExecutionContext(ExecutionContext):
     result_column_struct: Optional[
         Tuple[List[ResultColumnsEntry], bool, bool, bool]
     ] = None
-    returned_default_rows: Optional[List[Row]] = None
+    returned_default_rows: Optional[Sequence[Row[Any]]] = None
 
     execution_options: _ExecuteOptions = util.EMPTY_DICT
 
@@ -1385,7 +1377,9 @@ class DefaultExecutionContext(ExecutionContext):
         if cursor_description is None:
             strategy = _cursor._NO_CURSOR_DML
 
-        result = _cursor.CursorResult(self, strategy, cursor_description)
+        result: _cursor.CursorResult[Any] = _cursor.CursorResult(
+            self, strategy, cursor_description
+        )
 
         if self.isinsert:
             if self._is_implicit_returning:
index ef10946a86c077c5d5733783c0b412ddd2f1a126..4093d3e0e79112b9b54767299b2ab28dd3d087a5 100644 (file)
@@ -28,7 +28,6 @@ from ..util.typing import Literal
 
 if typing.TYPE_CHECKING:
     from .base import Connection
-    from .interfaces import _CoreAnyExecuteParams
     from .interfaces import _CoreMultiExecuteParams
     from .interfaces import _CoreSingleExecuteParams
     from .interfaces import _DBAPIAnyExecuteParams
@@ -273,7 +272,7 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]):
         multiparams: _CoreMultiExecuteParams,
         params: _CoreSingleExecuteParams,
         execution_options: _ExecuteOptions,
-        result: Result,
+        result: Result[Any],
     ) -> None:
         """Intercept high level execute() events after execute.
 
index 54fe21d747afed32d64a0636748b59d948fe5ce9..6410246039fd157f074cf8b1ee1188621c6ea9bc 100644 (file)
@@ -2422,7 +2422,7 @@ class ExecutionContext:
     def _get_cache_stats(self) -> str:
         raise NotImplementedError()
 
-    def _setup_result_proxy(self) -> CursorResult:
+    def _setup_result_proxy(self) -> CursorResult[Any]:
         raise NotImplementedError()
 
     def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int:
index 71320a583dc92741df71eedf950c4b850aac82be..55d36a1d5bd6a448e074764367b7e8960d46d952 100644 (file)
@@ -28,6 +28,7 @@ from typing import overload
 from typing import Sequence
 from typing import Set
 from typing import Tuple
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
@@ -70,6 +71,8 @@ _RawRowType = Tuple[Any, ...]
 """represents the kind of row we get from a DBAPI cursor"""
 
 _R = TypeVar("_R", bound=_RowData)
+_T = TypeVar("_T", bound=Any)
+_TP = TypeVar("_TP", bound=Tuple[Any, ...])
 
 _InterimRowType = Union[_R, _RawRowType]
 """a catchall "anything" kind of return type that can be applied
@@ -141,7 +144,7 @@ class ResultMetaData:
 
     def _getter(
         self, key: Any, raiseerr: bool = True
-    ) -> Optional[Callable[[Row], Any]]:
+    ) -> Optional[Callable[[Row[Any]], Any]]:
 
         index = self._index_for_key(key, raiseerr)
 
@@ -270,7 +273,7 @@ class SimpleResultMetaData(ResultMetaData):
             _tuplefilter=_tuplefilter,
         )
 
-    def _contains(self, value: Any, row: Row) -> bool:
+    def _contains(self, value: Any, row: Row[Any]) -> bool:
         return value in row._data
 
     def _index_for_key(self, key: Any, raiseerr: bool = True) -> int:
@@ -335,7 +338,7 @@ class SimpleResultMetaData(ResultMetaData):
 
 def result_tuple(
     fields: Sequence[str], extra: Optional[Any] = None
-) -> Callable[[Iterable[Any]], Row]:
+) -> Callable[[Iterable[Any]], Row[Any]]:
     parent = SimpleResultMetaData(fields, extra)
     return functools.partial(
         Row, parent, parent._processors, parent._keymap, Row._default_key_style
@@ -355,7 +358,9 @@ SelfResultInternal = TypeVar("SelfResultInternal", bound="ResultInternal[Any]")
 
 
 class ResultInternal(InPlaceGenerative, Generic[_R]):
-    _real_result: Optional[Result] = None
+    __slots__ = ()
+
+    _real_result: Optional[Result[Any]] = None
     _generate_rows: bool = True
     _row_logging_fn: Optional[Callable[[Any], Any]]
 
@@ -367,20 +372,20 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
 
     _source_supports_scalars: bool
 
-    def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row]]:
+    def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]:
         raise NotImplementedError()
 
     def _fetchone_impl(
         self, hard_close: bool = False
-    ) -> Optional[_InterimRowType[Row]]:
+    ) -> Optional[_InterimRowType[Row[Any]]]:
         raise NotImplementedError()
 
     def _fetchmany_impl(
         self, size: Optional[int] = None
-    ) -> List[_InterimRowType[Row]]:
+    ) -> List[_InterimRowType[Row[Any]]]:
         raise NotImplementedError()
 
-    def _fetchall_impl(self) -> List[_InterimRowType[Row]]:
+    def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]:
         raise NotImplementedError()
 
     def _soft_close(self, hard: bool = False) -> None:
@@ -388,8 +393,10 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
 
     @HasMemoized_ro_memoized_attribute
     def _row_getter(self) -> Optional[Callable[..., _R]]:
-        real_result: Result = (
-            self._real_result if self._real_result else cast(Result, self)
+        real_result: Result[Any] = (
+            self._real_result
+            if self._real_result
+            else cast("Result[Any]", self)
         )
 
         if real_result._source_supports_scalars:
@@ -404,7 +411,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
                     keymap: _KeyMapType,
                     key_style: Any,
                     scalar_obj: Any,
-                ) -> Row:
+                ) -> Row[Any]:
                     return _proc(
                         metadata, processors, keymap, key_style, (scalar_obj,)
                     )
@@ -429,7 +436,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
 
             fixed_tf = tf
 
-            def make_row(row: _InterimRowType[Row]) -> _R:
+            def make_row(row: _InterimRowType[Row[Any]]) -> _R:
                 return _make_row_orig(fixed_tf(row))
 
         else:
@@ -447,7 +454,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
         if fns:
             _make_row = make_row
 
-            def make_row(row: _InterimRowType[Row]) -> _R:
+            def make_row(row: _InterimRowType[Row[Any]]) -> _R:
                 interim_row = _make_row(row)
                 for fn in fns:
                     interim_row = fn(interim_row)
@@ -465,7 +472,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
         if self._unique_filter_state:
             uniques, strategy = self._unique_strategy
 
-            def iterrows(self: Result) -> Iterator[_R]:
+            def iterrows(self: Result[Any]) -> Iterator[_R]:
                 for raw_row in self._fetchiter_impl():
                     obj: _InterimRowType[Any] = (
                         make_row(raw_row) if make_row else raw_row
@@ -480,7 +487,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
 
         else:
 
-            def iterrows(self: Result) -> Iterator[_R]:
+            def iterrows(self: Result[Any]) -> Iterator[_R]:
                 for raw_row in self._fetchiter_impl():
                     row: _InterimRowType[Any] = (
                         make_row(raw_row) if make_row else raw_row
@@ -546,7 +553,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
         if self._unique_filter_state:
             uniques, strategy = self._unique_strategy
 
-            def onerow(self: Result) -> Union[_NoRow, _R]:
+            def onerow(self: Result[Any]) -> Union[_NoRow, _R]:
                 _onerow = self._fetchone_impl
                 while True:
                     row = _onerow()
@@ -567,7 +574,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
 
         else:
 
-            def onerow(self: Result) -> Union[_NoRow, _R]:
+            def onerow(self: Result[Any]) -> Union[_NoRow, _R]:
                 row = self._fetchone_impl()
                 if row is None:
                     return _NO_ROW
@@ -627,7 +634,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
                     real_result = (
                         self._real_result
                         if self._real_result
-                        else cast(Result, self)
+                        else cast("Result[Any]", self)
                     )
                     if real_result._yield_per:
                         num_required = num = real_result._yield_per
@@ -667,7 +674,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
                     real_result = (
                         self._real_result
                         if self._real_result
-                        else cast(Result, self)
+                        else cast("Result[Any]", self)
                     )
                     num = real_result._yield_per
 
@@ -799,7 +806,9 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
         self: SelfResultInternal, indexes: Sequence[_KeyIndexType]
     ) -> SelfResultInternal:
         real_result = (
-            self._real_result if self._real_result else cast(Result, self)
+            self._real_result
+            if self._real_result
+            else cast("Result[Any]", self)
         )
 
         if not real_result._source_supports_scalars or len(indexes) != 1:
@@ -817,7 +826,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
         real_result = (
             self._real_result
             if self._real_result is not None
-            else cast(Result, self)
+            else cast("Result[Any]", self)
         )
 
         if not strategy and self._metadata._unique_filters:
@@ -836,6 +845,8 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
 
 
 class _WithKeys:
+    __slots__ = ()
+
     _metadata: ResultMetaData
 
     # used mainly to share documentation on the keys method.
@@ -859,10 +870,10 @@ class _WithKeys:
         return self._metadata.keys
 
 
-SelfResult = TypeVar("SelfResult", bound="Result")
+SelfResult = TypeVar("SelfResult", bound="Result[Any]")
 
 
-class Result(_WithKeys, ResultInternal[Row]):
+class Result(_WithKeys, ResultInternal[Row[_TP]]):
     """Represent a set of database results.
 
     .. versionadded:: 1.4  The :class:`.Result` object provides a completely
@@ -887,7 +898,9 @@ class Result(_WithKeys, ResultInternal[Row]):
 
     """
 
-    _row_logging_fn: Optional[Callable[[Row], Row]] = None
+    __slots__ = ("_metadata", "__dict__")
+
+    _row_logging_fn: Optional[Callable[[Row[Any]], Row[Any]]] = None
 
     _source_supports_scalars: bool = False
 
@@ -1011,6 +1024,15 @@ class Result(_WithKeys, ResultInternal[Row]):
         appropriate :class:`.ColumnElement` objects which correspond to
         a given statement construct.
 
+        .. versionchanged:: 2.0  Due to a bug in 1.4, the
+           :meth:`.Result.columns` method had an incorrect behavior where
+           calling upon the method with just one index would cause the
+           :class:`.Result` object to yield scalar values rather than
+           :class:`.Row` objects.   In version 2.0, this behavior has been
+           corrected such that calling upon :meth:`.Result.columns` with
+           a single index will produce a :class:`.Result` object that continues
+           to yield :class:`.Row` objects, which include only a single column.
+
         E.g.::
 
             statement = select(table.c.x, table.c.y, table.c.z)
@@ -1040,6 +1062,20 @@ class Result(_WithKeys, ResultInternal[Row]):
         """
         return self._column_slices(col_expressions)
 
+    @overload
+    def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]:
+        ...
+
+    @overload
+    def scalars(
+        self: Result[Tuple[_T]], index: Literal[0]
+    ) -> ScalarResult[_T]:
+        ...
+
+    @overload
+    def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]:
+        ...
+
     def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]:
         """Return a :class:`_result.ScalarResult` filtering object which
         will return single elements rather than :class:`_row.Row` objects.
@@ -1067,7 +1103,7 @@ class Result(_WithKeys, ResultInternal[Row]):
 
     def _getter(
         self, key: _KeyIndexType, raiseerr: bool = True
-    ) -> Optional[Callable[[Row], Any]]:
+    ) -> Optional[Callable[[Row[Any]], Any]]:
         """return a callable that will retrieve the given key from a
         :class:`.Row`.
 
@@ -1105,6 +1141,43 @@ class Result(_WithKeys, ResultInternal[Row]):
 
         return MappingResult(self)
 
+    @property
+    def t(self) -> TupleResult[_TP]:
+        """Apply a "typed tuple" typing filter to returned rows.
+
+        The :attr:`.Result.t` attribute is a synonym for calling the
+        :meth:`.Result.tuples` method.
+
+        .. versionadded:: 2.0
+
+        """
+        return self  # type: ignore
+
+    def tuples(self) -> TupleResult[_TP]:
+        """Apply a "typed tuple" typing filter to returned rows.
+
+        This method returns the same :class:`.Result` object at runtime,
+        however annotates as returning a :class:`.TupleResult` object
+        that will indicate to :pep:`484` typing tools that plain typed
+        ``Tuple`` instances are returned rather than rows.  This allows
+        tuple unpacking and ``__getitem__`` access of :class:`.Row` objects
+        to by typed, for those cases where the statement invoked itself
+        included typing information.
+
+        .. versionadded:: 2.0
+
+        :return: the :class:`_result.TupleResult` type at typing time.
+
+        .. seealso::
+
+            :attr:`.Result.t` - shorter synonym
+
+            :attr:`.Row.t` - :class:`.Row` version
+
+        """
+
+        return self  # type: ignore
+
     def _raw_row_iterator(self) -> Iterator[_RowData]:
         """Return a safe iterator that yields raw row data.
 
@@ -1114,13 +1187,15 @@ class Result(_WithKeys, ResultInternal[Row]):
         """
         raise NotImplementedError()
 
-    def __iter__(self) -> Iterator[Row]:
+    def __iter__(self) -> Iterator[Row[_TP]]:
         return self._iter_impl()
 
-    def __next__(self) -> Row:
+    def __next__(self) -> Row[_TP]:
         return self._next_impl()
 
-    def partitions(self, size: Optional[int] = None) -> Iterator[List[Row]]:
+    def partitions(
+        self, size: Optional[int] = None
+    ) -> Iterator[Sequence[Row[_TP]]]:
         """Iterate through sub-lists of rows of the size given.
 
         Each list will be of the size given, excluding the last list to
@@ -1158,12 +1233,12 @@ class Result(_WithKeys, ResultInternal[Row]):
             else:
                 break
 
-    def fetchall(self) -> List[Row]:
+    def fetchall(self) -> Sequence[Row[_TP]]:
         """A synonym for the :meth:`_engine.Result.all` method."""
 
         return self._allrows()
 
-    def fetchone(self) -> Optional[Row]:
+    def fetchone(self) -> Optional[Row[_TP]]:
         """Fetch one row.
 
         When all rows are exhausted, returns None.
@@ -1185,7 +1260,7 @@ class Result(_WithKeys, ResultInternal[Row]):
         else:
             return row
 
-    def fetchmany(self, size: Optional[int] = None) -> List[Row]:
+    def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]:
         """Fetch many rows.
 
         When all rows are exhausted, returns an empty list.
@@ -1202,7 +1277,7 @@ class Result(_WithKeys, ResultInternal[Row]):
 
         return self._manyrow_getter(self, size)
 
-    def all(self) -> List[Row]:
+    def all(self) -> Sequence[Row[_TP]]:
         """Return all rows in a list.
 
         Closes the result set after invocation.   Subsequent invocations
@@ -1216,7 +1291,7 @@ class Result(_WithKeys, ResultInternal[Row]):
 
         return self._allrows()
 
-    def first(self) -> Optional[Row]:
+    def first(self) -> Optional[Row[_TP]]:
         """Fetch the first row or None if no row is present.
 
         Closes the result set and discards remaining rows.
@@ -1252,7 +1327,7 @@ class Result(_WithKeys, ResultInternal[Row]):
             raise_for_second_row=False, raise_for_none=False, scalar=False
         )
 
-    def one_or_none(self) -> Optional[Row]:
+    def one_or_none(self) -> Optional[Row[_TP]]:
         """Return at most one result or raise an exception.
 
         Returns ``None`` if the result has no rows.
@@ -1276,6 +1351,14 @@ class Result(_WithKeys, ResultInternal[Row]):
             raise_for_second_row=True, raise_for_none=False, scalar=False
         )
 
+    @overload
+    def scalar_one(self: Result[Tuple[_T]]) -> _T:
+        ...
+
+    @overload
+    def scalar_one(self) -> Any:
+        ...
+
     def scalar_one(self) -> Any:
         """Return exactly one scalar result or raise an exception.
 
@@ -1293,6 +1376,14 @@ class Result(_WithKeys, ResultInternal[Row]):
             raise_for_second_row=True, raise_for_none=True, scalar=True
         )
 
+    @overload
+    def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]:
+        ...
+
+    @overload
+    def scalar_one_or_none(self) -> Optional[Any]:
+        ...
+
     def scalar_one_or_none(self) -> Optional[Any]:
         """Return exactly one or no scalar result.
 
@@ -1310,7 +1401,7 @@ class Result(_WithKeys, ResultInternal[Row]):
             raise_for_second_row=True, raise_for_none=False, scalar=True
         )
 
-    def one(self) -> Row:
+    def one(self) -> Row[_TP]:
         """Return exactly one row or raise an exception.
 
         Raises :class:`.NoResultFound` if the result returns no
@@ -1341,6 +1432,14 @@ class Result(_WithKeys, ResultInternal[Row]):
             raise_for_second_row=True, raise_for_none=True, scalar=False
         )
 
+    @overload
+    def scalar(self: Result[Tuple[_T]]) -> Optional[_T]:
+        ...
+
+    @overload
+    def scalar(self) -> Any:
+        ...
+
     def scalar(self) -> Any:
         """Fetch the first column of the first row, and close the result set.
 
@@ -1359,7 +1458,7 @@ class Result(_WithKeys, ResultInternal[Row]):
             raise_for_second_row=False, raise_for_none=False, scalar=True
         )
 
-    def freeze(self) -> FrozenResult:
+    def freeze(self) -> FrozenResult[_TP]:
         """Return a callable object that will produce copies of this
         :class:`.Result` when invoked.
 
@@ -1382,7 +1481,7 @@ class Result(_WithKeys, ResultInternal[Row]):
 
         return FrozenResult(self)
 
-    def merge(self, *others: Result) -> MergedResult:
+    def merge(self, *others: Result[Any]) -> MergedResult[_TP]:
         """Merge this :class:`.Result` with other compatible result
         objects.
 
@@ -1405,9 +1504,17 @@ class FilterResult(ResultInternal[_R]):
 
     """
 
-    _post_creational_filter: Optional[Callable[[Any], Any]] = None
+    __slots__ = (
+        "_real_result",
+        "_post_creational_filter",
+        "_metadata",
+        "_unique_filter_state",
+        "__dict__",
+    )
+
+    _post_creational_filter: Optional[Callable[[Any], Any]]
 
-    _real_result: Result
+    _real_result: Result[Any]
 
     def _soft_close(self, hard: bool = False) -> None:
         self._real_result._soft_close(hard=hard)
@@ -1416,20 +1523,20 @@ class FilterResult(ResultInternal[_R]):
     def _attributes(self) -> Dict[Any, Any]:
         return self._real_result._attributes
 
-    def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row]]:
+    def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]:
         return self._real_result._fetchiter_impl()
 
     def _fetchone_impl(
         self, hard_close: bool = False
-    ) -> Optional[_InterimRowType[Row]]:
+    ) -> Optional[_InterimRowType[Row[Any]]]:
         return self._real_result._fetchone_impl(hard_close=hard_close)
 
-    def _fetchall_impl(self) -> List[_InterimRowType[Row]]:
+    def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]:
         return self._real_result._fetchall_impl()
 
     def _fetchmany_impl(
         self, size: Optional[int] = None
-    ) -> List[_InterimRowType[Row]]:
+    ) -> List[_InterimRowType[Row[Any]]]:
         return self._real_result._fetchmany_impl(size=size)
 
 
@@ -1452,11 +1559,13 @@ class ScalarResult(FilterResult[_R]):
 
     """
 
+    __slots__ = ()
+
     _generate_rows = False
 
     _post_creational_filter: Optional[Callable[[Any], Any]]
 
-    def __init__(self, real_result: Result, index: _KeyIndexType):
+    def __init__(self, real_result: Result[Any], index: _KeyIndexType):
         self._real_result = real_result
 
         if real_result._source_supports_scalars:
@@ -1480,7 +1589,7 @@ class ScalarResult(FilterResult[_R]):
         self._unique_filter_state = (set(), strategy)
         return self
 
-    def partitions(self, size: Optional[int] = None) -> Iterator[List[_R]]:
+    def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_R]]:
         """Iterate through sub-lists of elements of the size given.
 
         Equivalent to :meth:`_result.Result.partitions` except that
@@ -1498,12 +1607,12 @@ class ScalarResult(FilterResult[_R]):
             else:
                 break
 
-    def fetchall(self) -> List[_R]:
+    def fetchall(self) -> Sequence[_R]:
         """A synonym for the :meth:`_engine.ScalarResult.all` method."""
 
         return self._allrows()
 
-    def fetchmany(self, size: Optional[int] = None) -> List[_R]:
+    def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
         """Fetch many objects.
 
         Equivalent to :meth:`_result.Result.fetchmany` except that
@@ -1513,7 +1622,7 @@ class ScalarResult(FilterResult[_R]):
         """
         return self._manyrow_getter(self, size)
 
-    def all(self) -> List[_R]:
+    def all(self) -> Sequence[_R]:
         """Return all scalar values in a list.
 
         Equivalent to :meth:`_result.Result.all` except that
@@ -1567,6 +1676,177 @@ class ScalarResult(FilterResult[_R]):
         )
 
 
+SelfTupleResult = TypeVar("SelfTupleResult", bound="TupleResult[Any]")
+
+
+class TupleResult(FilterResult[_R], util.TypingOnly):
+    """a :class:`.Result` that's typed as returning plain Python tuples
+    instead of rows.
+
+    Since :class:`.Row` acts like a tuple in every way already,
+    this class is a typing only class, regular :class:`.Result` is still
+    used at runtime.
+
+    """
+
+    __slots__ = ()
+
+    if TYPE_CHECKING:
+
+        def partitions(
+            self, size: Optional[int] = None
+        ) -> Iterator[Sequence[_R]]:
+            """Iterate through sub-lists of elements of the size given.
+
+            Equivalent to :meth:`_result.Result.partitions` except that
+            tuple values, rather than :class:`_result.Row` objects,
+            are returned.
+
+            """
+            ...
+
+        def fetchone(self) -> Optional[_R]:
+            """Fetch one tuple.
+
+            Equivalent to :meth:`_result.Result.fetchone` except that
+            tuple values, rather than :class:`_result.Row`
+            objects, are returned.
+
+            """
+            ...
+
+        def fetchall(self) -> Sequence[_R]:
+            """A synonym for the :meth:`_engine.ScalarResult.all` method."""
+            ...
+
+        def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
+            """Fetch many objects.
+
+            Equivalent to :meth:`_result.Result.fetchmany` except that
+            tuple values, rather than :class:`_result.Row` objects,
+            are returned.
+
+            """
+            ...
+
+        def all(self) -> Sequence[_R]:  # noqa: A001
+            """Return all scalar values in a list.
+
+            Equivalent to :meth:`_result.Result.all` except that
+            tuple values, rather than :class:`_result.Row` objects,
+            are returned.
+
+            """
+            ...
+
+        def __iter__(self) -> Iterator[_R]:
+            ...
+
+        def __next__(self) -> _R:
+            ...
+
+        def first(self) -> Optional[_R]:
+            """Fetch the first object or None if no object is present.
+
+            Equivalent to :meth:`_result.Result.first` except that
+            tuple values, rather than :class:`_result.Row` objects,
+            are returned.
+
+
+            """
+            ...
+
+        def one_or_none(self) -> Optional[_R]:
+            """Return at most one object or raise an exception.
+
+            Equivalent to :meth:`_result.Result.one_or_none` except that
+            tuple values, rather than :class:`_result.Row` objects,
+            are returned.
+
+            """
+            ...
+
+        def one(self) -> _R:
+            """Return exactly one object or raise an exception.
+
+            Equivalent to :meth:`_result.Result.one` except that
+            tuple values, rather than :class:`_result.Row` objects,
+            are returned.
+
+            """
+            ...
+
+        @overload
+        def scalar_one(self: TupleResult[Tuple[_T]]) -> _T:
+            ...
+
+        @overload
+        def scalar_one(self) -> Any:
+            ...
+
+        def scalar_one(self) -> Any:
+            """Return exactly one scalar result or raise an exception.
+
+            This is equivalent to calling :meth:`.Result.scalars` and then
+            :meth:`.Result.one`.
+
+            .. seealso::
+
+                :meth:`.Result.one`
+
+                :meth:`.Result.scalars`
+
+            """
+            ...
+
+        @overload
+        def scalar_one_or_none(self: TupleResult[Tuple[_T]]) -> Optional[_T]:
+            ...
+
+        @overload
+        def scalar_one_or_none(self) -> Optional[Any]:
+            ...
+
+        def scalar_one_or_none(self) -> Optional[Any]:
+            """Return exactly one or no scalar result.
+
+            This is equivalent to calling :meth:`.Result.scalars` and then
+            :meth:`.Result.one_or_none`.
+
+            .. seealso::
+
+                :meth:`.Result.one_or_none`
+
+                :meth:`.Result.scalars`
+
+            """
+            ...
+
+        @overload
+        def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]:
+            ...
+
+        @overload
+        def scalar(self) -> Any:
+            ...
+
+        def scalar(self) -> Any:
+            """Fetch the first column of the first row, and close the result set.
+
+            Returns None if there are no rows to fetch.
+
+            No validation is performed to test if additional rows remain.
+
+            After calling this method, the object is fully closed,
+            e.g. the :meth:`_engine.CursorResult.close`
+            method will have been called.
+
+            :return: a Python scalar value , or None if no rows remain.
+
+            """
+            ...
+
+
 SelfMappingResult = TypeVar("SelfMappingResult", bound="MappingResult")
 
 
@@ -1579,11 +1859,13 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
 
     """
 
+    __slots__ = ()
+
     _generate_rows = True
 
     _post_creational_filter = operator.attrgetter("_mapping")
 
-    def __init__(self, result: Result):
+    def __init__(self, result: Result[Any]):
         self._real_result = result
         self._unique_filter_state = result._unique_filter_state
         self._metadata = result._metadata
@@ -1610,7 +1892,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
 
     def partitions(
         self, size: Optional[int] = None
-    ) -> Iterator[List[RowMapping]]:
+    ) -> Iterator[Sequence[RowMapping]]:
         """Iterate through sub-lists of elements of the size given.
 
         Equivalent to :meth:`_result.Result.partitions` except that
@@ -1628,7 +1910,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
             else:
                 break
 
-    def fetchall(self) -> List[RowMapping]:
+    def fetchall(self) -> Sequence[RowMapping]:
         """A synonym for the :meth:`_engine.MappingResult.all` method."""
 
         return self._allrows()
@@ -1648,7 +1930,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
         else:
             return row
 
-    def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]:
+    def fetchmany(self, size: Optional[int] = None) -> Sequence[RowMapping]:
         """Fetch many objects.
 
         Equivalent to :meth:`_result.Result.fetchmany` except that
@@ -1659,7 +1941,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
 
         return self._manyrow_getter(self, size)
 
-    def all(self) -> List[RowMapping]:
+    def all(self) -> Sequence[RowMapping]:
         """Return all scalar values in a list.
 
         Equivalent to :meth:`_result.Result.all` except that
@@ -1714,7 +1996,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
         )
 
 
-class FrozenResult:
+class FrozenResult(Generic[_TP]):
     """Represents a :class:`.Result` object in a "frozen" state suitable
     for caching.
 
@@ -1755,7 +2037,7 @@ class FrozenResult:
 
     data: Sequence[Any]
 
-    def __init__(self, result: Result):
+    def __init__(self, result: Result[_TP]):
         self.metadata = result._metadata._for_freeze()
         self._source_supports_scalars = result._source_supports_scalars
         self._attributes = result._attributes
@@ -1771,7 +2053,9 @@ class FrozenResult:
         else:
             return [list(row) for row in self.data]
 
-    def with_new_rows(self, tuple_data: Sequence[Row]) -> FrozenResult:
+    def with_new_rows(
+        self, tuple_data: Sequence[Row[_TP]]
+    ) -> FrozenResult[_TP]:
         fr = FrozenResult.__new__(FrozenResult)
         fr.metadata = self.metadata
         fr._attributes = self._attributes
@@ -1783,14 +2067,16 @@ class FrozenResult:
             fr.data = tuple_data
         return fr
 
-    def __call__(self) -> Result:
-        result = IteratorResult(self.metadata, iter(self.data))
+    def __call__(self) -> Result[_TP]:
+        result: IteratorResult[_TP] = IteratorResult(
+            self.metadata, iter(self.data)
+        )
         result._attributes = self._attributes
         result._source_supports_scalars = self._source_supports_scalars
         return result
 
 
-class IteratorResult(Result):
+class IteratorResult(Result[_TP]):
     """A :class:`.Result` that gets data from a Python iterator of
     :class:`.Row` objects or similar row-like data.
 
@@ -1833,7 +2119,7 @@ class IteratorResult(Result):
 
     def _fetchone_impl(
         self, hard_close: bool = False
-    ) -> Optional[_InterimRowType[Row]]:
+    ) -> Optional[_InterimRowType[Row[Any]]]:
         if self._hard_closed:
             self._raise_hard_closed()
 
@@ -1844,7 +2130,7 @@ class IteratorResult(Result):
         else:
             return row
 
-    def _fetchall_impl(self) -> List[_InterimRowType[Row]]:
+    def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]:
         if self._hard_closed:
             self._raise_hard_closed()
         try:
@@ -1854,23 +2140,23 @@ class IteratorResult(Result):
 
     def _fetchmany_impl(
         self, size: Optional[int] = None
-    ) -> List[_InterimRowType[Row]]:
+    ) -> List[_InterimRowType[Row[Any]]]:
         if self._hard_closed:
             self._raise_hard_closed()
 
         return list(itertools.islice(self.iterator, 0, size))
 
 
-def null_result() -> IteratorResult:
+def null_result() -> IteratorResult[Any]:
     return IteratorResult(SimpleResultMetaData([]), iter([]))
 
 
 SelfChunkedIteratorResult = TypeVar(
-    "SelfChunkedIteratorResult", bound="ChunkedIteratorResult"
+    "SelfChunkedIteratorResult", bound="ChunkedIteratorResult[Any]"
 )
 
 
-class ChunkedIteratorResult(IteratorResult):
+class ChunkedIteratorResult(IteratorResult[_TP]):
     """An :class:`.IteratorResult` that works from an iterator-producing callable.
 
     The given ``chunks`` argument is a function that is given a number of rows
@@ -1922,13 +2208,13 @@ class ChunkedIteratorResult(IteratorResult):
 
     def _fetchmany_impl(
         self, size: Optional[int] = None
-    ) -> List[_InterimRowType[Row]]:
+    ) -> List[_InterimRowType[Row[Any]]]:
         if self.dynamic_yield_per:
             self.iterator = itertools.chain.from_iterable(self.chunks(size))
         return super()._fetchmany_impl(size=size)
 
 
-class MergedResult(IteratorResult):
+class MergedResult(IteratorResult[_TP]):
     """A :class:`_engine.Result` that is merged from any number of
     :class:`_engine.Result` objects.
 
@@ -1942,7 +2228,7 @@ class MergedResult(IteratorResult):
     rowcount: Optional[int]
 
     def __init__(
-        self, cursor_metadata: ResultMetaData, results: Sequence[Result]
+        self, cursor_metadata: ResultMetaData, results: Sequence[Result[_TP]]
     ):
         self._results = results
         super(MergedResult, self).__init__(
index 4ba39b55d6700f1e08b94f133a4aa34dc2edf478..7c9eacb78c11108cf77a36526850191f80a670a1 100644 (file)
@@ -16,6 +16,7 @@ import typing
 from typing import Any
 from typing import Callable
 from typing import Dict
+from typing import Generic
 from typing import Iterator
 from typing import List
 from typing import Mapping
@@ -24,12 +25,14 @@ from typing import Optional
 from typing import overload
 from typing import Sequence
 from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from ..sql import util as sql_util
 from ..util._has_cy import HAS_CYEXTENSION
 
-if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
+if TYPE_CHECKING or not HAS_CYEXTENSION:
     from ._py_row import BaseRow as BaseRow
     from ._py_row import KEY_INTEGER_ONLY
     from ._py_row import KEY_OBJECTS_ONLY
@@ -38,13 +41,16 @@ else:
     from sqlalchemy.cyextension.resultproxy import KEY_INTEGER_ONLY
     from sqlalchemy.cyextension.resultproxy import KEY_OBJECTS_ONLY
 
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
     from .result import _KeyType
     from .result import RMKeyView
     from ..sql.type_api import _ResultProcessorType
 
+_T = TypeVar("_T", bound=Any)
+_TP = TypeVar("_TP", bound=Tuple[Any, ...])
 
-class Row(BaseRow, typing.Sequence[Any]):
+
+class Row(BaseRow, Sequence[Any], Generic[_TP]):
     """Represent a single result row.
 
     The :class:`.Row` object represents a row of a database result.  It is
@@ -82,6 +88,37 @@ class Row(BaseRow, typing.Sequence[Any]):
     def __delattr__(self, name: str) -> NoReturn:
         raise AttributeError("can't delete attribute")
 
+    def tuple(self) -> _TP:
+        """Return a 'tuple' form of this :class:`.Row`.
+
+        At runtime, this method returns "self"; the :class:`.Row` object is
+        already a named tuple. However, at the typing level, if this
+        :class:`.Row` is typed, the "tuple" return type will be a :pep:`484`
+        ``Tuple`` datatype that contains typing information about individual
+        elements, supporting typed unpacking and attribute access.
+
+        .. versionadded:: 2.0
+
+        .. seealso::
+
+            :meth:`.Result.tuples`
+
+        """
+        return self  # type: ignore
+
+    @property
+    def t(self) -> _TP:
+        """a synonym for :attr:`.Row.tuple`
+
+        .. versionadded:: 2.0
+
+        .. seealso::
+
+            :meth:`.Result.t`
+
+        """
+        return self  # type: ignore
+
     @property
     def _mapping(self) -> RowMapping:
         """Return a :class:`.RowMapping` for this :class:`.Row`.
@@ -107,7 +144,7 @@ class Row(BaseRow, typing.Sequence[Any]):
 
     def _filter_on_values(
         self, filters: Optional[Sequence[Optional[_ResultProcessorType[Any]]]]
-    ) -> Row:
+    ) -> Row[Any]:
         return Row(
             self._parent,
             filters,
@@ -116,7 +153,7 @@ class Row(BaseRow, typing.Sequence[Any]):
             self._data,
         )
 
-    if not typing.TYPE_CHECKING:
+    if not TYPE_CHECKING:
 
         def _special_name_accessor(name: str) -> Any:
             """Handle ambiguous names such as "count" and "index" """
@@ -151,7 +188,7 @@ class Row(BaseRow, typing.Sequence[Any]):
 
     __hash__ = BaseRow.__hash__
 
-    if typing.TYPE_CHECKING:
+    if TYPE_CHECKING:
 
         @overload
         def __getitem__(self, index: int) -> Any:
@@ -299,7 +336,7 @@ class RowMapping(BaseRow, typing.Mapping[str, Any]):
 
     _default_key_style = KEY_OBJECTS_ONLY
 
-    if typing.TYPE_CHECKING:
+    if TYPE_CHECKING:
 
         def __getitem__(self, key: _KeyType) -> Any:
             ...
index fb05f512e41713c403af609548acc672c34bdb91..95549ada69f9a9ac48de14c5e17c65bfd04a0a9f 100644 (file)
@@ -12,8 +12,10 @@ from typing import Generator
 from typing import NoReturn
 from typing import Optional
 from typing import overload
+from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from . import exc as async_exc
@@ -50,6 +52,9 @@ if TYPE_CHECKING:
     from ...pool import PoolProxiedConnection
     from ...sql._typing import _InfoType
     from ...sql.base import Executable
+    from ...sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
 
 
 class _SyncConnectionCallable(Protocol):
@@ -407,7 +412,7 @@ class AsyncConnection(
         statement: str,
         parameters: Optional[_DBAPIAnyExecuteParams] = None,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
-    ) -> CursorResult:
+    ) -> CursorResult[Any]:
         r"""Executes a driver-level SQL string and return buffered
         :class:`_engine.Result`.
 
@@ -423,12 +428,33 @@ class AsyncConnection(
 
         return await _ensure_sync_result(result, self.exec_driver_sql)
 
+    @overload
+    async def stream(
+        self,
+        statement: TypedReturnsRows[_T],
+        parameters: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> AsyncResult[_T]:
+        ...
+
+    @overload
     async def stream(
         self,
         statement: Executable,
         parameters: Optional[_CoreAnyExecuteParams] = None,
+        *,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
-    ) -> AsyncResult:
+    ) -> AsyncResult[Any]:
+        ...
+
+    async def stream(
+        self,
+        statement: Executable,
+        parameters: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> AsyncResult[Any]:
         """Execute a statement and return a streaming
         :class:`_asyncio.AsyncResult` object."""
 
@@ -436,7 +462,7 @@ class AsyncConnection(
             self._proxied.execute,
             statement,
             parameters,
-            util.EMPTY_DICT.merge_with(
+            execution_options=util.EMPTY_DICT.merge_with(
                 execution_options, {"stream_results": True}
             ),
             _require_await=True,
@@ -446,12 +472,33 @@ class AsyncConnection(
             assert False, "server side result expected"
         return AsyncResult(result)
 
+    @overload
+    async def execute(
+        self,
+        statement: TypedReturnsRows[_T],
+        parameters: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> CursorResult[_T]:
+        ...
+
+    @overload
     async def execute(
         self,
         statement: Executable,
         parameters: Optional[_CoreAnyExecuteParams] = None,
+        *,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
-    ) -> CursorResult:
+    ) -> CursorResult[Any]:
+        ...
+
+    async def execute(
+        self,
+        statement: Executable,
+        parameters: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> CursorResult[Any]:
         r"""Executes a SQL statement construct and return a buffered
         :class:`_engine.Result`.
 
@@ -487,15 +534,36 @@ class AsyncConnection(
             self._proxied.execute,
             statement,
             parameters,
-            execution_options,
+            execution_options=execution_options,
             _require_await=True,
         )
         return await _ensure_sync_result(result, self.execute)
 
+    @overload
+    async def scalar(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> Optional[_T]:
+        ...
+
+    @overload
     async def scalar(
         self,
         statement: Executable,
         parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> Any:
+        ...
+
+    async def scalar(
+        self,
+        statement: Executable,
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
     ) -> Any:
         r"""Executes a SQL statement construct and returns a scalar object.
@@ -508,13 +576,36 @@ class AsyncConnection(
          first row returned.
 
         """
-        result = await self.execute(statement, parameters, execution_options)
+        result = await self.execute(
+            statement, parameters, execution_options=execution_options
+        )
         return result.scalar()
 
+    @overload
+    async def scalars(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> ScalarResult[_T]:
+        ...
+
+    @overload
+    async def scalars(
+        self,
+        statement: Executable,
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> ScalarResult[Any]:
+        ...
+
     async def scalars(
         self,
         statement: Executable,
         parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
     ) -> ScalarResult[Any]:
         r"""Executes a SQL statement construct and returns a scalar objects.
@@ -528,13 +619,36 @@ class AsyncConnection(
         .. versionadded:: 1.4.24
 
         """
-        result = await self.execute(statement, parameters, execution_options)
+        result = await self.execute(
+            statement, parameters, execution_options=execution_options
+        )
         return result.scalars()
 
+    @overload
+    async def stream_scalars(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> AsyncScalarResult[_T]:
+        ...
+
+    @overload
     async def stream_scalars(
         self,
         statement: Executable,
         parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> AsyncScalarResult[Any]:
+        ...
+
+    async def stream_scalars(
+        self,
+        statement: Executable,
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        *,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
     ) -> AsyncScalarResult[Any]:
         r"""Executes a SQL statement and returns a streaming scalar result
@@ -549,7 +663,9 @@ class AsyncConnection(
         .. versionadded:: 1.4.24
 
         """
-        result = await self.stream(statement, parameters, execution_options)
+        result = await self.stream(
+            statement, parameters, execution_options=execution_options
+        )
         return result.scalars()
 
     async def run_sync(
index d0337554cf981715ead6aab22d4fb861604f8831..ff3dcf41744cee54234fe14b0476fd76ab8c0636 100644 (file)
@@ -9,12 +9,15 @@ from __future__ import annotations
 import operator
 from typing import Any
 from typing import AsyncIterator
-from typing import List
 from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Tuple
 from typing import TYPE_CHECKING
 from typing import TypeVar
 
 from . import exc as async_exc
+from ... import util
 from ...engine.result import _NO_ROW
 from ...engine.result import _R
 from ...engine.result import FilterResult
@@ -24,6 +27,7 @@ from ...engine.result import ResultMetaData
 from ...engine.row import Row
 from ...engine.row import RowMapping
 from ...util.concurrency import greenlet_spawn
+from ...util.typing import Literal
 
 if TYPE_CHECKING:
     from ...engine import CursorResult
@@ -32,9 +36,14 @@ if TYPE_CHECKING:
     from ...engine.result import _UniqueFilterType
     from ...engine.result import RMKeyView
 
+_T = TypeVar("_T", bound=Any)
+_TP = TypeVar("_TP", bound=Tuple[Any, ...])
+
 
 class AsyncCommon(FilterResult[_R]):
-    _real_result: Result
+    __slots__ = ()
+
+    _real_result: Result[Any]
     _metadata: ResultMetaData
 
     async def close(self) -> None:
@@ -43,10 +52,10 @@ class AsyncCommon(FilterResult[_R]):
         await greenlet_spawn(self._real_result.close)
 
 
-SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult")
+SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult[Any]")
 
 
-class AsyncResult(AsyncCommon[Row]):
+class AsyncResult(AsyncCommon[Row[_TP]]):
     """An asyncio wrapper around a :class:`_result.Result` object.
 
     The :class:`_asyncio.AsyncResult` only applies to statement executions that
@@ -67,11 +76,16 @@ class AsyncResult(AsyncCommon[Row]):
 
     """
 
-    def __init__(self, real_result: Result):
+    __slots__ = ()
+
+    _real_result: Result[_TP]
+
+    def __init__(self, real_result: Result[_TP]):
         self._real_result = real_result
 
         self._metadata = real_result._metadata
         self._unique_filter_state = real_result._unique_filter_state
+        self._post_creational_filter = None
 
         # BaseCursorResult pre-generates the "_row_getter".  Use that
         # if available rather than building a second one
@@ -80,6 +94,43 @@ class AsyncResult(AsyncCommon[Row]):
                 "_row_getter", real_result.__dict__["_row_getter"]
             )
 
+    @property
+    def t(self) -> AsyncTupleResult[_TP]:
+        """Apply a "typed tuple" typing filter to returned rows.
+
+        The :attr:`.AsyncResult.t` attribute is a synonym for calling the
+        :meth:`.AsyncResult.tuples` method.
+
+        .. versionadded:: 2.0
+
+        """
+        return self  # type: ignore
+
+    def tuples(self) -> AsyncTupleResult[_TP]:
+        """Apply a "typed tuple" typing filter to returned rows.
+
+        This method returns the same :class:`.AsyncResult` object at runtime,
+        however annotates as returning a :class:`.AsyncTupleResult` object
+        that will indicate to :pep:`484` typing tools that plain typed
+        ``Tuple`` instances are returned rather than rows.  This allows
+        tuple unpacking and ``__getitem__`` access of :class:`.Row` objects
+        to by typed, for those cases where the statement invoked itself
+        included typing information.
+
+        .. versionadded:: 2.0
+
+        :return: the :class:`_result.AsyncTupleResult` type at typing time.
+
+        .. seealso::
+
+            :attr:`.AsyncResult.t` - shorter synonym
+
+            :attr:`.Row.t` - :class:`.Row` version
+
+        """
+
+        return self  # type: ignore
+
     def keys(self) -> RMKeyView:
         """Return the :meth:`_engine.Result.keys` collection from the
         underlying :class:`_engine.Result`.
@@ -115,7 +166,7 @@ class AsyncResult(AsyncCommon[Row]):
 
     async def partitions(
         self, size: Optional[int] = None
-    ) -> AsyncIterator[List[Row]]:
+    ) -> AsyncIterator[Sequence[Row[_TP]]]:
         """Iterate through sub-lists of rows of the size given.
 
         An async iterator is returned::
@@ -141,7 +192,16 @@ class AsyncResult(AsyncCommon[Row]):
             else:
                 break
 
-    async def fetchone(self) -> Optional[Row]:
+    async def fetchall(self) -> Sequence[Row[_TP]]:
+        """A synonym for the :meth:`.AsyncResult.all` method.
+
+        .. versionadded:: 2.0
+
+        """
+
+        return await greenlet_spawn(self._allrows)
+
+    async def fetchone(self) -> Optional[Row[_TP]]:
         """Fetch one row.
 
         When all rows are exhausted, returns None.
@@ -163,7 +223,9 @@ class AsyncResult(AsyncCommon[Row]):
         else:
             return row
 
-    async def fetchmany(self, size: Optional[int] = None) -> List[Row]:
+    async def fetchmany(
+        self, size: Optional[int] = None
+    ) -> Sequence[Row[_TP]]:
         """Fetch many rows.
 
         When all rows are exhausted, returns an empty list.
@@ -184,7 +246,7 @@ class AsyncResult(AsyncCommon[Row]):
 
         return await greenlet_spawn(self._manyrow_getter, self, size)
 
-    async def all(self) -> List[Row]:
+    async def all(self) -> Sequence[Row[_TP]]:
         """Return all rows in a list.
 
         Closes the result set after invocation.   Subsequent invocations
@@ -196,17 +258,17 @@ class AsyncResult(AsyncCommon[Row]):
 
         return await greenlet_spawn(self._allrows)
 
-    def __aiter__(self) -> AsyncResult:
+    def __aiter__(self) -> AsyncResult[_TP]:
         return self
 
-    async def __anext__(self) -> Row:
+    async def __anext__(self) -> Row[_TP]:
         row = await greenlet_spawn(self._onerow_getter, self)
         if row is _NO_ROW:
             raise StopAsyncIteration()
         else:
             return row
 
-    async def first(self) -> Optional[Row]:
+    async def first(self) -> Optional[Row[_TP]]:
         """Fetch the first row or None if no row is present.
 
         Closes the result set and discards remaining rows.
@@ -229,7 +291,7 @@ class AsyncResult(AsyncCommon[Row]):
         """
         return await greenlet_spawn(self._only_one_row, False, False, False)
 
-    async def one_or_none(self) -> Optional[Row]:
+    async def one_or_none(self) -> Optional[Row[_TP]]:
         """Return at most one result or raise an exception.
 
         Returns ``None`` if the result has no rows.
@@ -251,6 +313,14 @@ class AsyncResult(AsyncCommon[Row]):
         """
         return await greenlet_spawn(self._only_one_row, True, False, False)
 
+    @overload
+    async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T:
+        ...
+
+    @overload
+    async def scalar_one(self) -> Any:
+        ...
+
     async def scalar_one(self) -> Any:
         """Return exactly one scalar result or raise an exception.
 
@@ -266,6 +336,16 @@ class AsyncResult(AsyncCommon[Row]):
         """
         return await greenlet_spawn(self._only_one_row, True, True, True)
 
+    @overload
+    async def scalar_one_or_none(
+        self: AsyncResult[Tuple[_T]],
+    ) -> Optional[_T]:
+        ...
+
+    @overload
+    async def scalar_one_or_none(self) -> Optional[Any]:
+        ...
+
     async def scalar_one_or_none(self) -> Optional[Any]:
         """Return exactly one or no scalar result.
 
@@ -281,7 +361,7 @@ class AsyncResult(AsyncCommon[Row]):
         """
         return await greenlet_spawn(self._only_one_row, True, False, True)
 
-    async def one(self) -> Row:
+    async def one(self) -> Row[_TP]:
         """Return exactly one row or raise an exception.
 
         Raises :class:`.NoResultFound` if the result returns no
@@ -312,6 +392,14 @@ class AsyncResult(AsyncCommon[Row]):
         """
         return await greenlet_spawn(self._only_one_row, True, True, False)
 
+    @overload
+    async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]:
+        ...
+
+    @overload
+    async def scalar(self) -> Any:
+        ...
+
     async def scalar(self) -> Any:
         """Fetch the first column of the first row, and close the result set.
 
@@ -328,7 +416,7 @@ class AsyncResult(AsyncCommon[Row]):
         """
         return await greenlet_spawn(self._only_one_row, False, False, True)
 
-    async def freeze(self) -> FrozenResult:
+    async def freeze(self) -> FrozenResult[_TP]:
         """Return a callable object that will produce copies of this
         :class:`_asyncio.AsyncResult` when invoked.
 
@@ -351,7 +439,7 @@ class AsyncResult(AsyncCommon[Row]):
 
         return await greenlet_spawn(FrozenResult, self)
 
-    def merge(self, *others: AsyncResult) -> MergedResult:
+    def merge(self, *others: AsyncResult[_TP]) -> MergedResult[_TP]:
         """Merge this :class:`_asyncio.AsyncResult` with other compatible result
         objects.
 
@@ -370,6 +458,20 @@ class AsyncResult(AsyncCommon[Row]):
             (self._real_result,) + tuple(o._real_result for o in others),
         )
 
+    @overload
+    def scalars(
+        self: AsyncResult[Tuple[_T]], index: Literal[0]
+    ) -> AsyncScalarResult[_T]:
+        ...
+
+    @overload
+    def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]:
+        ...
+
+    @overload
+    def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
+        ...
+
     def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
         """Return an :class:`_asyncio.AsyncScalarResult` filtering object which
         will return single elements rather than :class:`_row.Row` objects.
@@ -423,9 +525,11 @@ class AsyncScalarResult(AsyncCommon[_R]):
 
     """
 
+    __slots__ = ()
+
     _generate_rows = False
 
-    def __init__(self, real_result: Result, index: _KeyIndexType):
+    def __init__(self, real_result: Result[Any], index: _KeyIndexType):
         self._real_result = real_result
 
         if real_result._source_supports_scalars:
@@ -452,7 +556,7 @@ class AsyncScalarResult(AsyncCommon[_R]):
 
     async def partitions(
         self, size: Optional[int] = None
-    ) -> AsyncIterator[List[_R]]:
+    ) -> AsyncIterator[Sequence[_R]]:
         """Iterate through sub-lists of elements of the size given.
 
         Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
@@ -470,12 +574,12 @@ class AsyncScalarResult(AsyncCommon[_R]):
             else:
                 break
 
-    async def fetchall(self) -> List[_R]:
+    async def fetchall(self) -> Sequence[_R]:
         """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
 
         return await greenlet_spawn(self._allrows)
 
-    async def fetchmany(self, size: Optional[int] = None) -> List[_R]:
+    async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
         """Fetch many objects.
 
         Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
@@ -485,7 +589,7 @@ class AsyncScalarResult(AsyncCommon[_R]):
         """
         return await greenlet_spawn(self._manyrow_getter, self, size)
 
-    async def all(self) -> List[_R]:
+    async def all(self) -> Sequence[_R]:
         """Return all scalar values in a list.
 
         Equivalent to :meth:`_asyncio.AsyncResult.all` except that
@@ -555,11 +659,13 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
 
     """
 
+    __slots__ = ()
+
     _generate_rows = True
 
     _post_creational_filter = operator.attrgetter("_mapping")
 
-    def __init__(self, result: Result):
+    def __init__(self, result: Result[Any]):
         self._real_result = result
         self._unique_filter_state = result._unique_filter_state
         self._metadata = result._metadata
@@ -602,7 +708,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
 
     async def partitions(
         self, size: Optional[int] = None
-    ) -> AsyncIterator[List[RowMapping]]:
+    ) -> AsyncIterator[Sequence[RowMapping]]:
 
         """Iterate through sub-lists of elements of the size given.
 
@@ -621,7 +727,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
             else:
                 break
 
-    async def fetchall(self) -> List[RowMapping]:
+    async def fetchall(self) -> Sequence[RowMapping]:
         """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
 
         return await greenlet_spawn(self._allrows)
@@ -641,7 +747,9 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
         else:
             return row
 
-    async def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]:
+    async def fetchmany(
+        self, size: Optional[int] = None
+    ) -> Sequence[RowMapping]:
         """Fetch many rows.
 
         Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
@@ -652,7 +760,7 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
 
         return await greenlet_spawn(self._manyrow_getter, self, size)
 
-    async def all(self) -> List[RowMapping]:
+    async def all(self) -> Sequence[RowMapping]:
         """Return all rows in a list.
 
         Equivalent to :meth:`_asyncio.AsyncResult.all` except that
@@ -705,11 +813,186 @@ class AsyncMappingResult(AsyncCommon[RowMapping]):
         return await greenlet_spawn(self._only_one_row, True, True, False)
 
 
-_RT = TypeVar("_RT", bound="Result")
+SelfAsyncTupleResult = TypeVar(
+    "SelfAsyncTupleResult", bound="AsyncTupleResult[Any]"
+)
+
+
+class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
+    """a :class:`.AsyncResult` that's typed as returning plain Python tuples
+    instead of rows.
+
+    Since :class:`.Row` acts like a tuple in every way already,
+    this class is a typing only class, regular :class:`.AsyncResult` is
+    still used at runtime.
+
+    """
+
+    __slots__ = ()
+
+    if TYPE_CHECKING:
+
+        async def partitions(
+            self, size: Optional[int] = None
+        ) -> AsyncIterator[Sequence[_R]]:
+            """Iterate through sub-lists of elements of the size given.
+
+            Equivalent to :meth:`_result.Result.partitions` except that
+            tuple values, rather than :class:`_result.Row` objects,
+            are returned.
+
+            """
+            ...
+
+        async def fetchone(self) -> Optional[_R]:
+            """Fetch one tuple.
+
+            Equivalent to :meth:`_result.Result.fetchone` except that
+            tuple values, rather than :class:`_result.Row`
+            objects, are returned.
+
+            """
+            ...
+
+        async def fetchall(self) -> Sequence[_R]:
+            """A synonym for the :meth:`_engine.ScalarResult.all` method."""
+            ...
+
+        async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
+            """Fetch many objects.
+
+            Equivalent to :meth:`_result.Result.fetchmany` except that
+            tuple values, rather than :class:`_result.Row` objects,
+            are returned.
+
+            """
+            ...
+
+        async def all(self) -> Sequence[_R]:  # noqa: A001
+            """Return all scalar values in a list.
+
+            Equivalent to :meth:`_result.Result.all` except that
+            tuple values, rather than :class:`_result.Row` objects,
+            are returned.
+
+            """
+            ...
+
+        async def __aiter__(self) -> AsyncIterator[_R]:
+            ...
+
+        async def __anext__(self) -> _R:
+            ...
+
+        async def first(self) -> Optional[_R]:
+            """Fetch the first object or None if no object is present.
+
+            Equivalent to :meth:`_result.Result.first` except that
+            tuple values, rather than :class:`_result.Row` objects,
+            are returned.
+
+
+            """
+            ...
+
+        async def one_or_none(self) -> Optional[_R]:
+            """Return at most one object or raise an exception.
+
+            Equivalent to :meth:`_result.Result.one_or_none` except that
+            tuple values, rather than :class:`_result.Row` objects,
+            are returned.
+
+            """
+            ...
+
+        async def one(self) -> _R:
+            """Return exactly one object or raise an exception.
+
+            Equivalent to :meth:`_result.Result.one` except that
+            tuple values, rather than :class:`_result.Row` objects,
+            are returned.
+
+            """
+            ...
+
+        @overload
+        async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T:
+            ...
+
+        @overload
+        async def scalar_one(self) -> Any:
+            ...
+
+        async def scalar_one(self) -> Any:
+            """Return exactly one scalar result or raise an exception.
+
+            This is equivalent to calling :meth:`.Result.scalars` and then
+            :meth:`.Result.one`.
+
+            .. seealso::
+
+                :meth:`.Result.one`
+
+                :meth:`.Result.scalars`
+
+            """
+            ...
+
+        @overload
+        async def scalar_one_or_none(
+            self: AsyncTupleResult[Tuple[_T]],
+        ) -> Optional[_T]:
+            ...
+
+        @overload
+        async def scalar_one_or_none(self) -> Optional[Any]:
+            ...
+
+        async def scalar_one_or_none(self) -> Optional[Any]:
+            """Return exactly one or no scalar result.
+
+            This is equivalent to calling :meth:`.Result.scalars` and then
+            :meth:`.Result.one_or_none`.
+
+            .. seealso::
+
+                :meth:`.Result.one_or_none`
+
+                :meth:`.Result.scalars`
+
+            """
+            ...
+
+        @overload
+        async def scalar(self: AsyncTupleResult[Tuple[_T]]) -> Optional[_T]:
+            ...
+
+        @overload
+        async def scalar(self) -> Any:
+            ...
+
+        async def scalar(self) -> Any:
+            """Fetch the first column of the first row, and close the result set.
+
+            Returns None if there are no rows to fetch.
+
+            No validation is performed to test if additional rows remain.
+
+            After calling this method, the object is fully closed,
+            e.g. the :meth:`_engine.CursorResult.close`
+            method will have been called.
+
+            :return: a Python scalar value , or None if no rows remain.
+
+            """
+            ...
+
+
+_RT = TypeVar("_RT", bound="Result[Any]")
 
 
 async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT:
-    cursor_result: CursorResult
+    cursor_result: CursorResult[Any]
 
     try:
         is_cursor = result._is_cursor
index c7a6e2ca0157866c5afcd2c3789695acd9139a2e..22a060a0d42f2ef9e201e166609339b87ac5d5f7 100644 (file)
@@ -12,10 +12,12 @@ from typing import Callable
 from typing import Iterable
 from typing import Iterator
 from typing import Optional
+from typing import overload
 from typing import Sequence
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from .session import async_sessionmaker
@@ -37,9 +39,9 @@ if TYPE_CHECKING:
     from ...engine import Engine
     from ...engine import Result
     from ...engine import Row
+    from ...engine import RowMapping
     from ...engine.interfaces import _CoreAnyExecuteParams
     from ...engine.interfaces import _CoreSingleExecuteParams
-    from ...engine.interfaces import _ExecuteOptions
     from ...engine.interfaces import _ExecuteOptionsParameter
     from ...engine.result import ScalarResult
     from ...orm._typing import _IdentityKeyType
@@ -52,6 +54,9 @@ if TYPE_CHECKING:
     from ...sql.base import Executable
     from ...sql.elements import ClauseElement
     from ...sql.selectable import ForUpdateArg
+    from ...sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
 
 
 @create_proxy_methods(
@@ -480,6 +485,32 @@ class async_scoped_session:
 
         return await self._proxied.delete(instance)
 
+    @overload
+    async def execute(
+        self,
+        statement: TypedReturnsRows[_T],
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> Result[_T]:
+        ...
+
+    @overload
+    async def execute(
+        self,
+        statement: Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> Result[Any]:
+        ...
+
     async def execute(
         self,
         statement: Executable,
@@ -488,7 +519,7 @@ class async_scoped_session:
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
         **kw: Any,
-    ) -> Result:
+    ) -> Result[Any]:
         r"""Execute a statement and return a buffered
         :class:`_engine.Result` object.
 
@@ -916,6 +947,30 @@ class async_scoped_session:
 
         return await self._proxied.rollback()
 
+    @overload
+    async def scalar(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> Optional[_T]:
+        ...
+
+    @overload
+    async def scalar(
+        self,
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> Any:
+        ...
+
     async def scalar(
         self,
         statement: Executable,
@@ -947,6 +1002,30 @@ class async_scoped_session:
             **kw,
         )
 
+    @overload
+    async def scalars(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> ScalarResult[_T]:
+        ...
+
+    @overload
+    async def scalars(
+        self,
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> ScalarResult[Any]:
+        ...
+
     async def scalars(
         self,
         statement: Executable,
@@ -984,6 +1063,19 @@ class async_scoped_session:
             **kw,
         )
 
+    @overload
+    async def stream(
+        self,
+        statement: TypedReturnsRows[_T],
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> AsyncResult[_T]:
+        ...
+
+    @overload
     async def stream(
         self,
         statement: Executable,
@@ -992,7 +1084,18 @@ class async_scoped_session:
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
         **kw: Any,
-    ) -> AsyncResult:
+    ) -> AsyncResult[Any]:
+        ...
+
+    async def stream(
+        self,
+        statement: Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> AsyncResult[Any]:
         r"""Execute a statement and return a streaming
         :class:`_asyncio.AsyncResult` object.
 
@@ -1012,6 +1115,30 @@ class async_scoped_session:
             **kw,
         )
 
+    @overload
+    async def stream_scalars(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> AsyncScalarResult[_T]:
+        ...
+
+    @overload
+    async def stream_scalars(
+        self,
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> AsyncScalarResult[Any]:
+        ...
+
     async def stream_scalars(
         self,
         statement: Executable,
@@ -1323,7 +1450,7 @@ class async_scoped_session:
         ident: Union[Any, Tuple[Any, ...]] = None,
         *,
         instance: Optional[Any] = None,
-        row: Optional[Row] = None,
+        row: Optional[Union[Row[Any], RowMapping]] = None,
         identity_token: Optional[Any] = None,
     ) -> _IdentityKeyType[Any]:
         r"""Return an identity key.
index 1422f99a393e70282be3bc9ff64af045734ba0cc..f2a69e9cd9f40be2dfdb40bcdbf82e39426af277 100644 (file)
@@ -12,10 +12,12 @@ from typing import Iterable
 from typing import Iterator
 from typing import NoReturn
 from typing import Optional
+from typing import overload
 from typing import Sequence
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from . import engine
@@ -39,11 +41,10 @@ if TYPE_CHECKING:
     from ...engine import Engine
     from ...engine import Result
     from ...engine import Row
+    from ...engine import RowMapping
     from ...engine import ScalarResult
-    from ...engine import Transaction
     from ...engine.interfaces import _CoreAnyExecuteParams
     from ...engine.interfaces import _CoreSingleExecuteParams
-    from ...engine.interfaces import _ExecuteOptions
     from ...engine.interfaces import _ExecuteOptionsParameter
     from ...event import dispatcher
     from ...orm._typing import _IdentityKeyType
@@ -59,9 +60,12 @@ if TYPE_CHECKING:
     from ...sql.base import Executable
     from ...sql.elements import ClauseElement
     from ...sql.selectable import ForUpdateArg
+    from ...sql.selectable import TypedReturnsRows
 
 _AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
 
+_T = TypeVar("_T", bound=Any)
+
 
 class _SyncSessionCallable(Protocol):
     def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any:
@@ -257,6 +261,32 @@ class AsyncSession(ReversibleProxy[Session]):
 
         return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
 
+    @overload
+    async def execute(
+        self,
+        statement: TypedReturnsRows[_T],
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> Result[_T]:
+        ...
+
+    @overload
+    async def execute(
+        self,
+        statement: Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> Result[Any]:
+        ...
+
     async def execute(
         self,
         statement: Executable,
@@ -265,7 +295,7 @@ class AsyncSession(ReversibleProxy[Session]):
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
         **kw: Any,
-    ) -> Result:
+    ) -> Result[Any]:
         """Execute a statement and return a buffered
         :class:`_engine.Result` object.
 
@@ -292,6 +322,30 @@ class AsyncSession(ReversibleProxy[Session]):
         )
         return await _ensure_sync_result(result, self.execute)
 
+    @overload
+    async def scalar(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> Optional[_T]:
+        ...
+
+    @overload
+    async def scalar(
+        self,
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> Any:
+        ...
+
     async def scalar(
         self,
         statement: Executable,
@@ -326,6 +380,30 @@ class AsyncSession(ReversibleProxy[Session]):
         )
         return result
 
+    @overload
+    async def scalars(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> ScalarResult[_T]:
+        ...
+
+    @overload
+    async def scalars(
+        self,
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> ScalarResult[Any]:
+        ...
+
     async def scalars(
         self,
         statement: Executable,
@@ -391,6 +469,30 @@ class AsyncSession(ReversibleProxy[Session]):
         )
         return result_obj
 
+    @overload
+    async def stream(
+        self,
+        statement: TypedReturnsRows[_T],
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> AsyncResult[_T]:
+        ...
+
+    @overload
+    async def stream(
+        self,
+        statement: Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> AsyncResult[Any]:
+        ...
+
     async def stream(
         self,
         statement: Executable,
@@ -399,7 +501,7 @@ class AsyncSession(ReversibleProxy[Session]):
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
         **kw: Any,
-    ) -> AsyncResult:
+    ) -> AsyncResult[Any]:
 
         """Execute a statement and return a streaming
         :class:`_asyncio.AsyncResult` object.
@@ -423,6 +525,30 @@ class AsyncSession(ReversibleProxy[Session]):
         )
         return AsyncResult(result)
 
+    @overload
+    async def stream_scalars(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> AsyncScalarResult[_T]:
+        ...
+
+    @overload
+    async def stream_scalars(
+        self,
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> AsyncScalarResult[Any]:
+        ...
+
     async def stream_scalars(
         self,
         statement: Executable,
@@ -1215,7 +1341,7 @@ class AsyncSession(ReversibleProxy[Session]):
         ident: Union[Any, Tuple[Any, ...]] = None,
         *,
         instance: Optional[Any] = None,
-        row: Optional[Row] = None,
+        row: Optional[Union[Row[Any], RowMapping]] = None,
         identity_token: Optional[Any] = None,
     ) -> _IdentityKeyType[Any]:
         r"""Return an identity key.
index b1138a4ad8ec8930b454dea63bc12f42a388c9f4..c14b466ebdc7e797cb7eb02046aa1a8aca605854 100644 (file)
@@ -23,6 +23,7 @@ from ..orm import base as orm_base
 from ..orm import collections
 from ..orm import exc as orm_exc
 from ..orm import instrumentation as orm_instrumentation
+from ..orm import util as orm_util
 from ..orm.instrumentation import _default_dict_getter
 from ..orm.instrumentation import _default_manager_getter
 from ..orm.instrumentation import _default_opt_manager_getter
@@ -437,5 +438,7 @@ def _install_lookups(lookups):
         attributes.manager_of_class
     ) = orm_instrumentation.manager_of_class = manager_of_class
     orm_base.opt_manager_of_class = (
+        orm_util.opt_manager_of_class
+    ) = (
         attributes.opt_manager_of_class
     ) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class
index 457ad5c5a6b1354cc1136d8e974980b3c98ecdcd..48615b174b1aa8069ee4e68acbe3319439231483 100644 (file)
@@ -38,6 +38,7 @@ from ..exc import InvalidRequestError
 from ..sql.base import SchemaEventTarget
 from ..sql.schema import SchemaConst
 from ..sql.selectable import FromClause
+from ..util.typing import Annotated
 from ..util.typing import Literal
 
 if TYPE_CHECKING:
@@ -45,6 +46,7 @@ if TYPE_CHECKING:
     from ._typing import _ORMColumnExprArgument
     from .descriptor_props import _CompositeAttrType
     from .interfaces import PropComparator
+    from .mapper import Mapper
     from .query import Query
     from .relationships import _LazyLoadArgumentType
     from .relationships import _ORMBackrefArgument
@@ -1849,9 +1851,27 @@ def clear_mappers():
     mapperlib._dispose_registries(mapperlib._all_registries(), False)
 
 
+# I would really like a way to get the Type[] here that shows up
+# in a different way in typing tools, however there is no current method
+# that is accepted by mypy (subclass of Type[_O] works in pylance, rejected
+# by mypy).
+AliasedType = Annotated[Type[_O], "aliased"]
+
+
+@overload
+def aliased(
+    element: Type[_O],
+    alias: Optional[Union[Alias, Subquery]] = None,
+    name: Optional[str] = None,
+    flat: bool = False,
+    adapt_on_names: bool = False,
+) -> AliasedType[_O]:
+    ...
+
+
 @overload
 def aliased(
-    element: _EntityType[_O],
+    element: Union[AliasedClass[_O], Mapper[_O], AliasedInsp[_O]],
     alias: Optional[Union[Alias, Subquery]] = None,
     name: Optional[str] = None,
     flat: bool = False,
@@ -1877,7 +1897,7 @@ def aliased(
     name: Optional[str] = None,
     flat: bool = False,
     adapt_on_names: bool = False,
-) -> Union[AliasedClass[_O], FromClause]:
+) -> Union[AliasedClass[_O], FromClause, AliasedType[_O]]:
     """Produce an alias of the given element, usually an :class:`.AliasedClass`
     instance.
 
@@ -1885,7 +1905,8 @@ def aliased(
 
         my_alias = aliased(MyClass)
 
-        session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id)
+        stmt = select(MyClass, my_alias).filter(MyClass.id > my_alias.id)
+        result = session.execute(stmt)
 
     The :func:`.aliased` function is used to create an ad-hoc mapping of a
     mapped class to a new selectable.  By default, a selectable is generated
@@ -1911,6 +1932,9 @@ def aliased(
 
     .. seealso::
 
+        :class:`.AsAliased` - a :pep:`484` typed version of
+        :func:`_orm.aliased`
+
         :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial`
 
         :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel`
index 41d944c57d3bdd0ee3d3c628a0846ed267124661..619af65104427aeb04d9a96fbe8be0e68441b7bc 100644 (file)
@@ -70,6 +70,7 @@ from .. import exc
 from .. import inspection
 from .. import util
 from ..sql import base as sql_base
+from ..sql import cache_key
 from ..sql import roles
 from ..sql import traversals
 from ..sql import visitors
@@ -99,10 +100,8 @@ class QueryableAttribute(
     traversals.HasCopyInternals,
     roles.JoinTargetRole,
     roles.OnClauseRole,
-    roles.ColumnsClauseRole,
-    roles.ExpressionElementRole[_T],
     sql_base.Immutable,
-    sql_base.MemoizedHasCacheKey,
+    cache_key.MemoizedHasCacheKey,
 ):
     """Base class for :term:`descriptor` objects that intercept
     attribute events on behalf of a :class:`.MapperProperty`
index 054d52d83bd8f5a5f59c1db6c968c81792da8568..367a5332dee061c87da1b3d6fec685eba6e2526b 100644 (file)
@@ -30,6 +30,7 @@ from ._typing import insp_is_mapper
 from .. import exc as sa_exc
 from .. import inspection
 from .. import util
+from ..sql import roles
 from ..sql.elements import SQLCoreOperations
 from ..util import FastIntFlag
 from ..util.langhelpers import TypingOnly
@@ -483,19 +484,6 @@ def _inspect_mapped_class(
         return mapper
 
 
-@inspection._inspects(type)
-def _inspect_mc(class_: Type[_O]) -> Optional[Mapper[_O]]:
-    try:
-        class_manager = opt_manager_of_class(class_)
-        if class_manager is None or not class_manager.is_mapped:
-            return None
-        mapper = class_manager.mapper
-    except exc.NO_STATE:
-        return None
-    else:
-        return mapper
-
-
 def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]:
     insp = inspection.inspect(arg, raiseerr=False)
     if insp_is_mapper(insp):
@@ -691,7 +679,7 @@ class ORMDescriptor(Generic[_T], TypingOnly):
             ...
 
 
-class Mapped(ORMDescriptor[_T], TypingOnly):
+class Mapped(ORMDescriptor[_T], roles.TypedColumnsClauseRole[_T], TypingOnly):
     """Represent an ORM mapped attribute on a mapped class.
 
     This class represents the complete descriptor interface for any class
index 4fee2d383d6558bffbaef58425950935948a7202..05287cbcfd3f41a192cd4e3156fbd7ef700311cc 100644 (file)
@@ -17,6 +17,7 @@ from typing import Set
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from . import attributes
@@ -48,14 +49,15 @@ from ..sql.base import _select_iterables
 from ..sql.base import CacheableOptions
 from ..sql.base import CompileState
 from ..sql.base import Executable
+from ..sql.base import Generative
 from ..sql.base import Options
 from ..sql.dml import UpdateBase
 from ..sql.elements import GroupedElement
 from ..sql.elements import TextClause
+from ..sql.selectable import ExecutableReturnsRows
 from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
 from ..sql.selectable import LABEL_STYLE_NONE
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
-from ..sql.selectable import ReturnsRows
 from ..sql.selectable import Select
 from ..sql.selectable import SelectLabelStyle
 from ..sql.selectable import SelectState
@@ -72,6 +74,7 @@ if TYPE_CHECKING:
     from ..sql.selectable import SelectBase
     from ..sql.type_api import TypeEngine
 
+_T = TypeVar("_T", bound=Any)
 _path_registry = PathRegistry.root
 
 _EMPTY_DICT = util.immutabledict()
@@ -574,7 +577,7 @@ class ORMFromStatementCompileState(ORMCompileState):
         return None
 
 
-class FromStatement(GroupedElement, ReturnsRows, Executable):
+class FromStatement(GroupedElement, Generative, ExecutableReturnsRows):
     """Core construct that represents a load of ORM objects from various
     :class:`.ReturnsRows` and other classes including:
 
index 0ca62b7e3584def7c6d96c6bdcea6d253e230558..6a5690be24ed8dab7178dafef803170f34b3930e 100644 (file)
@@ -61,7 +61,6 @@ from ..sql.schema import Column
 from ..sql.type_api import TypeEngine
 from ..util.typing import TypedDict
 
-
 if typing.TYPE_CHECKING:
     from ._typing import _EntityType
     from ._typing import _IdentityKeyType
@@ -106,12 +105,12 @@ class ORMStatementRole(roles.StatementRole):
     )
 
 
-class ORMColumnsClauseRole(roles.ColumnsClauseRole):
+class ORMColumnsClauseRole(roles.TypedColumnsClauseRole[_T]):
     __slots__ = ()
     _role_name = "ORM mapped entity, aliased entity, or Column expression"
 
 
-class ORMEntityColumnsClauseRole(ORMColumnsClauseRole):
+class ORMEntityColumnsClauseRole(ORMColumnsClauseRole[_T]):
     __slots__ = ()
     _role_name = "ORM mapped or aliased entity"
 
@@ -127,8 +126,8 @@ class ORMColumnDescription(TypedDict):
     # into "type" is a bad idea
     type: Union[Type[Any], TypeEngine[Any]]
     aliased: bool
-    expr: _ColumnsClauseArgument
-    entity: Optional[_ColumnsClauseArgument]
+    expr: _ColumnsClauseArgument[Any]
+    entity: Optional[_ColumnsClauseArgument[Any]]
 
 
 class _IntrospectsAnnotations:
@@ -282,7 +281,7 @@ class MapperProperty(
         query_entity: _MapperEntity,
         path: PathRegistry,
         mapper: Mapper[Any],
-        result: Result,
+        result: Result[Any],
         adapter: Optional[ColumnAdapter],
         populators: _PopulatorDict,
     ) -> None:
@@ -1170,7 +1169,7 @@ class LoaderStrategy:
         path: AbstractEntityRegistry,
         loadopt: Optional[_LoadElement],
         mapper: Mapper[Any],
-        result: Result,
+        result: Result[Any],
         adapter: Optional[ORMAdapter],
         populators: _PopulatorDict,
     ) -> None:
index b37c080eaf190f3041e4e6f7dca78f6021728609..0830350936f4091526a88893e94fc2ac9e472014 100644 (file)
@@ -96,7 +96,6 @@ if TYPE_CHECKING:
     from .descriptor_props import Synonym
     from .events import MapperEvents
     from .instrumentation import ClassManager
-    from .path_registry import AbstractEntityRegistry
     from .path_registry import CachingEntityRegistry
     from .properties import ColumnProperty
     from .relationships import Relationship
@@ -108,10 +107,10 @@ if TYPE_CHECKING:
     from ..sql.base import ReadOnlyColumnCollection
     from ..sql.elements import ColumnClause
     from ..sql.elements import ColumnElement
+    from ..sql.elements import KeyedColumnElement
     from ..sql.schema import Column
     from ..sql.schema import Table
     from ..sql.selectable import FromClause
-    from ..sql.selectable import TableClause
     from ..sql.util import ColumnAdapter
     from ..util import OrderedSet
 
@@ -161,7 +160,7 @@ _CONFIGURE_MUTEX = threading.RLock()
 @log.class_logger
 class Mapper(
     ORMFromClauseRole,
-    ORMEntityColumnsClauseRole,
+    ORMEntityColumnsClauseRole[_O],
     MemoizedHasCacheKey,
     InspectionAttr,
     log.Identified,
@@ -1006,7 +1005,7 @@ class Mapper(
 
     """
 
-    polymorphic_on: Optional[ColumnElement[Any]]
+    polymorphic_on: Optional[KeyedColumnElement[Any]]
     """The :class:`_schema.Column` or SQL expression specified as the
     ``polymorphic_on`` argument
     for this :class:`_orm.Mapper`, within an inheritance scenario.
@@ -1699,10 +1698,10 @@ class Mapper(
                     instrument = True
                 key = getattr(col, "key", None)
                 if key:
-                    if self._should_exclude(col.key, col.key, False, col):
+                    if self._should_exclude(key, key, False, col):
                         raise sa_exc.InvalidRequestError(
                             "Cannot exclude or override the "
-                            "discriminator column %r" % col.key
+                            "discriminator column %r" % key
                         )
                 else:
                     self.polymorphic_on = col = col.label("_sa_polymorphic_on")
@@ -2948,7 +2947,7 @@ class Mapper(
 
     def identity_key_from_row(
         self,
-        row: Optional[Union[Row, RowMapping]],
+        row: Optional[Union[Row[Any], RowMapping]],
         identity_token: Optional[Any] = None,
         adapter: Optional[ColumnAdapter] = None,
     ) -> _IdentityKeyType[_O]:
index 9f37e84571e929e631d0b5206f5135899f4b600f..0ca0559b4523b9c73ca56bde3812f7332cb61d16 100644 (file)
@@ -54,7 +54,7 @@ from ..util.typing import NoneType
 if TYPE_CHECKING:
     from ._typing import _ORMColumnExprArgument
     from ..sql._typing import _InfoType
-    from ..sql.elements import ColumnElement
+    from ..sql.elements import KeyedColumnElement
 
 _T = TypeVar("_T", bound=Any)
 _PT = TypeVar("_PT", bound=Any)
@@ -85,7 +85,8 @@ class ColumnProperty(
     inherit_cache = True
     _links_to_entity = False
 
-    columns: List[ColumnElement[Any]]
+    columns: List[KeyedColumnElement[Any]]
+    _orig_columns: List[KeyedColumnElement[Any]]
 
     _is_polymorphic_discriminator: bool
 
index 395d01a1eac66cc3eef369498481630943e3a039..5bd302b21d10dd0907456a7116fc2c80fd13bd53 100644 (file)
@@ -27,6 +27,8 @@ from typing import Generic
 from typing import Iterable
 from typing import List
 from typing import Optional
+from typing import overload
+from typing import Sequence
 from typing import Tuple
 from typing import TYPE_CHECKING
 from typing import TypeVar
@@ -36,6 +38,7 @@ from . import exc as orm_exc
 from . import interfaces
 from . import loading
 from . import util as orm_util
+from ._typing import _O
 from .base import _assertions
 from .context import _column_descriptions
 from .context import _determine_last_joined_entity
@@ -56,6 +59,7 @@ from .. import log
 from .. import sql
 from .. import util
 from ..engine import Result
+from ..engine import Row
 from ..sql import coercions
 from ..sql import expression
 from ..sql import roles
@@ -63,10 +67,12 @@ from ..sql import Select
 from ..sql import util as sql_util
 from ..sql import visitors
 from ..sql._typing import _FromClauseArgument
+from ..sql._typing import _TP
 from ..sql.annotation import SupportsCloneAnnotations
 from ..sql.base import _entity_namespace_key
 from ..sql.base import _generative
 from ..sql.base import Executable
+from ..sql.base import Generative
 from ..sql.expression import Exists
 from ..sql.selectable import _MemoizedSelectEntities
 from ..sql.selectable import _SelectFromElements
@@ -75,10 +81,33 @@ from ..sql.selectable import HasHints
 from ..sql.selectable import HasPrefixes
 from ..sql.selectable import HasSuffixes
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util.typing import Literal
 
 if TYPE_CHECKING:
+    from ._typing import _EntityType
+    from .session import Session
+    from ..engine.result import ScalarResult
+    from ..engine.row import Row
+    from ..sql._typing import _ColumnExpressionArgument
+    from ..sql._typing import _ColumnsClauseArgument
+    from ..sql._typing import _MAYBE_ENTITY
+    from ..sql._typing import _no_kw
+    from ..sql._typing import _NOT_ENTITY
+    from ..sql._typing import _PropagateAttrsType
+    from ..sql._typing import _T0
+    from ..sql._typing import _T1
+    from ..sql._typing import _T2
+    from ..sql._typing import _T3
+    from ..sql._typing import _T4
+    from ..sql._typing import _T5
+    from ..sql._typing import _T6
+    from ..sql._typing import _T7
+    from ..sql._typing import _TypedColumnClauseArgument as _TCCA
+    from ..sql.roles import TypedColumnsClauseRole
     from ..sql.selectable import _SetupJoinsElement
     from ..sql.selectable import Alias
+    from ..sql.selectable import ExecutableReturnsRows
+    from ..sql.selectable import ScalarSelect
     from ..sql.selectable import Subquery
 
 __all__ = ["Query", "QueryContext"]
@@ -97,6 +126,7 @@ class Query(
     HasSuffixes,
     HasHints,
     log.Identified,
+    Generative,
     Executable,
     Generic[_T],
 ):
@@ -159,9 +189,15 @@ class Query(
     # mirrors that of ClauseElement, used to propagate the "orm"
     # plugin as well as the "subject" of the plugin, e.g. the mapper
     # we are querying against.
-    _propagate_attrs = util.immutabledict()
+    @util.memoized_property
+    def _propagate_attrs(self) -> _PropagateAttrsType:
+        return util.EMPTY_DICT
 
-    def __init__(self, entities, session=None):
+    def __init__(
+        self,
+        entities: Sequence[_ColumnsClauseArgument[Any]],
+        session: Optional[Session] = None,
+    ):
         """Construct a :class:`_query.Query` directly.
 
         E.g.::
@@ -207,6 +243,36 @@ class Query(
             for ent in util.to_list(entities)
         ]
 
+    @overload
+    def tuples(self: Query[Row[_TP]]) -> Query[_TP]:
+        ...
+
+    @overload
+    def tuples(self: Query[_O]) -> Query[Tuple[_O]]:
+        ...
+
+    def tuples(self) -> Query[Any]:
+        """return a tuple-typed form of this :class:`.Query`.
+
+        This method invokes the :meth:`.Query.only_return_tuples`
+        method with a value of ``True``, which by itself ensures that this
+        :class:`.Query` will always return :class:`.Row` objects, even
+        if the query is made against a single entity.  It then also
+        at the typing level will return a "typed" query, if possible,
+        that will type result rows as ``Tuple`` objects with typed
+        elements.
+
+        This method can be compared to the :meth:`.Result.tuples` method,
+        which returns "self", but from a typing perspective returns an object
+        that will yield typed ``Tuple`` objects for results.   Typing
+        takes effect only if this :class:`.Query` object is a typed
+        query object already.
+
+        .. versionadded:: 2.0
+
+        """
+        return self.only_return_tuples(True)
+
     def _entity_from_pre_ent_zero(self):
         if not self._raw_columns:
             return None
@@ -582,20 +648,52 @@ class Query(
 
         return self.enable_eagerloads(False).statement.label(name)
 
+    @overload
+    def as_scalar(
+        self: Query[Tuple[_MAYBE_ENTITY]],
+    ) -> ScalarSelect[_MAYBE_ENTITY]:
+        ...
+
+    @overload
+    def as_scalar(
+        self: Query[Tuple[_NOT_ENTITY]],
+    ) -> ScalarSelect[_NOT_ENTITY]:
+        ...
+
+    @overload
+    def as_scalar(self) -> ScalarSelect[Any]:
+        ...
+
     @util.deprecated(
         "1.4",
         "The :meth:`_query.Query.as_scalar` method is deprecated and will be "
         "removed in a future release.  Please refer to "
         ":meth:`_query.Query.scalar_subquery`.",
     )
-    def as_scalar(self):
+    def as_scalar(self) -> ScalarSelect[Any]:
         """Return the full SELECT statement represented by this
         :class:`_query.Query`, converted to a scalar subquery.
 
         """
         return self.scalar_subquery()
 
-    def scalar_subquery(self):
+    @overload
+    def scalar_subquery(
+        self: Query[Tuple[_MAYBE_ENTITY]],
+    ) -> ScalarSelect[Any]:
+        ...
+
+    @overload
+    def scalar_subquery(
+        self: Query[Tuple[_NOT_ENTITY]],
+    ) -> ScalarSelect[_NOT_ENTITY]:
+        ...
+
+    @overload
+    def scalar_subquery(self) -> ScalarSelect[Any]:
+        ...
+
+    def scalar_subquery(self) -> ScalarSelect[Any]:
         """Return the full SELECT statement represented by this
         :class:`_query.Query`, converted to a scalar subquery.
 
@@ -630,16 +728,31 @@ class Query(
             .statement
         )
 
-    @_generative
-    def only_return_tuples(self: SelfQuery, value) -> SelfQuery:
-        """When set to True, the query results will always be a tuple.
+    @overload
+    def only_return_tuples(
+        self: Query[_O], value: Literal[True]
+    ) -> RowReturningQuery[Tuple[_O]]:
+        ...
 
-        This is specifically for single element queries. The default is False.
+    @overload
+    def only_return_tuples(
+        self: Query[_O], value: Literal[False]
+    ) -> Query[_O]:
+        ...
 
-        .. versionadded:: 1.2.5
+    @_generative
+    def only_return_tuples(self, value: bool) -> Query[Any]:
+        """When set to True, the query results will always be a
+        :class:`.Row` object.
+
+        This can change a query that normally returns a single entity
+        as a scalar to return a :class:`.Row` result in all cases.
 
         .. seealso::
 
+            :meth:`.Query.tuples` - returns tuples, but also at the typing
+            level will type results as ``Tuple``.
+
             :meth:`_query.Query.is_single_entity`
 
         """
@@ -1077,7 +1190,11 @@ class Query(
         return self.filter(with_parent(instance, property, entity_zero.entity))
 
     @_generative
-    def add_entity(self: SelfQuery, entity, alias=None) -> SelfQuery:
+    def add_entity(
+        self,
+        entity: _EntityType[Any],
+        alias: Optional[Union[Alias, Subquery]] = None,
+    ) -> Query[Any]:
         """add a mapped entity to the list of result columns
         to be returned."""
 
@@ -1209,8 +1326,107 @@ class Query(
         except StopIteration:
             return None
 
+    @overload
+    def with_entities(
+        self, _entity: _EntityType[_O], **kwargs: Any
+    ) -> ScalarInstanceQuery[_O]:
+        ...
+
+    @overload
+    def with_entities(
+        self, _colexpr: TypedColumnsClauseRole[_T]
+    ) -> RowReturningQuery[Tuple[_T]]:
+        ...
+
+    # START OVERLOADED FUNCTIONS self.with_entities RowReturningQuery 2-8
+
+    # code within this block is **programmatically,
+    # statically generated** by tools/generate_tuple_map_overloads.py
+
+    @overload
+    def with_entities(
+        self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+    ) -> RowReturningQuery[Tuple[_T0, _T1]]:
+        ...
+
+    @overload
+    def with_entities(
+        self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]:
+        ...
+
+    @overload
+    def with_entities(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]:
+        ...
+
+    @overload
+    def with_entities(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+        ...
+
+    @overload
+    def with_entities(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+        __ent5: _TCCA[_T5],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+        ...
+
+    @overload
+    def with_entities(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+        __ent5: _TCCA[_T5],
+        __ent6: _TCCA[_T6],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+        ...
+
+    @overload
+    def with_entities(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+        __ent5: _TCCA[_T5],
+        __ent6: _TCCA[_T6],
+        __ent7: _TCCA[_T7],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+        ...
+
+    # END OVERLOADED FUNCTIONS self.with_entities
+
+    @overload
+    def with_entities(
+        self: SelfQuery, *entities: _ColumnsClauseArgument[Any]
+    ) -> SelfQuery:
+        ...
+
     @_generative
-    def with_entities(self: SelfQuery, *entities) -> SelfQuery:
+    def with_entities(
+        self: SelfQuery, *entities: _ColumnsClauseArgument[Any], **__kw: Any
+    ) -> SelfQuery:
         r"""Return a new :class:`_query.Query`
         replacing the SELECT list with the
         given entities.
@@ -1234,12 +1450,14 @@ class Query(
                         limit(1)
 
         """
+        if __kw:
+            raise _no_kw()
         _MemoizedSelectEntities._generate_for_statement(self)
         self._set_entities(entities)
         return self
 
     @_generative
-    def add_columns(self: SelfQuery, *column) -> SelfQuery:
+    def add_columns(self, *column: _ColumnExpressionArgument) -> Query[Any]:
         """Add one or more column expressions to the list
         of result columns to be returned."""
 
@@ -1262,7 +1480,7 @@ class Query(
         "is deprecated and will be removed in a "
         "future release.  Please use :meth:`_query.Query.add_columns`",
     )
-    def add_column(self, column):
+    def add_column(self, column) -> Query[Any]:
         """Add a column expression to the list of result columns to be
         returned.
 
@@ -1472,7 +1690,9 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition, _no_limit_offset)
-    def filter(self: SelfQuery, *criterion) -> SelfQuery:
+    def filter(
+        self: SelfQuery, *criterion: _ColumnExpressionArgument[bool]
+    ) -> SelfQuery:
         r"""Apply the given filtering criterion to a copy
         of this :class:`_query.Query`, using SQL expressions.
 
@@ -1556,7 +1776,7 @@ class Query(
 
         return self._raw_columns[0]
 
-    def filter_by(self, **kwargs):
+    def filter_by(self: SelfQuery, **kwargs: Any) -> SelfQuery:
         r"""Apply the given filtering criterion to a copy
         of this :class:`_query.Query`, using keyword expressions.
 
@@ -1597,7 +1817,9 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition, _no_limit_offset)
-    def order_by(self: SelfQuery, *clauses) -> SelfQuery:
+    def order_by(
+        self: SelfQuery, *clauses: _ColumnExpressionArgument[Any]
+    ) -> SelfQuery:
         """Apply one or more ORDER BY criteria to the query and return
         the newly resulting :class:`_query.Query`.
 
@@ -1635,7 +1857,9 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition, _no_limit_offset)
-    def group_by(self: SelfQuery, *clauses) -> SelfQuery:
+    def group_by(
+        self: SelfQuery, *clauses: _ColumnExpressionArgument[Any]
+    ) -> SelfQuery:
         """Apply one or more GROUP BY criterion to the query and return
         the newly resulting :class:`_query.Query`.
 
@@ -1667,7 +1891,9 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition, _no_limit_offset)
-    def having(self: SelfQuery, criterion) -> SelfQuery:
+    def having(
+        self: SelfQuery, *having: _ColumnExpressionArgument[bool]
+    ) -> SelfQuery:
         r"""Apply a HAVING criterion to the query and return the
         newly resulting :class:`_query.Query`.
 
@@ -1684,17 +1910,17 @@ class Query(
 
         """
 
-        self._having_criteria += (
-            coercions.expect(
-                roles.WhereHavingRole, criterion, apply_propagate_attrs=self
-            ),
-        )
+        for criterion in having:
+            having_criteria = coercions.expect(
+                roles.WhereHavingRole, criterion
+            )
+            self._having_criteria += (having_criteria,)
         return self
 
     def _set_op(self, expr_fn, *q):
         return self._from_selectable(expr_fn(*([self] + list(q))).subquery())
 
-    def union(self, *q):
+    def union(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
         """Produce a UNION of this Query against one or more queries.
 
         e.g.::
@@ -1733,7 +1959,7 @@ class Query(
         """
         return self._set_op(expression.union, *q)
 
-    def union_all(self, *q):
+    def union_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
         """Produce a UNION ALL of this Query against one or more queries.
 
         Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -1742,7 +1968,7 @@ class Query(
         """
         return self._set_op(expression.union_all, *q)
 
-    def intersect(self, *q):
+    def intersect(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
         """Produce an INTERSECT of this Query against one or more queries.
 
         Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -1751,7 +1977,7 @@ class Query(
         """
         return self._set_op(expression.intersect, *q)
 
-    def intersect_all(self, *q):
+    def intersect_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
         """Produce an INTERSECT ALL of this Query against one or more queries.
 
         Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -1760,7 +1986,7 @@ class Query(
         """
         return self._set_op(expression.intersect_all, *q)
 
-    def except_(self, *q):
+    def except_(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
         """Produce an EXCEPT of this Query against one or more queries.
 
         Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -1769,7 +1995,7 @@ class Query(
         """
         return self._set_op(expression.except_, *q)
 
-    def except_all(self, *q):
+    def except_all(self: SelfQuery, *q: Query[Any]) -> SelfQuery:
         """Produce an EXCEPT ALL of this Query against one or more queries.
 
         Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
@@ -2194,7 +2420,9 @@ class Query(
 
     @_generative
     @_assertions(_no_clauseelement_condition)
-    def from_statement(self: SelfQuery, statement) -> SelfQuery:
+    def from_statement(
+        self: SelfQuery, statement: ExecutableReturnsRows
+    ) -> SelfQuery:
         """Execute the given SELECT statement and return results.
 
         This method bypasses all internal statement compilation, and the
@@ -2283,7 +2511,7 @@ class Query(
             :meth:`_query.Query.one_or_none`
 
         """
-        return self._iter().one()
+        return self._iter().one()  # type: ignore
 
     def scalar(self) -> Any:
         """Return the first element of the first result or None
@@ -2316,7 +2544,7 @@ class Query(
     def __iter__(self) -> Iterable[_T]:
         return self._iter().__iter__()
 
-    def _iter(self):
+    def _iter(self) -> Union[ScalarResult[_T], Result[_T]]:
         # new style execution.
         params = self._params
 
@@ -2837,3 +3065,7 @@ class BulkUpdate(BulkUD):
 
 class BulkDelete(BulkUD):
     """BulkUD which handles DELETEs."""
+
+
+class RowReturningQuery(Query[Row[_TP]]):
+    pass
index 93d18b8d79e2dc6071bed167a4c26f3159ca0aff..9220c44c7fb8790efe282e36d3a3cea75978fb19 100644 (file)
@@ -13,6 +13,7 @@ from typing import Dict
 from typing import Iterable
 from typing import Iterator
 from typing import Optional
+from typing import overload
 from typing import Sequence
 from typing import Tuple
 from typing import Type
@@ -20,8 +21,6 @@ from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
-from . import exc as orm_exc
-from .base import class_mapper
 from .session import Session
 from .. import exc as sa_exc
 from .. import util
@@ -33,11 +32,13 @@ from ..util import warn_deprecated
 from ..util.typing import Protocol
 
 if TYPE_CHECKING:
+    from ._typing import _EntityType
     from ._typing import _IdentityKeyType
     from .identity import IdentityMap
     from .interfaces import ORMOption
     from .mapper import Mapper
     from .query import Query
+    from .query import RowReturningQuery
     from .session import _BindArguments
     from .session import _EntityBindKey
     from .session import _PKIdentityArgument
@@ -48,19 +49,33 @@ if TYPE_CHECKING:
     from ..engine import Engine
     from ..engine import Result
     from ..engine import Row
+    from ..engine import RowMapping
     from ..engine.interfaces import _CoreAnyExecuteParams
     from ..engine.interfaces import _CoreSingleExecuteParams
     from ..engine.interfaces import _ExecuteOptions
     from ..engine.interfaces import _ExecuteOptionsParameter
     from ..engine.result import ScalarResult
     from ..sql._typing import _ColumnsClauseArgument
+    from ..sql._typing import _T0
+    from ..sql._typing import _T1
+    from ..sql._typing import _T2
+    from ..sql._typing import _T3
+    from ..sql._typing import _T4
+    from ..sql._typing import _T5
+    from ..sql._typing import _T6
+    from ..sql._typing import _T7
+    from ..sql._typing import _TypedColumnClauseArgument as _TCCA
     from ..sql.base import Executable
     from ..sql.elements import ClauseElement
+    from ..sql.roles import TypedColumnsClauseRole
     from ..sql.selectable import ForUpdateArg
+    from ..sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
 
 
 class _QueryDescriptorType(Protocol):
-    def __get__(self, instance: Any, owner: Type[Any]) -> Optional[Query[Any]]:
+    def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]:
         ...
 
 
@@ -236,7 +251,7 @@ class scoped_session:
         self.registry.clear()
 
     def query_property(
-        self, query_cls: Optional[Type[Query[Any]]] = None
+        self, query_cls: Optional[Type[Query[_T]]] = None
     ) -> _QueryDescriptorType:
         """return a class property which produces a :class:`_query.Query`
         object
@@ -264,20 +279,13 @@ class scoped_session:
         """
 
         class query:
-            def __get__(
-                s, instance: Any, owner: Type[Any]
-            ) -> Optional[Query[Any]]:
-                try:
-                    mapper = class_mapper(owner)
-                    assert mapper is not None
-                    if query_cls:
-                        # custom query class
-                        return query_cls(mapper, session=self.registry())
-                    else:
-                        # session's configured query class
-                        return self.registry().query(mapper)
-                except orm_exc.UnmappedClassError:
-                    return None
+            def __get__(s, instance: Any, owner: Type[_O]) -> Query[_O]:
+                if query_cls:
+                    # custom query class
+                    return query_cls(owner, session=self.registry())  # type: ignore  # noqa: E501
+                else:
+                    # session's configured query class
+                    return self.registry().query(owner)
 
         return query()
 
@@ -548,6 +556,32 @@ class scoped_session:
 
         return self._proxied.delete(instance)
 
+    @overload
+    def execute(
+        self,
+        statement: TypedReturnsRows[_T],
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> Result[_T]:
+        ...
+
+    @overload
+    def execute(
+        self,
+        statement: Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> Result[Any]:
+        ...
+
     def execute(
         self,
         statement: Executable,
@@ -557,7 +591,7 @@ class scoped_session:
         bind_arguments: Optional[_BindArguments] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
-    ) -> Result:
+    ) -> Result[Any]:
         r"""Execute a SQL expression construct.
 
         .. container:: class_bases
@@ -1430,8 +1464,103 @@ class scoped_session:
 
         return self._proxied.merge(instance, load=load, options=options)
 
+    @overload
+    def query(self, _entity: _EntityType[_O]) -> Query[_O]:
+        ...
+
+    @overload
     def query(
-        self, *entities: _ColumnsClauseArgument, **kwargs: Any
+        self, _colexpr: TypedColumnsClauseRole[_T]
+    ) -> RowReturningQuery[Tuple[_T]]:
+        ...
+
+    # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8
+
+    # code within this block is **programmatically,
+    # statically generated** by tools/generate_tuple_map_overloads.py
+
+    @overload
+    def query(
+        self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+    ) -> RowReturningQuery[Tuple[_T0, _T1]]:
+        ...
+
+    @overload
+    def query(
+        self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]:
+        ...
+
+    @overload
+    def query(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]:
+        ...
+
+    @overload
+    def query(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+        ...
+
+    @overload
+    def query(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+        __ent5: _TCCA[_T5],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+        ...
+
+    @overload
+    def query(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+        __ent5: _TCCA[_T5],
+        __ent6: _TCCA[_T6],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+        ...
+
+    @overload
+    def query(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+        __ent5: _TCCA[_T5],
+        __ent6: _TCCA[_T6],
+        __ent7: _TCCA[_T7],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+        ...
+
+    # END OVERLOADED FUNCTIONS self.query
+
+    @overload
+    def query(
+        self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
+    ) -> Query[Any]:
+        ...
+
+    def query(
+        self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
     ) -> Query[Any]:
         r"""Return a new :class:`_query.Query` object corresponding to this
         :class:`_orm.Session`.
@@ -1559,6 +1688,30 @@ class scoped_session:
 
         return self._proxied.rollback()
 
+    @overload
+    def scalar(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> Optional[_T]:
+        ...
+
+    @overload
+    def scalar(
+        self,
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> Any:
+        ...
+
     def scalar(
         self,
         statement: Executable,
@@ -1590,6 +1743,30 @@ class scoped_session:
             **kw,
         )
 
+    @overload
+    def scalars(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> ScalarResult[_T]:
+        ...
+
+    @overload
+    def scalars(
+        self,
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> ScalarResult[Any]:
+        ...
+
     def scalars(
         self,
         statement: Executable,
@@ -1848,7 +2025,7 @@ class scoped_session:
         ident: Union[Any, Tuple[Any, ...]] = None,
         *,
         instance: Optional[Any] = None,
-        row: Optional[Row] = None,
+        row: Optional[Union[Row[Any], RowMapping]] = None,
         identity_token: Optional[Any] = None,
     ) -> _IdentityKeyType[Any]:
         r"""Return an identity key.
index 74035ec0aaa33e65f968a3176eaccf67ba10760b..263d5610193b9ddae01e93533764fc72771e0f7c 100644 (file)
@@ -27,6 +27,7 @@ from typing import Set
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 import weakref
 
@@ -85,12 +86,16 @@ from ..util.typing import Literal
 from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
+    from ._typing import _EntityType
     from ._typing import _IdentityKeyType
     from ._typing import _InstanceDict
+    from ._typing import _O
+    from .context import FromStatement
     from .interfaces import ORMOption
     from .interfaces import UserDefinedOption
     from .mapper import Mapper
     from .path_registry import PathRegistry
+    from .query import RowReturningQuery
     from ..engine import Result
     from ..engine import Row
     from ..engine import RowMapping
@@ -104,10 +109,23 @@ if typing.TYPE_CHECKING:
     from ..event import _InstanceLevelDispatch
     from ..sql._typing import _ColumnsClauseArgument
     from ..sql._typing import _InfoType
+    from ..sql._typing import _T0
+    from ..sql._typing import _T1
+    from ..sql._typing import _T2
+    from ..sql._typing import _T3
+    from ..sql._typing import _T4
+    from ..sql._typing import _T5
+    from ..sql._typing import _T6
+    from ..sql._typing import _T7
+    from ..sql._typing import _TypedColumnClauseArgument as _TCCA
     from ..sql.base import Executable
     from ..sql.elements import ClauseElement
+    from ..sql.roles import TypedColumnsClauseRole
     from ..sql.schema import Table
-    from ..sql.selectable import TableClause
+    from ..sql.selectable import Select
+    from ..sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
 
 __all__ = [
     "Session",
@@ -189,7 +207,7 @@ class _SessionClassMethods:
         ident: Union[Any, Tuple[Any, ...]] = None,
         *,
         instance: Optional[Any] = None,
-        row: Optional[Union[Row, RowMapping]] = None,
+        row: Optional[Union[Row[Any], RowMapping]] = None,
         identity_token: Optional[Any] = None,
     ) -> _IdentityKeyType[Any]:
         """Return an identity key.
@@ -295,7 +313,7 @@ class ORMExecuteState(util.MemoizedSlots):
         params: Optional[_CoreAnyExecuteParams] = None,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
         bind_arguments: Optional[_BindArguments] = None,
-    ) -> Result:
+    ) -> Result[Any]:
         """Execute the statement represented by this
         :class:`.ORMExecuteState`, without re-invoking events that have
         already proceeded.
@@ -1718,7 +1736,7 @@ class Session(_SessionClassMethods, EventTarget):
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
         _scalar_result: bool = ...,
-    ) -> Result:
+    ) -> Result[Any]:
         ...
 
     def _execute_internal(
@@ -1789,7 +1807,7 @@ class Session(_SessionClassMethods, EventTarget):
             )
             for idx, fn in enumerate(events_todo):
                 orm_exec_state._starting_event_idx = idx
-                fn_result: Optional[Result] = fn(orm_exec_state)
+                fn_result: Optional[Result[Any]] = fn(orm_exec_state)
                 if fn_result:
                     if _scalar_result:
                         return fn_result.scalar()
@@ -1806,10 +1824,12 @@ class Session(_SessionClassMethods, EventTarget):
         if _scalar_result and not compile_state_cls:
             if TYPE_CHECKING:
                 params = cast(_CoreSingleExecuteParams, params)
-            return conn.scalar(statement, params or {}, execution_options)
+            return conn.scalar(
+                statement, params or {}, execution_options=execution_options
+            )
 
-        result: Result = conn.execute(
-            statement, params or {}, execution_options
+        result: Result[Any] = conn.execute(
+            statement, params or {}, execution_options=execution_options
         )
 
         if compile_state_cls:
@@ -1827,6 +1847,32 @@ class Session(_SessionClassMethods, EventTarget):
         else:
             return result
 
+    @overload
+    def execute(
+        self,
+        statement: TypedReturnsRows[_T],
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> Result[_T]:
+        ...
+
+    @overload
+    def execute(
+        self,
+        statement: Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> Result[Any]:
+        ...
+
     def execute(
         self,
         statement: Executable,
@@ -1836,7 +1882,7 @@ class Session(_SessionClassMethods, EventTarget):
         bind_arguments: Optional[_BindArguments] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
-    ) -> Result:
+    ) -> Result[Any]:
         r"""Execute a SQL expression construct.
 
         Returns a :class:`_engine.Result` object representing
@@ -1897,6 +1943,30 @@ class Session(_SessionClassMethods, EventTarget):
             _add_event=_add_event,
         )
 
+    @overload
+    def scalar(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> Optional[_T]:
+        ...
+
+    @overload
+    def scalar(
+        self,
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> Any:
+        ...
+
     def scalar(
         self,
         statement: Executable,
@@ -1923,6 +1993,30 @@ class Session(_SessionClassMethods, EventTarget):
             **kw,
         )
 
+    @overload
+    def scalars(
+        self,
+        statement: TypedReturnsRows[Tuple[_T]],
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> ScalarResult[_T]:
+        ...
+
+    @overload
+    def scalars(
+        self,
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        *,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> ScalarResult[Any]:
+        ...
+
     def scalars(
         self,
         statement: Executable,
@@ -2284,8 +2378,103 @@ class Session(_SessionClassMethods, EventTarget):
             f'{", ".join(context)} or this Session.'
         )
 
+    @overload
+    def query(self, _entity: _EntityType[_O]) -> Query[_O]:
+        ...
+
+    @overload
+    def query(
+        self, _colexpr: TypedColumnsClauseRole[_T]
+    ) -> RowReturningQuery[Tuple[_T]]:
+        ...
+
+    # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8
+
+    # code within this block is **programmatically,
+    # statically generated** by tools/generate_tuple_map_overloads.py
+
+    @overload
+    def query(
+        self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+    ) -> RowReturningQuery[Tuple[_T0, _T1]]:
+        ...
+
+    @overload
+    def query(
+        self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]:
+        ...
+
+    @overload
+    def query(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]:
+        ...
+
+    @overload
+    def query(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+        ...
+
+    @overload
+    def query(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+        __ent5: _TCCA[_T5],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+        ...
+
+    @overload
+    def query(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+        __ent5: _TCCA[_T5],
+        __ent6: _TCCA[_T6],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+        ...
+
+    @overload
+    def query(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+        __ent5: _TCCA[_T5],
+        __ent6: _TCCA[_T6],
+        __ent7: _TCCA[_T7],
+    ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+        ...
+
+    # END OVERLOADED FUNCTIONS self.query
+
+    @overload
+    def query(
+        self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
+    ) -> Query[Any]:
+        ...
+
     def query(
-        self, *entities: _ColumnsClauseArgument, **kwargs: Any
+        self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
     ) -> Query[Any]:
         """Return a new :class:`_query.Query` object corresponding to this
         :class:`_orm.Session`.
@@ -2486,7 +2675,7 @@ class Session(_SessionClassMethods, EventTarget):
 
         with_for_update = ForUpdateArg._from_argument(with_for_update)
 
-        stmt = sql.select(object_mapper(instance))
+        stmt: Select[Any] = sql.select(object_mapper(instance))
         if (
             loading.load_on_ident(
                 self,
index 58f141997e0a2a9e1c2f2f703db2cfb5ee4fc31c..ab32a3981a9e715be6d03f2e951eec2a69a4cc9f 100644 (file)
@@ -656,13 +656,13 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
     @classmethod
     def _instance_level_callable_processor(
         cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any
-    ) -> Callable[[InstanceState[_O], _InstanceDict, Row], None]:
+    ) -> Callable[[InstanceState[_O], _InstanceDict, Row[Any]], None]:
         impl = manager[key].impl
         if is_collection_impl(impl):
             fixed_impl = impl
 
             def _set_callable(
-                state: InstanceState[_O], dict_: _InstanceDict, row: Row
+                state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any]
             ) -> None:
                 if "callables" not in state.__dict__:
                     state.callables = {}
@@ -674,7 +674,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
         else:
 
             def _set_callable(
-                state: InstanceState[_O], dict_: _InstanceDict, row: Row
+                state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any]
             ) -> None:
                 if "callables" not in state.__dict__:
                     state.callables = {}
index 3934de5355325803ec906110e5fee9bd9f1d3f5c..8148793b122961aacb2eb705aa3ca32196a7f314 100644 (file)
@@ -28,6 +28,7 @@ from typing import Union
 import weakref
 
 from . import attributes  # noqa
+from . import exc
 from ._typing import _O
 from ._typing import insp_is_aliased_class
 from ._typing import insp_is_mapper
@@ -41,6 +42,7 @@ from .base import InspectionAttr as InspectionAttr
 from .base import instance_str as instance_str
 from .base import object_mapper as object_mapper
 from .base import object_state as object_state
+from .base import opt_manager_of_class
 from .base import state_attribute_str as state_attribute_str
 from .base import state_class_str as state_class_str
 from .base import state_str as state_str
@@ -68,6 +70,7 @@ from ..sql.base import ColumnCollection
 from ..sql.cache_key import HasCacheKey
 from ..sql.cache_key import MemoizedHasCacheKey
 from ..sql.elements import ColumnElement
+from ..sql.elements import KeyedColumnElement
 from ..sql.selectable import FromClause
 from ..util.langhelpers import MemoizedSlots
 from ..util.typing import de_stringify_annotation
@@ -95,9 +98,7 @@ if typing.TYPE_CHECKING:
     from ..sql.selectable import _ColumnsClauseElement
     from ..sql.selectable import Alias
     from ..sql.selectable import Subquery
-    from ..sql.visitors import _ET
     from ..sql.visitors import anon_map
-    from ..sql.visitors import ExternallyTraversible
 
 _T = TypeVar("_T", bound=Any)
 
@@ -341,7 +342,7 @@ def identity_key(
     ident: Union[Any, Tuple[Any, ...]] = None,
     *,
     instance: Optional[_T] = None,
-    row: Optional[Union[Row, RowMapping]] = None,
+    row: Optional[Union[Row[Any], RowMapping]] = None,
     identity_token: Optional[Any] = None,
 ) -> _IdentityKeyType[_T]:
     r"""Generate "identity key" tuples, as are used as keys in the
@@ -468,7 +469,9 @@ class ORMAdapter(sql_util.ColumnAdapter):
         return not entity or entity.isa(self.mapper)
 
 
-class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]):
+class AliasedClass(
+    inspection.Inspectable["AliasedInsp[_O]"], ORMColumnsClauseRole[_O]
+):
     r"""Represents an "aliased" form of a mapped class for usage with Query.
 
     The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias`
@@ -663,7 +666,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]):
 
 @inspection._self_inspects
 class AliasedInsp(
-    ORMEntityColumnsClauseRole,
+    ORMEntityColumnsClauseRole[_O],
     ORMFromClauseRole,
     HasCacheKey,
     InspectionAttr,
@@ -1276,12 +1279,29 @@ class LoaderCriteriaOption(CriteriaOption):
 inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
 
 
+@inspection._inspects(type)
+def _inspect_mc(
+    class_: Type[_O],
+) -> Optional[Mapper[_O]]:
+
+    try:
+        class_manager = opt_manager_of_class(class_)
+        if class_manager is None or not class_manager.is_mapped:
+            return None
+        mapper = class_manager.mapper
+    except exc.NO_STATE:
+
+        return None
+    else:
+        return mapper
+
+
 @inspection._self_inspects
 class Bundle(
-    ORMColumnsClauseRole,
+    ORMColumnsClauseRole[_T],
     SupportsCloneAnnotations,
     MemoizedHasCacheKey,
-    inspection.Inspectable["Bundle"],
+    inspection.Inspectable["Bundle[_T]"],
     InspectionAttr,
 ):
     """A grouping of SQL expressions that are returned by a :class:`.Query`
@@ -1373,10 +1393,10 @@ class Bundle(
     @property
     def entity_namespace(
         self,
-    ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]:
+    ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         return self.c
 
-    columns: ReadOnlyColumnCollection[str, ColumnElement[Any]]
+    columns: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]
 
     """A namespace of SQL expressions referred to by this :class:`.Bundle`.
 
@@ -1402,7 +1422,7 @@ class Bundle(
 
     """
 
-    c: ReadOnlyColumnCollection[str, ColumnElement[Any]]
+    c: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]
     """An alias for :attr:`.Bundle.columns`."""
 
     def _clone(self):
@@ -1908,9 +1928,10 @@ def _extract_mapped_subtype(
     raw_annotation: Union[type, str],
     cls: type,
     key: str,
-    attr_cls: type,
+    attr_cls: Type[Any],
     required: bool,
     is_dataclass_field: bool,
+    superclasses: Optional[Tuple[Type[Any], ...]] = None,
 ) -> Optional[Union[type, str]]:
 
     if raw_annotation is None:
@@ -1930,9 +1951,13 @@ def _extract_mapped_subtype(
     if is_dataclass_field:
         return annotated
     else:
-        if (
-            not hasattr(annotated, "__origin__")
-            or not issubclass(annotated.__origin__, attr_cls)  # type: ignore
+        # TODO: there don't seem to be tests for the failure
+        # conditions here
+        if not hasattr(annotated, "__origin__") or (
+            not issubclass(
+                annotated.__origin__,  # type: ignore
+                superclasses if superclasses else attr_cls,
+            )
             and not issubclass(attr_cls, annotated.__origin__)  # type: ignore
         ):
             our_annotated_str = (
index 84913225d7bcdbe45b043b4628eeba5a06146715..c3ebb4596066b47e82c36050b6f87a2fcf6d9aa8 100644 (file)
@@ -121,7 +121,6 @@ def __go(lcls: Any) -> None:
     coercions.lambdas = lambdas
     coercions.schema = schema
     coercions.selectable = selectable
-    coercions.traversals = traversals
 
     from .annotation import _prepare_annotations
     from .annotation import Annotated
index 37d44976a278d9e4c1b9c84ece9579d720c9c195..f89e8f578d8d7c87eaa753054f9172d4e1ac74a6 100644 (file)
@@ -9,12 +9,16 @@ from __future__ import annotations
 
 from typing import Any
 from typing import Optional
+from typing import overload
+from typing import Tuple
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from . import coercions
 from . import roles
 from ._typing import _ColumnsClauseArgument
+from ._typing import _no_kw
 from .elements import ColumnClause
 from .selectable import Alias
 from .selectable import CompoundSelect
@@ -34,6 +38,17 @@ if TYPE_CHECKING:
     from ._typing import _FromClauseArgument
     from ._typing import _OnClauseArgument
     from ._typing import _SelectStatementForCompoundArgument
+    from ._typing import _T0
+    from ._typing import _T1
+    from ._typing import _T2
+    from ._typing import _T3
+    from ._typing import _T4
+    from ._typing import _T5
+    from ._typing import _T6
+    from ._typing import _T7
+    from ._typing import _T8
+    from ._typing import _T9
+    from ._typing import _TypedColumnClauseArgument as _TCCA
     from .functions import Function
     from .selectable import CTE
     from .selectable import HasCTE
@@ -41,6 +56,9 @@ if TYPE_CHECKING:
     from .selectable import SelectBase
 
 
+_T = TypeVar("_T", bound=Any)
+
+
 def alias(
     selectable: FromClause, name: Optional[str] = None, flat: bool = False
 ) -> NamedFromClause:
@@ -89,7 +107,9 @@ def cte(
     )
 
 
-def except_(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
+def except_(
+    *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
     r"""Return an ``EXCEPT`` of multiple selectables.
 
     The returned object is an instance of
@@ -119,7 +139,7 @@ def except_all(
 
 def exists(
     __argument: Optional[
-        Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]]
+        Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]]
     ] = None,
 ) -> Exists:
     """Construct a new :class:`_expression.Exists` construct.
@@ -162,7 +182,9 @@ def exists(
     return Exists(__argument)
 
 
-def intersect(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
+def intersect(
+    *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
     r"""Return an ``INTERSECT`` of multiple selectables.
 
     The returned object is an instance of
@@ -306,7 +328,129 @@ def outerjoin(
     return Join(left, right, onclause, isouter=True, full=full)
 
 
-def select(*entities: _ColumnsClauseArgument) -> Select:
+# START OVERLOADED FUNCTIONS select Select 1-10
+
+# code within this block is **programmatically,
+# statically generated** by tools/generate_tuple_map_overloads.py
+
+
+@overload
+def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]:
+    ...
+
+
+@overload
+def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[Tuple[_T0, _T1]]:
+    ...
+
+
+@overload
+def select(
+    __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+) -> Select[Tuple[_T0, _T1, _T2]]:
+    ...
+
+
+@overload
+def select(
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
+    __ent3: _TCCA[_T3],
+) -> Select[Tuple[_T0, _T1, _T2, _T3]]:
+    ...
+
+
+@overload
+def select(
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
+    __ent3: _TCCA[_T3],
+    __ent4: _TCCA[_T4],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+    ...
+
+
+@overload
+def select(
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
+    __ent3: _TCCA[_T3],
+    __ent4: _TCCA[_T4],
+    __ent5: _TCCA[_T5],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+    ...
+
+
+@overload
+def select(
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
+    __ent3: _TCCA[_T3],
+    __ent4: _TCCA[_T4],
+    __ent5: _TCCA[_T5],
+    __ent6: _TCCA[_T6],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+    ...
+
+
+@overload
+def select(
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
+    __ent3: _TCCA[_T3],
+    __ent4: _TCCA[_T4],
+    __ent5: _TCCA[_T5],
+    __ent6: _TCCA[_T6],
+    __ent7: _TCCA[_T7],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+    ...
+
+
+@overload
+def select(
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
+    __ent3: _TCCA[_T3],
+    __ent4: _TCCA[_T4],
+    __ent5: _TCCA[_T5],
+    __ent6: _TCCA[_T6],
+    __ent7: _TCCA[_T7],
+    __ent8: _TCCA[_T8],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]:
+    ...
+
+
+@overload
+def select(
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
+    __ent3: _TCCA[_T3],
+    __ent4: _TCCA[_T4],
+    __ent5: _TCCA[_T5],
+    __ent6: _TCCA[_T6],
+    __ent7: _TCCA[_T7],
+    __ent8: _TCCA[_T8],
+    __ent9: _TCCA[_T9],
+) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]:
+    ...
+
+
+# END OVERLOADED FUNCTIONS select
+
+
+@overload
+def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]:
+    ...
+
+
+def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]:
     r"""Construct a new :class:`_expression.Select`.
 
 
@@ -343,7 +487,11 @@ def select(*entities: _ColumnsClauseArgument) -> Select:
       given, as well as ORM-mapped classes.
 
     """
-
+    # the keyword args are a necessary element in order for the typing
+    # to work out w/ the varargs vs. having named "keyword" arguments that
+    # aren't always present.
+    if __kw:
+        raise _no_kw()
     return Select(*entities)
 
 
@@ -425,7 +573,9 @@ def tablesample(
     return TableSample._factory(selectable, sampling, name=name, seed=seed)
 
 
-def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
+def union(
+    *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
     r"""Return a ``UNION`` of multiple selectables.
 
     The returned object is an instance of
@@ -445,7 +595,9 @@ def union(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
     return CompoundSelect._create_union(*selects)
 
 
-def union_all(*selects: _SelectStatementForCompoundArgument) -> CompoundSelect:
+def union_all(
+    *selects: _SelectStatementForCompoundArgument,
+) -> CompoundSelect:
     r"""Return a ``UNION ALL`` of multiple selectables.
 
     The returned object is an instance of
index 53d29b628f5ce16ef43efaab91496bfa1515a7d1..1df530dbd6b110a88da85a3b442783049a926511 100644 (file)
@@ -5,18 +5,27 @@ from typing import Any
 from typing import Callable
 from typing import Dict
 from typing import Set
+from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
 from . import roles
+from .. import exc
 from .. import util
 from ..inspection import Inspectable
 from ..util.typing import Literal
 from ..util.typing import Protocol
 
 if TYPE_CHECKING:
+    from datetime import date
+    from datetime import datetime
+    from datetime import time
+    from datetime import timedelta
+    from decimal import Decimal
+    from uuid import UUID
+
     from .base import Executable
     from .compiler import Compiled
     from .compiler import DDLCompiler
@@ -26,17 +35,15 @@ if TYPE_CHECKING:
     from .elements import ClauseElement
     from .elements import ColumnClause
     from .elements import ColumnElement
+    from .elements import KeyedColumnElement
     from .elements import quoted_name
-    from .elements import SQLCoreOperations
     from .elements import TextClause
     from .lambdas import LambdaElement
     from .roles import ColumnsClauseRole
     from .roles import FromClauseRole
     from .schema import Column
-    from .schema import DefaultGenerator
-    from .schema import Sequence
-    from .schema import Table
     from .selectable import Alias
+    from .selectable import CTE
     from .selectable import FromClause
     from .selectable import Join
     from .selectable import NamedFromClause
@@ -61,6 +68,30 @@ class _HasClauseElement(Protocol):
         ...
 
 
+# match column types that are not ORM entities
+_NOT_ENTITY = TypeVar(
+    "_NOT_ENTITY",
+    int,
+    str,
+    "datetime",
+    "date",
+    "time",
+    "timedelta",
+    "UUID",
+    float,
+    "Decimal",
+)
+
+_MAYBE_ENTITY = TypeVar(
+    "_MAYBE_ENTITY",
+    roles.ColumnsClauseRole,
+    Literal["*", 1],
+    Type[Any],
+    Inspectable[_HasClauseElement],
+    _HasClauseElement,
+)
+
+
 # convention:
 # XYZArgument - something that the end user is passing to a public API method
 # XYZElement - the internal representation that we use for the thing.
@@ -76,9 +107,10 @@ _TextCoercedExpressionArgument = Union[
 ]
 
 _ColumnsClauseArgument = Union[
-    Literal["*", 1],
+    roles.TypedColumnsClauseRole[_T],
     roles.ColumnsClauseRole,
-    Type[Any],
+    Literal["*", 1],
+    Type[_T],
     Inspectable[_HasClauseElement],
     _HasClauseElement,
 ]
@@ -92,6 +124,24 @@ sets; select(...), insert().returning(...), etc.
 
 """
 
+_TypedColumnClauseArgument = Union[
+    roles.TypedColumnsClauseRole[_T], roles.ExpressionElementRole[_T], Type[_T]
+]
+
+_TP = TypeVar("_TP", bound=Tuple[Any, ...])
+
+_T0 = TypeVar("_T0", bound=Any)
+_T1 = TypeVar("_T1", bound=Any)
+_T2 = TypeVar("_T2", bound=Any)
+_T3 = TypeVar("_T3", bound=Any)
+_T4 = TypeVar("_T4", bound=Any)
+_T5 = TypeVar("_T5", bound=Any)
+_T6 = TypeVar("_T6", bound=Any)
+_T7 = TypeVar("_T7", bound=Any)
+_T8 = TypeVar("_T8", bound=Any)
+_T9 = TypeVar("_T9", bound=Any)
+
+
 _ColumnExpressionArgument = Union[
     "ColumnElement[_T]",
     _HasClauseElement,
@@ -169,6 +219,7 @@ _DMLTableArgument = Union[
     "TableClause",
     "Join",
     "Alias",
+    "CTE",
     Type[Any],
     Inspectable[_HasClauseElement],
     _HasClauseElement,
@@ -194,6 +245,11 @@ if TYPE_CHECKING:
     def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]:
         ...
 
+    def is_keyed_column_element(
+        c: ClauseElement,
+    ) -> TypeGuard[KeyedColumnElement[Any]]:
+        ...
+
     def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]:
         ...
 
@@ -216,7 +272,7 @@ if TYPE_CHECKING:
 
     def is_select_statement(
         t: Union[Executable, ReturnsRows]
-    ) -> TypeGuard[Select]:
+    ) -> TypeGuard[Select[Any]]:
         ...
 
     def is_table(t: FromClause) -> TypeGuard[TableClause]:
@@ -234,6 +290,7 @@ else:
     is_ddl_compiler = operator.attrgetter("is_ddl")
     is_named_from_clause = operator.attrgetter("named_with_column")
     is_column_element = operator.attrgetter("_is_column_element")
+    is_keyed_column_element = operator.attrgetter("_is_keyed_column_element")
     is_text_clause = operator.attrgetter("_is_text_clause")
     is_from_clause = operator.attrgetter("_is_from_clause")
     is_tuple_type = operator.attrgetter("_is_tuple_type")
@@ -260,3 +317,10 @@ def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]:
 
 def is_insert_update(c: ClauseElement) -> TypeGuard[ValuesBase]:
     return c.is_dml and (c.is_insert or c.is_update)  # type: ignore
+
+
+def _no_kw() -> exc.ArgumentError:
+    return exc.ArgumentError(
+        "Additional keyword arguments are not accepted by this "
+        "function/method.  The presence of **kw is for pep-484 typing purposes"
+    )
index f81878d55d8f2ac4ab1f69009b2c8559ef3efa82..790edefc6e6a8caa26c19bc9e98c2c344dd704c2 100644 (file)
@@ -62,10 +62,10 @@ if TYPE_CHECKING:
     from . import coercions
     from . import elements
     from . import type_api
-    from ._typing import _ColumnsClauseArgument
     from .elements import BindParameter
-    from .elements import ColumnClause
+    from .elements import ColumnClause  # noqa
     from .elements import ColumnElement
+    from .elements import KeyedColumnElement
     from .elements import NamedColumn
     from .elements import SQLCoreOperations
     from .elements import TextClause
@@ -74,7 +74,6 @@ if TYPE_CHECKING:
     from .selectable import FromClause
     from ..engine import Connection
     from ..engine import CursorResult
-    from ..engine import Result
     from ..engine.base import _CompiledCacheType
     from ..engine.interfaces import _CoreMultiExecuteParams
     from ..engine.interfaces import _ExecuteOptions
@@ -704,8 +703,11 @@ class InPlaceGenerative(HasMemoized):
     """Provide a method-chaining pattern in conjunction with the
     @_generative decorator that mutates in place."""
 
+    __slots__ = ()
+
     def _generate(self):
         skip = self._memoized_keys
+        # note __dict__ needs to be in __slots__ if this is used
         for k in skip:
             self.__dict__.pop(k, None)
         return self
@@ -937,7 +939,7 @@ class ExecutableOption(HasCopyInternals):
 SelfExecutable = TypeVar("SelfExecutable", bound="Executable")
 
 
-class Executable(roles.StatementRole, Generative):
+class Executable(roles.StatementRole):
     """Mark a :class:`_expression.ClauseElement` as supporting execution.
 
     :class:`.Executable` is a superclass for all "statement" types
@@ -994,7 +996,7 @@ class Executable(roles.StatementRole, Generative):
             connection: Connection,
             distilled_params: _CoreMultiExecuteParams,
             execution_options: _ExecuteOptionsParameter,
-        ) -> CursorResult:
+        ) -> CursorResult[Any]:
             ...
 
         def _execute_on_scalar(
@@ -1253,7 +1255,7 @@ class SchemaVisitor(ClauseVisitor):
 _COLKEY = TypeVar("_COLKEY", Union[None, str], str)
 
 _COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True)
-_COL = TypeVar("_COL", bound="ColumnElement[Any]")
+_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]")
 
 
 class ColumnCollection(Generic[_COLKEY, _COL_co]):
@@ -1505,6 +1507,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
     ) -> None:
         """populate from an iterator of (key, column)"""
         cols = list(iter_)
+
         self._collection[:] = cols
         self._colset.update(c for k, c in self._collection)
         self._index.update(
index 0659709ab4385e8aaaa9c915a36abce84a822dc7..9b7231360e2ae4c37a72db0eb302d858249b35e8 100644 (file)
@@ -29,6 +29,7 @@ from typing import Union
 from . import operators
 from . import roles
 from . import visitors
+from ._typing import is_from_clause
 from .base import ExecutableOption
 from .base import Options
 from .cache_key import HasCacheKey
@@ -38,25 +39,18 @@ from .. import inspection
 from .. import util
 from ..util.typing import Literal
 
-if not typing.TYPE_CHECKING:
-    elements = None
-    lambdas = None
-    schema = None
-    selectable = None
-    traversals = None
-
 if typing.TYPE_CHECKING:
     from . import elements
     from . import lambdas
     from . import schema
     from . import selectable
-    from . import traversals
     from ._typing import _ColumnExpressionArgument
     from ._typing import _ColumnsClauseArgument
     from ._typing import _DDLColumnArgument
     from ._typing import _DMLTableArgument
     from ._typing import _FromClauseArgument
     from .dml import _DMLTableElement
+    from .elements import BindParameter
     from .elements import ClauseElement
     from .elements import ColumnClause
     from .elements import ColumnElement
@@ -64,9 +58,7 @@ if typing.TYPE_CHECKING:
     from .elements import SQLCoreOperations
     from .schema import Column
     from .selectable import _ColumnsClauseElement
-    from .selectable import _JoinTargetElement
     from .selectable import _JoinTargetProtocol
-    from .selectable import _OnClauseElement
     from .selectable import FromClause
     from .selectable import HasCTE
     from .selectable import SelectBase
@@ -168,6 +160,15 @@ def expect(
     ...
 
 
+@overload
+def expect(
+    role: Type[roles.LiteralValueRole],
+    element: Any,
+    **kw: Any,
+) -> BindParameter[Any]:
+    ...
+
+
 @overload
 def expect(
     role: Type[roles.DDLReferredColumnRole],
@@ -272,7 +273,7 @@ def expect(
 @overload
 def expect(
     role: Type[roles.ColumnsClauseRole],
-    element: _ColumnsClauseArgument,
+    element: _ColumnsClauseArgument[Any],
     **kw: Any,
 ) -> _ColumnsClauseElement:
     ...
@@ -933,7 +934,7 @@ class GroupByImpl(ByOfImpl, RoleImpl):
         argname: Optional[str] = None,
         **kw: Any,
     ) -> Any:
-        if isinstance(resolved, roles.StrictFromClauseRole):
+        if is_from_clause(resolved):
             return elements.ClauseList(*resolved.c)
         else:
             return resolved
index c524a2602c9c066bc0ce634e87ffdd0c77188e74..a1b25b8a6b2ebd3ba0192e98eebaf5dce643c692 100644 (file)
@@ -80,7 +80,6 @@ from ..util.typing import Protocol
 from ..util.typing import TypedDict
 
 if typing.TYPE_CHECKING:
-    from . import roles
     from .annotation import _AnnotationDict
     from .base import _AmbiguousTableNameMap
     from .base import CompileState
@@ -95,7 +94,6 @@ if typing.TYPE_CHECKING:
     from .elements import ColumnElement
     from .elements import Label
     from .functions import Function
-    from .selectable import Alias
     from .selectable import AliasedReturnsRows
     from .selectable import CompoundSelectState
     from .selectable import CTE
@@ -386,7 +384,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
     need_result_map_for_nested: bool
     need_result_map_for_compound: bool
     select_0: ReturnsRows
-    insert_from_select: Select
+    insert_from_select: Select[Any]
 
 
 class ExpandedState(NamedTuple):
@@ -2834,15 +2832,31 @@ class SQLCompiler(Compiled):
                         "unique bind parameter of the same name" % name
                     )
                 elif existing._is_crud or bindparam._is_crud:
-                    raise exc.CompileError(
-                        "bindparam() name '%s' is reserved "
-                        "for automatic usage in the VALUES or SET "
-                        "clause of this "
-                        "insert/update statement.   Please use a "
-                        "name other than column name when using bindparam() "
-                        "with insert() or update() (for example, 'b_%s')."
-                        % (bindparam.key, bindparam.key)
-                    )
+                    if existing._is_crud and bindparam._is_crud:
+                        # TODO: this condition is not well understood.
+                        # see tests in test/sql/test_update.py
+                        raise exc.CompileError(
+                            "Encountered unsupported case when compiling an "
+                            "INSERT or UPDATE statement.  If this is a "
+                            "multi-table "
+                            "UPDATE statement, please provide string-named "
+                            "arguments to the "
+                            "values() method with distinct names; support for "
+                            "multi-table UPDATE statements that "
+                            "target multiple tables for UPDATE is very "
+                            "limited",
+                        )
+                    else:
+                        raise exc.CompileError(
+                            f"bindparam() name '{bindparam.key}' is reserved "
+                            "for automatic usage in the VALUES or SET "
+                            "clause of this "
+                            "insert/update statement.   Please use a "
+                            "name other than column name when using "
+                            "bindparam() "
+                            "with insert() or update() (for example, "
+                            f"'b_{bindparam.key}')."
+                        )
 
         self.binds[bindparam.key] = self.binds[name] = bindparam
 
@@ -3881,7 +3895,7 @@ class SQLCompiler(Compiled):
         return text
 
     def _setup_select_hints(
-        self, select: Select
+        self, select: Select[Any]
     ) -> Tuple[str, _FromHintsType]:
         byfrom = dict(
             [
index e4408cd31647007fffc2c4fb14796e1e1c34afc0..29d7b45d7ad7279de2a94ec6c390a94ff607455c 100644 (file)
@@ -22,6 +22,7 @@ from typing import MutableMapping
 from typing import NamedTuple
 from typing import Optional
 from typing import overload
+from typing import Sequence
 from typing import Tuple
 from typing import TYPE_CHECKING
 from typing import Union
@@ -30,8 +31,10 @@ from . import coercions
 from . import dml
 from . import elements
 from . import roles
+from .elements import ColumnClause
 from .schema import default_is_clause_element
 from .schema import default_is_sequence
+from .selectable import TableClause
 from .. import exc
 from .. import util
 from ..util.typing import Literal
@@ -41,16 +44,9 @@ if TYPE_CHECKING:
     from .compiler import SQLCompiler
     from .dml import _DMLColumnElement
     from .dml import DMLState
-    from .dml import Insert
-    from .dml import Update
-    from .dml import UpdateDMLState
     from .dml import ValuesBase
-    from .elements import ClauseElement
-    from .elements import ColumnClause
     from .elements import ColumnElement
-    from .elements import TextClause
     from .schema import _SQLExprDefault
-    from .schema import Column
     from .selectable import TableClause
 
 REQUIRED = util.symbol(
@@ -68,12 +64,20 @@ values present.
 )
 
 
+def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]:
+    if not isinstance(c, ColumnClause):
+        raise exc.CompileError(
+            f"Can't create DML statement against column expression {c!r}"
+        )
+    return c
+
+
 class _CrudParams(NamedTuple):
-    single_params: List[
-        Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
+    single_params: Sequence[
+        Tuple[ColumnElement[Any], str, Optional[Union[str, _SQLExprDefault]]]
     ]
     all_multi_params: List[
-        List[
+        Sequence[
             Tuple[
                 ColumnClause[Any],
                 str,
@@ -274,7 +278,7 @@ def _get_crud_params(
             compiler,
             stmt,
             compile_state,
-            cast("List[Tuple[ColumnClause[Any], str, str]]", values),
+            cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values),
             cast("Callable[..., str]", _column_as_key),
             kw,
         )
@@ -290,7 +294,7 @@ def _get_crud_params(
         # insert_executemany_returning mode :)
         values = [
             (
-                stmt.table.columns[0],
+                _as_dml_column(stmt.table.columns[0]),
                 compiler.preparer.format_column(stmt.table.columns[0]),
                 "DEFAULT",
             )
@@ -1135,10 +1139,10 @@ def _extend_values_for_multiparams(
     compiler: SQLCompiler,
     stmt: ValuesBase,
     compile_state: DMLState,
-    initial_values: List[Tuple[ColumnClause[Any], str, str]],
+    initial_values: Sequence[Tuple[ColumnClause[Any], str, str]],
     _column_as_key: Callable[..., str],
     kw: Dict[str, Any],
-) -> List[List[Tuple[ColumnClause[Any], str, str]]]:
+) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]:
     values_0 = initial_values
     values = [initial_values]
 
index 8307f64003fcee3a7fd82919b2c3424e9ae8f8a4..e0f162fc8698b6fbba623731e15b5e87d8512298 100644 (file)
@@ -22,15 +22,19 @@ from typing import List
 from typing import MutableMapping
 from typing import NoReturn
 from typing import Optional
+from typing import overload
 from typing import Sequence
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from . import coercions
 from . import roles
 from . import util as sql_util
+from ._typing import _no_kw
+from ._typing import _TP
 from ._typing import is_column_element
 from ._typing import is_named_from_clause
 from .base import _entity_namespace_key
@@ -42,6 +46,7 @@ from .base import ColumnCollection
 from .base import CompileState
 from .base import DialectKWArgs
 from .base import Executable
+from .base import Generative
 from .base import HasCompileState
 from .elements import BooleanClauseList
 from .elements import ClauseElement
@@ -49,12 +54,13 @@ from .elements import ColumnClause
 from .elements import ColumnElement
 from .elements import Null
 from .selectable import Alias
+from .selectable import ExecutableReturnsRows
 from .selectable import FromClause
 from .selectable import HasCTE
 from .selectable import HasPrefixes
 from .selectable import Join
-from .selectable import ReturnsRows
 from .selectable import TableClause
+from .selectable import TypedReturnsRows
 from .sqltypes import NullType
 from .visitors import InternalTraversal
 from .. import exc
@@ -66,9 +72,19 @@ if TYPE_CHECKING:
     from ._typing import _ColumnsClauseArgument
     from ._typing import _DMLColumnArgument
     from ._typing import _DMLTableArgument
-    from ._typing import _FromClauseArgument
+    from ._typing import _T0  # noqa
+    from ._typing import _T1  # noqa
+    from ._typing import _T2  # noqa
+    from ._typing import _T3  # noqa
+    from ._typing import _T4  # noqa
+    from ._typing import _T5  # noqa
+    from ._typing import _T6  # noqa
+    from ._typing import _T7  # noqa
+    from ._typing import _TypedColumnClauseArgument as _TCCA  # noqa
     from .base import ReadOnlyColumnCollection
     from .compiler import SQLCompiler
+    from .elements import ColumnElement
+    from .elements import KeyedColumnElement
     from .selectable import _ColumnsClauseElement
     from .selectable import _SelectIterable
     from .selectable import Select
@@ -88,6 +104,8 @@ else:
     isinsert = operator.attrgetter("isinsert")
 
 
+_T = TypeVar("_T", bound=Any)
+
 _DMLColumnElement = Union[str, ColumnClause[Any]]
 _DMLTableElement = Union[TableClause, Alias, Join]
 
@@ -185,6 +203,11 @@ class DMLState(CompileState):
                 "%s construct does not support "
                 "multiple parameter sets." % statement.__visit_name__.upper()
             )
+        else:
+            assert isinstance(statement, Insert)
+
+            # which implies...
+            # assert isinstance(statement.table, TableClause)
 
         for parameters in statement._multi_values:
             multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [
@@ -291,7 +314,9 @@ class UpdateDMLState(DMLState):
         elif statement._multi_values:
             self._process_multi_values(statement)
         self._extra_froms = ef = self._make_extra_froms(statement)
-        self.is_multitable = mt = ef and self._dict_parameters
+
+        self.is_multitable = mt = ef
+
         self.include_table_with_column_exprs = bool(
             mt and compiler.render_table_with_column_in_update_from
         )
@@ -317,8 +342,8 @@ class UpdateBase(
     HasCompileState,
     DialectKWArgs,
     HasPrefixes,
-    ReturnsRows,
-    Executable,
+    Generative,
+    ExecutableReturnsRows,
     ClauseElement,
 ):
     """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
@@ -383,8 +408,8 @@ class UpdateBase(
 
     @_generative
     def returning(
-        self: SelfUpdateBase, *cols: _ColumnsClauseArgument
-    ) -> SelfUpdateBase:
+        self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+    ) -> UpdateBase:
         r"""Add a :term:`RETURNING` or equivalent clause to this statement.
 
         e.g.:
@@ -454,6 +479,8 @@ class UpdateBase(
           :ref:`tutorial_insert_returning` - in the :ref:`unified_tutorial`
 
         """  # noqa: E501
+        if __kw:
+            raise _no_kw()
         if self._return_defaults:
             raise exc.InvalidRequestError(
                 "return_defaults() is already configured on this statement"
@@ -464,7 +491,7 @@ class UpdateBase(
         return self
 
     def corresponding_column(
-        self, column: ColumnElement[Any], require_embedded: bool = False
+        self, column: KeyedColumnElement[Any], require_embedded: bool = False
     ) -> Optional[ColumnElement[Any]]:
         return self.exported_columns.corresponding_column(
             column, require_embedded=require_embedded
@@ -628,7 +655,7 @@ class ValuesBase(UpdateBase):
 
     _supports_multi_parameters = False
 
-    select: Optional[Select] = None
+    select: Optional[Select[Any]] = None
     """SELECT statement for INSERT .. FROM SELECT"""
 
     _post_values_clause: Optional[ClauseElement] = None
@@ -804,11 +831,15 @@ class ValuesBase(UpdateBase):
                 )
 
             elif isinstance(arg, collections_abc.Sequence):
-
                 if arg and isinstance(arg[0], (list, dict, tuple)):
                     self._multi_values += (arg,)
                     return self
 
+                if TYPE_CHECKING:
+                    # crud.py raises during compilation if this is not the
+                    # case
+                    assert isinstance(self, Insert)
+
                 # tuple values
                 arg = {c.key: value for c, value in zip(self.table.c, arg)}
 
@@ -1010,7 +1041,7 @@ class Insert(ValuesBase):
     def from_select(
         self: SelfInsert,
         names: List[str],
-        select: Select,
+        select: Select[Any],
         include_defaults: bool = True,
     ) -> SelfInsert:
         """Return a new :class:`_expression.Insert` construct which represents
@@ -1073,6 +1104,114 @@ class Insert(ValuesBase):
         self.select = coercions.expect(roles.DMLSelectRole, select)
         return self
 
+    if TYPE_CHECKING:
+
+        # START OVERLOADED FUNCTIONS self.returning ReturningInsert 1-8
+
+        # code within this block is **programmatically,
+        # statically generated** by tools/generate_tuple_map_overloads.py
+
+        @overload
+        def returning(self, __ent0: _TCCA[_T0]) -> ReturningInsert[Tuple[_T0]]:
+            ...
+
+        @overload
+        def returning(
+            self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+        ) -> ReturningInsert[Tuple[_T0, _T1]]:
+            ...
+
+        @overload
+        def returning(
+            self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+        ) -> ReturningInsert[Tuple[_T0, _T1, _T2]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+        ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+            __ent4: _TCCA[_T4],
+        ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+            __ent4: _TCCA[_T4],
+            __ent5: _TCCA[_T5],
+        ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+            __ent4: _TCCA[_T4],
+            __ent5: _TCCA[_T5],
+            __ent6: _TCCA[_T6],
+        ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+            __ent4: _TCCA[_T4],
+            __ent5: _TCCA[_T5],
+            __ent6: _TCCA[_T6],
+            __ent7: _TCCA[_T7],
+        ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+            ...
+
+        # END OVERLOADED FUNCTIONS self.returning
+
+        @overload
+        def returning(
+            self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+        ) -> ReturningInsert[Any]:
+            ...
+
+        def returning(
+            self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+        ) -> ReturningInsert[Any]:
+            ...
+
+
+class ReturningInsert(Insert, TypedReturnsRows[_TP]):
+    """Typing-only class that establishes a generic type form of
+    :class:`.Insert` which tracks returned column types.
+
+    This datatype is delivered when calling the
+    :meth:`.Insert.returning` method.
+
+    .. versionadded:: 2.0
+
+    """
+
 
 SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase")
 
@@ -1264,6 +1403,113 @@ class Update(DMLWhereBase, ValuesBase):
         self._inline = True
         return self
 
+    if TYPE_CHECKING:
+        # START OVERLOADED FUNCTIONS self.returning ReturningUpdate 1-8
+
+        # code within this block is **programmatically,
+        # statically generated** by tools/generate_tuple_map_overloads.py
+
+        @overload
+        def returning(self, __ent0: _TCCA[_T0]) -> ReturningUpdate[Tuple[_T0]]:
+            ...
+
+        @overload
+        def returning(
+            self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+        ) -> ReturningUpdate[Tuple[_T0, _T1]]:
+            ...
+
+        @overload
+        def returning(
+            self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+        ) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+        ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+            __ent4: _TCCA[_T4],
+        ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+            __ent4: _TCCA[_T4],
+            __ent5: _TCCA[_T5],
+        ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+            __ent4: _TCCA[_T4],
+            __ent5: _TCCA[_T5],
+            __ent6: _TCCA[_T6],
+        ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+            __ent4: _TCCA[_T4],
+            __ent5: _TCCA[_T5],
+            __ent6: _TCCA[_T6],
+            __ent7: _TCCA[_T7],
+        ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+            ...
+
+        # END OVERLOADED FUNCTIONS self.returning
+
+        @overload
+        def returning(
+            self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+        ) -> ReturningUpdate[Any]:
+            ...
+
+        def returning(
+            self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+        ) -> ReturningUpdate[Any]:
+            ...
+
+
+class ReturningUpdate(Update, TypedReturnsRows[_TP]):
+    """Typing-only class that establishes a generic type form of
+    :class:`.Update` which tracks returned column types.
+
+    This datatype is delivered when calling the
+    :meth:`.Update.returning` method.
+
+    .. versionadded:: 2.0
+
+    """
+
 
 SelfDelete = typing.TypeVar("SelfDelete", bound="Delete")
 
@@ -1297,3 +1543,111 @@ class Delete(DMLWhereBase, UpdateBase):
         self.table = coercions.expect(
             roles.DMLTableRole, table, apply_propagate_attrs=self
         )
+
+    if TYPE_CHECKING:
+
+        # START OVERLOADED FUNCTIONS self.returning ReturningDelete 1-8
+
+        # code within this block is **programmatically,
+        # statically generated** by tools/generate_tuple_map_overloads.py
+
+        @overload
+        def returning(self, __ent0: _TCCA[_T0]) -> ReturningDelete[Tuple[_T0]]:
+            ...
+
+        @overload
+        def returning(
+            self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+        ) -> ReturningDelete[Tuple[_T0, _T1]]:
+            ...
+
+        @overload
+        def returning(
+            self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+        ) -> ReturningDelete[Tuple[_T0, _T1, _T2]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+        ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+            __ent4: _TCCA[_T4],
+        ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+            __ent4: _TCCA[_T4],
+            __ent5: _TCCA[_T5],
+        ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+            __ent4: _TCCA[_T4],
+            __ent5: _TCCA[_T5],
+            __ent6: _TCCA[_T6],
+        ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+            ...
+
+        @overload
+        def returning(
+            self,
+            __ent0: _TCCA[_T0],
+            __ent1: _TCCA[_T1],
+            __ent2: _TCCA[_T2],
+            __ent3: _TCCA[_T3],
+            __ent4: _TCCA[_T4],
+            __ent5: _TCCA[_T5],
+            __ent6: _TCCA[_T6],
+            __ent7: _TCCA[_T7],
+        ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+            ...
+
+        # END OVERLOADED FUNCTIONS self.returning
+
+        @overload
+        def returning(
+            self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+        ) -> ReturningDelete[Any]:
+            ...
+
+        def returning(
+            self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
+        ) -> ReturningDelete[Any]:
+            ...
+
+
+class ReturningDelete(Update, TypedReturnsRows[_TP]):
+    """Typing-only class that establishes a generic type form of
+    :class:`.Delete` which tracks returned column types.
+
+    This datatype is delivered when calling the
+    :meth:`.Delete.returning` method.
+
+    .. versionadded:: 2.0
+
+    """
index 34d5127ab73301ef0b2fdfb50aa460c5acbbc8f1..a29561291839dd01bf4867c1df257b230e41a63e 100644 (file)
@@ -54,6 +54,7 @@ from .base import _clone
 from .base import _generative
 from .base import _NoArg
 from .base import Executable
+from .base import Generative
 from .base import HasMemoized
 from .base import Immutable
 from .base import NO_ARG
@@ -94,10 +95,7 @@ if typing.TYPE_CHECKING:
     from .selectable import _SelectIterable
     from .selectable import FromClause
     from .selectable import NamedFromClause
-    from .selectable import ReturnsRows
     from .selectable import Select
-    from .selectable import TableClause
-    from .sqltypes import Boolean
     from .sqltypes import TupleType
     from .type_api import TypeEngine
     from .visitors import _CloneCallableType
@@ -122,7 +120,9 @@ _NT = TypeVar("_NT", bound="_NUMERIC")
 _NMT = TypeVar("_NMT", bound="_NUMBER")
 
 
-def literal(value, type_=None):
+def literal(
+    value: Any, type_: Optional[_TypeEngineArgument[_T]] = None
+) -> BindParameter[_T]:
     r"""Return a literal clause, bound to a bind parameter.
 
     Literal clauses are created automatically when non-
@@ -144,7 +144,9 @@ def literal(value, type_=None):
     return coercions.expect(roles.LiteralValueRole, value, type_=type_)
 
 
-def literal_column(text, type_=None):
+def literal_column(
+    text: str, type_: Optional[_TypeEngineArgument[_T]] = None
+) -> ColumnClause[_T]:
     r"""Produce a :class:`.ColumnClause` object that has the
     :paramref:`_expression.column.is_literal` flag set to True.
 
@@ -316,6 +318,7 @@ class ClauseElement(
     is_selectable = False
     is_dml = False
     _is_column_element = False
+    _is_keyed_column_element = False
     _is_table = False
     _is_textual = False
     _is_from_clause = False
@@ -342,7 +345,7 @@ class ClauseElement(
     if typing.TYPE_CHECKING:
 
         def get_children(
-            self, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any
+            self, *, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any
         ) -> Iterable[ClauseElement]:
             ...
 
@@ -455,7 +458,7 @@ class ClauseElement(
         connection: Connection,
         distilled_params: _CoreMultiExecuteParams,
         execution_options: _ExecuteOptions,
-    ) -> Result:
+    ) -> Result[Any]:
         if self.supports_execution:
             if TYPE_CHECKING:
                 assert isinstance(self, Executable)
@@ -833,13 +836,13 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
 
         def in_(
             self,
-            other: Union[Sequence[Any], BindParameter[Any], Select],
+            other: Union[Sequence[Any], BindParameter[Any], Select[Any]],
         ) -> BinaryExpression[bool]:
             ...
 
         def not_in(
             self,
-            other: Union[Sequence[Any], BindParameter[Any], Select],
+            other: Union[Sequence[Any], BindParameter[Any], Select[Any]],
         ) -> BinaryExpression[bool]:
             ...
 
@@ -1699,6 +1702,14 @@ class ColumnElement(
         return self._anon_label(label, add_hash=idx)
 
 
+class KeyedColumnElement(ColumnElement[_T]):
+    """ColumnElement where ``.key`` is non-None."""
+
+    _is_keyed_column_element = True
+
+    key: str
+
+
 class WrapsColumnExpression(ColumnElement[_T]):
     """Mixin that defines a :class:`_expression.ColumnElement`
     as a wrapper with special
@@ -1760,7 +1771,7 @@ class WrapsColumnExpression(ColumnElement[_T]):
 SelfBindParameter = TypeVar("SelfBindParameter", bound="BindParameter[Any]")
 
 
-class BindParameter(roles.InElementRole, ColumnElement[_T]):
+class BindParameter(roles.InElementRole, KeyedColumnElement[_T]):
     r"""Represent a "bound expression".
 
     :class:`.BindParameter` is invoked explicitly using the
@@ -2073,6 +2084,7 @@ class TextClause(
     roles.FromClauseRole,
     roles.SelectStatementRole,
     roles.InElementRole,
+    Generative,
     Executable,
     DQLDMLClauseElement,
     roles.BinaryElementRole[Any],
@@ -4160,7 +4172,7 @@ class FunctionFilter(ColumnElement[_T]):
         )
 
 
-class NamedColumn(ColumnElement[_T]):
+class NamedColumn(KeyedColumnElement[_T]):
     is_literal = False
     table: Optional[FromClause] = None
     name: str
@@ -4502,7 +4514,7 @@ class ColumnClause(
 
         self.is_literal = is_literal
 
-    def get_children(self, column_tables=False, **kw):
+    def get_children(self, *, column_tables=False, **kw):
         # override base get_children() to not return the Table
         # or selectable that is parent to this column.  Traversals
         # expect the columns of tables and subqueries to be leaf nodes.
index 64816823555844ba3fba4cf49cf5fd48ac3f7a26..b827df3df8dd153c819ad713bd40dc96628b845d 100644 (file)
@@ -175,7 +175,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
         connection: Connection,
         distilled_params: _CoreMultiExecuteParams,
         execution_options: _ExecuteOptionsParameter,
-    ) -> CursorResult:
+    ) -> CursorResult[Any]:
         return connection._execute_function(
             self, distilled_params, execution_options
         )
@@ -623,7 +623,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
             joins_implicitly=joins_implicitly,
         )
 
-    def select(self) -> "Select":
+    def select(self) -> Select[Any]:
         """Produce a :func:`_expression.select` construct
         against this :class:`.FunctionElement`.
 
@@ -632,7 +632,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
             s = select(function_element)
 
         """
-        s = Select(self)
+        s: Select[Any] = Select(self)
         if self._execution_options:
             s = s.execution_options(**self._execution_options)
         return s
@@ -846,7 +846,7 @@ class _FunctionGenerator:
 
     @overload
     def __call__(
-        self, *c: Any, type_: TypeEngine[_T], **kwargs: Any
+        self, *c: Any, type_: _TypeEngineArgument[_T], **kwargs: Any
     ) -> Function[_T]:
         ...
 
index 231c70a5bad54188ee498372e6a03eaafd200910..09d4b35ad02bacadb0690313e4941c7a497a572f 100644 (file)
@@ -8,8 +8,6 @@ from __future__ import annotations
 
 from typing import Any
 from typing import Generic
-from typing import Iterable
-from typing import List
 from typing import Optional
 from typing import TYPE_CHECKING
 from typing import TypeVar
@@ -19,12 +17,7 @@ from ..util.typing import Literal
 
 if TYPE_CHECKING:
     from ._typing import _PropagateAttrsType
-    from .base import _EntityNamespace
-    from .base import ColumnCollection
-    from .base import ReadOnlyColumnCollection
-    from .elements import ColumnClause
     from .elements import Label
-    from .elements import NamedColumn
     from .selectable import _SelectIterable
     from .selectable import FromClause
     from .selectable import Subquery
@@ -108,13 +101,21 @@ class TruncatedLabelRole(StringRole, SQLRole):
 
 class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
     __slots__ = ()
-    _role_name = "Column expression or FROM clause"
+    _role_name = (
+        "Column expression, FROM clause, or other columns clause element"
+    )
 
     @property
     def _select_iterable(self) -> _SelectIterable:
         raise NotImplementedError()
 
 
+class TypedColumnsClauseRole(Generic[_T], SQLRole):
+    """element-typed form of ColumnsClauseRole"""
+
+    __slots__ = ()
+
+
 class LimitOffsetRole(SQLRole):
     __slots__ = ()
     _role_name = "LIMIT / OFFSET expression"
@@ -161,7 +162,7 @@ class WhereHavingRole(OnClauseRole):
     _role_name = "SQL expression for WHERE/HAVING role"
 
 
-class ExpressionElementRole(Generic[_T], SQLRole):
+class ExpressionElementRole(TypedColumnsClauseRole[_T]):
     # note when using generics for ExpressionElementRole,
     # the generic type needs to be in
     # sqlalchemy.sql.coercions._impl_lookup mapping also.
@@ -212,39 +213,11 @@ class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
 
     named_with_column: bool
 
-    if TYPE_CHECKING:
-
-        @util.ro_non_memoized_property
-        def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
-            ...
-
-        @util.ro_non_memoized_property
-        def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
-            ...
-
-        @util.ro_non_memoized_property
-        def entity_namespace(self) -> _EntityNamespace:
-            ...
-
-        @util.ro_non_memoized_property
-        def _hide_froms(self) -> Iterable[FromClause]:
-            ...
-
-        @util.ro_non_memoized_property
-        def _from_objects(self) -> List[FromClause]:
-            ...
-
 
 class StrictFromClauseRole(FromClauseRole):
     __slots__ = ()
     # does not allow text() or select() objects
 
-    if TYPE_CHECKING:
-
-        @util.ro_non_memoized_property
-        def description(self) -> str:
-            ...
-
 
 class AnonymizedFromClauseRole(StrictFromClauseRole):
     __slots__ = ()
@@ -317,16 +290,6 @@ class DMLTableRole(FromClauseRole):
     __slots__ = ()
     _role_name = "subject table for an INSERT, UPDATE or DELETE"
 
-    if TYPE_CHECKING:
-
-        @util.ro_non_memoized_property
-        def primary_key(self) -> Iterable[NamedColumn[Any]]:
-            ...
-
-        @util.ro_non_memoized_property
-        def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]:
-            ...
-
 
 class DMLColumnRole(SQLRole):
     __slots__ = ()
index 52ba60a62c8f991b93c53d03953948a93554d251..27456d2be4ffed00c1c0d1abbcfb1dbf242199fe 100644 (file)
@@ -86,7 +86,6 @@ if typing.TYPE_CHECKING:
     from ._typing import _InfoType
     from ._typing import _TextCoercedExpressionArgument
     from ._typing import _TypeEngineArgument
-    from .base import ColumnCollection
     from .base import DedupeColumnCollection
     from .base import ReadOnlyColumnCollection
     from .compiler import DDLCompiler
@@ -97,9 +96,7 @@ if typing.TYPE_CHECKING:
     from .visitors import anon_map
     from ..engine import Connection
     from ..engine import Engine
-    from ..engine.cursor import CursorResult
     from ..engine.interfaces import _CoreMultiExecuteParams
-    from ..engine.interfaces import _CoreSingleExecuteParams
     from ..engine.interfaces import _ExecuteOptionsParameter
     from ..engine.interfaces import ExecutionContext
     from ..engine.mock import MockConnection
@@ -2609,8 +2606,10 @@ class ForeignKey(DialectKWArgs, SchemaItem):
         :class:`_schema.Table`.
 
         """
-
-        return table.columns.corresponding_column(self.column)
+        # our column is a Column, and any subquery etc. proxying us
+        # would be doing so via another Column, so that's what would
+        # be returned here
+        return table.columns.corresponding_column(self.column)  # type: ignore
 
     @util.memoized_property
     def _column_tokens(self) -> Tuple[Optional[str], str, Optional[str]]:
index 9d4d1d6c79c90ea7387c4f806c58c886dfa0e364..b08f13f993e8f658fde2a874e421e75cda6f2f0b 100644 (file)
@@ -23,6 +23,7 @@ from typing import Any
 from typing import Callable
 from typing import cast
 from typing import Dict
+from typing import Generic
 from typing import Iterable
 from typing import Iterator
 from typing import List
@@ -46,6 +47,8 @@ from . import traversals
 from . import type_api
 from . import visitors
 from ._typing import _ColumnsClauseArgument
+from ._typing import _no_kw
+from ._typing import _TP
 from ._typing import is_column_element
 from ._typing import is_select_statement
 from ._typing import is_subquery
@@ -103,9 +106,20 @@ if TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
     from ._typing import _FromClauseArgument
     from ._typing import _JoinTargetArgument
+    from ._typing import _MAYBE_ENTITY
+    from ._typing import _NOT_ENTITY
     from ._typing import _OnClauseArgument
     from ._typing import _SelectStatementForCompoundArgument
+    from ._typing import _T0
+    from ._typing import _T1
+    from ._typing import _T2
+    from ._typing import _T3
+    from ._typing import _T4
+    from ._typing import _T5
+    from ._typing import _T6
+    from ._typing import _T7
     from ._typing import _TextCoercedExpressionArgument
+    from ._typing import _TypedColumnClauseArgument as _TCCA
     from ._typing import _TypeEngineArgument
     from .base import _AmbiguousTableNameMap
     from .base import ExecutableOption
@@ -115,14 +129,13 @@ if TYPE_CHECKING:
     from .dml import Delete
     from .dml import Insert
     from .dml import Update
+    from .elements import KeyedColumnElement
     from .elements import NamedColumn
     from .elements import TextClause
     from .functions import Function
-    from .schema import Column
     from .schema import ForeignKey
     from .schema import ForeignKeyConstraint
     from .type_api import TypeEngine
-    from .util import ClauseAdapter
     from .visitors import _CloneCallableType
 
 
@@ -245,6 +258,14 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
         raise NotImplementedError()
 
 
+class ExecutableReturnsRows(Executable, ReturnsRows):
+    """base for executable statements that return rows."""
+
+
+class TypedReturnsRows(ExecutableReturnsRows, Generic[_TP]):
+    """base for executable statements that return rows."""
+
+
 SelfSelectable = TypeVar("SelfSelectable", bound="Selectable")
 
 
@@ -293,8 +314,8 @@ class Selectable(ReturnsRows):
         )
 
     def corresponding_column(
-        self, column: ColumnElement[Any], require_embedded: bool = False
-    ) -> Optional[ColumnElement[Any]]:
+        self, column: KeyedColumnElement[Any], require_embedded: bool = False
+    ) -> Optional[KeyedColumnElement[Any]]:
         """Given a :class:`_expression.ColumnElement`, return the exported
         :class:`_expression.ColumnElement` object from the
         :attr:`_expression.Selectable.exported_columns`
@@ -593,7 +614,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
 
     _use_schema_map = False
 
-    def select(self) -> Select:
+    def select(self) -> Select[Any]:
         r"""Return a SELECT of this :class:`_expression.FromClause`.
 
 
@@ -795,7 +816,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
         )
 
     @util.ro_non_memoized_property
-    def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]:
+    def exported_columns(
+        self,
+    ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         """A :class:`_expression.ColumnCollection`
         that represents the "exported"
         columns of this :class:`_expression.Selectable`.
@@ -817,7 +840,9 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
         return self.c
 
     @util.ro_non_memoized_property
-    def columns(self) -> ReadOnlyColumnCollection[str, Any]:
+    def columns(
+        self,
+    ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         """A named-based collection of :class:`_expression.ColumnElement`
         objects maintained by this :class:`_expression.FromClause`.
 
@@ -833,7 +858,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
         return self.c
 
     @util.ro_memoized_property
-    def c(self) -> ReadOnlyColumnCollection[str, Any]:
+    def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         """
         A synonym for :attr:`.FromClause.columns`
 
@@ -1223,7 +1248,7 @@ class Join(roles.DMLTableRole, FromClause):
     @util.preload_module("sqlalchemy.sql.util")
     def _populate_column_collection(self):
         sqlutil = util.preloaded.sql_util
-        columns: List[ColumnClause[Any]] = [c for c in self.left.c] + [
+        columns: List[KeyedColumnElement[Any]] = [c for c in self.left.c] + [
             c for c in self.right.c
         ]
 
@@ -1458,7 +1483,7 @@ class Join(roles.DMLTableRole, FromClause):
                 "join explicitly." % (a.description, b.description)
             )
 
-    def select(self) -> "Select":
+    def select(self) -> Select[Any]:
         r"""Create a :class:`_expression.Select` from this
         :class:`_expression.Join`.
 
@@ -2764,6 +2789,7 @@ class Subquery(AliasedReturnsRows):
         cls, selectable: SelectBase, name: Optional[str] = None
     ) -> Subquery:
         """Return a :class:`.Subquery` object."""
+
         return coercions.expect(
             roles.SelectStatementRole, selectable
         ).subquery(name=name)
@@ -3216,7 +3242,6 @@ class SelectBase(
     roles.CompoundElementRole,
     roles.InElementRole,
     HasCTE,
-    Executable,
     SupportsCloneAnnotations,
     Selectable,
 ):
@@ -3239,7 +3264,9 @@ class SelectBase(
         self._reset_memoizations()
 
     @util.ro_non_memoized_property
-    def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
+    def selected_columns(
+        self,
+    ) -> ColumnCollection[str, ColumnElement[Any]]:
         """A :class:`_expression.ColumnCollection`
         representing the columns that
         this SELECT statement or similar construct returns in its result set.
@@ -3284,7 +3311,9 @@ class SelectBase(
         raise NotImplementedError()
 
     @property
-    def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]:
+    def exported_columns(
+        self,
+    ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]:
         """A :class:`_expression.ColumnCollection`
         that represents the "exported"
         columns of this :class:`_expression.Selectable`, not including
@@ -3377,7 +3406,7 @@ class SelectBase(
     def as_scalar(self):
         return self.scalar_subquery()
 
-    def exists(self):
+    def exists(self) -> Exists:
         """Return an :class:`_sql.Exists` representation of this selectable,
         which can be used as a column expression.
 
@@ -3394,7 +3423,7 @@ class SelectBase(
         """
         return Exists(self)
 
-    def scalar_subquery(self):
+    def scalar_subquery(self) -> ScalarSelect[Any]:
         """Return a 'scalar' representation of this selectable, which can be
         used as a column expression.
 
@@ -3607,7 +3636,7 @@ SelfGenerativeSelect = typing.TypeVar(
 )
 
 
-class GenerativeSelect(SelectBase):
+class GenerativeSelect(SelectBase, Generative):
     """Base class for SELECT statements where additional elements can be
     added.
 
@@ -4128,7 +4157,7 @@ class _CompoundSelectKeyword(Enum):
     INTERSECT_ALL = "INTERSECT ALL"
 
 
-class CompoundSelect(HasCompileState, GenerativeSelect):
+class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
     """Forms the basis of ``UNION``, ``UNION ALL``, and other
     SELECT-based set operations.
 
@@ -4293,7 +4322,9 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
         return self.selects[0]._all_selected_columns
 
     @util.ro_non_memoized_property
-    def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
+    def selected_columns(
+        self,
+    ) -> ColumnCollection[str, ColumnElement[Any]]:
         """A :class:`_expression.ColumnCollection`
         representing the columns that
         this SELECT statement or similar construct returns in its result set,
@@ -4343,7 +4374,10 @@ class SelectState(util.MemoizedSlots, CompileState):
             ...
 
     def __init__(
-        self, statement: Select, compiler: Optional[SQLCompiler], **kw: Any
+        self,
+        statement: Select[Any],
+        compiler: Optional[SQLCompiler],
+        **kw: Any,
     ):
         self.statement = statement
         self.from_clauses = statement._from_obj
@@ -4369,7 +4403,7 @@ class SelectState(util.MemoizedSlots, CompileState):
 
     @classmethod
     def get_column_descriptions(
-        cls, statement: Select
+        cls, statement: Select[Any]
     ) -> List[Dict[str, Any]]:
         return [
             {
@@ -4384,12 +4418,14 @@ class SelectState(util.MemoizedSlots, CompileState):
 
     @classmethod
     def from_statement(
-        cls, statement: Select, from_statement: ReturnsRows
-    ) -> Any:
+        cls, statement: Select[Any], from_statement: ExecutableReturnsRows
+    ) -> ExecutableReturnsRows:
         cls._plugin_not_implemented()
 
     @classmethod
-    def get_columns_clause_froms(cls, statement: Select) -> List[FromClause]:
+    def get_columns_clause_froms(
+        cls, statement: Select[Any]
+    ) -> List[FromClause]:
         return cls._normalize_froms(
             itertools.chain.from_iterable(
                 element._from_objects for element in statement._raw_columns
@@ -4439,7 +4475,7 @@ class SelectState(util.MemoizedSlots, CompileState):
 
         return go
 
-    def _get_froms(self, statement: Select) -> List[FromClause]:
+    def _get_froms(self, statement: Select[Any]) -> List[FromClause]:
         ambiguous_table_name_map: _AmbiguousTableNameMap
         self._ambiguous_table_name_map = ambiguous_table_name_map = {}
 
@@ -4467,7 +4503,7 @@ class SelectState(util.MemoizedSlots, CompileState):
     def _normalize_froms(
         cls,
         iterable_of_froms: Iterable[FromClause],
-        check_statement: Optional[Select] = None,
+        check_statement: Optional[Select[Any]] = None,
         ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
     ) -> List[FromClause]:
         """given an iterable of things to select FROM, reduce them to what
@@ -4615,7 +4651,7 @@ class SelectState(util.MemoizedSlots, CompileState):
 
     @classmethod
     def determine_last_joined_entity(
-        cls, stmt: Select
+        cls, stmt: Select[Any]
     ) -> Optional[_JoinTargetElement]:
         if stmt._setup_joins:
             return stmt._setup_joins[-1][0]
@@ -4623,7 +4659,7 @@ class SelectState(util.MemoizedSlots, CompileState):
             return None
 
     @classmethod
-    def all_selected_columns(cls, statement: Select) -> _SelectIterable:
+    def all_selected_columns(cls, statement: Select[Any]) -> _SelectIterable:
         return [c for c in _select_iterables(statement._raw_columns)]
 
     def _setup_joins(
@@ -4876,7 +4912,7 @@ class _MemoizedSelectEntities(
         return c  # type: ignore
 
     @classmethod
-    def _generate_for_statement(cls, select_stmt: Select) -> None:
+    def _generate_for_statement(cls, select_stmt: Select[Any]) -> None:
         if select_stmt._setup_joins or select_stmt._with_options:
             self = _MemoizedSelectEntities()
             self._raw_columns = select_stmt._raw_columns
@@ -4888,7 +4924,7 @@ class _MemoizedSelectEntities(
             select_stmt._setup_joins = select_stmt._with_options = ()
 
 
-SelfSelect = typing.TypeVar("SelfSelect", bound="Select")
+SelfSelect = typing.TypeVar("SelfSelect", bound="Select[Any]")
 
 
 class Select(
@@ -4898,6 +4934,7 @@ class Select(
     HasCompileState,
     _SelectFromElements,
     GenerativeSelect,
+    TypedReturnsRows[_TP],
 ):
     """Represents a ``SELECT`` statement.
 
@@ -4973,7 +5010,7 @@ class Select(
     _compile_state_factory: Type[SelectState]
 
     @classmethod
-    def _create_raw_select(cls, **kw: Any) -> Select:
+    def _create_raw_select(cls, **kw: Any) -> Select[Any]:
         """Create a :class:`.Select` using raw ``__new__`` with no coercions.
 
         Used internally to build up :class:`.Select` constructs with
@@ -4985,7 +5022,7 @@ class Select(
         stmt.__dict__.update(kw)
         return stmt
 
-    def __init__(self, *entities: _ColumnsClauseArgument):
+    def __init__(self, *entities: _ColumnsClauseArgument[Any]):
         r"""Construct a new :class:`_expression.Select`.
 
         The public constructor for :class:`_expression.Select` is the
@@ -5013,7 +5050,9 @@ class Select(
         cols = list(elem._select_iterable)
         return cols[0].type
 
-    def filter(self: SelfSelect, *criteria: ColumnElement[Any]) -> SelfSelect:
+    def filter(
+        self: SelfSelect, *criteria: _ColumnExpressionArgument[bool]
+    ) -> SelfSelect:
         """A synonym for the :meth:`_future.Select.where` method."""
 
         return self.where(*criteria)
@@ -5032,7 +5071,28 @@ class Select(
 
         return self._raw_columns[0]
 
-    def filter_by(self, **kwargs):
+    if TYPE_CHECKING:
+
+        @overload
+        def scalar_subquery(
+            self: Select[Tuple[_MAYBE_ENTITY]],
+        ) -> ScalarSelect[Any]:
+            ...
+
+        @overload
+        def scalar_subquery(
+            self: Select[Tuple[_NOT_ENTITY]],
+        ) -> ScalarSelect[_NOT_ENTITY]:
+            ...
+
+        @overload
+        def scalar_subquery(self) -> ScalarSelect[Any]:
+            ...
+
+        def scalar_subquery(self) -> ScalarSelect[Any]:
+            ...
+
+    def filter_by(self: SelfSelect, **kwargs: Any) -> SelfSelect:
         r"""apply the given filtering criterion as a WHERE clause
         to this select.
 
@@ -5046,7 +5106,7 @@ class Select(
         return self.filter(*clauses)
 
     @property
-    def column_descriptions(self):
+    def column_descriptions(self) -> Any:
         """Return a :term:`plugin-enabled` 'column descriptions' structure
         referring to the columns which are SELECTed by this statement.
 
@@ -5089,7 +5149,9 @@ class Select(
         meth = SelectState.get_plugin_class(self).get_column_descriptions
         return meth(self)
 
-    def from_statement(self, statement):
+    def from_statement(
+        self, statement: ExecutableReturnsRows
+    ) -> ExecutableReturnsRows:
         """Apply the columns which this :class:`.Select` would select
         onto another statement.
 
@@ -5410,7 +5472,7 @@ class Select(
         )
 
     @property
-    def inner_columns(self):
+    def inner_columns(self) -> _SelectIterable:
         """An iterator of all :class:`_expression.ColumnElement`
         expressions which would
         be rendered into the columns clause of the resulting SELECT statement.
@@ -5487,18 +5549,19 @@ class Select(
 
         self._reset_memoizations()
 
-    def get_children(self, **kwargs):
+    def get_children(self, **kw: Any) -> Iterable[ClauseElement]:
         return itertools.chain(
             super(Select, self).get_children(
-                omit_attrs=("_from_obj", "_correlate", "_correlate_except")
+                omit_attrs=("_from_obj", "_correlate", "_correlate_except"),
+                **kw,
             ),
             self._iterate_from_elements(),
         )
 
     @_generative
     def add_columns(
-        self: SelfSelect, *columns: _ColumnsClauseArgument
-    ) -> SelfSelect:
+        self, *columns: _ColumnsClauseArgument[Any]
+    ) -> Select[Any]:
         """Return a new :func:`_expression.select` construct with
         the given column expressions added to its columns clause.
 
@@ -5523,7 +5586,7 @@ class Select(
         return self
 
     def _set_entities(
-        self, entities: Iterable[_ColumnsClauseArgument]
+        self, entities: Iterable[_ColumnsClauseArgument[Any]]
     ) -> None:
         self._raw_columns = [
             coercions.expect(
@@ -5538,7 +5601,7 @@ class Select(
         "be removed in a future release.  Please use "
         ":meth:`_expression.Select.add_columns`",
     )
-    def column(self: SelfSelect, column: _ColumnsClauseArgument) -> SelfSelect:
+    def column(self, column: _ColumnsClauseArgument[Any]) -> Select[Any]:
         """Return a new :func:`_expression.select` construct with
         the given column expression added to its columns clause.
 
@@ -5555,9 +5618,7 @@ class Select(
         return self.add_columns(column)
 
     @util.preload_module("sqlalchemy.sql.util")
-    def reduce_columns(
-        self: SelfSelect, only_synonyms: bool = True
-    ) -> SelfSelect:
+    def reduce_columns(self, only_synonyms: bool = True) -> Select[Any]:
         """Return a new :func:`_expression.select` construct with redundantly
         named, equivalently-valued columns removed from the columns clause.
 
@@ -5580,20 +5641,115 @@ class Select(
          all columns that are equivalent to another are removed.
 
         """
-        return self.with_only_columns(
+        woc: Select[Any]
+        woc = self.with_only_columns(
             *util.preloaded.sql_util.reduce_columns(
                 self._all_selected_columns,
                 only_synonyms=only_synonyms,
                 *(self._where_criteria + self._from_obj),
             )
         )
+        return woc
+
+    # START OVERLOADED FUNCTIONS self.with_only_columns Select 8
+
+    # code within this block is **programmatically,
+    # statically generated** by tools/generate_sel_v1_overloads.py
+
+    @overload
+    def with_only_columns(self, __ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]:
+        ...
+
+    @overload
+    def with_only_columns(
+        self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+    ) -> Select[Tuple[_T0, _T1]]:
+        ...
+
+    @overload
+    def with_only_columns(
+        self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+    ) -> Select[Tuple[_T0, _T1, _T2]]:
+        ...
+
+    @overload
+    def with_only_columns(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+    ) -> Select[Tuple[_T0, _T1, _T2, _T3]]:
+        ...
+
+    @overload
+    def with_only_columns(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+    ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+        ...
+
+    @overload
+    def with_only_columns(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+        __ent5: _TCCA[_T5],
+    ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+        ...
+
+    @overload
+    def with_only_columns(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+        __ent5: _TCCA[_T5],
+        __ent6: _TCCA[_T6],
+    ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+        ...
+
+    @overload
+    def with_only_columns(
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        __ent3: _TCCA[_T3],
+        __ent4: _TCCA[_T4],
+        __ent5: _TCCA[_T5],
+        __ent6: _TCCA[_T6],
+        __ent7: _TCCA[_T7],
+    ) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+        ...
+
+    # END OVERLOADED FUNCTIONS self.with_only_columns
+
+    @overload
+    def with_only_columns(
+        self,
+        *columns: _ColumnsClauseArgument[Any],
+        maintain_column_froms: bool = False,
+        **__kw: Any,
+    ) -> Select[Any]:
+        ...
 
     @_generative
     def with_only_columns(
-        self: SelfSelect,
-        *columns: _ColumnsClauseArgument,
+        self,
+        *columns: _ColumnsClauseArgument[Any],
         maintain_column_froms: bool = False,
-    ) -> SelfSelect:
+        **__kw: Any,
+    ) -> Select[Any]:
         r"""Return a new :func:`_expression.select` construct with its columns
         clause replaced with the given columns.
 
@@ -5647,6 +5803,9 @@ class Select(
 
         """  # noqa: E501
 
+        if __kw:
+            raise _no_kw()
+
         # memoizations should be cleared here as of
         # I95c560ffcbfa30b26644999412fb6a385125f663 , asserting this
         # is the case for now.
@@ -5915,7 +6074,9 @@ class Select(
         return self
 
     @HasMemoized_ro_memoized_attribute
-    def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
+    def selected_columns(
+        self,
+    ) -> ColumnCollection[str, ColumnElement[Any]]:
         """A :class:`_expression.ColumnCollection`
         representing the columns that
         this SELECT statement or similar construct returns in its result set,
@@ -6215,7 +6376,7 @@ class ScalarSelect(
         by this :class:`_expression.ScalarSelect`.
 
         """
-        self.element = cast(Select, self.element).where(crit)
+        self.element = cast("Select[Any]", self.element).where(crit)
         return self
 
     @overload
@@ -6269,7 +6430,9 @@ class ScalarSelect(
 
 
         """
-        self.element = cast(Select, self.element).correlate(*fromclauses)
+        self.element = cast("Select[Any]", self.element).correlate(
+            *fromclauses
+        )
         return self
 
     @_generative
@@ -6307,7 +6470,7 @@ class ScalarSelect(
 
         """
 
-        self.element = cast(Select, self.element).correlate_except(
+        self.element = cast("Select[Any]", self.element).correlate_except(
             *fromclauses
         )
         return self
@@ -6331,12 +6494,18 @@ class Exists(UnaryExpression[bool]):
     def __init__(
         self,
         __argument: Optional[
-            Union[_ColumnsClauseArgument, SelectBase, ScalarSelect[bool]]
+            Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]]
         ] = None,
     ):
+        s: ScalarSelect[Any]
+
+        # TODO: this seems like we should be using coercions for this
         if __argument is None:
             s = Select(literal_column("*")).scalar_subquery()
-        elif isinstance(__argument, (SelectBase, ScalarSelect)):
+        elif isinstance(__argument, SelectBase):
+            s = __argument.scalar_subquery()
+            s._propagate_attrs = __argument._propagate_attrs
+        elif isinstance(__argument, ScalarSelect):
             s = __argument
         else:
             s = Select(__argument).scalar_subquery()
@@ -6358,7 +6527,7 @@ class Exists(UnaryExpression[bool]):
         element = fn(element)
         return element.self_group(against=operators.exists)
 
-    def select(self) -> Select:
+    def select(self) -> Select[Any]:
         r"""Return a SELECT of this :class:`_expression.Exists`.
 
         e.g.::
@@ -6452,7 +6621,7 @@ class Exists(UnaryExpression[bool]):
 SelfTextualSelect = typing.TypeVar("SelfTextualSelect", bound="TextualSelect")
 
 
-class TextualSelect(SelectBase):
+class TextualSelect(SelectBase, Executable, Generative):
     """Wrap a :class:`_expression.TextClause` construct within a
     :class:`_expression.SelectBase`
     interface.
@@ -6503,7 +6672,9 @@ class TextualSelect(SelectBase):
         self.positional = positional
 
     @HasMemoized_ro_memoized_attribute
-    def selected_columns(self) -> ColumnCollection[str, ColumnElement[Any]]:
+    def selected_columns(
+        self,
+    ) -> ColumnCollection[str, KeyedColumnElement[Any]]:
         """A :class:`_expression.ColumnCollection`
         representing the columns that
         this SELECT statement or similar construct returns in its result set,
index d08fef60a9b93fa42c677dcf0e8f2954c7951faf..8c45ba4101298ea2fe21ff9cb2594de236bd8226 100644 (file)
@@ -50,6 +50,7 @@ from .elements import ClauseElement
 from .elements import ColumnClause
 from .elements import ColumnElement
 from .elements import Grouping
+from .elements import KeyedColumnElement
 from .elements import Label
 from .elements import Null
 from .elements import UnaryExpression
@@ -72,9 +73,7 @@ if typing.TYPE_CHECKING:
     from ._typing import _EquivalentColumnMap
     from ._typing import _TypeEngineArgument
     from .elements import TextClause
-    from .roles import FromClauseRole
     from .selectable import _JoinTargetElement
-    from .selectable import _OnClauseElement
     from .selectable import _SelectIterable
     from .selectable import Selectable
     from .visitors import _TraverseCallableType
@@ -569,7 +568,7 @@ class _repr_row(_repr_base):
 
     __slots__ = ("row",)
 
-    def __init__(self, row: "Row", max_chars: int = 300):
+    def __init__(self, row: "Row[Any]", max_chars: int = 300):
         self.row = row
         self.max_chars = max_chars
 
@@ -1068,7 +1067,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
             col = col._annotations["adapt_column"]
 
         if TYPE_CHECKING:
-            assert isinstance(col, ColumnElement)
+            assert isinstance(col, KeyedColumnElement)
 
         if self.adapt_from_selectables and col not in self.equivalents:
             for adp in self.adapt_from_selectables:
@@ -1078,7 +1077,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
                 return None
 
         if TYPE_CHECKING:
-            assert isinstance(col, ColumnElement)
+            assert isinstance(col, KeyedColumnElement)
 
         if self.include_fn and not self.include_fn(col):
             return None
index e0a66fbcf48edf7345f688245064337c168e5cbb..88586d834ada2655e864b5294690f03058dbe46d 100644 (file)
@@ -450,7 +450,7 @@ class HasTraverseInternals:
 
     @util.preload_module("sqlalchemy.sql.traversals")
     def get_children(
-        self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+        self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
     ) -> Iterable[HasTraverseInternals]:
         r"""Return immediate child :class:`.visitors.HasTraverseInternals`
         elements of this :class:`.visitors.HasTraverseInternals`.
@@ -594,7 +594,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
     if typing.TYPE_CHECKING:
 
         def get_children(
-            self, omit_attrs: Tuple[str, ...] = (), **kw: Any
+            self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
         ) -> Iterable[ExternallyTraversible]:
             ...
 
index 49c5d693afa7c527b9311a50c1830444750a54fd..da3fbc718a64bdf0fece07827d99511e9409ca10 100644 (file)
@@ -18,6 +18,7 @@ import hashlib
 import inspect
 import itertools
 import operator
+import os
 import re
 import sys
 import textwrap
@@ -32,6 +33,7 @@ from typing import Generic
 from typing import Iterator
 from typing import List
 from typing import Mapping
+from typing import no_type_check
 from typing import NoReturn
 from typing import Optional
 from typing import overload
@@ -2106,3 +2108,45 @@ def has_compiled_ext(raise_=False):
         )
     else:
         return False
+
+
+@no_type_check
+def console_scripts(
+    path: str, options: dict, ignore_output: bool = False
+) -> None:
+
+    import subprocess
+    import shlex
+    from pathlib import Path
+
+    is_posix = os.name == "posix"
+
+    entrypoint_name = options["entrypoint"]
+
+    for entry in compat.importlib_metadata_get("console_scripts"):
+        if entry.name == entrypoint_name:
+            impl = entry
+            break
+    else:
+        raise Exception(
+            f"Could not find entrypoint console_scripts.{entrypoint_name}"
+        )
+    cmdline_options_str = options.get("options", "")
+    cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [
+        path
+    ]
+
+    kw = {}
+    if ignore_output:
+        kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
+
+    subprocess.run(
+        [
+            sys.executable,
+            "-c",
+            "import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
+        ]
+        + cmdline_options_list,
+        cwd=Path(__file__).parent.parent,
+        **kw,
+    )
index d192dc06bfc0d9e98206b7b11dd3f9087bd499af..2a215c4f1a9fb4f9c8c9c3b7cdcbd7bd939bb7e3 100644 (file)
@@ -14,7 +14,7 @@ from typing import Type
 from typing import TypeVar
 from typing import Union
 
-from typing_extensions import NotRequired as NotRequired  # noqa
+from typing_extensions import NotRequired as NotRequired
 
 from . import compat
 
index d16f03c032848922629c546b6c1b37712b0f9e9e..516831bca55bd8984e827970dd42e3d5ce0bc5e3 100644 (file)
@@ -12,7 +12,6 @@ target-version = ['py37']
 
 [tool.zimports]
 black-line-length = 79
-keep-unused-type-checking = true
 
 [tool.slotscheck]
 exclude-modules = '^sqlalchemy\.testing'
index bc7bfefa491fe74d7428b4f30ca8548d36451188..90938263f517a64463566a6c711be7ccc15e51a1 100644 (file)
@@ -276,6 +276,13 @@ class ResultTest(fixtures.TestBase):
         eq_(m1.fetchone(), {"a": 1, "b": 1, "c": 1})
         eq_(r1.fetchone(), (2, 1, 2))
 
+    def test_tuples_plus_base(self):
+        r1 = self._fixture()
+
+        t1 = r1.tuples()
+        eq_(t1.fetchone(), (1, 1, 1))
+        eq_(r1.fetchone(), (2, 1, 2))
+
     def test_scalar_plus_base(self):
         r1 = self._fixture()
 
index e8b57a0c0226d2c1d296b9ef71574d11b5a1ed74..cb9f0b85d7dc85646e416db51a87e048cd29a175 100644 (file)
@@ -40,8 +40,8 @@ class Address(Base):
 u1 = User()
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*?\[builtins.str\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.associationproxy.AssociationProxyInstance\[builtins.set\*?\[builtins.str\]\]
     reveal_type(User.email_addresses)
 
-    # EXPECTED_TYPE: builtins.set\*?\[builtins.str\]
+    # EXPECTED_RE_TYPE: builtins.set\*?\[builtins.str\]
     reveal_type(u1.email_addresses)
index 1a1649e4ec5cafad9e428cddfe80a0cd99f657d5..20e252cddc2398098e652fe85a4274ed7539b795 100644 (file)
@@ -14,11 +14,11 @@ c1 = cols[0]
 
 if typing.TYPE_CHECKING:
 
-    # EXPECTED_TYPE: sqlalchemy.engine.base.Engine
+    # EXPECTED_RE_TYPE: sqlalchemy.engine.base.Engine
     reveal_type(e)
 
-    # EXPECTED_TYPE: sqlalchemy.engine.reflection.Inspector.*
+    # EXPECTED_RE_TYPE: sqlalchemy.engine.reflection.Inspector.*
     reveal_type(insp)
 
-    # EXPECTED_TYPE: .*list.*TypedDict.*ReflectedColumn.*
+    # EXPECTED_RE_TYPE: .*list.*TypedDict.*ReflectedColumn.*
     reveal_type(cols)
index fe2742072c306c461b3986d7d6bf948c3eb39cb6..a8d81426e7d6714952a52928581d240d2e49999e 100644 (file)
@@ -49,20 +49,20 @@ class Address(Base):
 
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\]
     reveal_type(User.extra)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\]
     reveal_type(User.extra_name)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\]
     reveal_type(Address.email)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\]
     reveal_type(Address.email_name)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[experimental_relationship.Address\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[experimental_relationship.Address\]\]
     reveal_type(User.addresses_style_one)
 
-    # EXPECTED_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[experimental_relationship.Address\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[experimental_relationship.Address\]\]
     reveal_type(User.addresses_style_two)
index d9f97ebcff5156abe69d05d8d298b393c4ef0deb..12c7c204c573cf28c22df46bcdc65b042267ebf1 100644 (file)
@@ -47,17 +47,17 @@ expr2 = Interval.contains(7)
 expr3 = Interval.intersects(i2)
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: builtins.int\*?
+    # EXPECTED_RE_TYPE: builtins.int\*?
     reveal_type(i1.length)
 
-    # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
     reveal_type(Interval.length)
 
-    # EXPECTED_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
     reveal_type(expr1)
 
-    # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
     reveal_type(expr2)
 
-    # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
     reveal_type(expr3)
index ab2970656ed0655aed507ea631e02241d2f1edeb..430d796c6023624f95a412203c05ec4f55abbeaf 100644 (file)
@@ -47,10 +47,10 @@ class Interval(Base):
 
         # while we are here, check some Float[] / div type stuff
         if typing.TYPE_CHECKING:
-            # EXPECTED_TYPE: sqlalchemy.*Function\[builtins.float\*?\]
+            # EXPECTED_RE_TYPE: sqlalchemy.*Function\[builtins.float\*?\]
             reveal_type(f1)
 
-            # EXPECTED_TYPE: sqlalchemy.*ColumnElement\[builtins.float\*?\]
+            # EXPECTED_RE_TYPE: sqlalchemy.*ColumnElement\[builtins.float\*?\]
             reveal_type(expr)
         return expr
 
@@ -69,23 +69,23 @@ expr3 = Interval.radius.in_([0.5, 5.2])
 
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: builtins.int\*?
+    # EXPECTED_RE_TYPE: builtins.int\*?
     reveal_type(i1.length)
 
-    # EXPECTED_TYPE: builtins.float\*?
+    # EXPECTED_RE_TYPE: builtins.float\*?
     reveal_type(i2.radius)
 
-    # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
     reveal_type(Interval.length)
 
-    # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\]
     reveal_type(Interval.radius)
 
-    # EXPECTED_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
     reveal_type(expr1)
 
-    # EXPECTED_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\]
     reveal_type(expr2)
 
-    # EXPECTED_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
     reveal_type(expr3)
index 199d3a804fc9cd7a3f283af7d0a16d7f2af2bec3..0dfa0a75201eae70de67ccfdc8ccd0bab21da4ae 100644 (file)
@@ -4,7 +4,6 @@ from typing import List
 
 from sqlalchemy import create_engine
 from sqlalchemy import ForeignKey
-from sqlalchemy import select
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
@@ -43,7 +42,20 @@ with Session(e) as sess:
     sess.add_all([Address(user=u1, email="e1"), Address(user=u1, email="e2")])
     sess.commit()
 
-with Session(e) as sess:
-    users: List[User] = sess.scalars(
-        select(User), execution_options={"stream_results": False}
-    ).all()
+    q = sess.query(User).filter_by(id=7)
+
+    # EXPECTED_TYPE: Query[User]
+    reveal_type(q)
+
+    rows1 = q.all()
+
+    # EXPECTED_RE_TYPE: builtins.[Ll]ist\[.*User\*?\]
+    reveal_type(rows1)
+
+    q2 = sess.query(User.id).filter_by(id=7)
+    rows2 = q2.all()
+
+    # EXPECTED_TYPE: List[Row[Tuple[int]]]
+    reveal_type(rows2)
+
+# more result tests in typed_results.py
index f9b9b2ffe50f26ffed193fb2ccc5c4f9b72fdab6..6b06535bf10f4fda75bc11572d77f261ae061250 100644 (file)
@@ -40,32 +40,32 @@ if typing.TYPE_CHECKING:
     # as far as if this is ColumnElement, BinaryElement, SQLCoreOperations,
     # that might change.  main thing is it's SomeSQLColThing[bool] and
     # not 'bool' or 'Any'.
-    # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[builtins.bool\]
+    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.bool\]
     reveal_type(expr1)
 
-    # EXPECTED_TYPE: sqlalchemy..*ColumnClause\[builtins.str.?\]
+    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnClause\[builtins.str.?\]
     reveal_type(c1)
 
-    # EXPECTED_TYPE: sqlalchemy..*ColumnClause\[builtins.int.?\]
+    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnClause\[builtins.int.?\]
     reveal_type(c2)
 
-    # EXPECTED_TYPE: sqlalchemy..*BinaryExpression\[builtins.bool\]
+    # EXPECTED_RE_TYPE: sqlalchemy..*BinaryExpression\[builtins.bool\]
     reveal_type(expr2)
 
-    # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, decimal.Decimal\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, decimal.Decimal\]\]
     reveal_type(expr3)
 
-    # EXPECTED_TYPE: sqlalchemy..*UnaryExpression\[builtins.int.?\]
+    # EXPECTED_RE_TYPE: sqlalchemy..*UnaryExpression\[builtins.int.?\]
     reveal_type(expr4)
 
-    # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[builtins.bool.?\]
+    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.bool.?\]
     reveal_type(expr5)
 
-    # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[builtins.bool.?\]
+    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.bool.?\]
     reveal_type(expr6)
 
-    # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[builtins.str\]
+    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.str\]
     reveal_type(expr7)
 
-    # EXPECTED_TYPE: sqlalchemy..*ColumnElement\[builtins.int.?\]
+    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[builtins.int.?\]
     reveal_type(expr8)
index af7d292be7ae8c42f0793fad7d6be489213d70a1..4d17dab78d7e1a7d08742c90a0d7d2175629ffd8 100644 (file)
@@ -101,45 +101,45 @@ class Address(Base):
 
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[trad_relationship_uselist.Address\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[trad_relationship_uselist.Address\]\]
     reveal_type(User.addresses_style_one)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[trad_relationship_uselist.Address\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[trad_relationship_uselist.Address\]\]
     reveal_type(User.addresses_style_two)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(User.addresses_style_three)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(User.addresses_style_three_cast)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(User.addresses_style_four)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_one)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\]
     reveal_type(Address.user_style_one_typed)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_two)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*?\]
     reveal_type(Address.user_style_two_typed)
 
     # reveal_type(Address.user_style_six)
 
     # reveal_type(Address.user_style_seven)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_eight)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_nine)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_ten)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*?\[builtins.str, trad_relationship_uselist.User\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*?\[builtins.str, trad_relationship_uselist.User\]\]
     reveal_type(Address.user_style_ten_typed)
index ce131dd00413ecec69d5054e344e8f73d560a67d..0edeb694adb2a23789e6a212014b6e3584be5cf2 100644 (file)
@@ -60,29 +60,29 @@ class Address(Base):
 
 
 if typing.TYPE_CHECKING:
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.Address\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.Address\]\]
     reveal_type(User.addresses_style_one)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[traditional_relationship.Address\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*?\[traditional_relationship.Address\]\]
     reveal_type(User.addresses_style_two)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_one)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\]
     reveal_type(Address.user_style_one_typed)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_two)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*?\]
     reveal_type(Address.user_style_two_typed)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\]
     reveal_type(Address.user_style_three)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[traditional_relationship.User\]\]
     reveal_type(Address.user_style_four)
 
-    # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\]
     reveal_type(Address.user_style_five)
diff --git a/test/ext/mypy/plain_files/typed_queries.py b/test/ext/mypy/plain_files/typed_queries.py
new file mode 100644 (file)
index 0000000..234c4da
--- /dev/null
@@ -0,0 +1,433 @@
+from __future__ import annotations
+
+from typing import Tuple
+
+from sqlalchemy import column
+from sqlalchemy import create_engine
+from sqlalchemy import delete
+from sqlalchemy import func
+from sqlalchemy import insert
+from sqlalchemy import Select
+from sqlalchemy import select
+from sqlalchemy import update
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import Session
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class User(Base):
+    __tablename__ = "user"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    name: Mapped[str]
+    data: Mapped[str]
+
+
+session = Session()
+
+e = create_engine("sqlite://")
+connection = e.connect()
+
+
+def t_select_1() -> None:
+    stmt = select(User.id, User.name).filter(User.id == 5)
+
+    # EXPECTED_TYPE: Select[Tuple[int, str]]
+    reveal_type(stmt)
+
+    result = session.execute(stmt)
+
+    # EXPECTED_TYPE: Result[Tuple[int, str]]
+    reveal_type(result)
+
+
+def t_select_2() -> None:
+    stmt = select(User).filter(User.id == 5)
+
+    # EXPECTED_TYPE: Select[Tuple[User]]
+    reveal_type(stmt)
+
+    result = session.execute(stmt)
+
+    # EXPECTED_TYPE: Result[Tuple[User]]
+    reveal_type(result)
+
+
+def t_select_3() -> None:
+    ua = aliased(User)
+
+    # this will fail at runtime, but as we at the moment see aliased(_T)
+    # as _T, typing tools see the constructor as fine.
+    # this line would ideally have a typing error but we'd need the ability
+    # for aliased() to return some namespace of User that's not User.
+    # AsAliased superclass type was tested for this but it had its own
+    # awkwardnesses that aren't really worth it
+    x = ua(id=1, name="foo")
+
+    # EXPECTED_TYPE: Type[User]
+    reveal_type(ua)
+
+    stmt = select(ua.id, ua.name).filter(User.id == 5)
+
+    # EXPECTED_TYPE: Select[Tuple[int, str]]
+    reveal_type(stmt)
+
+    result = session.execute(stmt)
+
+    # EXPECTED_TYPE: Result[Tuple[int, str]]
+    reveal_type(result)
+
+
+def t_select_4() -> None:
+    ua = aliased(User)
+    stmt = select(ua, User).filter(User.id == 5)
+
+    # EXPECTED_TYPE: Select[Tuple[User, User]]
+    reveal_type(stmt)
+
+    result = session.execute(stmt)
+
+    # EXPECTED_TYPE: Result[Tuple[User, User]]
+    reveal_type(result)
+
+
+def t_legacy_query_single_entity() -> None:
+    q1 = session.query(User).filter(User.id == 5)
+
+    # EXPECTED_TYPE: Query[User]
+    reveal_type(q1)
+
+    # EXPECTED_TYPE: User
+    reveal_type(q1.one())
+
+    # EXPECTED_TYPE: List[User]
+    reveal_type(q1.all())
+
+    # mypy switches to builtins.list for some reason here
+    # EXPECTED_RE_TYPE: .*\.[Ll]ist\[.*Row\*?\[Tuple\[.*User\]\]\]
+    reveal_type(q1.only_return_tuples(True).all())
+
+    # EXPECTED_TYPE: List[Tuple[User]]
+    reveal_type(q1.tuples().all())
+
+
+def t_legacy_query_cols_1() -> None:
+    q1 = session.query(User.id, User.name).filter(User.id == 5)
+
+    # EXPECTED_TYPE: RowReturningQuery[Tuple[int, str]]
+    reveal_type(q1)
+
+    # EXPECTED_TYPE: Row[Tuple[int, str]]
+    reveal_type(q1.one())
+
+    r1 = q1.one()
+
+    x, y = r1.t
+
+    # EXPECTED_TYPE: int
+    reveal_type(x)
+
+    # EXPECTED_TYPE: str
+    reveal_type(y)
+
+
+def t_legacy_query_cols_tupleq_1() -> None:
+    q1 = session.query(User.id, User.name).filter(User.id == 5)
+
+    # EXPECTED_TYPE: RowReturningQuery[Tuple[int, str]]
+    reveal_type(q1)
+
+    q2 = q1.tuples()
+
+    # EXPECTED_TYPE: Tuple[int, str]
+    reveal_type(q2.one())
+
+    r1 = q2.one()
+
+    x, y = r1
+
+    # EXPECTED_TYPE: int
+    reveal_type(x)
+
+    # EXPECTED_TYPE: str
+    reveal_type(y)
+
+
+def t_legacy_query_cols_1_with_entities() -> None:
+    q1 = session.query(User).filter(User.id == 5)
+
+    # EXPECTED_TYPE: Query[User]
+    reveal_type(q1)
+
+    q2 = q1.with_entities(User.id, User.name)
+
+    # EXPECTED_TYPE: RowReturningQuery[Tuple[int, str]]
+    reveal_type(q2)
+
+    # EXPECTED_TYPE: Row[Tuple[int, str]]
+    reveal_type(q2.one())
+
+    r1 = q2.one()
+
+    x, y = r1.t
+
+    # EXPECTED_TYPE: int
+    reveal_type(x)
+
+    # EXPECTED_TYPE: str
+    reveal_type(y)
+
+
+def t_select_with_only_cols() -> None:
+    q1 = select(User).where(User.id == 5)
+
+    # EXPECTED_TYPE: Select[Tuple[User]]
+    reveal_type(q1)
+
+    q2 = q1.with_only_columns(User.id, User.name)
+
+    # EXPECTED_TYPE: Select[Tuple[int, str]]
+    reveal_type(q2)
+
+    row = connection.execute(q2).one()
+
+    # EXPECTED_TYPE: Row[Tuple[int, str]]
+    reveal_type(row)
+
+    x, y = row.t
+
+    # EXPECTED_TYPE: int
+    reveal_type(x)
+
+    # EXPECTED_TYPE: str
+    reveal_type(y)
+
+
+def t_legacy_query_cols_2() -> None:
+    a1 = aliased(User)
+    q1 = session.query(User, a1, User.name).filter(User.id == 5)
+
+    # EXPECTED_TYPE: RowReturningQuery[Tuple[User, User, str]]
+    reveal_type(q1)
+
+    # EXPECTED_TYPE: Row[Tuple[User, User, str]]
+    reveal_type(q1.one())
+
+    r1 = q1.one()
+
+    x, y, z = r1.t
+
+    # EXPECTED_TYPE: User
+    reveal_type(x)
+
+    # EXPECTED_TYPE: User
+    reveal_type(y)
+
+    # EXPECTED_TYPE: str
+    reveal_type(z)
+
+
+def t_legacy_query_cols_2_with_entities() -> None:
+
+    q1 = session.query(User)
+
+    # EXPECTED_TYPE: Query[User]
+    reveal_type(q1)
+
+    a1 = aliased(User)
+    q2 = q1.with_entities(User, a1, User.name).filter(User.id == 5)
+
+    # EXPECTED_TYPE: RowReturningQuery[Tuple[User, User, str]]
+    reveal_type(q2)
+
+    # EXPECTED_TYPE: Row[Tuple[User, User, str]]
+    reveal_type(q2.one())
+
+    r1 = q2.one()
+
+    x, y, z = r1.t
+
+    # EXPECTED_TYPE: User
+    reveal_type(x)
+
+    # EXPECTED_TYPE: User
+    reveal_type(y)
+
+    # EXPECTED_TYPE: str
+    reveal_type(z)
+
+
+def t_select_add_col_loses_type() -> None:
+    q1 = select(User.id, User.name).filter(User.id == 5)
+
+    q2 = q1.add_columns(User.data)
+
+    # note this should not match Select
+    # EXPECTED_TYPE: Select[Any]
+    reveal_type(q2)
+
+
+def t_legacy_query_add_col_loses_type() -> None:
+    q1 = session.query(User.id, User.name).filter(User.id == 5)
+
+    q2 = q1.add_columns(User.data)
+
+    # this should match only Any
+    # EXPECTED_TYPE: Query[Any]
+    reveal_type(q2)
+
+    ua = aliased(User)
+    q3 = q1.add_entity(ua)
+
+    # EXPECTED_TYPE: Query[Any]
+    reveal_type(q3)
+
+
+def t_legacy_query_scalar_subquery() -> None:
+    """scalar subquery should receive the type if first element is a
+    column only"""
+    q1 = session.query(User.id)
+
+    q2 = q1.scalar_subquery()
+
+    # this should be int but mypy can't see it due to the
+    # overload that tries to match an entity.
+    # EXPECTED_RE_TYPE: .*ScalarSelect\[(?:int|Any)\]
+    reveal_type(q2)
+
+    q3 = session.query(User)
+
+    q4 = q3.scalar_subquery()
+
+    # EXPECTED_TYPE: ScalarSelect[Any]
+    reveal_type(q4)
+
+    q5 = session.query(User, User.name)
+
+    q6 = q5.scalar_subquery()
+
+    # EXPECTED_TYPE: ScalarSelect[Any]
+    reveal_type(q6)
+
+    # try to simulate the problem with select()
+    q7 = session.query(User).only_return_tuples(True)
+    q8 = q7.scalar_subquery()
+
+    # EXPECTED_TYPE: ScalarSelect[Any]
+    reveal_type(q8)
+
+
+def t_select_scalar_subquery() -> None:
+    """scalar subquery should receive the type if first element is a
+    column only"""
+    s1 = select(User.id)
+    s2 = s1.scalar_subquery()
+
+    # this should be int but mypy can't see it due to the
+    # overload that tries to match an entity.
+    # EXPECTED_TYPE: ScalarSelect[Any]
+    reveal_type(s2)
+
+    s3 = select(User)
+    s4 = s3.scalar_subquery()
+
+    # it's more important that mypy doesn't get a false positive of
+    # 'User' here
+    # EXPECTED_TYPE: ScalarSelect[Any]
+    reveal_type(s4)
+
+
+def t_select_w_core_selectables() -> None:
+    """things that come from .c. or are FromClause objects currently are not
+    typed.  Make sure we are still getting Select at least.
+
+    """
+    s1 = select(User.id, User.name).subquery()
+
+    # EXPECTED_TYPE: KeyedColumnElement[Any]
+    reveal_type(s1.c.name)
+
+    s2 = select(User.id, s1.c.name)
+
+    # this one unfortunately is not working in mypy.
+    # pylance gets the correct type
+    #   EXPECTED_TYPE: Select[Tuple[int, Any]]
+    # when experimenting with having a separate TypedSelect class for typing,
+    # mypy would downgrade to Any rather than picking the basemost type.
+    # with typing integrated into Select etc. we can at least get a Select
+    # object back.
+    # EXPECTED_TYPE: Select[Any]
+    reveal_type(s2)
+
+    # so a fully explicit type may be given
+    s2_typed: Select[Tuple[int, str]] = select(User.id, s1.c.name)
+
+    # EXPECTED_TYPE: Select[Tuple[int, str]]
+    reveal_type(s2_typed)
+
+    # plain FromClause etc we at least get Select
+    s3 = select(s1)
+
+    # EXPECTED_TYPE: Select[Any]
+    reveal_type(s3)
+
+    t1 = User.__table__
+    assert t1 is not None
+
+    # EXPECTED_TYPE: FromClause
+    reveal_type(t1)
+
+    s4 = select(t1)
+
+    # EXPECTED_TYPE: Select[Any]
+    reveal_type(s4)
+
+
+def t_dml_insert() -> None:
+    s1 = insert(User).returning(User.id, User.name)
+
+    r1 = session.execute(s1)
+
+    # EXPECTED_TYPE: Result[Tuple[int, str]]
+    reveal_type(r1)
+
+    s2 = insert(User).returning(User)
+
+    r2 = session.execute(s2)
+
+    # EXPECTED_TYPE: Result[Tuple[User]]
+    reveal_type(r2)
+
+    s3 = insert(User).returning(func.foo(), column("q"))
+
+    # EXPECTED_TYPE: ReturningInsert[Any]
+    reveal_type(s3)
+
+    r3 = session.execute(s3)
+
+    # EXPECTED_TYPE: Result[Any]
+    reveal_type(r3)
+
+
+def t_dml_update() -> None:
+    s1 = update(User).returning(User.id, User.name)
+
+    r1 = session.execute(s1)
+
+    # EXPECTED_TYPE: Result[Tuple[int, str]]
+    reveal_type(r1)
+
+
+def t_dml_delete() -> None:
+    s1 = delete(User).returning(User.id, User.name)
+
+    r1 = session.execute(s1)
+
+    # EXPECTED_TYPE: Result[Tuple[int, str]]
+    reveal_type(r1)
diff --git a/test/ext/mypy/plain_files/typed_results.py b/test/ext/mypy/plain_files/typed_results.py
new file mode 100644 (file)
index 0000000..1eecb4c
--- /dev/null
@@ -0,0 +1,554 @@
+from __future__ import annotations
+
+import asyncio
+from typing import cast
+
+from sqlalchemy import Column
+from sqlalchemy import column
+from sqlalchemy import create_engine
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import table
+from sqlalchemy.ext.asyncio import AsyncConnection
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import Session
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class User(Base):
+    __tablename__ = "user"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    name: Mapped[str]
+
+
+e = create_engine("sqlite://")
+ae = create_async_engine("sqlite+aiosqlite://")
+
+
+connection = e.connect()
+session = Session(connection)
+
+
+async def async_connect() -> AsyncConnection:
+    return await ae.connect()
+
+
+# the thing with the \*? seems like it could go away
+# as of mypy 0.950
+
+async_connection = asyncio.run(async_connect())
+
+# EXPECTED_RE_TYPE: sqlalchemy..*AsyncConnection\*?
+reveal_type(async_connection)
+
+async_session = AsyncSession(async_connection)
+
+
+# EXPECTED_RE_TYPE: sqlalchemy..*AsyncSession\*?
+reveal_type(async_session)
+
+
+single_stmt = select(User.name).where(User.name == "foo")
+
+# EXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[Tuple\[builtins.str\*?\]\]
+reveal_type(single_stmt)
+
+multi_stmt = select(User.id, User.name).where(User.name == "foo")
+
+# EXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[Tuple\[builtins.int\*?, builtins.str\*?\]\]
+reveal_type(multi_stmt)
+
+
+def t_entity_varieties() -> None:
+
+    a1 = aliased(User)
+
+    s1 = select(User.id, User, User.name).where(User.name == "foo")
+
+    r1 = session.execute(s1)
+
+    # EXPECTED_RE_TYPE: sqlalchemy..*.Result\[Tuple\[builtins.int\*?, typed_results.User\*?, builtins.str\*?\]\]
+    reveal_type(r1)
+
+    s2 = select(User, a1).where(User.name == "foo")
+
+    r2 = session.execute(s2)
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*Result\[Tuple\[typed_results.User\*?, typed_results.User\*?\]\]
+    reveal_type(r2)
+
+    row = r2.t.one()
+
+    # EXPECTED_RE_TYPE: .*typed_results.User\*?
+    reveal_type(row[0])
+    # EXPECTED_RE_TYPE: .*typed_results.User\*?
+    reveal_type(row[1])
+
+    # testing that plain Mapped[x] gets picked up as well as
+    # aliased class
+    # there is unfortunately no way for attributes on an AliasedClass to be
+    # automatically typed since they are dynamically generated
+    a1_id = cast(Mapped[int], a1.id)
+    s3 = select(User.id, a1_id, a1, User).where(User.name == "foo")
+    # EXPECTED_RE_TYPE: sqlalchemy.*Select\*?\[Tuple\[builtins.int\*?, builtins.int\*?, typed_results.User\*?, typed_results.User\*?\]\]
+    reveal_type(s3)
+
+    # testing Mapped[entity]
+    some_mp = cast(Mapped[User], object())
+    s4 = select(some_mp, a1, User).where(User.name == "foo")
+
+    # NOTEXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[Tuple\[typed_results.User\*?, typed_results.User\*?, typed_results.User\*?\]\]
+
+    # sqlalchemy.sql._gen_overloads.Select[Tuple[typed_results.User, typed_results.User, typed_results.User]]
+
+    # EXPECTED_TYPE: Select[Tuple[User, User, User]]
+    reveal_type(s4)
+
+    # test plain core expressions
+    x = Column("x", Integer)
+    y = x + 5
+
+    s5 = select(x, y, User.name + "hi")
+
+    # EXPECTED_RE_TYPE: sqlalchemy..*Select\*?\[Tuple\[builtins.int\*?, builtins.int\*?\, builtins.str\*?]\]
+    reveal_type(s5)
+
+
+def t_ambiguous_result_type_one() -> None:
+    stmt = select(column("q", Integer), table("x", column("y")))
+
+    # EXPECTED_TYPE: Select[Any]
+    reveal_type(stmt)
+
+    result = session.execute(stmt)
+
+    # EXPECTED_TYPE: Result[Any]
+    reveal_type(result)
+
+
+def t_ambiguous_result_type_two() -> None:
+
+    stmt = select(column("q"))
+
+    # EXPECTED_TYPE: Select[Tuple[Any]]
+    reveal_type(stmt)
+    result = session.execute(stmt)
+
+    # EXPECTED_TYPE: Result[Any]
+    reveal_type(result)
+
+
+def t_aliased() -> None:
+
+    a1 = aliased(User)
+
+    s1 = select(a1)
+    # EXPECTED_TYPE: Select[Tuple[User]]
+    reveal_type(s1)
+
+    s4 = select(a1.name, a1, a1, User).where(User.name == "foo")
+    # EXPECTED_TYPE: Select[Tuple[str, User, User, User]]
+    reveal_type(s4)
+
+
+def t_result_scalar_accessors() -> None:
+    result = connection.execute(single_stmt)
+
+    r1 = result.scalar()
+
+    # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\]
+    reveal_type(r1)
+
+    r2 = result.scalar_one()
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(r2)
+
+    r3 = result.scalar_one_or_none()
+
+    # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\]
+    reveal_type(r3)
+
+    r4 = result.scalars()
+
+    # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\]
+    reveal_type(r4)
+
+    r5 = result.scalars(0)
+
+    # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\]
+    reveal_type(r5)
+
+
+async def t_async_result_scalar_accessors() -> None:
+    result = await async_connection.stream(single_stmt)
+
+    r1 = await result.scalar()
+
+    # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\]
+    reveal_type(r1)
+
+    r2 = await result.scalar_one()
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(r2)
+
+    r3 = await result.scalar_one_or_none()
+
+    # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\]
+    reveal_type(r3)
+
+    r4 = result.scalars()
+
+    # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\]
+    reveal_type(r4)
+
+    r5 = result.scalars(0)
+
+    # EXPECTED_RE_TYPE: sqlalchemy..*ScalarResult\[builtins.str.*?\]
+    reveal_type(r5)
+
+
+def t_connection_execute_multi_row_t() -> None:
+    result = connection.execute(multi_stmt)
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*CursorResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\]
+    reveal_type(result)
+    row = result.one()
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*Row\[Tuple\[builtins.int\*?, builtins.str\*?\]\]
+    reveal_type(row)
+
+    x, y = row.t
+
+    # EXPECTED_RE_TYPE: builtins.int\*?
+    reveal_type(x)
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(y)
+
+
+def t_connection_execute_multi() -> None:
+    result = connection.execute(multi_stmt).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\]
+    reveal_type(result)
+    row = result.one()
+
+    # EXPECTED_RE_TYPE: Tuple\[builtins.int\*?, builtins.str\*?\]
+    reveal_type(row)
+
+    x, y = row
+
+    # EXPECTED_RE_TYPE: builtins.int\*?
+    reveal_type(x)
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(y)
+
+
+def t_connection_execute_single() -> None:
+    result = connection.execute(single_stmt).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.str\*?\]\]
+    reveal_type(result)
+    row = result.one()
+
+    # EXPECTED_RE_TYPE: Tuple\[builtins.str\*?\]
+    reveal_type(row)
+
+    (x,) = row
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(x)
+
+
+def t_connection_execute_single_row_scalar() -> None:
+    result = connection.execute(single_stmt).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.str\*?\]\]
+    reveal_type(result)
+
+    x = result.scalar()
+
+    # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\]
+    reveal_type(x)
+
+
+def t_connection_scalar() -> None:
+    obj = connection.scalar(single_stmt)
+
+    # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\]
+    reveal_type(obj)
+
+
+def t_connection_scalars() -> None:
+    result = connection.scalars(single_stmt)
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\[builtins.str\*?\]
+    reveal_type(result)
+    data = result.all()
+
+    # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\]
+    reveal_type(data)
+
+
+def t_session_execute_multi() -> None:
+    result = session.execute(multi_stmt).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\]
+    reveal_type(result)
+    row = result.one()
+
+    # EXPECTED_RE_TYPE: Tuple\[builtins.int\*?, builtins.str\*?\]
+    reveal_type(row)
+
+    x, y = row
+
+    # EXPECTED_RE_TYPE: builtins.int\*?
+    reveal_type(x)
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(y)
+
+
+def t_session_execute_single() -> None:
+    result = session.execute(single_stmt).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.str\*?\]\]
+    reveal_type(result)
+    row = result.one()
+
+    # EXPECTED_RE_TYPE: Tuple\[builtins.str\*?\]
+    reveal_type(row)
+
+    (x,) = row
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(x)
+
+
+def t_session_scalar() -> None:
+    obj = session.scalar(single_stmt)
+
+    # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\]
+    reveal_type(obj)
+
+
+def t_session_scalars() -> None:
+    result = session.scalars(single_stmt)
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\[builtins.str\*?\]
+    reveal_type(result)
+    data = result.all()
+
+    # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\]
+    reveal_type(data)
+
+
+async def t_async_connection_execute_multi() -> None:
+    result = (await async_connection.execute(multi_stmt)).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\]
+    reveal_type(result)
+    row = result.one()
+
+    # EXPECTED_RE_TYPE: Tuple\[builtins.int\*?, builtins.str\*?\]
+    reveal_type(row)
+
+    x, y = row
+
+    # EXPECTED_RE_TYPE: builtins.int\*?
+    reveal_type(x)
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(y)
+
+
+async def t_async_connection_execute_single() -> None:
+    result = (await async_connection.execute(single_stmt)).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.str\*?\]\]
+    reveal_type(result)
+
+    row = result.one()
+
+    # EXPECTED_RE_TYPE: Tuple\[builtins.str\*?\]
+    reveal_type(row)
+
+    (x,) = row
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(x)
+
+
+async def t_async_connection_scalar() -> None:
+    obj = await async_connection.scalar(single_stmt)
+
+    # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\]
+    reveal_type(obj)
+
+
+async def t_async_connection_scalars() -> None:
+    result = await async_connection.scalars(single_stmt)
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\*?\[builtins.str\*?\]
+    reveal_type(result)
+    data = result.all()
+
+    # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\]
+    reveal_type(data)
+
+
+async def t_async_session_execute_multi() -> None:
+    result = (await async_session.execute(multi_stmt)).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\]
+    reveal_type(result)
+    row = result.one()
+
+    # EXPECTED_RE_TYPE: Tuple\[builtins.int\*?, builtins.str\*?\]
+    reveal_type(row)
+
+    x, y = row
+
+    # EXPECTED_RE_TYPE: builtins.int\*?
+    reveal_type(x)
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(y)
+
+
+async def t_async_session_execute_single() -> None:
+    result = (await async_session.execute(single_stmt)).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.str\*?\]\]
+    reveal_type(result)
+    row = result.one()
+
+    # EXPECTED_RE_TYPE: Tuple\[builtins.str\*?\]
+    reveal_type(row)
+
+    (x,) = row
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(x)
+
+
+async def t_async_session_scalar() -> None:
+    obj = await async_session.scalar(single_stmt)
+
+    # EXPECTED_RE_TYPE: Union\[builtins.str\*?, None\]
+    reveal_type(obj)
+
+
+async def t_async_session_scalars() -> None:
+    result = await async_session.scalars(single_stmt)
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*ScalarResult\*?\[builtins.str\*?\]
+    reveal_type(result)
+    data = result.all()
+
+    # EXPECTED_RE_TYPE: typing.Sequence\[builtins.str\*?\]
+    reveal_type(data)
+
+
+async def t_async_connection_stream_multi() -> None:
+    result = (await async_connection.stream(multi_stmt)).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*AsyncTupleResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\]
+    reveal_type(result)
+    row = await result.one()
+
+    # EXPECTED_RE_TYPE: Tuple\[builtins.int\*?, builtins.str\*?\]
+    reveal_type(row)
+
+    x, y = row
+
+    # EXPECTED_RE_TYPE: builtins.int\*?
+    reveal_type(x)
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(y)
+
+
+async def t_async_connection_stream_single() -> None:
+    result = (await async_connection.stream(single_stmt)).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*AsyncTupleResult\[Tuple\[builtins.str\*?\]\]
+    reveal_type(result)
+    row = await result.one()
+
+    # EXPECTED_RE_TYPE: Tuple\[builtins.str\*?\]
+    reveal_type(row)
+
+    (x,) = row
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(x)
+
+
+async def t_async_connection_stream_scalars() -> None:
+    result = await async_connection.stream_scalars(single_stmt)
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*AsyncScalarResult\*?\[builtins.str\*?\]
+    reveal_type(result)
+    data = await result.all()
+
+    # EXPECTED_RE_TYPE: typing.Sequence\*?\[builtins.str\*?\]
+    reveal_type(data)
+
+
+async def t_async_session_stream_multi() -> None:
+    result = (await async_session.stream(multi_stmt)).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*TupleResult\[Tuple\[builtins.int\*?, builtins.str\*?\]\]
+    reveal_type(result)
+    row = await result.one()
+
+    # EXPECTED_RE_TYPE: Tuple\[builtins.int\*?, builtins.str\*?\]
+    reveal_type(row)
+
+    x, y = row
+
+    # EXPECTED_RE_TYPE: builtins.int\*?
+    reveal_type(x)
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(y)
+
+
+async def t_async_session_stream_single() -> None:
+    result = (await async_session.stream(single_stmt)).t
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*AsyncTupleResult\[Tuple\[builtins.str\*?\]\]
+    reveal_type(result)
+    row = await result.one()
+
+    # EXPECTED_RE_TYPE: Tuple\[builtins.str\*?\]
+    reveal_type(row)
+
+    (x,) = row
+
+    # EXPECTED_RE_TYPE: builtins.str\*?
+    reveal_type(x)
+
+
+async def t_async_session_stream_scalars() -> None:
+    result = await async_session.stream_scalars(single_stmt)
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*AsyncScalarResult\*?\[builtins.str\*?\]
+    reveal_type(result)
+    data = await result.all()
+
+    # EXPECTED_RE_TYPE: typing.Sequence\*?\[builtins.str\*?\]
+    reveal_type(data)
index 9928b5a335d7ea523932049f9ca1f38b7a0f1bd5..8ad69dbd0f462717759cbb223e43a29425045558 100644 (file)
@@ -66,5 +66,5 @@ class Address:
         _mypy_mapped_attrs = [id, user_id, email_address]
 
 
-stmt = select(User.name).where(User.id.in_([1, 2, 3]))
-stmt = select(Address).where(Address.email_address.contains(["foo"]))
+stmt1 = select(User.name).where(User.id.in_([1, 2, 3]))
+stmt2 = select(Address).where(Address.email_address.contains(["foo"]))
index 3a932021de25e2d2966d06d4c0b1f2b94975a4be..1086f187af451ad8843c5b3d6287879967e2ae97 100644 (file)
@@ -233,6 +233,41 @@ class MypyPluginTest(fixtures.TestBase):
 
                     expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4))
                     if is_type:
+                        if not is_re:
+                            # the goal here is that we can cut-and-paste
+                            # from vscode -> pylance into the
+                            # EXPECTED_TYPE: line, then the test suite will
+                            # validate that line against what mypy produces
+                            expected_msg = re.sub(
+                                r"([\[\]])",
+                                lambda m: rf"\{m.group(0)}",
+                                expected_msg,
+                            )
+
+                            # note making sure preceding text matches
+                            # with a dot, so that an expect for "Select"
+                            # does not match "TypedSelect"
+                            expected_msg = re.sub(
+                                r"([\w_]+)",
+                                lambda m: rf"(?:.*\.)?{m.group(1)}\*?",
+                                expected_msg,
+                            )
+
+                            expected_msg = re.sub(
+                                "List", "builtins.list", expected_msg
+                            )
+
+                            expected_msg = re.sub(
+                                r"(int|str|float|bool)",
+                                lambda m: rf"builtins.{m.group(0)}\*?",
+                                expected_msg,
+                            )
+                            # expected_msg = re.sub(
+                            #     r"(Sequence|Tuple|List|Union)",
+                            #     lambda m: fr"typing.{m.group(0)}\*?",
+                            #     expected_msg,
+                            # )
+
                         is_mypy = is_re = True
                         expected_msg = f'Revealed type is "{expected_msg}"'
                     current_assert_messages.append(
@@ -295,7 +330,7 @@ class MypyPluginTest(fixtures.TestBase):
                 del output[idx]
 
             if output:
-                print("messages from mypy that were not consumed:")
+                print(f"{len(output)} messages from mypy were not consumed:")
                 print("\n".join(msg for _, msg in output))
                 assert False, "errors and/or notes remain, see stdout"
 
index 9585da125b41a538821e1e5fdfda301d61971bd6..01a8698a4a88dedabe9ba8e3815562c2dea513a7 100644 (file)
@@ -492,7 +492,7 @@ class EntityFromSubqueryTest(QueryTest, AssertsCompiledSQL):
 
         assert_raises_message(
             sa_exc.ArgumentError,
-            "Column expression or FROM clause expected, got "
+            "Column expression, FROM clause, or other .* expected, got "
             "<sqlalchemy.sql.selectable.Select .*> object resolved from "
             "<AliasedClass .* User> object. To create a FROM clause from "
             "a <class 'sqlalchemy.sql.selectable.Select'> object",
index 55414364c5439b357071ed5690ca1f4722cc007f..8374d05d4db0ae6ea1752a936f90f2dcad2c7346 100644 (file)
@@ -186,6 +186,14 @@ class OnlyReturnTuplesTest(QueryTest):
         assert isinstance(row, collections_abc.Sequence)
         assert isinstance(row._mapping, collections_abc.Mapping)
 
+    def test_single_entity_tuples(self):
+        User = self.classes.User
+        query = fixture_session().query(User).tuples()
+        is_false(query.is_single_entity)
+        row = query.first()
+        assert isinstance(row, collections_abc.Sequence)
+        assert isinstance(row._mapping, collections_abc.Mapping)
+
     def test_multiple_entity_false(self):
         User = self.classes.User
         query = (
@@ -204,6 +212,14 @@ class OnlyReturnTuplesTest(QueryTest):
         assert isinstance(row, collections_abc.Sequence)
         assert isinstance(row._mapping, collections_abc.Mapping)
 
+    def test_multiple_entity_true(self):
+        User = self.classes.User
+        query = fixture_session().query(User.id, User).tuples()
+        is_false(query.is_single_entity)
+        row = query.first()
+        assert isinstance(row, collections_abc.Sequence)
+        assert isinstance(row._mapping, collections_abc.Mapping)
+
 
 class RowTupleTest(QueryTest, AssertsCompiledSQL):
     run_setup_mappers = None
index b175f96633e4ffe349a9af27d282b87802944a33..1a070fcf52258764c34edb0bdf69261d882aadba 100644 (file)
@@ -22,6 +22,7 @@ from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import PrimaryKeyConstraint
 from sqlalchemy import schema
+from sqlalchemy import select
 from sqlalchemy import Sequence
 from sqlalchemy import String
 from sqlalchemy import Table
@@ -1337,6 +1338,20 @@ class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables):
 
         self._assert_fk(t2, None, "p.t1.x", referred_schema_fn=ref_fn)
 
+    def test_fk_get_referent_is_always_a_column(self):
+        """test the annotation on ForeignKey.get_referent() in that it does
+        in fact return Column even if given a labeled expr in a subquery"""
+
+        m = MetaData()
+        a = Table("a", m, Column("id", Integer, primary_key=True))
+        b = Table("b", m, Column("aid", Integer, ForeignKey("a.id")))
+
+        stmt = select(a.c.id.label("somelabel")).subquery()
+
+        referent = list(b.c.aid.foreign_keys)[0].get_referent(stmt)
+        is_(referent, stmt.c.somelabel)
+        assert isinstance(referent, Column)
+
     def test_copy_info(self):
         m = MetaData()
         fk = ForeignKey("t2.id")
index ff70fc184aeba4191cf612de9ac80ddd78d85d72..cb9f930180e795ef2e131688a58cd3172b113044 100644 (file)
@@ -156,6 +156,50 @@ class CursorResultTest(fixtures.TablesTest):
             rows.append(row)
         eq_(len(rows), 3)
 
+    def test_scalars(self, connection):
+        users = self.tables.users
+
+        connection.execute(
+            users.insert(),
+            [
+                {"user_id": 7, "user_name": "jack"},
+                {"user_id": 8, "user_name": "ed"},
+                {"user_id": 9, "user_name": "fred"},
+            ],
+        )
+        r = connection.scalars(users.select().order_by(users.c.user_id))
+        eq_(r.all(), [7, 8, 9])
+
+    def test_result_tuples(self, connection):
+        users = self.tables.users
+
+        connection.execute(
+            users.insert(),
+            [
+                {"user_id": 7, "user_name": "jack"},
+                {"user_id": 8, "user_name": "ed"},
+                {"user_id": 9, "user_name": "fred"},
+            ],
+        )
+        r = connection.execute(
+            users.select().order_by(users.c.user_id)
+        ).tuples()
+        eq_(r.all(), [(7, "jack"), (8, "ed"), (9, "fred")])
+
+    def test_row_tuple(self, connection):
+        users = self.tables.users
+
+        connection.execute(
+            users.insert(),
+            [
+                {"user_id": 7, "user_name": "jack"},
+                {"user_id": 8, "user_name": "ed"},
+                {"user_id": 9, "user_name": "fred"},
+            ],
+        )
+        r = connection.execute(users.select().order_by(users.c.user_id))
+        eq_([row.t for row in r], [(7, "jack"), (8, "ed"), (9, "fred")])
+
     def test_row_next(self, connection):
         users = self.tables.users
 
index be64e205e46b1649dd55004725542cb3fdf94390..d91e50e6373271c317dadafee023ff96ad51d844 100644 (file)
@@ -61,7 +61,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_old_bracket_style_fail(self):
         with expect_raises_message(
             exc.ArgumentError,
-            r"Column expression or FROM clause expected, "
+            r"Column expression, FROM clause, or other columns clause .*"
             r".*Did you mean to say",
         ):
             select([table1.c.myid])
index 619cbd863b185a59c482bd02e1ae6d3bdb65c7c6..e93900bbda9932f841058dede9d5673d16841eda 100644 (file)
@@ -25,6 +25,7 @@ from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.schema import Column
@@ -1016,6 +1017,46 @@ class UpdateFromCompileTest(
             dialect="mysql",
         )
 
+    def test_update_from_join_unsupported_cases(self):
+        """
+        found_during_typing
+
+        It's unclear how to cleanly guard against this case without producing
+        false positives, particularly due to the support for UPDATE
+        of a CTE.  I'm also not sure of the nature of the failure and why
+        it happens this way.
+
+        """
+        users, addresses = self.tables.users, self.tables.addresses
+
+        j = users.join(addresses)
+
+        with expect_raises_message(
+            exc.CompileError,
+            r"Encountered unsupported case when compiling an INSERT or UPDATE "
+            r"statement.  If this is a multi-table "
+            r"UPDATE statement, please provide string-named arguments to the "
+            r"values\(\) method with distinct names; support for multi-table "
+            r"UPDATE statements that "
+            r"target multiple tables for UPDATE is very limited",
+        ):
+            update(j).where(addresses.c.email_address == "e1").values(
+                {users.c.id: 10, addresses.c.email_address: "asdf"}
+            ).compile(dialect=mysql.dialect())
+
+        with expect_raises_message(
+            exc.CompileError,
+            r"Encountered unsupported case when compiling an INSERT or UPDATE "
+            r"statement.  If this is a multi-table "
+            r"UPDATE statement, please provide string-named arguments to the "
+            r"values\(\) method with distinct names; support for multi-table "
+            r"UPDATE statements that "
+            r"target multiple tables for UPDATE is very limited",
+        ):
+            update(j).where(addresses.c.email_address == "e1").compile(
+                dialect=mysql.dialect()
+            )
+
     def test_update_from_join_mysql_whereclause(self):
         users, addresses = self.tables.users, self.tables.addresses
 
index ffc470972f51dc6dfc0cbd79686b5fb2db8245a3..91a8918824c322eee399f0f8f8ac63b6b44555b7 100644 (file)
@@ -31,6 +31,12 @@ A similar approach is used in Alembic where a dynamic approach towards creating
 alembic "ops" was enhanced to generate a .pyi stubs file statically for
 consumption by typing tools.
 
+Note that the usual OO approach of having a common interface class with
+concrete subtypes doesn't really solve any problems here; the concrete subtypes
+must still list out all methods, arguments, typing annotations, and docstrings,
+all of which is copied by this script rather than requiring it all be
+typed by hand.
+
 .. versionadded:: 2.0
 
 """
@@ -43,9 +49,7 @@ import inspect
 import os
 from pathlib import Path
 import re
-import shlex
 import shutil
-import subprocess
 import sys
 from tempfile import NamedTemporaryFile
 import textwrap
@@ -61,6 +65,7 @@ from typing import TypeVar
 from sqlalchemy import util
 from sqlalchemy.util import compat
 from sqlalchemy.util import langhelpers
+from sqlalchemy.util.langhelpers import console_scripts
 from sqlalchemy.util.langhelpers import format_argspec_plus
 from sqlalchemy.util.langhelpers import inject_docstring_text
 
@@ -122,6 +127,47 @@ def create_proxy_methods(
     return decorate
 
 
+def _grab_overloads(fn):
+    """grab @overload entries for a function, assuming black-formatted
+    code ;) so that we can do a simple regex
+
+    """
+
+    # functions that use @util.deprecated and whatnot will have a string
+    # generated fn.  we can look at __wrapped__ but these functions don't
+    # have any overloads in any case right now so skip
+    if fn.__code__.co_filename == "<string>":
+        return []
+
+    with open(fn.__code__.co_filename) as f:
+        lines = [l for i, l in zip(range(fn.__code__.co_firstlineno), f)]
+
+        lines.reverse()
+
+    output = []
+
+    current_ov = []
+    for line in lines[1:]:
+        current_ov.append(line)
+        outside_block_match = re.match(r"^\w", line)
+        if outside_block_match:
+            current_ov[:] = []
+            break
+
+        fn_match = re.match(rf"^    (?:async )?def (.*)\($", line)
+        if fn_match and fn_match.group(1) != fn.__name__:
+            current_ov[:] = []
+            break
+
+        ov_match = re.match(r"^    @overload$", line)
+        if ov_match:
+            output.append("".join(reversed(current_ov)))
+            current_ov[:] = []
+
+    output.reverse()
+    return output
+
+
 def process_class(
     buf: TextIO,
     target_cls: Type[Any],
@@ -145,6 +191,12 @@ def process_class(
 
     def instrument(buf: TextIO, name: str, clslevel: bool = False) -> None:
         fn = getattr(target_cls, name)
+
+        overloads = _grab_overloads(fn)
+
+        for overload in overloads:
+            buf.write(overload)
+
         spec = compat.inspect_getfullargspec(fn)
 
         iscoroutine = inspect.iscoroutinefunction(fn)
@@ -311,7 +363,7 @@ def process_module(modname: str, filename: str) -> str:
                     "\n    # code within this block is "
                     "**programmatically, \n"
                     "    # statically generated** by"
-                    " tools/generate_proxy_methods.py\n\n"
+                    f" tools/{os.path.basename(__file__)}\n\n"
                 )
 
                 process_class(buf, *args)
@@ -323,41 +375,6 @@ def process_module(modname: str, filename: str) -> str:
     return buf.name
 
 
-def console_scripts(
-    path: str, options: dict, ignore_output: bool = False
-) -> None:
-
-    entrypoint_name = options["entrypoint"]
-
-    for entry in compat.importlib_metadata_get("console_scripts"):
-        if entry.name == entrypoint_name:
-            impl = entry
-            break
-    else:
-        raise Exception(
-            f"Could not find entrypoint console_scripts.{entrypoint_name}"
-        )
-    cmdline_options_str = options.get("options", "")
-    cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [
-        path
-    ]
-
-    kw = {}
-    if ignore_output:
-        kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
-
-    subprocess.run(
-        [
-            sys.executable,
-            "-c",
-            "import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
-        ]
-        + cmdline_options_list,
-        cwd=Path(__file__).parent.parent,
-        **kw,
-    )
-
-
 def run_module(modname, stdout):
 
     sys.stderr.write(f"importing module {modname}\n")
diff --git a/tools/generate_tuple_map_overloads.py b/tools/generate_tuple_map_overloads.py
new file mode 100644 (file)
index 0000000..ff8f378
--- /dev/null
@@ -0,0 +1,174 @@
+r"""Generate tuple mapping overloads.
+
+the problem solved by this script is that of there's no way in current
+pep-484 typing to unpack \*args: _T into Tuple[_T].  pep-646 is the first
+pep to provide this, but it doesn't work for the actual Tuple class
+and also mypy does not have support for pep-646 as of yet.  Better pep-646
+support would allow us to use a TypeVarTuple with Unpack, but TypeVarTuple
+does not have support for sequence operations like ``__getitem__`` and
+iteration; there's also no way for TypeVarTuple to be translated back to a
+Tuple which does have those things without a combinatoric hardcoding approach
+to each length of tuple.
+
+So here, the script creates a map from `*args` to a Tuple directly using a
+combinatoric generated code approach.
+
+.. versionadded:: 2.0
+
+"""
+from __future__ import annotations
+
+from argparse import ArgumentParser
+import importlib
+import os
+from pathlib import Path
+import re
+import shutil
+import sys
+from tempfile import NamedTemporaryFile
+import textwrap
+
+from sqlalchemy.util.langhelpers import console_scripts
+
+is_posix = os.name == "posix"
+
+
+sys.path.append(str(Path(__file__).parent.parent))
+
+
+def process_module(modname: str, filename: str) -> str:
+
+    # use tempfile in same path as the module, or at least in the
+    # current working directory, so that black / zimports use
+    # local pyproject.toml
+    with NamedTemporaryFile(
+        mode="w", delete=False, suffix=".py", dir=Path(filename).parent
+    ) as buf, open(filename) as orig_py:
+        indent = ""
+        in_block = False
+        current_fnname = given_fnname = None
+        for line in orig_py:
+            m = re.match(
+                r"^( *)# START OVERLOADED FUNCTIONS ([\.\w_]+) ([\w_]+) (\d+)-(\d+)$",  # noqa: E501
+                line,
+            )
+            if m:
+                indent = m.group(1)
+                given_fnname = current_fnname = m.group(2)
+                if current_fnname.startswith("self."):
+                    use_self = True
+                    current_fnname = current_fnname.split(".")[1]
+                else:
+                    use_self = False
+                return_type = m.group(3)
+                start_index = int(m.group(4))
+                end_index = int(m.group(5))
+
+                sys.stderr.write(
+                    f"Generating {start_index}-{end_index} overloads "
+                    f"attributes for "
+                    f"class {'self.' if use_self else ''}{current_fnname} "
+                    f"-> {return_type}\n"
+                )
+                in_block = True
+                buf.write(line)
+                buf.write(
+                    "\n    # code within this block is "
+                    "**programmatically, \n"
+                    "    # statically generated** by"
+                    f" tools/{os.path.basename(__file__)}\n\n"
+                )
+
+                for num_args in range(start_index, end_index + 1):
+                    combinations = [
+                        [
+                            f"__ent{arg}: _TCCA[_T{arg}]"
+                            for arg in range(num_args)
+                        ]
+                    ]
+                    for combination in combinations:
+                        buf.write(
+                            textwrap.indent(
+                                f"""
+@overload
+def {current_fnname}(
+    {'self, ' if use_self else ''}{", ".join(combination)}
+) -> {return_type}[Tuple[{', '.join(f'_T{i}' for i in range(num_args))}]]:
+    ...
+
+""",  # noqa: E501
+                                indent,
+                            )
+                        )
+
+            if in_block and line.startswith(
+                f"{indent}# END OVERLOADED FUNCTIONS {given_fnname}"
+            ):
+                in_block = False
+
+            if not in_block:
+                buf.write(line)
+    return buf.name
+
+
+def run_module(modname, stdout):
+
+    sys.stderr.write(f"importing module {modname}\n")
+    mod = importlib.import_module(modname)
+    filename = destination_path = mod.__file__
+    assert filename is not None
+
+    tempfile = process_module(modname, filename)
+
+    ignore_output = stdout
+
+    console_scripts(
+        str(tempfile),
+        {"entrypoint": "zimports"},
+        ignore_output=ignore_output,
+    )
+
+    console_scripts(
+        str(tempfile),
+        {"entrypoint": "black"},
+        ignore_output=ignore_output,
+    )
+
+    if stdout:
+        with open(tempfile) as tf:
+            print(tf.read())
+        os.unlink(tempfile)
+    else:
+        sys.stderr.write(f"Writing {destination_path}...\n")
+        shutil.move(tempfile, destination_path)
+
+
+def main(args):
+    for modname in entries:
+        if args.module in {"all", modname}:
+            run_module(modname, args.stdout)
+
+
+entries = [
+    "sqlalchemy.sql._selectable_constructors",
+    "sqlalchemy.orm.session",
+    "sqlalchemy.orm.query",
+    "sqlalchemy.sql.selectable",
+    "sqlalchemy.sql.dml",
+]
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument(
+        "--module",
+        choices=entries + ["all"],
+        default="all",
+        help="Which file to generate. Default is to regenerate all files",
+    )
+    parser.add_argument(
+        "--stdout",
+        action="store_true",
+        help="Write to stdout instead of saving to file",
+    )
+    args = parser.parse_args()
+    main(args)