]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
additional mypy strictness
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 10 Mar 2022 16:57:00 +0000 (11:57 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 12 Mar 2022 16:42:50 +0000 (11:42 -0500)
enable type checking within untyped defs.  This allowed
some more internals to be fixed up with assertions etc.

some internals that were unnecessary or not even used
at all were removed.  BaseCursorResult was no longer
necessary since we only have one kind of CursorResult
now.  The different ResultProxy subclasses that had
alternate "strategies" dont appear to be used at all
even in 1.4.x, as there's no code that accesses the
_cursor_strategy_cls attribute, which is also removed.
As these were mostly private constructs that weren't
even functioning correctly in any case,
it's fine to remove these over the 2.0 boundary.

Change-Id: Ifd536987d104b1cd8b546cefdbd5c1e5d1801082

19 files changed:
lib/sqlalchemy/__init__.py
lib/sqlalchemy/engine/__init__.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/_py_collections.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/typing.py
pyproject.toml
test/engine/test_execute.py
test/sql/test_types.py

index 7ceb33c7ca0505ae43d66cd6397b85f42a94f271..de01a1b46104727ee29aa62d7ab75f976db7cf7d 100644 (file)
@@ -9,12 +9,8 @@ from __future__ import annotations
 
 from . import util as _util
 from .engine import AdaptedConnection as AdaptedConnection
-from .engine import BaseCursorResult as BaseCursorResult
 from .engine import BaseRow as BaseRow
 from .engine import BindTyping as BindTyping
-from .engine import BufferedColumnResultProxy as BufferedColumnResultProxy
-from .engine import BufferedColumnRow as BufferedColumnRow
-from .engine import BufferedRowResultProxy as BufferedRowResultProxy
 from .engine import ChunkedIteratorResult as ChunkedIteratorResult
 from .engine import Compiled as Compiled
 from .engine import Connection as Connection
@@ -28,7 +24,6 @@ from .engine import engine_from_config as engine_from_config
 from .engine import ExceptionContext as ExceptionContext
 from .engine import ExecutionContext as ExecutionContext
 from .engine import FrozenResult as FrozenResult
-from .engine import FullyBufferedResultProxy as FullyBufferedResultProxy
 from .engine import Inspector as Inspector
 from .engine import IteratorResult as IteratorResult
 from .engine import make_url as make_url
index 32f3f2eccd3f79bd29d9ecef32bbf89f2f7539f3..29dd6aff90489e1cad37541af1803dd01d81405b 100644 (file)
@@ -25,12 +25,7 @@ from .base import Transaction as Transaction
 from .base import TwoPhaseTransaction as TwoPhaseTransaction
 from .create import create_engine as create_engine
 from .create import engine_from_config as engine_from_config
-from .cursor import BaseCursorResult as BaseCursorResult
-from .cursor import BufferedColumnResultProxy as BufferedColumnResultProxy
-from .cursor import BufferedColumnRow as BufferedColumnRow
-from .cursor import BufferedRowResultProxy as BufferedRowResultProxy
 from .cursor import CursorResult as CursorResult
-from .cursor import FullyBufferedResultProxy as FullyBufferedResultProxy
 from .cursor import ResultProxy as ResultProxy
 from .interfaces import AdaptedConnection as AdaptedConnection
 from .interfaces import BindTyping as BindTyping
index 37faa880ecd3ae53bc70600d8c68bb41f50091e7..d8009e26c606b6ffcf67281dbb5a689f76b2166b 100644 (file)
@@ -23,6 +23,7 @@ from typing import Tuple
 from typing import Type
 from typing import Union
 
+from .interfaces import _IsolationLevel
 from .interfaces import BindTyping
 from .interfaces import ConnectionEventsTarget
 from .interfaces import DBAPICursor
@@ -510,7 +511,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
             self._handle_dbapi_exception(e, None, None, None, None)
 
     @property
-    def default_isolation_level(self) -> str:
+    def default_isolation_level(self) -> Optional[_IsolationLevel]:
         """The default isolation level assigned to this
         :class:`_engine.Connection`.
 
index 78805bac1be4e40f1958b8bc3146887d88926997..821c0cb8e3334b2bd23d410a58b8891956e10588 100644 (file)
@@ -23,8 +23,9 @@ from typing import List
 from typing import Optional
 from typing import Sequence
 from typing import Tuple
-from typing import Type
+from typing import Union
 
+from .result import MergedResult
 from .result import Result
 from .result import ResultMetaData
 from .result import SimpleResultMetaData
@@ -36,10 +37,12 @@ from ..sql import elements
 from ..sql import sqltypes
 from ..sql import util as sql_util
 from ..sql.base import _generative
+from ..sql.compiler import ResultColumnsEntry
 from ..sql.compiler import RM_NAME
 from ..sql.compiler import RM_OBJECTS
 from ..sql.compiler import RM_RENDERED_NAME
 from ..sql.compiler import RM_TYPE
+from ..sql.type_api import TypeEngine
 from ..util import compat
 from ..util.typing import Literal
 
@@ -101,6 +104,7 @@ class CursorResultMetaData(ResultMetaData):
     _keymap_by_result_column_idx: Optional[Dict[int, _KeyMapRecType]]
     _unpickled: bool
     _safe_for_cache: bool
+    _translated_indexes: Optional[List[int]]
 
     returns_rows: ClassVar[bool] = True
 
@@ -123,7 +127,6 @@ class CursorResultMetaData(ResultMetaData):
 
         if self._translated_indexes:
             indexes = [self._translated_indexes[idx] for idx in indexes]
-
         tup = tuplegetter(*indexes)
 
         new_metadata = self.__class__.__new__(self.__class__)
@@ -526,7 +529,7 @@ class CursorResultMetaData(ResultMetaData):
     def _merge_textual_cols_by_position(
         self, context, cursor_description, result_columns
     ):
-        num_ctx_cols = len(result_columns) if result_columns else None
+        num_ctx_cols = len(result_columns)
 
         if num_ctx_cols > len(cursor_description):
             util.warn(
@@ -568,6 +571,8 @@ class CursorResultMetaData(ResultMetaData):
         match_map = self._create_description_match_map(
             result_columns, loose_column_name_matching
         )
+        mapped_type: TypeEngine[Any]
+
         for (
             idx,
             colname,
@@ -597,15 +602,17 @@ class CursorResultMetaData(ResultMetaData):
     @classmethod
     def _create_description_match_map(
         cls,
-        result_columns,
-        loose_column_name_matching=False,
-    ):
+        result_columns: List[ResultColumnsEntry],
+        loose_column_name_matching: bool = False,
+    ) -> Dict[Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int]]:
         """when matching cursor.description to a set of names that are present
         in a Compiled object, as is the case with TextualSelect, get all the
         names we expect might match those in cursor.description.
         """
 
-        d = {}
+        d: Dict[
+            Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int]
+        ] = {}
         for ridx, elem in enumerate(result_columns):
             key = elem[RM_RENDERED_NAME]
 
@@ -630,7 +637,6 @@ class CursorResultMetaData(ResultMetaData):
                         r_key,
                         (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx),
                     )
-
         return d
 
     def _merge_cols_by_none(self, context, cursor_description):
@@ -739,7 +745,9 @@ class CursorResultMetaData(ResultMetaData):
         self._keys = state["_keys"]
         self._unpickled = True
         if state["_translated_indexes"]:
-            self._translated_indexes = state["_translated_indexes"]
+            self._translated_indexes = cast(
+                "List[int]", state["_translated_indexes"]
+            )
             self._tuplefilter = tuplegetter(*self._translated_indexes)
         else:
             self._translated_indexes = self._tuplefilter = None
@@ -1144,12 +1152,32 @@ class _NoResultMetaData(ResultMetaData):
 _NO_RESULT_METADATA = _NoResultMetaData()
 
 
-class BaseCursorResult:
-    """Base class for database result objects."""
+class CursorResult(Result):
+    """A Result that is representing state from a DBAPI cursor.
+
+    .. versionchanged:: 1.4  The :class:`.CursorResult``
+       class replaces the previous :class:`.ResultProxy` interface.
+       This classes are based on the :class:`.Result` calling API
+       which provides an updated usage model and calling facade for
+       SQLAlchemy Core and SQLAlchemy ORM.
+
+    Returns database rows via the :class:`.Row` class, which provides
+    additional API features and behaviors on top of the raw data returned by
+    the DBAPI.   Through the use of filters such as the :meth:`.Result.scalars`
+    method, other kinds of objects may also be returned.
+
+    .. seealso::
+
+        :ref:`coretutorial_selecting` - introductory material for accessing
+        :class:`_engine.CursorResult` and :class:`.Row` objects.
 
-    _metadata: ResultMetaData
+    """
+
+    _metadata: Union[CursorResultMetaData, _NoResultMetaData]
+    _no_result_metadata = _NO_RESULT_METADATA
     _soft_closed: bool = False
     closed: bool = False
+    _is_cursor = True
 
     def __init__(self, context, cursor_strategy, cursor_description):
         self.context = context
@@ -1169,11 +1197,11 @@ class BaseCursorResult:
             if echo:
                 log = self.context.connection._log_debug
 
-                def log_row(row):
+                def _log_row(row):
                     log("Row %r", sql_util._repr_row(row))
                     return row
 
-                self._row_logging_fn = log_row
+                self._row_logging_fn = log_row = _log_row
             else:
                 log_row = None
 
@@ -1188,13 +1216,16 @@ class BaseCursorResult:
             )
             if log_row:
 
-                def make_row(row):
+                def _make_row_2(row):
                     made_row = _make_row(row)
+                    assert log_row is not None
                     log_row(made_row)
                     return made_row
 
+                make_row = _make_row_2
             else:
                 make_row = _make_row
+
             self._set_memoized_attribute("_row_getter", make_row)
 
         else:
@@ -1208,7 +1239,7 @@ class BaseCursorResult:
             if compiled._cached_metadata:
                 metadata = compiled._cached_metadata
             else:
-                metadata = self._cursor_metadata(self, cursor_description)
+                metadata = CursorResultMetaData(self, cursor_description)
                 if metadata._safe_for_cache:
                     compiled._cached_metadata = metadata
 
@@ -1239,7 +1270,7 @@ class BaseCursorResult:
             self._metadata = metadata
 
         else:
-            self._metadata = metadata = self._cursor_metadata(
+            self._metadata = metadata = CursorResultMetaData(
                 self, cursor_description
             )
         if self._echo:
@@ -1669,33 +1700,6 @@ class BaseCursorResult:
         """
         return self.context.isinsert
 
-
-class CursorResult(BaseCursorResult, Result):
-    """A Result that is representing state from a DBAPI cursor.
-
-    .. versionchanged:: 1.4  The :class:`.CursorResult``
-       class replaces the previous :class:`.ResultProxy` interface.
-       This classes are based on the :class:`.Result` calling API
-       which provides an updated usage model and calling facade for
-       SQLAlchemy Core and SQLAlchemy ORM.
-
-    Returns database rows via the :class:`.Row` class, which provides
-    additional API features and behaviors on top of the raw data returned by
-    the DBAPI.   Through the use of filters such as the :meth:`.Result.scalars`
-    method, other kinds of objects may also be returned.
-
-    .. seealso::
-
-        :ref:`coretutorial_selecting` - introductory material for accessing
-        :class:`_engine.CursorResult` and :class:`.Row` objects.
-
-    """
-
-    _cursor_metadata: Type[ResultMetaData] = CursorResultMetaData
-    _cursor_strategy_cls = CursorFetchStrategy
-    _no_result_metadata = _NO_RESULT_METADATA
-    _is_cursor = True
-
     def _fetchiter_impl(self):
         fetchone = self.cursor_strategy.fetchone
 
@@ -1717,12 +1721,13 @@ class CursorResult(BaseCursorResult, Result):
     def _raw_row_iterator(self):
         return self._fetchiter_impl()
 
-    def merge(self, *others):
-        merged_result = super(CursorResult, self).merge(*others)
+    def merge(self, *others: Result) -> MergedResult:
+        merged_result = super().merge(*others)
         setup_rowcounts = not self._metadata.returns_rows
         if setup_rowcounts:
             merged_result.rowcount = sum(
-                result.rowcount for result in (self,) + others
+                cast(CursorResult, result).rowcount
+                for result in (self,) + others
             )
         return merged_result
 
@@ -1756,40 +1761,3 @@ class CursorResult(BaseCursorResult, Result):
 
 
 ResultProxy = CursorResult
-
-
-class BufferedRowResultProxy(ResultProxy):
-    """A ResultProxy with row buffering behavior.
-
-    .. deprecated::  1.4 this class is now supplied using a strategy object.
-       See :class:`.BufferedRowCursorFetchStrategy`.
-
-    """
-
-    _cursor_strategy_cls: Type[
-        CursorFetchStrategy
-    ] = BufferedRowCursorFetchStrategy
-
-
-class FullyBufferedResultProxy(ResultProxy):
-    """A result proxy that buffers rows fully upon creation.
-
-    .. deprecated::  1.4 this class is now supplied using a strategy object.
-       See :class:`.FullyBufferedCursorFetchStrategy`.
-
-    """
-
-    _cursor_strategy_cls = FullyBufferedCursorFetchStrategy
-
-
-class BufferedColumnRow(Row):
-    """Row is now BufferedColumn in all cases"""
-
-
-class BufferedColumnResultProxy(ResultProxy):
-    """A ResultProxy with column buffering behavior.
-
-    .. versionchanged:: 1.4   This is now the default behavior of the Row
-       and this class does not change behavior in any way.
-
-    """
index 0e0c76389a0b467e6f9f80a4c9c7f2f558f1bb16..2579f573c5f2bb90b3e22a38d462b29d926afb53 100644 (file)
@@ -55,6 +55,9 @@ from ..sql.compiler import SQLCompiler
 from ..sql.elements import quoted_name
 
 if typing.TYPE_CHECKING:
+    from .base import Connection
+    from .base import Engine
+    from .characteristics import ConnectionCharacteristic
     from .interfaces import _AnyMultiExecuteParams
     from .interfaces import _CoreMultiExecuteParams
     from .interfaces import _CoreSingleExecuteParams
@@ -62,6 +65,7 @@ if typing.TYPE_CHECKING:
     from .interfaces import _DBAPIMultiExecuteParams
     from .interfaces import _DBAPISingleExecuteParams
     from .interfaces import _ExecuteOptions
+    from .interfaces import _MutableCoreSingleExecuteParams
     from .result import _ProcessorType
     from .row import Row
     from .url import URL
@@ -71,6 +75,7 @@ if typing.TYPE_CHECKING:
     from ..sql import Executable
     from ..sql.compiler import Compiled
     from ..sql.compiler import ResultColumnsEntry
+    from ..sql.compiler import TypeCompiler
     from ..sql.schema import Column
     from ..sql.type_api import TypeEngine
 
@@ -92,7 +97,11 @@ class DefaultDialect(Dialect):
 
     statement_compiler = compiler.SQLCompiler
     ddl_compiler = compiler.DDLCompiler
-    type_compiler = compiler.GenericTypeCompiler  # type: ignore
+    if typing.TYPE_CHECKING:
+        type_compiler: TypeCompiler
+    else:
+        type_compiler = compiler.GenericTypeCompiler
+
     preparer = compiler.IdentifierPreparer
     supports_alter = True
     supports_comments = False
@@ -202,7 +211,7 @@ class DefaultDialect(Dialect):
 
     server_version_info = None
 
-    default_schema_name = None
+    default_schema_name: Optional[str] = None
 
     # indicates symbol names are
     # UPPERCASEd if they are case insensitive
@@ -290,7 +299,12 @@ class DefaultDialect(Dialect):
         self.positional = self.paramstyle in ("qmark", "format", "numeric")
         self.identifier_preparer = self.preparer(self)
         self._on_connect_isolation_level = isolation_level
-        self.type_compiler = self.type_compiler(self)
+
+        tt_callable = cast(
+            Type[compiler.GenericTypeCompiler],
+            self.type_compiler,
+        )
+        self.type_compiler = tt_callable(self)
         if supports_native_boolean is not None:
             self.supports_native_boolean = supports_native_boolean
 
@@ -490,12 +504,14 @@ class DefaultDialect(Dialect):
         opts.update(url.query)
         return [[], opts]
 
-    def set_engine_execution_options(self, engine, opts):
+    def set_engine_execution_options(
+        self, engine: Engine, opts: Mapping[str, str]
+    ) -> None:
         supported_names = set(self.connection_characteristics).intersection(
             opts
         )
         if supported_names:
-            characteristics = util.immutabledict(
+            characteristics: Mapping[str, str] = util.immutabledict(
                 (name, opts[name]) for name in supported_names
             )
 
@@ -505,12 +521,14 @@ class DefaultDialect(Dialect):
                     connection, characteristics
                 )
 
-    def set_connection_execution_options(self, connection, opts):
+    def set_connection_execution_options(
+        self, connection: Connection, opts: Mapping[str, str]
+    ) -> None:
         supported_names = set(self.connection_characteristics).intersection(
             opts
         )
         if supported_names:
-            characteristics = util.immutabledict(
+            characteristics: Mapping[str, str] = util.immutabledict(
                 (name, opts[name]) for name in supported_names
             )
             self._set_connection_characteristics(connection, characteristics)
@@ -800,7 +818,7 @@ class DefaultExecutionContext(ExecutionContext):
     dialect: Dialect
     unicode_statement: str
     cursor: DBAPICursor
-    compiled_parameters: _CoreMultiExecuteParams
+    compiled_parameters: List[_MutableCoreSingleExecuteParams]
     parameters: _DBAPIMultiExecuteParams
     extracted_parameters: _CoreSingleExecuteParams
 
@@ -1157,7 +1175,11 @@ class DefaultExecutionContext(ExecutionContext):
                 parameters = {}
 
         conn._cursor_execute(self.cursor, stmt, parameters, context=self)
-        r = self.cursor.fetchone()[0]
+        row = self.cursor.fetchone()
+        if row is not None:
+            r = row[0]
+        else:
+            r = None
         if type_ is not None:
             # apply type post processors to the result
             proc = type_._cached_result_processor(
@@ -1299,10 +1321,11 @@ class DefaultExecutionContext(ExecutionContext):
 
             result = _cursor.CursorResult(self, strategy, cursor_description)
 
+        compiled = self.compiled
         if (
-            self.compiled
+            compiled
             and not self.isddl
-            and self.compiled.has_out_parameters
+            and cast(SQLCompiler, compiled).has_out_parameters
         ):
             self._setup_out_parameters(result)
 
@@ -1311,10 +1334,11 @@ class DefaultExecutionContext(ExecutionContext):
         return result
 
     def _setup_out_parameters(self, result):
+        compiled = cast(SQLCompiler, self.compiled)
 
         out_bindparams = [
             (param, name)
-            for param, name in self.compiled.bind_names.items()
+            for param, name in compiled.bind_names.items()
             if param.isoutparam
         ]
         out_parameters = {}
@@ -1339,9 +1363,10 @@ class DefaultExecutionContext(ExecutionContext):
         result.out_parameters = out_parameters
 
     def _setup_dml_or_text_result(self):
+        compiled = cast(SQLCompiler, self.compiled)
 
         if self.isinsert:
-            if self.compiled.postfetch_lastrowid:
+            if compiled.postfetch_lastrowid:
                 self.inserted_primary_key_rows = (
                     self._setup_ins_pk_from_lastrowid()
                 )
@@ -1397,7 +1422,8 @@ class DefaultExecutionContext(ExecutionContext):
             result.rowcount
 
             row = result.fetchone()
-            self.returned_default_rows = [row]
+            if row is not None:
+                self.returned_default_rows = [row]
 
             result._soft_close()
 
@@ -1420,13 +1446,17 @@ class DefaultExecutionContext(ExecutionContext):
         return self._setup_ins_pk_from_empty()
 
     def _setup_ins_pk_from_lastrowid(self):
-        getter = self.compiled._inserted_primary_key_from_lastrowid_getter
+        getter = cast(
+            SQLCompiler, self.compiled
+        )._inserted_primary_key_from_lastrowid_getter
 
         lastrowid = self.get_lastrowid()
         return [getter(lastrowid, self.compiled_parameters[0])]
 
     def _setup_ins_pk_from_empty(self):
-        getter = self.compiled._inserted_primary_key_from_lastrowid_getter
+        getter = cast(
+            SQLCompiler, self.compiled
+        )._inserted_primary_key_from_lastrowid_getter
         return [getter(None, param) for param in self.compiled_parameters]
 
     def _setup_ins_pk_from_implicit_returning(self, result, rows):
@@ -1434,7 +1464,9 @@ class DefaultExecutionContext(ExecutionContext):
         if not rows:
             return []
 
-        getter = self.compiled._inserted_primary_key_from_returning_getter
+        getter = cast(
+            SQLCompiler, self.compiled
+        )._inserted_primary_key_from_returning_getter
         compiled_params = self.compiled_parameters
 
         return [
@@ -1443,7 +1475,7 @@ class DefaultExecutionContext(ExecutionContext):
 
     def lastrow_has_defaults(self):
         return (self.isinsert or self.isupdate) and bool(
-            self.compiled.postfetch
+            cast(SQLCompiler, self.compiled).postfetch
         )
 
     def _set_input_sizes(self):
@@ -1464,7 +1496,7 @@ class DefaultExecutionContext(ExecutionContext):
         if self.isddl or self.is_text:
             return
 
-        compiled = self.compiled
+        compiled = cast(SQLCompiler, self.compiled)
 
         inputsizes = compiled._get_set_input_sizes_lookup()
 
@@ -1487,7 +1519,8 @@ class DefaultExecutionContext(ExecutionContext):
 
         if dialect.positional:
             items = [
-                (key, compiled.binds[key]) for key in compiled.positiontup
+                (key, compiled.binds[key])
+                for key in compiled.positiontup or ()
             ]
         else:
             items = [
@@ -1495,7 +1528,7 @@ class DefaultExecutionContext(ExecutionContext):
                 for bindparam, key in compiled.bind_names.items()
             ]
 
-        generic_inputsizes = []
+        generic_inputsizes: List[Tuple[str, Any, TypeEngine[Any]]] = []
         for key, bindparam in items:
             if bindparam in compiled.literal_execute_params:
                 continue
@@ -1578,20 +1611,19 @@ class DefaultExecutionContext(ExecutionContext):
         compiled_params = compiled.construct_params()
         processors = compiled._bind_processors
         if compiled.positional:
-            positiontup = compiled.positiontup
             parameters = self.dialect.execute_sequence_format(
                 [
-                    processors[key](compiled_params[key])
+                    processors[key](compiled_params[key])  # type: ignore
                     if key in processors
                     else compiled_params[key]
-                    for key in positiontup
+                    for key in compiled.positiontup or ()
                 ]
             )
         else:
             parameters = dict(
                 (
                     key,
-                    processors[key](compiled_params[key])
+                    processors[key](compiled_params[key])  # type: ignore
                     if key in processors
                     else compiled_params[key],
                 )
@@ -1667,15 +1699,18 @@ class DefaultExecutionContext(ExecutionContext):
                 "get_current_parameters() can only be invoked in the "
                 "context of a Python side column default function"
             )
-
-        compile_state = self.compiled.compile_state
+        else:
+            assert column is not None
+            assert parameters is not None
+        compile_state = cast(SQLCompiler, self.compiled).compile_state
+        assert compile_state is not None
         if (
             isolate_multiinsert_groups
             and self.isinsert
             and compile_state._has_multi_parameters
         ):
             if column._is_multiparam_column:
-                index = column.index + 1
+                index = column.index + 1  # type: ignore
                 d = {column.original.key: parameters[column.key]}
             else:
                 d = {column.key: parameters[column.key]}
@@ -1701,12 +1736,14 @@ class DefaultExecutionContext(ExecutionContext):
             return self._exec_default(column, column.onupdate, column.type)
 
     def _process_executemany_defaults(self):
-        key_getter = self.compiled._within_exec_param_key_getter
+        compiled = cast(SQLCompiler, self.compiled)
 
-        scalar_defaults = {}
+        key_getter = compiled._within_exec_param_key_getter
 
-        insert_prefetch = self.compiled.insert_prefetch
-        update_prefetch = self.compiled.update_prefetch
+        scalar_defaults: Dict[Column[Any], Any] = {}
+
+        insert_prefetch = compiled.insert_prefetch
+        update_prefetch = compiled.update_prefetch
 
         # pre-determine scalar Python-side defaults
         # to avoid many calls of get_insert_default()/
@@ -1739,12 +1776,14 @@ class DefaultExecutionContext(ExecutionContext):
         del self.current_parameters
 
     def _process_executesingle_defaults(self):
-        key_getter = self.compiled._within_exec_param_key_getter
+        compiled = cast(SQLCompiler, self.compiled)
+
+        key_getter = compiled._within_exec_param_key_getter
         self.current_parameters = (
             compiled_parameters
         ) = self.compiled_parameters[0]
 
-        for c in self.compiled.insert_prefetch:
+        for c in compiled.insert_prefetch:
             if c.default and not c.default.is_sequence and c.default.is_scalar:
                 val = c.default.arg
             else:
@@ -1753,7 +1792,7 @@ class DefaultExecutionContext(ExecutionContext):
             if val is not None:
                 compiled_parameters[key_getter(c)] = val
 
-        for c in self.compiled.update_prefetch:
+        for c in compiled.update_prefetch:
             val = self.get_update_default(c)
 
             if val is not None:
index 5aefcf5b565a946812f69353c0febfa5dddc4d35..e65546eb777ffa55af9c65ef9c5dcfcd9b1dd44f 100644 (file)
@@ -36,7 +36,6 @@ from ..sql.compiler import TypeCompiler as TypeCompiler
 from ..sql.compiler import TypeCompiler  # noqa
 from ..util import immutabledict
 from ..util.concurrency import await_only
-from ..util.typing import _TypeToInstance
 from ..util.typing import Literal
 from ..util.typing import NotRequired
 from ..util.typing import Protocol
@@ -58,6 +57,8 @@ if TYPE_CHECKING:
     from ..sql.elements import ClauseElement
     from ..sql.schema import Column
     from ..sql.schema import ColumnDefault
+    from ..sql.schema import Sequence as Sequence_SchemaItem
+    from ..sql.sqltypes import Integer
     from ..sql.type_api import TypeEngine
 
 ConnectArgsType = Tuple[Tuple[str], MutableMapping[str, Any]]
@@ -156,6 +157,8 @@ class DBAPICursor(Protocol):
 
     arraysize: int
 
+    lastrowid: int
+
     def close(self) -> None:
         ...
 
@@ -196,6 +199,7 @@ class DBAPICursor(Protocol):
 
 
 _CoreSingleExecuteParams = Mapping[str, Any]
+_MutableCoreSingleExecuteParams = MutableMapping[str, Any]
 _CoreMultiExecuteParams = Sequence[_CoreSingleExecuteParams]
 _CoreAnyExecuteParams = Union[
     _CoreMultiExecuteParams, _CoreSingleExecuteParams
@@ -605,7 +609,7 @@ class Dialect(EventTarget):
     ddl_compiler: Type[DDLCompiler]
     """a :class:`.Compiled` class used to compile DDL statements"""
 
-    type_compiler: _TypeToInstance[TypeCompiler]
+    type_compiler: Union[Type[TypeCompiler], TypeCompiler]
     """a :class:`.Compiled` class used to compile SQL type objects"""
 
     preparer: Type[IdentifierPreparer]
@@ -633,7 +637,7 @@ class Dialect(EventTarget):
 
     """
 
-    default_isolation_level: _IsolationLevel
+    default_isolation_level: Optional[_IsolationLevel]
     """the isolation that is implicitly present on new connections"""
 
     execution_ctx_cls: Type["ExecutionContext"]
@@ -653,6 +657,13 @@ class Dialect(EventTarget):
     max_identifier_length: int
     """The maximum length of identifier names."""
 
+    supports_server_side_cursors: bool
+    """indicates if the dialect supports server side cursors"""
+
+    server_side_cursors: bool
+    """deprecated; indicates if the dialect should attempt to use server
+    side cursors by default"""
+
     supports_sane_rowcount: bool
     """Indicate whether the dialect properly implements rowcount for
       ``UPDATE`` and ``DELETE`` statements.
@@ -2302,6 +2313,11 @@ class ExecutionContext:
     def _setup_result_proxy(self) -> Result:
         raise NotImplementedError()
 
+    def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int:
+        """given a :class:`.Sequence`, invoke it and return the next int
+        value"""
+        raise NotImplementedError()
+
     def create_cursor(self) -> DBAPICursor:
         """Return a new cursor generated from this ExecutionContext's
         connection.
index 0951d57702acedb3bad8e6d0f2bea11698df4414..87d3cac1c779fe7fff852fdad2e5fdd33d4b25e8 100644 (file)
@@ -1880,6 +1880,7 @@ class MergedResult(IteratorResult):
     """
 
     closed = False
+    rowcount: Optional[int]
 
     def __init__(
         self, cursor_metadata: ResultMetaData, results: Sequence[Result]
index 09e38a5ab96ca81218550ffe1cfb7666d61576b1..423c3d446e3f09c485cbe92edfa38ed373e3ff5a 100644 (file)
@@ -34,6 +34,7 @@ import re
 from time import perf_counter
 import typing
 from typing import Any
+from typing import Callable
 from typing import Dict
 from typing import List
 from typing import Mapping
@@ -629,11 +630,11 @@ class SQLCompiler(Compiled):
     """list of columns that can be post-fetched after INSERT or UPDATE to
     receive server-updated values"""
 
-    insert_prefetch: Optional[List[Column[Any]]]
+    insert_prefetch: Sequence[Column[Any]] = ()
     """list of columns for which default values should be evaluated before
     an INSERT takes place"""
 
-    update_prefetch: Optional[List[Column[Any]]]
+    update_prefetch: Sequence[Column[Any]] = ()
     """list of columns for which onupdate default values should be evaluated
     before an UPDATE takes place"""
 
@@ -739,8 +740,6 @@ class SQLCompiler(Compiled):
     """if True, there are bindparam() objects that have the isoutparam
     flag set."""
 
-    insert_prefetch = update_prefetch = ()
-
     postfetch_lastrowid = False
     """if True, and this in insert, use cursor.lastrowid to populate
     result.inserted_primary_key. """
@@ -1340,7 +1339,7 @@ class SQLCompiler(Compiled):
         )
 
     @util.memoized_property
-    def _within_exec_param_key_getter(self):
+    def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
         getter = self._key_getters_for_crud_column[2]
         if self.escaped_bind_names:
 
index 4c38c4efabe491941f80a79166ab5e76d1705d74..168da17ccc007fa4df644e146c8ac3a4aa8d28be 100644 (file)
@@ -58,12 +58,13 @@ from ..util.langhelpers import TypingOnly
 if typing.TYPE_CHECKING:
     from decimal import Decimal
 
+    from .compiler import Compiled
+    from .compiler import SQLCompiler
     from .operators import OperatorType
     from .selectable import FromClause
     from .selectable import Select
     from .sqltypes import Boolean  # noqa
     from .type_api import TypeEngine
-    from ..engine import Compiled
     from ..engine import Connection
     from ..engine import Dialect
     from ..engine import Engine
@@ -573,6 +574,25 @@ class ClauseElement(
             )
 
 
+class DQLDMLClauseElement(ClauseElement):
+    """represents a :class:`.ClauseElement` that compiles to a DQL or DML
+    expression, not DDL.
+
+    .. versionadded:: 2.0
+
+    """
+
+    if typing.TYPE_CHECKING:
+
+        def compile(  # noqa: A001
+            self,
+            bind: Optional[Union[Engine, Connection]] = None,
+            dialect: Optional[Dialect] = None,
+            **kw: Any,
+        ) -> SQLCompiler:
+            ...
+
+
 class CompilerColumnElement(
     roles.DMLColumnRole,
     roles.DDLConstraintColumnRole,
@@ -955,7 +975,7 @@ class ColumnElement(
     roles.DDLExpressionRole,
     SQLCoreOperations[_T],
     operators.ColumnOperators[SQLCoreOperations],
-    ClauseElement,
+    DQLDMLClauseElement,
 ):
     """Represent a column-oriented SQL expression suitable for usage in the
     "columns" clause, WHERE clause etc. of a statement.
@@ -1820,7 +1840,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
         )
 
 
-class TypeClause(ClauseElement):
+class TypeClause(DQLDMLClauseElement):
     """Handle a type keyword in a SQL statement.
 
     Used by the ``Case`` statement.
@@ -1849,7 +1869,7 @@ class TextClause(
     roles.BinaryElementRole,
     roles.InElementRole,
     Executable,
-    ClauseElement,
+    DQLDMLClauseElement,
 ):
     """Represent a literal SQL text fragment.
 
@@ -2285,7 +2305,7 @@ class ClauseList(
     roles.OrderByRole,
     roles.ColumnsClauseRole,
     roles.DMLColumnRole,
-    ClauseElement,
+    DQLDMLClauseElement,
 ):
     """Describe a list of clauses, separated by an operator.
 
@@ -3205,7 +3225,7 @@ class IndexExpression(BinaryExpression):
     inherit_cache = True
 
 
-class GroupedElement(ClauseElement):
+class GroupedElement(DQLDMLClauseElement):
     """Represent any parenthesized expression"""
 
     __visit_name__ = "grouping"
index fdae4d7b04e5f3e3777d9b5cea2a6437cc8fb0e1..c270e15648b846dfc5ca557330b9dc1864e43348 100644 (file)
@@ -1131,6 +1131,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
     __visit_name__ = "column"
 
     inherit_cache = True
+    key: str
 
     @overload
     def __init__(
index a5cbffb5e1efb2f2d81bf5b4f79f08727f028dd8..e5c2bef6860ffc86d01f5644e304f3e42039c6f5 100644 (file)
@@ -62,6 +62,7 @@ from .elements import ClauseElement
 from .elements import ClauseList
 from .elements import ColumnClause
 from .elements import ColumnElement
+from .elements import DQLDMLClauseElement
 from .elements import GroupedElement
 from .elements import Grouping
 from .elements import literal_column
@@ -85,7 +86,7 @@ class _OffsetLimitParam(BindParameter):
         return self.effective_value
 
 
-class ReturnsRows(roles.ReturnsRowsRole, ClauseElement):
+class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
     """The base-most class for Core constructs that have some concept of
     columns that can represent rows.
 
index 45e31aaf7ff6f122ac5d8e9abf730a886cbe1885..b0df99c4150b32c5b30defb5b0286d0b637e6638 100644 (file)
 
 from __future__ import annotations
 
-from .sql.sqltypes import _Binary
-from .sql.sqltypes import ARRAY
-from .sql.sqltypes import BIGINT
-from .sql.sqltypes import BigInteger
-from .sql.sqltypes import BINARY
-from .sql.sqltypes import BLOB
-from .sql.sqltypes import BOOLEAN
-from .sql.sqltypes import Boolean
-from .sql.sqltypes import CHAR
-from .sql.sqltypes import CLOB
-from .sql.sqltypes import Concatenable
-from .sql.sqltypes import DATE
-from .sql.sqltypes import Date
-from .sql.sqltypes import DATETIME
-from .sql.sqltypes import DateTime
-from .sql.sqltypes import DECIMAL
-from .sql.sqltypes import DOUBLE
-from .sql.sqltypes import Double
-from .sql.sqltypes import DOUBLE_PRECISION
-from .sql.sqltypes import Enum
-from .sql.sqltypes import FLOAT
-from .sql.sqltypes import Float
-from .sql.sqltypes import Indexable
-from .sql.sqltypes import INT
-from .sql.sqltypes import INTEGER
-from .sql.sqltypes import Integer
-from .sql.sqltypes import Interval
-from .sql.sqltypes import JSON
-from .sql.sqltypes import LargeBinary
-from .sql.sqltypes import MatchType
-from .sql.sqltypes import NCHAR
-from .sql.sqltypes import NULLTYPE
-from .sql.sqltypes import NullType
-from .sql.sqltypes import NUMERIC
-from .sql.sqltypes import Numeric
-from .sql.sqltypes import NVARCHAR
-from .sql.sqltypes import PickleType
-from .sql.sqltypes import REAL
-from .sql.sqltypes import SchemaType
-from .sql.sqltypes import SMALLINT
-from .sql.sqltypes import SmallInteger
-from .sql.sqltypes import String
-from .sql.sqltypes import STRINGTYPE
-from .sql.sqltypes import TEXT
-from .sql.sqltypes import Text
-from .sql.sqltypes import TIME
-from .sql.sqltypes import Time
-from .sql.sqltypes import TIMESTAMP
-from .sql.sqltypes import TupleType
-from .sql.sqltypes import Unicode
-from .sql.sqltypes import UnicodeText
-from .sql.sqltypes import VARBINARY
-from .sql.sqltypes import VARCHAR
-from .sql.type_api import adapt_type
-from .sql.type_api import ExternalType
-from .sql.type_api import to_instance
-from .sql.type_api import TypeDecorator
-from .sql.type_api import TypeEngine
-from .sql.type_api import UserDefinedType
-from .sql.type_api import Variant
-
-__all__ = [
-    "TypeEngine",
-    "TypeDecorator",
-    "UserDefinedType",
-    "ExternalType",
-    "INT",
-    "CHAR",
-    "VARCHAR",
-    "NCHAR",
-    "NVARCHAR",
-    "TEXT",
-    "Text",
-    "FLOAT",
-    "NUMERIC",
-    "REAL",
-    "DECIMAL",
-    "TIMESTAMP",
-    "DATETIME",
-    "CLOB",
-    "BLOB",
-    "BINARY",
-    "VARBINARY",
-    "BOOLEAN",
-    "BIGINT",
-    "SMALLINT",
-    "INTEGER",
-    "DATE",
-    "TIME",
-    "TupleType",
-    "String",
-    "Integer",
-    "SmallInteger",
-    "BigInteger",
-    "Numeric",
-    "Float",
-    "Double",
-    "DOUBLE",
-    "DOUBLE_PRECISION",
-    "DateTime",
-    "Date",
-    "Time",
-    "LargeBinary",
-    "Boolean",
-    "Unicode",
-    "Concatenable",
-    "UnicodeText",
-    "PickleType",
-    "Interval",
-    "Enum",
-    "Indexable",
-    "ARRAY",
-    "JSON",
-]
+from .sql.sqltypes import _Binary as _Binary
+from .sql.sqltypes import ARRAY as ARRAY
+from .sql.sqltypes import BIGINT as BIGINT
+from .sql.sqltypes import BigInteger as BigInteger
+from .sql.sqltypes import BINARY as BINARY
+from .sql.sqltypes import BLOB as BLOB
+from .sql.sqltypes import BOOLEAN as BOOLEAN
+from .sql.sqltypes import Boolean as Boolean
+from .sql.sqltypes import CHAR as CHAR
+from .sql.sqltypes import CLOB as CLOB
+from .sql.sqltypes import Concatenable as Concatenable
+from .sql.sqltypes import DATE as DATE
+from .sql.sqltypes import Date as Date
+from .sql.sqltypes import DATETIME as DATETIME
+from .sql.sqltypes import DateTime as DateTime
+from .sql.sqltypes import DECIMAL as DECIMAL
+from .sql.sqltypes import DOUBLE as DOUBLE
+from .sql.sqltypes import Double as Double
+from .sql.sqltypes import DOUBLE_PRECISION as DOUBLE_PRECISION
+from .sql.sqltypes import Enum as Enum
+from .sql.sqltypes import FLOAT as FLOAT
+from .sql.sqltypes import Float as Float
+from .sql.sqltypes import Indexable as Indexable
+from .sql.sqltypes import INT as INT
+from .sql.sqltypes import INTEGER as INTEGER
+from .sql.sqltypes import Integer as Integer
+from .sql.sqltypes import Interval as Interval
+from .sql.sqltypes import JSON as JSON
+from .sql.sqltypes import LargeBinary as LargeBinary
+from .sql.sqltypes import MatchType as MatchType
+from .sql.sqltypes import NCHAR as NCHAR
+from .sql.sqltypes import NULLTYPE as NULLTYPE
+from .sql.sqltypes import NullType as NullType
+from .sql.sqltypes import NUMERIC as NUMERIC
+from .sql.sqltypes import Numeric as Numeric
+from .sql.sqltypes import NVARCHAR as NVARCHAR
+from .sql.sqltypes import PickleType as PickleType
+from .sql.sqltypes import REAL as REAL
+from .sql.sqltypes import SchemaType as SchemaType
+from .sql.sqltypes import SMALLINT as SMALLINT
+from .sql.sqltypes import SmallInteger as SmallInteger
+from .sql.sqltypes import String as String
+from .sql.sqltypes import STRINGTYPE as STRINGTYPE
+from .sql.sqltypes import TEXT as TEXT
+from .sql.sqltypes import Text as Text
+from .sql.sqltypes import TIME as TIME
+from .sql.sqltypes import Time as Time
+from .sql.sqltypes import TIMESTAMP as TIMESTAMP
+from .sql.sqltypes import TupleType as TupleType
+from .sql.sqltypes import Unicode as Unicode
+from .sql.sqltypes import UnicodeText as UnicodeText
+from .sql.sqltypes import VARBINARY as VARBINARY
+from .sql.sqltypes import VARCHAR as VARCHAR
+from .sql.type_api import adapt_type as adapt_type
+from .sql.type_api import ExternalType as ExternalType
+from .sql.type_api import to_instance as to_instance
+from .sql.type_api import TypeDecorator as TypeDecorator
+from .sql.type_api import TypeEngine as TypeEngine
+from .sql.type_api import UserDefinedType as UserDefinedType
+from .sql.type_api import Variant as Variant
index e0b53b44508c720cc66bf51796cc70580386b988..06a009c5b6362e40424b52754b762401ca30da1f 100644 (file)
@@ -34,6 +34,7 @@ import weakref
 
 from ._has_cy import HAS_CYEXTENSION
 from .typing import Literal
+from .typing import Protocol
 
 if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
     from ._py_collections import immutabledict as immutabledict
@@ -62,7 +63,7 @@ else:
 _T = TypeVar("_T", bound=Any)
 _KT = TypeVar("_KT", bound=Any)
 _VT = TypeVar("_VT", bound=Any)
-
+_T_co = TypeVar("_T_co", covariant=True)
 
 EMPTY_SET: FrozenSet[Any] = frozenset()
 
@@ -597,7 +598,17 @@ class LRUCache(typing.MutableMapping[_KT, _VT]):
             self._mutex.release()
 
 
-class ScopedRegistry:
+class _CreateFuncType(Protocol[_T_co]):
+    def __call__(self) -> _T_co:
+        ...
+
+
+class _ScopeFuncType(Protocol):
+    def __call__(self) -> Any:
+        ...
+
+
+class ScopedRegistry(Generic[_T]):
     """A Registry that can store one or multiple instances of a single
     class on the basis of a "scope" function.
 
@@ -614,6 +625,10 @@ class ScopedRegistry:
 
     __slots__ = "createfunc", "scopefunc", "registry"
 
+    createfunc: _CreateFuncType[_T]
+    scopefunc: _ScopeFuncType
+    registry: Any
+
     def __init__(self, createfunc, scopefunc):
         """Construct a new :class:`.ScopedRegistry`.
 
@@ -629,24 +644,24 @@ class ScopedRegistry:
         self.scopefunc = scopefunc
         self.registry = {}
 
-    def __call__(self):
+    def __call__(self) -> _T:
         key = self.scopefunc()
         try:
-            return self.registry[key]
+            return self.registry[key]  # type: ignore[no-any-return]
         except KeyError:
-            return self.registry.setdefault(key, self.createfunc())
+            return self.registry.setdefault(key, self.createfunc())  # type: ignore[no-any-return] # noqa: E501
 
-    def has(self):
+    def has(self) -> bool:
         """Return True if an object is present in the current scope."""
 
         return self.scopefunc() in self.registry
 
-    def set(self, obj):
+    def set(self, obj: _T) -> None:
         """Set the value for the current scope."""
 
         self.registry[self.scopefunc()] = obj
 
-    def clear(self):
+    def clear(self) -> None:
         """Clear the current scope, if any."""
 
         try:
@@ -655,32 +670,32 @@ class ScopedRegistry:
             pass
 
 
-class ThreadLocalRegistry(ScopedRegistry):
+class ThreadLocalRegistry(ScopedRegistry[_T]):
     """A :class:`.ScopedRegistry` that uses a ``threading.local()``
     variable for storage.
 
     """
 
-    def __init__(self, createfunc):
+    def __init__(self, createfunc: Callable[[], _T]):
         self.createfunc = createfunc
         self.registry = threading.local()
 
-    def __call__(self):
+    def __call__(self) -> _T:
         try:
-            return self.registry.value
+            return self.registry.value  # type: ignore[no-any-return]
         except AttributeError:
             val = self.registry.value = self.createfunc()
-            return val
+            return val  # type: ignore[no-any-return]
 
-    def has(self):
+    def has(self) -> bool:
         return hasattr(self.registry, "value")
 
-    def set(self, obj):
+    def set(self, obj: _T) -> None:
         self.registry.value = obj
 
-    def clear(self):
+    def clear(self) -> None:
         try:
-            del self.registry.value  # type: ignore
+            del self.registry.value
         except AttributeError:
             pass
 
index 771e974e93b26898a62b60bb7317c91269d7906a..d503529303359f5a4e714b34b3eec94fea384eac 100644 (file)
@@ -11,6 +11,7 @@ from itertools import filterfalse
 from typing import AbstractSet
 from typing import Any
 from typing import cast
+from typing import Collection
 from typing import Dict
 from typing import Iterable
 from typing import Iterator
@@ -67,7 +68,9 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
         dict.__init__(new, *args)
         return new
 
-    def __init__(self, *args: Union[Mapping[_KT, _VT], Tuple[_KT, _VT]]):
+    def __init__(
+        self, *args: Union[Mapping[_KT, _VT], Iterable[Tuple[_KT, _VT]]]
+    ):
         pass
 
     def __reduce__(self):
@@ -369,6 +372,8 @@ class IdentitySet:
 
     def difference(self, iterable):
         result = self.__new__(self.__class__)
+        other: Collection[Any]
+
         if isinstance(iterable, self.__class__):
             other = iterable._members
         else:
@@ -394,6 +399,9 @@ class IdentitySet:
 
     def intersection(self, iterable):
         result = self.__new__(self.__class__)
+
+        other: Collection[Any]
+
         if isinstance(iterable, self.__class__):
             other = iterable._members
         else:
@@ -466,7 +474,7 @@ class IdentitySet:
 
 
 def unique_list(seq, hashfunc=None):
-    seen = set()
+    seen: Set[Any] = set()
     seen_add = seen.add
     if not hashfunc:
         return [x for x in seq if x not in seen and not seen_add(x)]
index 5674e19afe2e70834fe68370d159b52b96ce494f..8cb84f73f5193a8f3fbbaed45b49b962481999e0 100644 (file)
@@ -679,7 +679,7 @@ def create_proxy_methods(
 
     def decorate(cls):
         def instrument(name, clslevel=False):
-            fn = cast(Callable[..., Any], getattr(target_cls, name))
+            fn = cast(types.FunctionType, getattr(target_cls, name))
             spec = compat.inspect_getfullargspec(fn)
             env = {"__name__": fn.__module__}
 
@@ -709,7 +709,7 @@ def create_proxy_methods(
                 )
 
             proxy_fn = cast(
-                Callable[..., Any], _exec_code_in_env(code, env, fn.__name__)
+                types.FunctionType, _exec_code_in_env(code, env, fn.__name__)
             )
             proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__
             proxy_fn.__doc__ = inject_docstring_text(
@@ -721,9 +721,9 @@ def create_proxy_methods(
             )
 
             if clslevel:
-                proxy_fn = classmethod(proxy_fn)
-
-            return proxy_fn
+                return classmethod(proxy_fn)
+            else:
+                return proxy_fn
 
         def makeprop(name):
             attr = target_cls.__dict__.get(name, None)
@@ -824,7 +824,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
     missing = object()
 
     pos_args = []
-    kw_args = _collections.OrderedDict()
+    kw_args: _collections.OrderedDict[str, Any] = _collections.OrderedDict()
     vargs = None
     for i, insp in enumerate(to_inspect):
         try:
@@ -855,7 +855,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
                         )
                     ]
                 )
-    output = []
+    output: List[str] = []
 
     output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
 
@@ -1007,7 +1007,7 @@ def monkeypatch_proxied_specials(
             if not hasattr(maybe_fn, "__call__"):
                 continue
             maybe_fn = getattr(maybe_fn, "__func__", maybe_fn)
-            fn = cast(Callable[..., Any], maybe_fn)
+            fn = cast(types.FunctionType, maybe_fn)
 
         except AttributeError:
             continue
@@ -1024,7 +1024,9 @@ def monkeypatch_proxied_specials(
             "return %(name)s.%(method)s%(d_args)s" % locals()
         )
 
-        env = from_instance is not None and {name: from_instance} or {}
+        env: Dict[str, types.FunctionType] = (
+            from_instance is not None and {name: from_instance} or {}
+        )
         exec(py, env)
         try:
             env[method].__defaults__ = fn.__defaults__
@@ -1482,6 +1484,7 @@ def dictlike_iteritems(dictlike):
 
         def iterator():
             for key in dictlike.iterkeys():
+                assert getter is not None
                 yield key, getter(key)
 
         return iterator()
@@ -1989,7 +1992,7 @@ def quoted_token_parser(value):
     # 0 = outside of quotes
     # 1 = inside of quotes
     state = 0
-    result = [[]]
+    result: List[List[str]] = [[]]
     idx = 0
     lv = len(value)
     while idx < lv:
index 291061561d7437eca980d4875a4605c04abba1fc..160eabd85fcdf4e3bce0296dec7dd6b26af67e66 100644 (file)
@@ -7,8 +7,6 @@ from typing import Callable  # noqa
 from typing import cast
 from typing import Dict
 from typing import ForwardRef
-from typing import Generic
-from typing import overload
 from typing import Type
 from typing import TypeVar
 from typing import Union
@@ -58,35 +56,6 @@ else:
     from typing import ParamSpec as ParamSpec  # noqa F401
 
 
-class _TypeToInstance(Generic[_T]):
-    """describe a variable that moves between a class and an instance of
-    that class.
-
-    """
-
-    @overload
-    def __get__(self, instance: None, owner: Any) -> Type[_T]:
-        ...
-
-    @overload
-    def __get__(self, instance: object, owner: Any) -> _T:
-        ...
-
-    def __get__(self, instance: object, owner: Any) -> Union[Type[_T], _T]:
-        ...
-
-    @overload
-    def __set__(self, instance: None, value: Type[_T]) -> None:
-        ...
-
-    @overload
-    def __set__(self, instance: object, value: _T) -> None:
-        ...
-
-    def __set__(self, instance: object, value: Union[Type[_T], _T]) -> None:
-        ...
-
-
 def de_stringify_annotation(
     cls: Type[Any], annotation: Union[str, Type[Any]]
 ) -> Union[str, Type[Any]]:
index 963b546ed910591f1fa6d1eec7b317c2bf1bf13b..b90feae49896707a52eac6ea094ec87a643377cd 100644 (file)
@@ -125,8 +125,8 @@ module = [
 
 ignore_errors = false
 
+# mostly strict without requiring totally untyped things to be
+# typed
+strict = true
 allow_untyped_defs = true
-check_untyped_defs = false
 allow_untyped_calls = true
-
-
index dbd957703f7bd82a3d49605f268b4bf05a8c578c..8b950026f2f5a860ea5c56ae6425bd7ce6879caf 100644 (file)
@@ -2983,7 +2983,7 @@ class HandleErrorTest(fixtures.TestBase):
             the_conn.append(connection)
 
         with mock.patch(
-            "sqlalchemy.engine.cursor.BaseCursorResult.__init__",
+            "sqlalchemy.engine.cursor.CursorResult.__init__",
             Mock(side_effect=tsa.exc.InvalidRequestError("duplicate col")),
         ):
             with engine.connect() as conn:
@@ -3019,7 +3019,7 @@ class HandleErrorTest(fixtures.TestBase):
         conn = engine.connect()
 
         with mock.patch(
-            "sqlalchemy.engine.cursor.BaseCursorResult.__init__",
+            "sqlalchemy.engine.cursor.CursorResult.__init__",
             Mock(side_effect=tsa.exc.InvalidRequestError("duplicate col")),
         ):
             assert_raises(
index da96f6c3a0d2b68fa963af993ab011264d67c647..acf16565a01dfc76b18d0c901b336f2ce5d70469 100644 (file)
@@ -141,7 +141,7 @@ class AdaptTest(fixtures.TestBase):
     def test_uppercase_importable(self, typ):
         if typ.__name__ == typ.__name__.upper():
             assert getattr(sa, typ.__name__) is typ
-            assert typ.__name__ in types.__all__
+            assert typ.__name__ in dir(types)
 
     @testing.combinations(
         ((d.name, d) for d in _all_dialects()), argnames="dialect", id_="ia"