]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pep 484 for types
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 17 Mar 2022 20:18:55 +0000 (16:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 20 Mar 2022 03:15:15 +0000 (23:15 -0400)
strict types type_api.py, including TypeDecorator,
NativeForEmulated, etc.

Change-Id: Ib2eba26de0981324a83733954cb7044a29bbd7db

31 files changed:
lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/engine/processors.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/engine/row.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/sql/_elements_constructors.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/util/deprecations.py
lib/sqlalchemy/util/langhelpers.py
pyproject.toml
test/engine/test_execute.py
test/engine/test_pool.py
test/ext/mypy/plain_files/sql_operations.py
test/sql/test_metadata.py
test/sql/test_quote.py
test/sql/test_types.py

index 61b40d935e1bd0c91fe69e0bec8331a3bcea678a..59951cd041379f35f96e22cf040e22096ec05ded 100644 (file)
@@ -143,7 +143,9 @@ class PyODBCConnector(Connector):
     def is_disconnect(
         self,
         e: Exception,
-        connection: Optional[pool.PoolProxiedConnection],
+        connection: Optional[
+            Union[pool.PoolProxiedConnection, interfaces.DBAPIConnection]
+        ],
         cursor: Optional[interfaces.DBAPICursor],
     ) -> bool:
         if isinstance(e, self.dbapi.ProgrammingError):
index dd792f70d7fb619301f5db799c0dbe2842ba2ab5..07ff495a717b8f3fe0678b4f7c2dfbbcb79b1089 100644 (file)
@@ -2311,7 +2311,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
         if column.computed is not None:
             colspec += " " + self.process(column.computed)
         else:
-            colspec += " " + self.dialect.type_compiler.process(
+            colspec += " " + self.dialect.type_compiler_instance.process(
                 column.type, type_expression=column
             )
 
@@ -2719,7 +2719,7 @@ class MSDialect(default.DefaultDialect):
 
     statement_compiler = MSSQLCompiler
     ddl_compiler = MSDDLCompiler
-    type_compiler = MSTypeCompiler
+    type_compiler_cls = MSTypeCompiler
     preparer = MSIdentifierPreparer
 
     construct_arguments = [
index b2ccfc90f272013eeab7963d59311499eb1842b8..cfdc3deb2baa8acabf19e72d118c5dc1db74e7b6 100644 (file)
@@ -1414,25 +1414,25 @@ class MySQLCompiler(compiler.SQLCompiler):
                 sqltypes.Time,
             ),
         ):
-            return self.dialect.type_compiler.process(type_)
+            return self.dialect.type_compiler_instance.process(type_)
         elif isinstance(type_, sqltypes.String) and not isinstance(
             type_, (ENUM, SET)
         ):
             adapted = CHAR._adapt_string_for_cast(type_)
-            return self.dialect.type_compiler.process(adapted)
+            return self.dialect.type_compiler_instance.process(adapted)
         elif isinstance(type_, sqltypes._Binary):
             return "BINARY"
         elif isinstance(type_, sqltypes.JSON):
             return "JSON"
         elif isinstance(type_, sqltypes.NUMERIC):
-            return self.dialect.type_compiler.process(type_).replace(
+            return self.dialect.type_compiler_instance.process(type_).replace(
                 "NUMERIC", "DECIMAL"
             )
         elif (
             isinstance(type_, sqltypes.Float)
             and self.dialect._support_float_cast
         ):
-            return self.dialect.type_compiler.process(type_)
+            return self.dialect.type_compiler_instance.process(type_)
         else:
             return None
 
@@ -1442,7 +1442,9 @@ class MySQLCompiler(compiler.SQLCompiler):
             util.warn(
                 "Datatype %s does not support CAST on MySQL/MariaDb; "
                 "the CAST will be skipped."
-                % self.dialect.type_compiler.process(cast.typeclause.type)
+                % self.dialect.type_compiler_instance.process(
+                    cast.typeclause.type
+                )
             )
             return self.process(cast.clause.self_group(), **kw)
 
@@ -1699,7 +1701,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
 
         colspec = [
             self.preparer.format_column(column),
-            self.dialect.type_compiler.process(
+            self.dialect.type_compiler_instance.process(
                 column.type, type_expression=column
             ),
         ]
@@ -2358,7 +2360,7 @@ class MySQLDialect(default.DefaultDialect):
 
     statement_compiler = MySQLCompiler
     ddl_compiler = MySQLDDLCompiler
-    type_compiler = MySQLTypeCompiler
+    type_compiler_cls = MySQLTypeCompiler
     ischema_names = ischema_names
     preparer = MySQLIdentifierPreparer
 
index 1ae58b8f476c3d6f306e1fc7dbccbd6e8ecbe125..3ee38c0cf1836c87f9090de58283c586b1e28db0 100644 (file)
@@ -743,9 +743,6 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
             day_precision=self.day_precision,
         )
 
-    def coerce_compared_value(self, op, value):
-        return self
-
 
 class ROWID(sqltypes.TypeEngine):
     """Oracle ROWID type.
@@ -1537,7 +1534,7 @@ class OracleDialect(default.DefaultDialect):
 
     statement_compiler = OracleCompiler
     ddl_compiler = OracleDDLCompiler
-    type_compiler = OracleTypeCompiler
+    type_compiler_cls = OracleTypeCompiler
     preparer = OracleIdentifierPreparer
     execution_ctx_cls = OracleExecutionContext
 
index e265cd0f7fa3ef7c3ab5af6c89172a30b44101e3..a1401ea0310631b77f50ab7d3789d4167c109b8a 100644 (file)
@@ -1712,9 +1712,6 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
     def python_type(self):
         return dt.timedelta
 
-    def coerce_compared_value(self, op, value):
-        return self
-
 
 PGInterval = INTERVAL
 
@@ -2168,7 +2165,7 @@ ischema_names = {
 class PGCompiler(compiler.SQLCompiler):
     def render_bind_cast(self, type_, dbapi_type, sqltext):
         return f"""{sqltext}::{
-                self.dialect.type_compiler.process(
+                self.dialect.type_compiler_instance.process(
                     dbapi_type, identifier_preparer=self.preparer
                 )
             }"""
@@ -2318,7 +2315,7 @@ class PGCompiler(compiler.SQLCompiler):
         return "SELECT %s WHERE 1!=1" % (
             ", ".join(
                 "CAST(NULL AS %s)"
-                % self.dialect.type_compiler.process(
+                % self.dialect.type_compiler_instance.process(
                     INTEGER() if type_._isnull else type_
                 )
                 for type_ in element_types or [INTEGER()]
@@ -2604,7 +2601,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
             else:
                 colspec += " SERIAL"
         else:
-            colspec += " " + self.dialect.type_compiler.process(
+            colspec += " " + self.dialect.type_compiler_instance.process(
                 column.type,
                 type_expression=column,
                 identifier_preparer=self.preparer,
@@ -3225,7 +3222,7 @@ class PGDialect(default.DefaultDialect):
 
     statement_compiler = PGCompiler
     ddl_compiler = PGDDLCompiler
-    type_compiler = PGTypeCompiler
+    type_compiler_cls = PGTypeCompiler
     preparer = PGIdentifierPreparer
     execution_ctx_cls = PGExecutionContext
     inspector = PGInspector
index 79068c75f08cb626f4e4651e913afac333b92e6a..03f35d5e2deb4cbe84a372485bbe762079468410 100644 (file)
@@ -1426,7 +1426,7 @@ class SQLiteCompiler(compiler.SQLCompiler):
 class SQLiteDDLCompiler(compiler.DDLCompiler):
     def get_column_specification(self, column, **kwargs):
 
-        coltype = self.dialect.type_compiler.process(
+        coltype = self.dialect.type_compiler_instance.process(
             column.type, type_expression=column
         )
         colspec = self.preparer.format_column(column) + " " + coltype
@@ -1815,7 +1815,7 @@ class SQLiteDialect(default.DefaultDialect):
     execution_ctx_cls = SQLiteExecutionContext
     statement_compiler = SQLiteCompiler
     ddl_compiler = SQLiteDDLCompiler
-    type_compiler = SQLiteTypeCompiler
+    type_compiler_cls = SQLiteTypeCompiler
     preparer = SQLiteIdentifierPreparer
     ischema_names = ischema_names
     colspecs = colspecs
index 5b66c537aa644b6f7d13ea2e4189f0b7358993bb..061794bded92743f35980d430721020c7dbf8ad7 100644 (file)
@@ -141,7 +141,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         if connection is None:
             try:
                 self._dbapi_connection = engine.raw_connection()
-            except dialect.dbapi.Error as err:
+            except dialect.loaded_dbapi.Error as err:
                 Connection._handle_dbapi_exception_noconnection(
                     err, dialect, engine
                 )
@@ -1809,7 +1809,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
 
         if not self._is_disconnect:
             self._is_disconnect = (
-                isinstance(e, self.dialect.dbapi.Error)
+                isinstance(e, self.dialect.loaded_dbapi.Error)
                 and not self.closed
                 and self.dialect.is_disconnect(
                     e,
@@ -1825,7 +1825,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
                 statement,
                 parameters,
                 e,
-                self.dialect.dbapi.Error,
+                self.dialect.loaded_dbapi.Error,
                 hide_parameters=self.engine.hide_parameters,
                 dialect=self.dialect,
                 ismulti=context.executemany if context is not None else None,
@@ -1834,7 +1834,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         try:
             # non-DBAPI error - if we already got a context,
             # or there's no string statement, don't wrap it
-            should_wrap = isinstance(e, self.dialect.dbapi.Error) or (
+            should_wrap = isinstance(e, self.dialect.loaded_dbapi.Error) or (
                 statement is not None
                 and context is None
                 and not is_exit_exception
@@ -1845,7 +1845,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
                     statement,
                     parameters,
                     cast(Exception, e),
-                    self.dialect.dbapi.Error,
+                    self.dialect.loaded_dbapi.Error,
                     hide_parameters=self.engine.hide_parameters,
                     connection_invalidated=self._is_disconnect,
                     dialect=self.dialect,
@@ -1943,17 +1943,17 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         exc_info = sys.exc_info()
 
         is_disconnect = isinstance(
-            e, dialect.dbapi.Error
+            e, dialect.loaded_dbapi.Error
         ) and dialect.is_disconnect(e, None, None)
 
-        should_wrap = isinstance(e, dialect.dbapi.Error)
+        should_wrap = isinstance(e, dialect.loaded_dbapi.Error)
 
         if should_wrap:
             sqlalchemy_exception = exc.DBAPIError.instance(
                 None,
                 None,
                 cast(Exception, e),
-                dialect.dbapi.Error,
+                dialect.loaded_dbapi.Error,
                 hide_parameters=engine.hide_parameters,
                 connection_invalidated=is_disconnect,
             )
index f776e597538bb01355a60a77056546c8450a182f..4f151e79ccac90673a75a70142e32d252c9cf506 100644 (file)
@@ -23,6 +23,7 @@ from typing import List
 from typing import Optional
 from typing import Sequence
 from typing import Tuple
+from typing import TYPE_CHECKING
 from typing import Union
 
 from .result import MergedResult
@@ -57,7 +58,7 @@ if typing.TYPE_CHECKING:
     from .result import _KeyMapType
     from .result import _KeyType
     from .result import _ProcessorsType
-    from .result import _ProcessorType
+    from ..sql.type_api import _ResultProcessorType
 
 # metadata entry tuple indexes.
 # using raw tuple is faster than namedtuple.
@@ -77,7 +78,7 @@ MD_UNTRANSLATED: Literal[6] = 6  # raw name from cursor.description
 
 
 _CursorKeyMapRecType = Tuple[
-    int, int, List[Any], str, str, Optional["_ProcessorType"], str
+    int, int, List[Any], str, str, Optional["_ResultProcessorType"], str
 ]
 
 _CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType]
@@ -164,6 +165,9 @@ class CursorResultMetaData(ResultMetaData):
         compiled_statement = context.compiled.statement
         invoked_statement = context.invoked_statement
 
+        if TYPE_CHECKING:
+            assert isinstance(invoked_statement, elements.ClauseElement)
+
         if compiled_statement is invoked_statement:
             return self
 
index ba34a0d421eaa91dfa7e4af06e80a24d78da3467..4a833d2e549fb0804da38ef709362927cc149b48 100644 (file)
@@ -57,6 +57,8 @@ from ..sql.compiler import SQLCompiler
 from ..sql.elements import quoted_name
 
 if typing.TYPE_CHECKING:
+    from types import ModuleType
+
     from .base import Connection
     from .base import Engine
     from .characteristics import ConnectionCharacteristic
@@ -67,8 +69,10 @@ if typing.TYPE_CHECKING:
     from .interfaces import _DBAPIMultiExecuteParams
     from .interfaces import _DBAPISingleExecuteParams
     from .interfaces import _ExecuteOptions
+    from .interfaces import _IsolationLevel
     from .interfaces import _MutableCoreSingleExecuteParams
-    from .result import _ProcessorType
+    from .interfaces import _ParamStyle
+    from .interfaces import DBAPIConnection
     from .row import Row
     from .url import URL
     from ..event import _ListenerFnType
@@ -76,12 +80,16 @@ if typing.TYPE_CHECKING:
     from ..pool import PoolProxiedConnection
     from ..sql import Executable
     from ..sql.compiler import Compiled
+    from ..sql.compiler import Linting
     from ..sql.compiler import ResultColumnsEntry
     from ..sql.compiler import TypeCompiler
     from ..sql.dml import DMLState
+    from ..sql.dml import UpdateBase
     from ..sql.elements import BindParameter
     from ..sql.schema import Column
     from ..sql.schema import ColumnDefault
+    from ..sql.type_api import _BindProcessorType
+    from ..sql.type_api import _ResultProcessorType
     from ..sql.type_api import TypeEngine
 
 # When we're handed literal SQL, ensure it's a SELECT query
@@ -102,10 +110,7 @@ class DefaultDialect(Dialect):
 
     statement_compiler = compiler.SQLCompiler
     ddl_compiler = compiler.DDLCompiler
-    if typing.TYPE_CHECKING:
-        type_compiler: TypeCompiler
-    else:
-        type_compiler = compiler.GenericTypeCompiler
+    type_compiler_cls = compiler.GenericTypeCompiler
 
     preparer = compiler.IdentifierPreparer
     supports_alter = True
@@ -253,20 +258,19 @@ class DefaultDialect(Dialect):
     )
     def __init__(
         self,
-        paramstyle=None,
-        isolation_level=None,
-        dbapi=None,
-        implicit_returning=None,
-        supports_native_boolean=None,
-        max_identifier_length=None,
-        label_length=None,
-        # int() is because the @deprecated_params decorator cannot accommodate
-        # the direct reference to the "NO_LINTING" object
-        compiler_linting=int(compiler.NO_LINTING),
-        server_side_cursors=False,
-        **kwargs,
+        paramstyle: Optional[_ParamStyle] = None,
+        isolation_level: Optional[_IsolationLevel] = None,
+        dbapi: Optional[ModuleType] = None,
+        implicit_returning: Optional[bool] = None,
+        supports_native_boolean: Optional[bool] = None,
+        max_identifier_length: Optional[int] = None,
+        label_length: Optional[int] = None,
+        # util.deprecated_params decorator cannot render the
+        # Linting.NO_LINTING constant
+        compiler_linting: Linting = int(compiler.NO_LINTING),  # type: ignore
+        server_side_cursors: bool = False,
+        **kwargs: Any,
     ):
-
         if server_side_cursors:
             if not self.supports_server_side_cursors:
                 raise exc.ArgumentError(
@@ -286,7 +290,9 @@ class DefaultDialect(Dialect):
 
         self.positional = False
         self._ischema = None
+
         self.dbapi = dbapi
+
         if paramstyle is not None:
             self.paramstyle = paramstyle
         elif self.dbapi is not None:
@@ -299,11 +305,17 @@ class DefaultDialect(Dialect):
         self.identifier_preparer = self.preparer(self)
         self._on_connect_isolation_level = isolation_level
 
-        tt_callable = cast(
-            Type[compiler.GenericTypeCompiler],
-            self.type_compiler,
-        )
-        self.type_compiler = tt_callable(self)
+        legacy_tt_callable = getattr(self, "type_compiler", None)
+        if legacy_tt_callable is not None:
+            tt_callable = cast(
+                Type[compiler.GenericTypeCompiler],
+                self.type_compiler,
+            )
+        else:
+            tt_callable = self.type_compiler_cls
+
+        self.type_compiler_instance = self.type_compiler = tt_callable(self)
+
         if supports_native_boolean is not None:
             self.supports_native_boolean = supports_native_boolean
 
@@ -315,6 +327,15 @@ class DefaultDialect(Dialect):
         self.label_length = label_length
         self.compiler_linting = compiler_linting
 
+    @util.memoized_property
+    def loaded_dbapi(self) -> ModuleType:
+        if self.dbapi is None:
+            raise exc.InvalidRequestError(
+                f"Dialect {self} does not have a Python DBAPI established "
+                "and cannot be used for actual database interaction"
+            )
+        return self.dbapi
+
     @util.memoized_property
     def _bind_typing_render_casts(self):
         return self.bind_typing is interfaces.BindTyping.RENDER_CASTS
@@ -495,7 +516,7 @@ class DefaultDialect(Dialect):
 
     def connect(self, *cargs, **cparams):
         # inherits the docstring from interfaces.Dialect.connect
-        return self.dbapi.connect(*cargs, **cparams)
+        return self.loaded_dbapi.connect(*cargs, **cparams)
 
     def create_connect_args(self, url):
         # inherits the docstring from interfaces.Dialect.create_connect_args
@@ -584,7 +605,7 @@ class DefaultDialect(Dialect):
     def _dialect_specific_select_one(self):
         return str(expression.select(1).compile(dialect=self))
 
-    def do_ping(self, dbapi_connection):
+    def do_ping(self, dbapi_connection: DBAPIConnection) -> bool:
         cursor = None
         try:
             cursor = dbapi_connection.cursor()
@@ -592,7 +613,7 @@ class DefaultDialect(Dialect):
                 cursor.execute(self._dialect_specific_select_one)
             finally:
                 cursor.close()
-        except self.dbapi.Error as err:
+        except self.loaded_dbapi.Error as err:
             if self.is_disconnect(err, dbapi_connection, cursor):
                 return False
             else:
@@ -747,7 +768,7 @@ class StrCompileDialect(DefaultDialect):
 
     statement_compiler = compiler.StrSQLCompiler
     ddl_compiler = compiler.DDLCompiler
-    type_compiler = compiler.StrSQLTypeCompiler  # type: ignore
+    type_compiler_cls = compiler.StrSQLTypeCompiler
     preparer = compiler.IdentifierPreparer
 
     supports_statement_cache = True
@@ -906,6 +927,8 @@ class DefaultExecutionContext(ExecutionContext):
         self.is_text = compiled.isplaintext
 
         if self.isinsert or self.isupdate or self.isdelete:
+            if TYPE_CHECKING:
+                assert isinstance(compiled.statement, UpdateBase)
             self.is_crud = True
             self._is_explicit_returning = bool(compiled.statement._returning)
             self._is_implicit_returning = bool(
@@ -943,7 +966,7 @@ class DefaultExecutionContext(ExecutionContext):
         processors = compiled._bind_processors
 
         flattened_processors: Mapping[
-            str, _ProcessorType
+            str, _BindProcessorType[Any]
         ] = processors  # type: ignore[assignment]
 
         if compiled.literal_execute_params or compiled.post_compile_params:
@@ -1354,7 +1377,7 @@ class DefaultExecutionContext(ExecutionContext):
 
             type_ = bindparam.type
             impl_type = type_.dialect_impl(self.dialect)
-            dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi)
+            dbapi_type = impl_type.get_dbapi_type(self.dialect.loaded_dbapi)
             result_processor = impl_type.result_processor(
                 self.dialect, dbapi_type
             )
index 3ca30d1bc95fd3036e405e5808e16c598296ede7..1b178641e1de660a38c534172c4511c8584e1ec5 100644 (file)
@@ -14,12 +14,14 @@ from types import ModuleType
 from typing import Any
 from typing import Awaitable
 from typing import Callable
+from typing import ClassVar
 from typing import Dict
 from typing import List
 from typing import Mapping
 from typing import MutableMapping
 from typing import Optional
 from typing import Sequence
+from typing import Set
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
@@ -60,6 +62,7 @@ if TYPE_CHECKING:
     from ..sql.schema import DefaultGenerator
     from ..sql.schema import Sequence as Sequence_SchemaItem
     from ..sql.sqltypes import Integer
+    from ..sql.type_api import _TypeMemoDict
     from ..sql.type_api import TypeEngine
 
 ConnectArgsType = Tuple[Tuple[str], MutableMapping[str, Any]]
@@ -166,7 +169,7 @@ class DBAPICursor(Protocol):
     def execute(
         self,
         operation: Any,
-        parameters: Optional[_DBAPISingleExecuteParams],
+        parameters: Optional[_DBAPISingleExecuteParams] = None,
     ) -> Any:
         ...
 
@@ -577,7 +580,7 @@ class Dialect(EventTarget):
     driver: str
     """identifying name for the dialect's DBAPI"""
 
-    dbapi: ModuleType
+    dbapi: Optional[ModuleType]
     """A reference to the DBAPI module object itself.
 
     SQLAlchemy dialects import DBAPI modules using the classmethod
@@ -600,6 +603,16 @@ class Dialect(EventTarget):
 
     """
 
+    @util.non_memoized_property
+    def loaded_dbapi(self) -> ModuleType:
+        """same as .dbapi, but is never None; will raise an error if no
+        DBAPI was set up.
+
+        .. versionadded:: 2.0
+
+        """
+        raise NotImplementedError()
+
     positional: bool
     """True if the paramstyle for this Dialect is positional."""
 
@@ -616,8 +629,28 @@ class Dialect(EventTarget):
     ddl_compiler: Type[DDLCompiler]
     """a :class:`.Compiled` class used to compile DDL statements"""
 
-    type_compiler: Union[Type[TypeCompiler], TypeCompiler]
-    """a :class:`.Compiled` class used to compile SQL type objects"""
+    type_compiler_cls: ClassVar[Type[TypeCompiler]]
+    """a :class:`.Compiled` class used to compile SQL type objects
+
+    .. versionadded:: 2.0
+
+    """
+
+    type_compiler_instance: TypeCompiler
+    """instance of a :class:`.Compiled` class used to compile SQL type
+    objects
+
+    .. versionadded:: 2.0
+
+    """
+
+    type_compiler: Any
+    """legacy; this is a TypeCompiler class at the class level, a
+    TypeCompiler instance at the instance level.
+
+    Refer to type_compiler_instance instead.
+
+    """
 
     preparer: Type[IdentifierPreparer]
     """a :class:`.IdentifierPreparer` class used to
@@ -683,9 +716,17 @@ class Dialect(EventTarget):
     """
 
     supports_default_values: bool
-    """Indicates if the construct ``INSERT INTO tablename DEFAULT
-      VALUES`` is supported
-    """
+    """dialect supports INSERT... DEFAULT VALUES syntax"""
+
+    supports_default_metavalue: bool
+    """dialect supports INSERT... VALUES (DEFAULT) syntax"""
+
+    supports_empty_insert: bool
+    """dialect supports INSERT () VALUES ()"""
+
+    supports_multivalues_insert: bool
+    """Target database supports INSERT...VALUES with multiple value
+    sets"""
 
     preexecute_autoincrement_sequences: bool
     """True if 'implicit' primary key functions must be executed separately
@@ -723,6 +764,12 @@ class Dialect(EventTarget):
       other backends.
     """
 
+    default_sequence_base: int
+    """the default value that will be rendered as the "START WITH" portion of
+    a CREATE SEQUENCE DDL statement.
+
+    """
+
     supports_native_enum: bool
     """Indicates if the dialect supports a native ENUM construct.
       This will prevent :class:`_types.Enum` from generating a CHECK
@@ -735,6 +782,10 @@ class Dialect(EventTarget):
       constraint when that type is used.
     """
 
+    supports_native_decimal: bool
+    """indicates if Decimal objects are handled and returned for precision
+    numeric types, or if floats are returned"""
+
     construct_arguments: Optional[
         List[Tuple[Type[ClauseElement], Mapping[str, Any]]]
     ] = None
@@ -842,6 +893,52 @@ class Dialect(EventTarget):
 
     """
 
+    label_length: Optional[int]
+    """optional user-defined max length for SQL labels"""
+
+    include_set_input_sizes: Optional[Set[Any]]
+    """set of DBAPI type objects that should be included in
+    automatic cursor.setinputsizes() calls.
+
+    This is only used if bind_typing is BindTyping.SET_INPUT_SIZES
+
+    """
+
+    exclude_set_input_sizes: Optional[Set[Any]]
+    """set of DBAPI type objects that should be excluded in
+    automatic cursor.setinputsizes() calls.
+
+    This is only used if bind_typing is BindTyping.SET_INPUT_SIZES
+
+    """
+
+    supports_simple_order_by_label: bool
+    """target database supports ORDER BY <labelname>, where <labelname>
+    refers to a label in the columns clause of the SELECT"""
+
+    div_is_floordiv: bool
+    """target database treats the / division operator as "floor division" """
+
+    tuple_in_values: bool
+    """target database supports tuple IN, i.e. (x, y) IN ((q, p), (r, z))"""
+
+    _bind_typing_render_casts: bool
+
+    supports_identity_columns: bool
+    """target database supports IDENTITY"""
+
+    cte_follows_insert: bool
+    """target database, when given a CTE with an INSERT statement, needs
+    the CTE to be below the INSERT"""
+
+    insert_executemany_returning: bool
+    """dialect / driver / database supports some means of providing RETURNING
+    support when dialect.do_executemany() is used.
+
+    """
+
+    _type_memos: MutableMapping[TypeEngine[Any], "_TypeMemoDict"]
+
     def _builtin_onconnect(self) -> Optional[_ListenerFnType]:
         raise NotImplementedError()
 
@@ -1495,7 +1592,7 @@ class Dialect(EventTarget):
     def is_disconnect(
         self,
         e: Exception,
-        connection: Optional[PoolProxiedConnection],
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
         cursor: Optional[DBAPICursor],
     ) -> bool:
         """Return True if the given DB-API error indicates an invalid
index 7a6a57c03ff2dc94f33112f0d03845dcfb624cd0..58792675299da65c0828c83cb75a00e8e04084bd 100644 (file)
@@ -20,23 +20,29 @@ from ._py_processors import str_to_datetime_processor_factory  # noqa
 from ..util._has_cy import HAS_CYEXTENSION
 
 if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
-    from ._py_processors import int_to_boolean  # noqa
-    from ._py_processors import str_to_date  # noqa
-    from ._py_processors import str_to_datetime  # noqa
-    from ._py_processors import str_to_time  # noqa
-    from ._py_processors import to_decimal_processor_factory  # noqa
-    from ._py_processors import to_float  # noqa
-    from ._py_processors import to_str  # noqa
+    from ._py_processors import int_to_boolean as int_to_boolean
+    from ._py_processors import str_to_date as str_to_date
+    from ._py_processors import str_to_datetime as str_to_datetime
+    from ._py_processors import str_to_time as str_to_time
+    from ._py_processors import (
+        to_decimal_processor_factory as to_decimal_processor_factory,
+    )
+    from ._py_processors import to_float as to_float
+    from ._py_processors import to_str as to_str
 else:
     from sqlalchemy.cyextension.processors import (
         DecimalResultProcessor,
     )  # noqa
-    from sqlalchemy.cyextension.processors import int_to_boolean  # noqa
-    from sqlalchemy.cyextension.processors import str_to_date  # noqa
-    from sqlalchemy.cyextension.processors import str_to_datetime  # noqa
-    from sqlalchemy.cyextension.processors import str_to_time  # noqa
-    from sqlalchemy.cyextension.processors import to_float  # noqa
-    from sqlalchemy.cyextension.processors import to_str  # noqa
+    from sqlalchemy.cyextension.processors import (
+        int_to_boolean as int_to_boolean,
+    )
+    from sqlalchemy.cyextension.processors import str_to_date as str_to_date
+    from sqlalchemy.cyextension.processors import (
+        str_to_datetime as str_to_datetime,
+    )
+    from sqlalchemy.cyextension.processors import str_to_time as str_to_time
+    from sqlalchemy.cyextension.processors import to_float as to_float
+    from sqlalchemy.cyextension.processors import to_str as to_str
 
     def to_decimal_processor_factory(target_class, scale):
         # Note that the scale argument is not taken into account for integer
index d428b8a9d4361846e54e252e792c7fff41ff45d1..05b06e84651f7463cf3fb5521507157326715d01 100644 (file)
@@ -46,6 +46,7 @@ else:
 if typing.TYPE_CHECKING:
     from .row import RowMapping
     from ..sql.schema import Column
+    from ..sql.type_api import _ResultProcessorType
 
 _KeyType = Union[str, "Column[Any]"]
 _KeyIndexType = Union[str, "Column[Any]", int]
@@ -70,8 +71,7 @@ across all the result types
 
 """
 
-_ProcessorType = Callable[[Any], Any]
-_ProcessorsType = Sequence[Optional[_ProcessorType]]
+_ProcessorsType = Sequence[Optional["_ResultProcessorType[Any]"]]
 _TupleGetterType = Callable[[Sequence[Any]], Tuple[Any, ...]]
 _UniqueFilterType = Callable[[Any], Any]
 _UniqueFilterStateType = Tuple[Set[Any], Optional[_UniqueFilterType]]
index ff63199d40e21434de281c38931bd59a5620167d..4ba39b55d6700f1e08b94f133a4aa34dc2edf478 100644 (file)
@@ -41,6 +41,7 @@ else:
 if typing.TYPE_CHECKING:
     from .result import _KeyType
     from .result import RMKeyView
+    from ..sql.type_api import _ResultProcessorType
 
 
 class Row(BaseRow, typing.Sequence[Any]):
@@ -105,7 +106,7 @@ class Row(BaseRow, typing.Sequence[Any]):
         )
 
     def _filter_on_values(
-        self, filters: Optional[Sequence[Optional[Callable[[Any], Any]]]]
+        self, filters: Optional[Sequence[Optional[_ResultProcessorType[Any]]]]
     ) -> Row:
         return Row(
             self._parent,
index 5a9d28a2899961f10e0aca4275f1436cb2d7f0df..9ab1477100445bae4b60ad109c44fcfbcdb12097 100644 (file)
@@ -82,7 +82,7 @@ class _ConnDialect:
     def do_close(self, dbapi_connection: DBAPIConnection) -> None:
         dbapi_connection.close()
 
-    def do_ping(self, dbapi_connection: DBAPIConnection) -> None:
+    def do_ping(self, dbapi_connection: DBAPIConnection) -> bool:
         raise NotImplementedError(
             "The ping feature requires that a dialect is "
             "passed to the connection pool."
index 770fbe40c982395bf8798f802096b2245e22c11e..aabd3871e11b2a10d8f889c03b68ff8a1ac32bd0 100644 (file)
@@ -58,7 +58,7 @@ if typing.TYPE_CHECKING:
 _T = TypeVar("_T")
 
 
-def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[_T]:
+def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]:
     """Produce an ALL expression.
 
     For dialects such as that of PostgreSQL, this operator applies
@@ -173,7 +173,7 @@ def and_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList:
     return BooleanClauseList.and_(*clauses)
 
 
-def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[_T]:
+def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]:
     """Produce an ANY expression.
 
     For dialects such as that of PostgreSQL, this operator applies
index 29f9028c8b4a5589d1876433302a7ba265c9deec..6a6b389de8430d76c123e056ff548cbb64ccf949 100644 (file)
@@ -1132,10 +1132,12 @@ class SchemaEventTarget:
 
     """
 
-    def _set_parent(self, parent, **kw):
+    def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
         """Associate with this SchemaEvent's parent object."""
 
-    def _set_parent_with_dispatch(self, parent, **kw):
+    def _set_parent_with_dispatch(
+        self, parent: SchemaEventTarget, **kw: Any
+    ) -> None:
         self.dispatch.before_parent_attach(self, parent)
         self._set_parent(parent, **kw)
         self.dispatch.after_parent_attach(self, parent)
index 8f878b66c08a5614e502920d94e5dc04909d4f26..f8019b9c64e242f46509503fcdfdddc733b8f08d 100644 (file)
@@ -43,6 +43,7 @@ from typing import List
 from typing import Mapping
 from typing import MutableMapping
 from typing import NamedTuple
+from typing import NoReturn
 from typing import Optional
 from typing import Sequence
 from typing import Set
@@ -51,6 +52,7 @@ from typing import Type
 from typing import TYPE_CHECKING
 from typing import Union
 
+from sqlalchemy.sql.ddl import DDLElement
 from . import base
 from . import coercions
 from . import crud
@@ -61,7 +63,9 @@ from . import schema
 from . import selectable
 from . import sqltypes
 from .base import _from_objects
+from .base import Executable
 from .base import NO_ARG
+from .elements import ClauseElement
 from .elements import quoted_name
 from .schema import Column
 from .sqltypes import TupleType
@@ -78,6 +82,10 @@ if typing.TYPE_CHECKING:
     from .base import _AmbiguousTableNameMap
     from .base import CompileState
     from .cache_key import CacheKey
+    from .dml import Insert
+    from .dml import UpdateBase
+    from .dml import ValuesBase
+    from .elements import _truncated_label
     from .elements import BindParameter
     from .elements import ColumnClause
     from .elements import Label
@@ -91,12 +99,13 @@ if typing.TYPE_CHECKING:
     from .selectable import ReturnsRows
     from .selectable import Select
     from .selectable import SelectState
+    from .type_api import _BindProcessorType
     from ..engine.cursor import CursorResultMetaData
     from ..engine.interfaces import _CoreSingleExecuteParams
     from ..engine.interfaces import _ExecuteOptions
     from ..engine.interfaces import _MutableCoreSingleExecuteParams
     from ..engine.interfaces import _SchemaTranslateMapType
-    from ..engine.result import _ProcessorType
+    from ..engine.interfaces import Dialect
 
 _FromHintsType = Dict["FromClause", str]
 
@@ -378,7 +387,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
 class ExpandedState(NamedTuple):
     statement: str
     additional_parameters: _CoreSingleExecuteParams
-    processors: Mapping[str, _ProcessorType]
+    processors: Mapping[str, _BindProcessorType[Any]]
     positiontup: Optional[Sequence[str]]
     parameter_expansion: Mapping[str, List[str]]
 
@@ -531,11 +540,11 @@ class Compiled:
 
     def __init__(
         self,
-        dialect,
-        statement,
-        schema_translate_map=None,
-        render_schema_translate=False,
-        compile_kwargs=util.immutabledict(),
+        dialect: Dialect,
+        statement: Optional[ClauseElement],
+        schema_translate_map: Optional[_SchemaTranslateMapType] = None,
+        render_schema_translate: bool = False,
+        compile_kwargs: Mapping[str, Any] = util.immutabledict(),
     ):
         """Construct a new :class:`.Compiled` object.
 
@@ -571,6 +580,8 @@ class Compiled:
             self.can_execute = statement.supports_execution
             self._annotations = statement._annotations
             if self.can_execute:
+                if TYPE_CHECKING:
+                    assert isinstance(statement, Executable)
                 self.execution_options = statement._execution_options
             self.string = self.process(self.statement, **compile_kwargs)
 
@@ -636,10 +647,10 @@ class TypeCompiler(util.EnsureKWArg):
 
     ensure_kwarg = r"visit_\w+"
 
-    def __init__(self, dialect):
+    def __init__(self, dialect: Dialect):
         self.dialect = dialect
 
-    def process(self, type_, **kw):
+    def process(self, type_: TypeEngine[Any], **kw: Any) -> str:
         if (
             type_._variant_mapping
             and self.dialect.name in type_._variant_mapping
@@ -647,7 +658,9 @@ class TypeCompiler(util.EnsureKWArg):
             type_ = type_._variant_mapping[self.dialect.name]
         return type_._compiler_dispatch(self, **kw)
 
-    def visit_unsupported_compilation(self, element, err, **kw):
+    def visit_unsupported_compilation(
+        self, element: Any, err: Exception, **kw: Any
+    ) -> NoReturn:
         raise exc.UnsupportedCompilationError(self, element) from err
 
 
@@ -877,13 +890,13 @@ class SQLCompiler(Compiled):
 
     def __init__(
         self,
-        dialect,
-        statement,
-        cache_key=None,
-        column_keys=None,
-        for_executemany=False,
-        linting=NO_LINTING,
-        **kwargs,
+        dialect: Dialect,
+        statement: Optional[ClauseElement],
+        cache_key: Optional[CacheKey] = None,
+        column_keys: Optional[Sequence[str]] = None,
+        for_executemany: bool = False,
+        linting: Linting = NO_LINTING,
+        **kwargs: Any,
     ):
         """Construct a new :class:`.SQLCompiler` object.
 
@@ -954,15 +967,21 @@ class SQLCompiler(Compiled):
 
         # a map which tracks "truncated" names based on
         # dialect.label_length or dialect.max_identifier_length
-        self.truncated_names = {}
+        self.truncated_names: Dict[Tuple[str, str], str] = {}
+        self._truncated_counters: Dict[str, int] = {}
 
         Compiled.__init__(self, dialect, statement, **kwargs)
 
         if self.isinsert or self.isupdate or self.isdelete:
+            if TYPE_CHECKING:
+                assert isinstance(statement, UpdateBase)
+
             if statement._returning:
                 self.returning = statement._returning
 
             if self.isinsert or self.isupdate:
+                if TYPE_CHECKING:
+                    assert isinstance(statement, ValuesBase)
                 if statement._inline:
                     self.inline = True
                 elif self.for_executemany and (
@@ -1082,9 +1101,14 @@ class SQLCompiler(Compiled):
     @util.memoized_property
     def _bind_processors(
         self,
-    ) -> MutableMapping[str, Union[_ProcessorType, Sequence[_ProcessorType]]]:
+    ) -> MutableMapping[
+        str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]]
+    ]:
+
+        # mypy is not able to see the two value types as the above Union,
+        # it just sees "object".  don't know how to resolve
         return dict(
-            (key, value)
+            (key, value)  # type: ignore
             for key, value in (
                 (
                     self.bind_names[bindparam],
@@ -1301,12 +1325,14 @@ class SQLCompiler(Compiled):
             positiontup = None
 
         processors = self._bind_processors
-        single_processors = cast("Mapping[str, _ProcessorType]", processors)
+        single_processors = cast(
+            "Mapping[str, _BindProcessorType[Any]]", processors
+        )
         tuple_processors = cast(
-            "Mapping[str, Sequence[_ProcessorType]]", processors
+            "Mapping[str, Sequence[_BindProcessorType[Any]]]", processors
         )
 
-        new_processors: Dict[str, _ProcessorType] = {}
+        new_processors: Dict[str, _BindProcessorType[Any]] = {}
 
         if self.positional and self._numeric_binds:
             # I'm not familiar with any DBAPI that uses 'numeric'.
@@ -1484,6 +1510,10 @@ class SQLCompiler(Compiled):
         result = util.preloaded.engine_result
 
         param_key_getter = self._within_exec_param_key_getter
+
+        if TYPE_CHECKING:
+            assert isinstance(self.statement, Insert)
+
         table = self.statement.table
 
         getters = [
@@ -1530,6 +1560,9 @@ class SQLCompiler(Compiled):
         else:
             result = util.preloaded.engine_result
 
+        if TYPE_CHECKING:
+            assert isinstance(self.statement, Insert)
+
         param_key_getter = self._within_exec_param_key_getter
         table = self.statement.table
 
@@ -1796,7 +1829,9 @@ class SQLCompiler(Compiled):
     def visit_typeclause(self, typeclause, **kw):
         kw["type_expression"] = typeclause
         kw["identifier_preparer"] = self.preparer
-        return self.dialect.type_compiler.process(typeclause.type, **kw)
+        return self.dialect.type_compiler_instance.process(
+            typeclause.type, **kw
+        )
 
     def post_process_text(self, text):
         if self.preparer._double_percents:
@@ -2855,26 +2890,28 @@ class SQLCompiler(Compiled):
 
         return bind_name
 
-    def _truncated_identifier(self, ident_class, name):
+    def _truncated_identifier(
+        self, ident_class: str, name: _truncated_label
+    ) -> str:
         if (ident_class, name) in self.truncated_names:
             return self.truncated_names[(ident_class, name)]
 
         anonname = name.apply_map(self.anon_map)
 
         if len(anonname) > self.label_length - 6:
-            counter = self.truncated_names.get(ident_class, 1)
+            counter = self._truncated_counters.get(ident_class, 1)
             truncname = (
                 anonname[0 : max(self.label_length - 6, 0)]
                 + "_"
                 + hex(counter)[2:]
             )
-            self.truncated_names[ident_class] = counter + 1
+            self._truncated_counters[ident_class] = counter + 1
         else:
             truncname = anonname
         self.truncated_names[(ident_class, name)] = truncname
         return truncname
 
-    def _anonymize(self, name):
+    def _anonymize(self, name: str) -> str:
         return name % self.anon_map
 
     def bindparam_string(
@@ -3221,7 +3258,7 @@ class SQLCompiler(Compiled):
                         % (
                             self.preparer.quote(col.name),
                             " %s"
-                            % self.dialect.type_compiler.process(
+                            % self.dialect.type_compiler_instance.process(
                                 col.type, **kwargs
                             )
                             if alias._render_derived_w_types
@@ -4685,6 +4722,18 @@ class StrSQLCompiler(SQLCompiler):
 
 
 class DDLCompiler(Compiled):
+    if TYPE_CHECKING:
+
+        def __init__(
+            self,
+            dialect: Dialect,
+            statement: DDLElement,
+            schema_translate_map: Optional[_SchemaTranslateMapType] = ...,
+            render_schema_translate: bool = ...,
+            compile_kwargs: Mapping[str, Any] = ...,
+        ):
+            ...
+
     @util.memoized_property
     def sql_compiler(self):
         return self.dialect.statement_compiler(
@@ -4693,7 +4742,7 @@ class DDLCompiler(Compiled):
 
     @util.memoized_property
     def type_compiler(self):
-        return self.dialect.type_compiler
+        return self.dialect.type_compiler_instance
 
     def construct_params(
         self,
@@ -5010,7 +5059,7 @@ class DDLCompiler(Compiled):
         colspec = (
             self.preparer.format_column(column)
             + " "
-            + self.dialect.type_compiler.process(
+            + self.dialect.type_compiler_instance.process(
                 column.type, type_expression=column
             )
         )
index 96e90b0ea183e9d4f14d3062c2cca85b2a6eb4b2..1271c5977c20f9e95b53e3d1100c85c5a6b98ce3 100644 (file)
@@ -433,7 +433,7 @@ class ValuesBase(UpdateBase):
     _multi_values = ()
     _ordered_values = None
     _select_names = None
-
+    _inline: bool = False
     _returning = ()
 
     def __init__(self, table):
@@ -742,7 +742,6 @@ class Insert(ValuesBase):
 
     select = None
     include_insert_from_select_defaults = False
-    _inline = False
 
     is_insert = True
 
@@ -959,7 +958,6 @@ class Update(DMLWhereBase, ValuesBase):
 
     is_update = True
     _preserve_parameter_order = False
-    _inline = False
 
     _traverse_internals = (
         [
index 696d3c6f2e128240f09901e6c405c25f90d55a16..48c3c3be665628ee656bed990ac0bc143fc922c9 100644 (file)
@@ -258,6 +258,8 @@ class CompilerElement(Visitable):
         """Return a compiler appropriate for this ClauseElement, given a
         Dialect."""
 
+        if TYPE_CHECKING:
+            assert isinstance(self, ClauseElement)
         return dialect.statement_compiler(dialect, self, **kw)
 
     def __str__(self) -> str:
@@ -663,6 +665,11 @@ class DQLDMLClauseElement(ClauseElement):
 
     if typing.TYPE_CHECKING:
 
+        def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler:
+            """Return a compiler appropriate for this ClauseElement, given a
+            Dialect."""
+            ...
+
         def compile(  # noqa: A001
             self,
             bind: Optional[Union[Engine, Connection]] = None,
@@ -671,9 +678,6 @@ class DQLDMLClauseElement(ClauseElement):
         ) -> SQLCompiler:
             ...
 
-        def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler:
-            ...
-
 
 class CompilerColumnElement(
     roles.DMLColumnRole,
@@ -1272,6 +1276,12 @@ class ColumnElement(
 
     _alt_names: Sequence[str] = ()
 
+    @overload
+    def self_group(
+        self: ColumnElement[_T], against: Optional[OperatorType] = None
+    ) -> ColumnElement[_T]:
+        ...
+
     @overload
     def self_group(
         self: ColumnElement[bool], against: Optional[OperatorType] = None
@@ -1280,8 +1290,8 @@ class ColumnElement(
 
     @overload
     def self_group(
-        self: ColumnElement[_T], against: Optional[OperatorType] = None
-    ) -> ColumnElement[_T]:
+        self: ColumnElement[Any], against: Optional[OperatorType] = None
+    ) -> ColumnElement[Any]:
         ...
 
     def self_group(
@@ -1777,7 +1787,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]):
             value = None
 
         if quote is not None:
-            key = quoted_name(key, quote)
+            key = quoted_name.construct(key, quote)
 
         if unique:
             self.key = _anonymous_label.safe_construct(
@@ -3121,7 +3131,11 @@ class UnaryExpression(ColumnElement[_T]):
         self.element = element.self_group(
             against=self.operator or self.modifier
         )
-        self.type: TypeEngine[_T] = type_api.to_instance(type_)
+
+        # if type is None, we get NULLTYPE, which is our _T.  But I don't
+        # know how to get the overloads to express that correctly
+        self.type = type_api.to_instance(type_)  # type: ignore
+
         self.wraps_column_expression = wraps_column_expression
 
     @classmethod
@@ -3224,27 +3238,32 @@ class CollectionAggregate(UnaryExpression[_T]):
     @classmethod
     def _create_any(
         cls, expr: _ColumnExpression[_T]
-    ) -> CollectionAggregate[_T]:
-        expr = coercions.expect(roles.ExpressionElementRole, expr)
-
-        expr = expr.self_group()
-        return CollectionAggregate(
+    ) -> CollectionAggregate[bool]:
+        col_expr = coercions.expect(
+            roles.ExpressionElementRole,
             expr,
+        )
+        col_expr = col_expr.self_group()
+        return CollectionAggregate(
+            col_expr,
             operator=operators.any_op,
-            type_=type_api.NULLTYPE,
+            type_=type_api.BOOLEANTYPE,
             wraps_column_expression=False,
         )
 
     @classmethod
     def _create_all(
         cls, expr: _ColumnExpression[_T]
-    ) -> CollectionAggregate[_T]:
-        expr = coercions.expect(roles.ExpressionElementRole, expr)
-        expr = expr.self_group()
-        return CollectionAggregate(
+    ) -> CollectionAggregate[bool]:
+        col_expr = coercions.expect(
+            roles.ExpressionElementRole,
             expr,
+        )
+        col_expr = col_expr.self_group()
+        return CollectionAggregate(
+            col_expr,
             operator=operators.all_op,
-            type_=type_api.NULLTYPE,
+            type_=type_api.BOOLEANTYPE,
             wraps_column_expression=False,
         )
 
@@ -3347,7 +3366,11 @@ class BinaryExpression(ColumnElement[_T]):
         self.left = left.self_group(against=operator)
         self.right = right.self_group(against=operator)
         self.operator = operator
-        self.type: TypeEngine[_T] = type_api.to_instance(type_)
+
+        # if type is None, we get NULLTYPE, which is our _T.  But I don't
+        # know how to get the overloads to express that correctly
+        self.type = type_api.to_instance(type_)  # type: ignore
+
         self.negate = negate
         self._is_implicitly_boolean = operators.is_boolean(operator)
 
@@ -3509,7 +3532,9 @@ class Grouping(GroupedElement, ColumnElement[_T]):
         self, element: Union[TextClause, ClauseList, ColumnElement[_T]]
     ):
         self.element = element
-        self.type = getattr(element, "type", type_api.NULLTYPE)
+
+        # nulltype assignment issue
+        self.type = getattr(element, "type", type_api.NULLTYPE)  # type: ignore
 
     def _with_binary_element_type(self, type_):
         return self.__class__(self.element._with_binary_element_type(type_))
@@ -3926,10 +3951,13 @@ class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]):
 
         self.key = self._tq_label = self._tq_key_label = self.name
         self._element = element
-        # self._type = type_
-        self.type = type_api.to_instance(
-            type_ or getattr(self._element, "type", None)
+
+        self.type = (
+            type_api.to_instance(type_)
+            if type_ is not None
+            else self._element.type
         )
+
         self._proxies = [element]
 
     def __reduce__(self):
@@ -4178,7 +4206,11 @@ class ColumnClause(
     ):
         self.key = self.name = text
         self.table = _selectable
-        self.type: TypeEngine[_T] = type_api.to_instance(type_)
+
+        # if type is None, we get NULLTYPE, which is our _T.  But I don't
+        # know how to get the overloads to express that correctly
+        self.type = type_api.to_instance(type_)  # type: ignore
+
         self.is_literal = is_literal
 
     def get_children(self, column_tables=False, **kw):
@@ -4465,19 +4497,32 @@ class quoted_name(util.MemoizedSlots, str):
 
     quote: Optional[bool]
 
-    def __new__(cls, value, quote):
+    @overload
+    @classmethod
+    def construct(cls, value: str, quote: Optional[bool]) -> quoted_name:
+        ...
+
+    @overload
+    @classmethod
+    def construct(cls, value: None, quote: Optional[bool]) -> None:
+        ...
+
+    @classmethod
+    def construct(
+        cls, value: Optional[str], quote: Optional[bool]
+    ) -> Optional[quoted_name]:
         if value is None:
             return None
-        # experimental - don't bother with quoted_name
-        # if quote flag is None.  doesn't seem to make any dent
-        # in performance however
-        # elif not sprcls and quote is None:
-        #   return value
-        elif isinstance(value, cls) and (
-            quote is None or value.quote == quote
-        ):
+        else:
+            return quoted_name(value, quote)
+
+    def __new__(cls, value: str, quote: Optional[bool]) -> quoted_name:
+        assert (
+            value is not None
+        ), "use quoted_name.construct() for None passthrough"
+        if isinstance(value, cls) and (quote is None or value.quote == quote):
             return value
-        self = super(quoted_name, cls).__new__(cls, value)
+        self = super().__new__(cls, value)
 
         self.quote = quote
         return self
@@ -4579,15 +4624,15 @@ class _truncated_label(quoted_name):
 
     __slots__ = ()
 
-    def __new__(cls, value, quote=None):
+    def __new__(cls, value: str, quote: Optional[bool] = None) -> Any:
         quote = getattr(value, "quote", quote)
         # return super(_truncated_label, cls).__new__(cls, value, quote, True)
         return super(_truncated_label, cls).__new__(cls, value, quote)
 
-    def __reduce__(self):
+    def __reduce__(self) -> Any:
         return self.__class__, (str(self), self.quote)
 
-    def apply_map(self, map_):
+    def apply_map(self, map_: Mapping[str, Any]) -> str:
         return self
 
 
index 78d5241278da55f4e977b8b6117449ad2f884920..5cfb55603feea6d66f0d0bc132c595a43082bb06 100644 (file)
@@ -3077,7 +3077,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
         elif metadata is not None and schema is None and metadata.schema:
             self.schema = schema = metadata.schema
         else:
-            self.schema = quoted_name(schema, quote_schema)
+            self.schema = quoted_name.construct(schema, quote_schema)
         self.metadata = metadata
         self._key = _get_table_key(name, schema)
         if metadata:
@@ -4258,7 +4258,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
         """
         self.table = table = None
 
-        self.name = quoted_name(name, kw.pop("quote", None))
+        self.name = quoted_name.construct(name, kw.pop("quote", None))
         self.unique = kw.pop("unique", False)
         _column_flag = kw.pop("_column_flag", False)
         if "info" in kw:
@@ -4493,7 +4493,7 @@ class MetaData(HasSchemaAttr):
 
         """
         self.tables = util.FacadeDict()
-        self.schema = quoted_name(schema, quote_schema)
+        self.schema = quoted_name.construct(schema, quote_schema)
         self.naming_convention = (
             naming_convention
             if naming_convention
index 4d0169370c4924db2ab0ca1f6b1cc0f9c1817bbd..829c1b72e7c1690c3e82d7aac59731ea4c60f276 100644 (file)
@@ -17,9 +17,15 @@ import enum
 import json
 import pickle
 from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
 from typing import overload
 from typing import Sequence
 from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
@@ -40,6 +46,7 @@ from .type_api import NativeForEmulated  # noqa
 from .type_api import to_instance
 from .type_api import TypeDecorator
 from .type_api import TypeEngine
+from .type_api import TypeEngineMixin
 from .type_api import Variant  # noqa
 from .visitors import InternalTraversal
 from .. import event
@@ -51,11 +58,19 @@ from ..util import langhelpers
 from ..util import OrderedDict
 from ..util.typing import Literal
 
+if TYPE_CHECKING:
+    from .operators import OperatorType
+    from .type_api import _BindProcessorType
+    from .type_api import _ComparatorFactory
+    from .type_api import _ResultProcessorType
+    from ..engine.interfaces import Dialect
 
 _T = TypeVar("_T", bound="Any")
+_CT = TypeVar("_CT", bound=Any)
+_TE = TypeVar("_TE", bound="TypeEngine[Any]")
 
 
-class _LookupExpressionAdapter:
+class HasExpressionLookup(TypeEngineMixin):
 
     """Mixin expression adaptations based on lookup tables.
 
@@ -68,11 +83,18 @@ class _LookupExpressionAdapter:
     def _expression_adaptations(self):
         raise NotImplementedError()
 
-    class Comparator(TypeEngine.Comparator[_T]):
-        _blank_dict = util.immutabledict()
+    class Comparator(TypeEngine.Comparator[_CT]):
+
+        _blank_dict = util.EMPTY_DICT
 
-        def _adapt_expression(self, op, other_comparator):
+        def _adapt_expression(
+            self,
+            op: OperatorType,
+            other_comparator: TypeEngine.Comparator[Any],
+        ) -> Tuple[OperatorType, TypeEngine[Any]]:
             othertype = other_comparator.type._type_affinity
+            if TYPE_CHECKING:
+                assert isinstance(self.type, HasExpressionLookup)
             lookup = self.type._expression_adaptations.get(
                 op, self._blank_dict
             ).get(othertype, self.type)
@@ -83,16 +105,20 @@ class _LookupExpressionAdapter:
             else:
                 return (op, to_instance(lookup))
 
-    comparator_factory = Comparator
+    comparator_factory: _ComparatorFactory[Any] = Comparator
 
 
-class Concatenable:
+class Concatenable(TypeEngineMixin):
 
     """A mixin that marks a type as supporting 'concatenation',
     typically strings."""
 
     class Comparator(TypeEngine.Comparator[_T]):
-        def _adapt_expression(self, op, other_comparator):
+        def _adapt_expression(
+            self,
+            op: OperatorType,
+            other_comparator: TypeEngine.Comparator[Any],
+        ) -> Tuple[OperatorType, TypeEngine[Any]]:
             if op is operators.add and isinstance(
                 other_comparator,
                 (Concatenable.Comparator, NullType.Comparator),
@@ -103,10 +129,10 @@ class Concatenable:
                     op, other_comparator
                 )
 
-    comparator_factory = Comparator
+    comparator_factory: _ComparatorFactory[Any] = Comparator
 
 
-class Indexable:
+class Indexable(TypeEngineMixin):
     """A mixin that marks a type as supporting indexing operations,
     such as array or JSON structures.
 
@@ -151,7 +177,7 @@ class String(Concatenable, TypeEngine[str]):
         # note pylance appears to require the "self" type in a constructor
         # for the _T type to be correctly recognized when we send the
         # class as the argument, e.g. `column("somecol", String)`
-        self: "String",
+        self,
         length=None,
         collation=None,
     ):
@@ -313,7 +339,7 @@ class UnicodeText(Text):
         super(UnicodeText, self).__init__(length=length, **kwargs)
 
 
-class Integer(_LookupExpressionAdapter, TypeEngine[int]):
+class Integer(HasExpressionLookup, TypeEngine[int]):
 
     """A type for ``int`` integers."""
 
@@ -378,7 +404,7 @@ class BigInteger(Integer):
 _N = TypeVar("_N", bound=Union[decimal.Decimal, float])
 
 
-class Numeric(_LookupExpressionAdapter, TypeEngine[_N]):
+class Numeric(HasExpressionLookup, TypeEngine[_N]):
 
     """A type for fixed precision numbers, such as ``NUMERIC`` or ``DECIMAL``.
 
@@ -423,7 +449,7 @@ class Numeric(_LookupExpressionAdapter, TypeEngine[_N]):
     _default_decimal_return_scale = 10
 
     def __init__(
-        self: "Numeric",
+        self,
         precision=None,
         scale=None,
         decimal_return_scale=None,
@@ -573,34 +599,26 @@ class Float(Numeric[_N]):
     @overload
     def __init__(
         self: Float[float],
-        precision=...,
-        decimal_return_scale=...,
+        precision: Optional[int] = ...,
+        asdecimal: Literal[False] = ...,
+        decimal_return_scale: Optional[int] = ...,
     ):
         ...
 
     @overload
     def __init__(
         self: Float[decimal.Decimal],
-        precision=...,
+        precision: Optional[int] = ...,
         asdecimal: Literal[True] = ...,
-        decimal_return_scale=...,
-    ):
-        ...
-
-    @overload
-    def __init__(
-        self: Float[float],
-        precision=...,
-        asdecimal: Literal[False] = ...,
-        decimal_return_scale=...,
+        decimal_return_scale: Optional[int] = ...,
     ):
         ...
 
     def __init__(
         self: Float[_N],
-        precision=None,
-        asdecimal=False,
-        decimal_return_scale=None,
+        precision: Optional[int] = None,
+        asdecimal: bool = False,
+        decimal_return_scale: Optional[int] = None,
     ):
         r"""
         Construct a Float.
@@ -662,7 +680,7 @@ class Float(Numeric[_N]):
             return None
 
 
-class Double(Float):
+class Double(Float[_N]):
     """A type for double ``FLOAT`` floating point types.
 
     Typically generates a ``DOUBLE`` or ``DOUBLE_PRECISION`` in DDL,
@@ -676,7 +694,7 @@ class Double(Float):
     __visit_name__ = "double"
 
 
-class DateTime(_LookupExpressionAdapter, TypeEngine[dt.datetime]):
+class DateTime(HasExpressionLookup, TypeEngine[dt.datetime]):
 
     """A type for ``datetime.datetime()`` objects.
 
@@ -738,7 +756,7 @@ class DateTime(_LookupExpressionAdapter, TypeEngine[dt.datetime]):
         }
 
 
-class Date(_LookupExpressionAdapter, TypeEngine[dt.date]):
+class Date(HasExpressionLookup, TypeEngine[dt.date]):
 
     """A type for ``datetime.date()`` objects."""
 
@@ -776,7 +794,7 @@ class Date(_LookupExpressionAdapter, TypeEngine[dt.date]):
         }
 
 
-class Time(_LookupExpressionAdapter, TypeEngine[dt.time]):
+class Time(HasExpressionLookup, TypeEngine[dt.time]):
 
     """A type for ``datetime.time()`` objects."""
 
@@ -895,9 +913,10 @@ class LargeBinary(_Binary):
         _Binary.__init__(self, length=length)
 
 
-class SchemaType(SchemaEventTarget):
+class SchemaType(SchemaEventTarget, TypeEngineMixin):
 
-    """Mark a type as possibly requiring schema-level DDL for usage.
+    """Add capabilities to a type which allow for schema-level DDL to be
+    associated with a type.
 
     Supports types that must be explicitly created/dropped (i.e. PG ENUM type)
     as well as types that are complimented by table or schema level
@@ -920,6 +939,8 @@ class SchemaType(SchemaEventTarget):
 
     _use_schema_map = True
 
+    name: Optional[str]
+
     def __init__(
         self,
         name=None,
@@ -1021,33 +1042,37 @@ class SchemaType(SchemaEventTarget):
             )
 
     def copy(self, **kw):
-        return self.adapt(self.__class__, _create_events=True)
-
-    def adapt(self, impltype, **kw):
-        schema = kw.pop("schema", self.schema)
-        metadata = kw.pop("metadata", self.metadata)
-        _create_events = kw.pop("_create_events", False)
-        return impltype(
-            name=self.name,
-            schema=schema,
-            inherit_schema=self.inherit_schema,
-            metadata=metadata,
-            _create_events=_create_events,
-            **kw,
+        return self.adapt(
+            cast("Type[TypeEngine[Any]]", self.__class__),
+            _create_events=True,
         )
 
+    @overload
+    def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
+        ...
+
+    @overload
+    def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
+        ...
+
+    def adapt(
+        self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any
+    ) -> TypeEngine[Any]:
+        kw.setdefault("_create_events", False)
+        return super().adapt(cls, **kw)
+
     def create(self, bind, checkfirst=False):
         """Issue CREATE DDL for this type, if applicable."""
 
         t = self.dialect_impl(bind.dialect)
-        if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+        if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t.create(bind, checkfirst=checkfirst)
 
     def drop(self, bind, checkfirst=False):
         """Issue DROP DDL for this type, if applicable."""
 
         t = self.dialect_impl(bind.dialect)
-        if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+        if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t.drop(bind, checkfirst=checkfirst)
 
     def _on_table_create(self, target, bind, **kw):
@@ -1055,7 +1080,7 @@ class SchemaType(SchemaEventTarget):
             return
 
         t = self.dialect_impl(bind.dialect)
-        if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+        if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t._on_table_create(target, bind, **kw)
 
     def _on_table_drop(self, target, bind, **kw):
@@ -1063,7 +1088,7 @@ class SchemaType(SchemaEventTarget):
             return
 
         t = self.dialect_impl(bind.dialect)
-        if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+        if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t._on_table_drop(target, bind, **kw)
 
     def _on_metadata_create(self, target, bind, **kw):
@@ -1071,7 +1096,7 @@ class SchemaType(SchemaEventTarget):
             return
 
         t = self.dialect_impl(bind.dialect)
-        if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+        if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t._on_metadata_create(target, bind, **kw)
 
     def _on_metadata_drop(self, target, bind, **kw):
@@ -1079,7 +1104,7 @@ class SchemaType(SchemaEventTarget):
             return
 
         t = self.dialect_impl(bind.dialect)
-        if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+        if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t._on_metadata_drop(target, bind, **kw)
 
     def _is_impl_for_variant(self, dialect, kw):
@@ -1112,7 +1137,7 @@ class SchemaType(SchemaEventTarget):
             return _we_are_the_impl(variant_mapping["_default"])
 
 
-class Enum(Emulated, String, TypeEngine[Union[str, enum.Enum]], SchemaType):
+class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
     """Generic Enum Type.
 
     The :class:`.Enum` type provides a set of possible string values
@@ -1464,8 +1489,14 @@ class Enum(Emulated, String, TypeEngine[Union[str, enum.Enum]], SchemaType):
                     )
                 ) from err
 
-    class Comparator(String.Comparator[_T]):
-        def _adapt_expression(self, op, other_comparator):
+    class Comparator(String.Comparator[str]):
+        type: String
+
+        def _adapt_expression(
+            self,
+            op: OperatorType,
+            other_comparator: TypeEngine.Comparator[Any],
+        ) -> Tuple[OperatorType, TypeEngine[Any]]:
             op, typ = super(Enum.Comparator, self)._adapt_expression(
                 op, other_comparator
             )
@@ -1663,15 +1694,16 @@ class PickleType(TypeDecorator[object]):
         return PickleType, (self.protocol, None, self.comparator)
 
     def bind_processor(self, dialect):
-        impl_processor = self.impl.bind_processor(dialect)
+        impl_processor = self.impl_instance.bind_processor(dialect)
         dumps = self.pickler.dumps
         protocol = self.protocol
         if impl_processor:
+            fixed_impl_processor = impl_processor
 
             def process(value):
                 if value is not None:
                     value = dumps(value, protocol)
-                return impl_processor(value)
+                return fixed_impl_processor(value)
 
         else:
 
@@ -1683,12 +1715,13 @@ class PickleType(TypeDecorator[object]):
         return process
 
     def result_processor(self, dialect, coltype):
-        impl_processor = self.impl.result_processor(dialect, coltype)
+        impl_processor = self.impl_instance.result_processor(dialect, coltype)
         loads = self.pickler.loads
         if impl_processor:
+            fixed_impl_processor = impl_processor
 
             def process(value):
-                value = impl_processor(value)
+                value = fixed_impl_processor(value)
                 if value is None:
                     return None
                 return loads(value)
@@ -1709,7 +1742,7 @@ class PickleType(TypeDecorator[object]):
             return x == y
 
 
-class Boolean(Emulated, TypeEngine[bool], SchemaType):
+class Boolean(SchemaType, Emulated, TypeEngine[bool]):
 
     """A bool datatype.
 
@@ -1733,7 +1766,7 @@ class Boolean(Emulated, TypeEngine[bool], SchemaType):
     native = True
 
     def __init__(
-        self: "Boolean",
+        self,
         create_constraint=False,
         name=None,
         _create_events=True,
@@ -1818,6 +1851,9 @@ class Boolean(Emulated, TypeEngine[bool], SchemaType):
 
     def bind_processor(self, dialect):
         _strict_as_bool = self._strict_as_bool
+
+        _coerce: Union[Type[bool], Type[int]]
+
         if dialect.supports_native_boolean:
             _coerce = bool
         else:
@@ -1838,7 +1874,7 @@ class Boolean(Emulated, TypeEngine[bool], SchemaType):
             return processors.int_to_boolean
 
 
-class _AbstractInterval(_LookupExpressionAdapter, TypeEngine[dt.timedelta]):
+class _AbstractInterval(HasExpressionLookup, TypeEngine[dt.timedelta]):
     @util.memoized_property
     def _expression_adaptations(self):
         # Based on https://www.postgresql.org/docs/current/\
@@ -1856,16 +1892,12 @@ class _AbstractInterval(_LookupExpressionAdapter, TypeEngine[dt.timedelta]):
             operators.truediv: {Numeric: self.__class__},
         }
 
-    @property
-    def _type_affinity(self):
+    @util.non_memoized_property
+    def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
         return Interval
 
-    def coerce_compared_value(self, op, value):
-        """See :meth:`.TypeEngine.coerce_compared_value` for a description."""
-        return self.impl.coerce_compared_value(op, value)
 
-
-class Interval(Emulated, _AbstractInterval, TypeDecorator):
+class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]):
 
     """A type for ``datetime.timedelta()`` objects.
 
@@ -1909,6 +1941,14 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator):
         self.second_precision = second_precision
         self.day_precision = day_precision
 
+    class Comparator(
+        TypeDecorator.Comparator[_CT],
+        _AbstractInterval.Comparator[_CT],
+    ):
+        pass
+
+    comparator_factory = Comparator
+
     @property
     def python_type(self):
         return dt.timedelta
@@ -1916,42 +1956,63 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator):
     def adapt_to_emulated(self, impltype, **kw):
         return _AbstractInterval.adapt(self, impltype, **kw)
 
-    def bind_processor(self, dialect):
-        impl_processor = self.impl.bind_processor(dialect)
+    def coerce_compared_value(self, op, value):
+        return self.impl_instance.coerce_compared_value(op, value)
+
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> _BindProcessorType[dt.timedelta]:
+        if TYPE_CHECKING:
+            assert isinstance(self.impl_instance, DateTime)
+        impl_processor = self.impl_instance.bind_processor(dialect)
         epoch = self.epoch
         if impl_processor:
+            fixed_impl_processor = impl_processor
 
-            def process(value):
+            def process(
+                value: Optional[dt.timedelta],
+            ) -> Any:
                 if value is not None:
-                    value = epoch + value
-                return impl_processor(value)
+                    dt_value = epoch + value
+                else:
+                    dt_value = None
+                return fixed_impl_processor(dt_value)
 
         else:
 
-            def process(value):
+            def process(
+                value: Optional[dt.timedelta],
+            ) -> Any:
                 if value is not None:
-                    value = epoch + value
-                return value
+                    dt_value = epoch + value
+                else:
+                    dt_value = None
+                return dt_value
 
         return process
 
-    def result_processor(self, dialect, coltype):
-        impl_processor = self.impl.result_processor(dialect, coltype)
+    def result_processor(
+        self, dialect: Dialect, coltype: Any
+    ) -> _ResultProcessorType[dt.timedelta]:
+        if TYPE_CHECKING:
+            assert isinstance(self.impl_instance, DateTime)
+        impl_processor = self.impl_instance.result_processor(dialect, coltype)
         epoch = self.epoch
         if impl_processor:
+            fixed_impl_processor = impl_processor
 
-            def process(value):
-                value = impl_processor(value)
-                if value is None:
+            def process(value: Any) -> Optional[dt.timedelta]:
+                dt_value = fixed_impl_processor(value)
+                if dt_value is None:
                     return None
-                return value - epoch
+                return dt_value - epoch
 
         else:
 
-            def process(value):
+            def process(value: Any) -> Optional[dt.timedelta]:
                 if value is None:
                     return None
-                return value - epoch
+                return value - epoch  # type: ignore
 
         return process
 
@@ -2233,7 +2294,7 @@ class JSON(Indexable, TypeEngine[Any]):
         """
         self.none_as_null = none_as_null
 
-    class JSONElementType(TypeEngine):
+    class JSONElementType(TypeEngine[Any]):
         """Common function for index / path elements in a JSON expression."""
 
         _integer = Integer()
@@ -2457,7 +2518,7 @@ class JSON(Indexable, TypeEngine[Any]):
     def python_type(self):
         return dict
 
-    @property
+    @property  # type: ignore  # mypy property bug
     def should_evaluate_none(self):
         """Alias of :attr:`_types.JSON.none_as_null`"""
         return not self.none_as_null
@@ -2632,23 +2693,26 @@ class ARRAY(
         """
 
         def _setup_getitem(self, index):
+
+            arr_type = cast(ARRAY, self.type)
+
             if isinstance(index, slice):
-                return_type = self.type
-                if self.type.zero_indexes:
+                return_type = arr_type
+                if arr_type.zero_indexes:
                     index = slice(index.start + 1, index.stop + 1, index.step)
                 slice_ = Slice(
                     index.start, index.stop, index.step, _name=self.expr.key
                 )
                 return operators.getitem, slice_, return_type
             else:
-                if self.type.zero_indexes:
+                if arr_type.zero_indexes:
                     index += 1
-                if self.type.dimensions is None or self.type.dimensions == 1:
-                    return_type = self.type.item_type
+                if arr_type.dimensions is None or arr_type.dimensions == 1:
+                    return_type = arr_type.item_type
                 else:
-                    adapt_kw = {"dimensions": self.type.dimensions - 1}
-                    return_type = self.type.adapt(
-                        self.type.__class__, **adapt_kw
+                    adapt_kw = {"dimensions": arr_type.dimensions - 1}
+                    return_type = arr_type.adapt(
+                        arr_type.__class__, **adapt_kw
                     )
 
                 return operators.getitem, index, return_type
@@ -2853,7 +2917,7 @@ class TupleType(TypeEngine[Tuple[Any, ...]]):
         )
 
 
-class REAL(Float):
+class REAL(Float[_N]):
 
     """The SQL REAL type.
 
@@ -2866,7 +2930,7 @@ class REAL(Float):
     __visit_name__ = "REAL"
 
 
-class FLOAT(Float):
+class FLOAT(Float[_N]):
 
     """The SQL FLOAT type.
 
@@ -2879,7 +2943,7 @@ class FLOAT(Float):
     __visit_name__ = "FLOAT"
 
 
-class DOUBLE(Double):
+class DOUBLE(Double[_N]):
     """The SQL DOUBLE type.
 
     .. versionadded:: 2.0
@@ -2893,7 +2957,7 @@ class DOUBLE(Double):
     __visit_name__ = "DOUBLE"
 
 
-class DOUBLE_PRECISION(Double):
+class DOUBLE_PRECISION(Double[_N]):
     """The SQL DOUBLE PRECISION type.
 
     .. versionadded:: 2.0
@@ -2907,7 +2971,7 @@ class DOUBLE_PRECISION(Double):
     __visit_name__ = "DOUBLE_PRECISION"
 
 
-class NUMERIC(Numeric):
+class NUMERIC(Numeric[_N]):
 
     """The SQL NUMERIC type.
 
@@ -2920,7 +2984,7 @@ class NUMERIC(Numeric):
     __visit_name__ = "NUMERIC"
 
 
-class DECIMAL(Numeric):
+class DECIMAL(Numeric[_N]):
 
     """The SQL DECIMAL type.
 
@@ -3099,7 +3163,7 @@ class BOOLEAN(Boolean):
     __visit_name__ = "BOOLEAN"
 
 
-class NullType(TypeEngine):
+class NullType(TypeEngine[None]):
 
     """An unknown type.
 
@@ -3139,7 +3203,11 @@ class NullType(TypeEngine):
         return process
 
     class Comparator(TypeEngine.Comparator[_T]):
-        def _adapt_expression(self, op, other_comparator):
+        def _adapt_expression(
+            self,
+            op: OperatorType,
+            other_comparator: TypeEngine.Comparator[Any],
+        ) -> Tuple[OperatorType, TypeEngine[Any]]:
             if isinstance(
                 other_comparator, NullType.Comparator
             ) or not operators.is_commutative(op):
@@ -3150,7 +3218,7 @@ class NullType(TypeEngine):
     comparator_factory = Comparator
 
 
-class TableValueType(HasCacheKey, TypeEngine):
+class TableValueType(HasCacheKey, TypeEngine[Any]):
     """Refers to a table value type."""
 
     _is_table_value = True
@@ -3195,7 +3263,7 @@ _TIME = Time()
 _STRING = String()
 _UNICODE = Unicode()
 
-_type_map = {
+_type_map: Dict[Type[Any], TypeEngine[Any]] = {
     int: Integer(),
     float: Float(),
     bool: BOOLEANTYPE,
@@ -3204,7 +3272,7 @@ _type_map = {
     dt.datetime: _DATETIME,
     dt.time: _TIME,
     dt.timedelta: Interval(),
-    util.NoneType: NULLTYPE,
+    type(None): NULLTYPE,
     bytes: LargeBinary(),
     str: _STRING,
 }
@@ -3213,7 +3281,7 @@ _type_map = {
 _type_map_get = _type_map.get
 
 
-def _resolve_value_to_type(value):
+def _resolve_value_to_type(value: Any) -> TypeEngine[Any]:
     _result_type = _type_map_get(type(value), False)
     if _result_type is False:
         # use inspect() to detect SQLAlchemy built-in
@@ -3231,7 +3299,9 @@ def _resolve_value_to_type(value):
             )
         return NULLTYPE
     else:
-        return _result_type._resolve_for_literal(value)
+        return _result_type._resolve_for_literal(  # type: ignore [union-attr]
+            value
+        )
 
 
 # back-assign to type_api
@@ -3240,7 +3310,6 @@ type_api.STRINGTYPE = STRINGTYPE
 type_api.INTEGERTYPE = INTEGERTYPE
 type_api.NULLTYPE = NULLTYPE
 type_api.MATCHTYPE = MATCHTYPE
-type_api.INDEXABLE = Indexable
+type_api.INDEXABLE = INDEXABLE = Indexable
 type_api.TABLEVALUE = TABLEVALUE
 type_api._resolve_value_to_type = _resolve_value_to_type
-TypeEngine.Comparator.BOOLEANTYPE = BOOLEANTYPE
index 55997556afd286644bb5bc020ee9f49e18a28200..5a0aba694f9552edc88cb71214b7d7386375c637 100644 (file)
 
 from __future__ import annotations
 
+from types import ModuleType
 import typing
 from typing import Any
 from typing import Callable
+from typing import cast
+from typing import Dict
 from typing import Generic
+from typing import Mapping
 from typing import Optional
+from typing import overload
+from typing import Sequence
 from typing import Tuple
 from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
 from .base import SchemaEventTarget
+from .cache_key import CacheConst
 from .cache_key import NO_CACHE
 from .operators import ColumnOperators
 from .visitors import Visitable
 from .. import exc
 from .. import util
+from ..util.typing import Protocol
+from ..util.typing import TypedDict
+from ..util.typing import TypeGuard
 
 # these are back-assigned by sqltypes.
 if typing.TYPE_CHECKING:
+    from .elements import BindParameter
     from .elements import ColumnElement
     from .operators import OperatorType
+    from .schema import Column
     from .sqltypes import _resolve_value_to_type as _resolve_value_to_type
     from .sqltypes import BOOLEANTYPE as BOOLEANTYPE
-    from .sqltypes import Indexable as INDEXABLE
+    from .sqltypes import INDEXABLE as INDEXABLE
     from .sqltypes import INTEGERTYPE as INTEGERTYPE
     from .sqltypes import MATCHTYPE as MATCHTYPE
     from .sqltypes import NULLTYPE as NULLTYPE
+    from .sqltypes import NullType
+    from .sqltypes import STRINGTYPE as STRINGTYPE
+    from .sqltypes import TABLEVALUE as TABLEVALUE
+    from ..engine.interfaces import Dialect
 
 
 _T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
+_T_con = TypeVar("_T_con", bound=Any, contravariant=True)
+_O = TypeVar("_O", bound=object)
 _TE = TypeVar("_TE", bound="TypeEngine[Any]")
 _CT = TypeVar("_CT", bound=Any)
 
 # replace with pep-673 when applicable
-SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine")
+SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine[Any]")
+
+
+class _LiteralProcessorType(Protocol[_T_co]):
+    def __call__(self, value: Any) -> str:
+        ...
+
+
+class _BindProcessorType(Protocol[_T_con]):
+    def __call__(self, value: Optional[_T_con]) -> Any:
+        ...
+
+
+class _ResultProcessorType(Protocol[_T_co]):
+    def __call__(self, value: Any) -> Optional[_T_co]:
+        ...
+
+
+class _BaseTypeMemoDict(TypedDict):
+    impl: TypeEngine[Any]
+    result: Dict[Any, Optional[_ResultProcessorType[Any]]]
+
+
+class _TypeMemoDict(_BaseTypeMemoDict, total=False):
+    literal: Optional[_LiteralProcessorType[Any]]
+    bind: Optional[_BindProcessorType[Any]]
+    custom: Dict[Any, object]
+
+
+class _ComparatorFactory(Protocol[_T]):
+    def __call__(self, expr: ColumnElement[_T]) -> TypeEngine.Comparator[_T]:
+        ...
 
 
 class TypeEngine(Visitable, Generic[_T]):
@@ -70,8 +121,6 @@ class TypeEngine(Visitable, Generic[_T]):
     _is_array = False
     _is_type_decorator = False
 
-    _block_from_type_affinity = False
-
     render_bind_cast = False
     """Render bind casts for :attr:`.BindTyping.RENDER_CASTS` mode.
 
@@ -99,38 +148,41 @@ class TypeEngine(Visitable, Generic[_T]):
 
         __slots__ = "expr", "type"
 
-        default_comparator = None
+        expr: ColumnElement[_CT]
+        type: TypeEngine[_CT]
 
-        def __clause_element__(self):
+        def __clause_element__(self) -> ColumnElement[_CT]:
             return self.expr
 
-        def __init__(self, expr: "ColumnElement[_CT]"):
+        def __init__(self, expr: ColumnElement[_CT]):
             self.expr = expr
-            self.type: TypeEngine[_CT] = expr.type
+            self.type = expr.type
 
         @util.preload_module("sqlalchemy.sql.default_comparator")
         def operate(
-            self, op: "OperatorType", *other, **kwargs
-        ) -> "ColumnElement":
+            self, op: OperatorType, *other: Any, **kwargs: Any
+        ) -> ColumnElement[_CT]:
             default_comparator = util.preloaded.sql_default_comparator
             op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__]
             if kwargs:
                 addtl_kw = addtl_kw.union(kwargs)
-            return op_fn(self.expr, op, *other, **addtl_kw)
+            return op_fn(self.expr, op, *other, **addtl_kw)  # type: ignore
 
         @util.preload_module("sqlalchemy.sql.default_comparator")
         def reverse_operate(
-            self, op: "OperatorType", other, **kwargs
-        ) -> "ColumnElement":
+            self, op: OperatorType, other: Any, **kwargs: Any
+        ) -> ColumnElement[_CT]:
             default_comparator = util.preloaded.sql_default_comparator
             op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__]
             if kwargs:
                 addtl_kw = addtl_kw.union(kwargs)
-            return op_fn(self.expr, op, other, reverse=True, **addtl_kw)
+            return op_fn(self.expr, op, other, reverse=True, **addtl_kw)  # type: ignore  # noqa E501
 
         def _adapt_expression(
-            self, op: "OperatorType", other_comparator
-        ) -> Tuple["OperatorType", "TypeEngine[_CT]"]:
+            self,
+            op: OperatorType,
+            other_comparator: TypeEngine.Comparator[Any],
+        ) -> Tuple[OperatorType, TypeEngine[Any]]:
             """evaluate the return type of <self> <op> <othertype>,
             and apply any adaptations to the given operator.
 
@@ -159,7 +211,7 @@ class TypeEngine(Visitable, Generic[_T]):
 
             return op, self.type
 
-        def __reduce__(self):
+        def __reduce__(self) -> Any:
             return _reconstitute_comparator, (self.expr,)
 
     hashable = True
@@ -169,7 +221,7 @@ class TypeEngine(Visitable, Generic[_T]):
 
     """
 
-    comparator_factory = Comparator
+    comparator_factory: _ComparatorFactory[Any] = Comparator
     """A :class:`.TypeEngine.Comparator` class which will apply
     to operations performed by owning :class:`_expression.ColumnElement`
     objects.
@@ -193,7 +245,7 @@ class TypeEngine(Visitable, Generic[_T]):
 
     """
 
-    sort_key_function = None
+    sort_key_function: Optional[Callable[[Any], Any]] = None
     """A sorting function that can be passed as the key to sorted.
 
     The default value of ``None`` indicates that the values stored by
@@ -203,7 +255,7 @@ class TypeEngine(Visitable, Generic[_T]):
 
     """
 
-    should_evaluate_none = False
+    should_evaluate_none: bool = False
     """If True, the Python constant ``None`` is considered to be handled
     explicitly by this type.
 
@@ -226,9 +278,11 @@ class TypeEngine(Visitable, Generic[_T]):
 
     """
 
-    _variant_mapping = util.EMPTY_DICT
+    _variant_mapping: util.immutabledict[
+        str, TypeEngine[Any]
+    ] = util.EMPTY_DICT
 
-    def evaluates_none(self):
+    def evaluates_none(self: SelfTypeEngine) -> SelfTypeEngine:
         """Return a copy of this type which has the :attr:`.should_evaluate_none`
         flag set to True.
 
@@ -280,10 +334,12 @@ class TypeEngine(Visitable, Generic[_T]):
         typ.should_evaluate_none = True
         return typ
 
-    def copy(self, **kw):
+    def copy(self: SelfTypeEngine, **kw: Any) -> SelfTypeEngine:
         return self.adapt(self.__class__)
 
-    def compare_against_backend(self, dialect, conn_type):
+    def compare_against_backend(
+        self, dialect: Dialect, conn_type: TypeEngine[Any]
+    ) -> Optional[bool]:
         """Compare this type against the given backend type.
 
         This function is currently not implemented for SQLAlchemy
@@ -310,10 +366,12 @@ class TypeEngine(Visitable, Generic[_T]):
         """
         return None
 
-    def copy_value(self, value):
+    def copy_value(self, value: Any) -> Any:
         return value
 
-    def literal_processor(self, dialect):
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> Optional[_LiteralProcessorType[_T]]:
         """Return a conversion function for processing literal values that are
         to be rendered directly without using binds.
 
@@ -348,7 +406,9 @@ class TypeEngine(Visitable, Generic[_T]):
         """
         return None
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> Optional[_BindProcessorType[_T]]:
         """Return a conversion function for processing bind values.
 
         Returns a callable which will receive a bind parameter value
@@ -382,7 +442,9 @@ class TypeEngine(Visitable, Generic[_T]):
         """
         return None
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> Optional[_ResultProcessorType[_T]]:
         """Return a conversion function for processing result row values.
 
         Returns a callable which will receive a result row column
@@ -417,7 +479,9 @@ class TypeEngine(Visitable, Generic[_T]):
         """
         return None
 
-    def column_expression(self, colexpr):
+    def column_expression(
+        self, colexpr: ColumnElement[_T]
+    ) -> Optional[ColumnElement[_T]]:
         """Given a SELECT column expression, return a wrapping SQL expression.
 
         This is typically a SQL function that wraps a column expression
@@ -461,7 +525,7 @@ class TypeEngine(Visitable, Generic[_T]):
         return None
 
     @util.memoized_property
-    def _has_column_expression(self):
+    def _has_column_expression(self) -> bool:
         """memoized boolean, check if column_expression is implemented.
 
         Allows the method to be skipped for the vast majority of expression
@@ -474,7 +538,9 @@ class TypeEngine(Visitable, Generic[_T]):
             is not TypeEngine.column_expression.__code__
         )
 
-    def bind_expression(self, bindvalue):
+    def bind_expression(
+        self, bindvalue: BindParameter[_T]
+    ) -> Optional[ColumnElement[_T]]:
         """Given a bind value (i.e. a :class:`.BindParameter` instance),
         return a SQL expression in its place.
 
@@ -521,7 +587,7 @@ class TypeEngine(Visitable, Generic[_T]):
         return None
 
     @util.memoized_property
-    def _has_bind_expression(self):
+    def _has_bind_expression(self) -> bool:
         """memoized boolean, check if bind_expression is implemented.
 
         Allows the method to be skipped for the vast majority of expression
@@ -535,12 +601,12 @@ class TypeEngine(Visitable, Generic[_T]):
     def _to_instance(cls_or_self: Union[Type[_TE], _TE]) -> _TE:
         return to_instance(cls_or_self)
 
-    def compare_values(self, x, y):
+    def compare_values(self, x: Any, y: Any) -> bool:
         """Compare two values for equality."""
 
-        return x == y
+        return x == y  # type: ignore[no-any-return]
 
-    def get_dbapi_type(self, dbapi):
+    def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]:
         """Return the corresponding type object from the underlying DB-API, if
         any.
 
@@ -550,7 +616,7 @@ class TypeEngine(Visitable, Generic[_T]):
         return None
 
     @property
-    def python_type(self):
+    def python_type(self) -> Type[Any]:
         """Return the Python type object expected to be returned
         by instances of this type, if known.
 
@@ -569,7 +635,7 @@ class TypeEngine(Visitable, Generic[_T]):
         raise NotImplementedError()
 
     def with_variant(
-        self: SelfTypeEngine, type_: "TypeEngine", *dialect_names: str
+        self: SelfTypeEngine, type_: TypeEngine[Any], *dialect_names: str
     ) -> SelfTypeEngine:
         r"""Produce a copy of this type object that will utilize the given
         type when applied to the dialect of the given name.
@@ -626,7 +692,9 @@ class TypeEngine(Visitable, Generic[_T]):
         )
         return new_type
 
-    def _resolve_for_literal(self, value):
+    def _resolve_for_literal(
+        self: SelfTypeEngine, value: Any
+    ) -> SelfTypeEngine:
         """adjust this type given a literal Python value that will be
         stored in a bound parameter.
 
@@ -638,28 +706,28 @@ class TypeEngine(Visitable, Generic[_T]):
         return self
 
     @util.memoized_property
-    def _type_affinity(self):
+    def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]:
         """Return a rudimental 'affinity' value expressing the general class
         of type."""
 
         typ = None
         for t in self.__class__.__mro__:
-            if t in (TypeEngine, UserDefinedType):
+            if t is TypeEngine or TypeEngineMixin in t.__bases__:
                 return typ
-            elif issubclass(
-                t, (TypeEngine, UserDefinedType)
-            ) and not t.__dict__.get("_block_from_type_affinity", False):
+            elif issubclass(t, TypeEngine):
                 typ = t
         else:
             return self.__class__
 
     @util.memoized_property
-    def _generic_type_affinity(self):
+    def _generic_type_affinity(
+        self,
+    ) -> Type[TypeEngine[_T]]:
         best_camelcase = None
         best_uppercase = None
 
-        if not isinstance(self, (TypeEngine, UserDefinedType)):
-            return self.__class__
+        if not isinstance(self, TypeEngine):
+            return self.__class__  # type: ignore  # mypy bug?
 
         for t in self.__class__.__mro__:
             if (
@@ -669,7 +737,8 @@ class TypeEngine(Visitable, Generic[_T]):
                     "sqlalchemy.sql.type_api",
                 )
                 and issubclass(t, TypeEngine)
-                and t is not TypeEngine
+                and TypeEngineMixin not in t.__bases__
+                and t not in (TypeEngine, TypeEngineMixin)
                 and t.__name__[0] != "_"
             ):
                 if t.__name__.isupper() and not best_uppercase:
@@ -677,9 +746,13 @@ class TypeEngine(Visitable, Generic[_T]):
                 elif not t.__name__.isupper() and not best_camelcase:
                     best_camelcase = t
 
-        return best_camelcase or best_uppercase or NULLTYPE.__class__
+        return (
+            best_camelcase
+            or best_uppercase
+            or cast("Type[TypeEngine[_T]]", NULLTYPE.__class__)
+        )
 
-    def as_generic(self, allow_nulltype=False):
+    def as_generic(self, allow_nulltype: bool = False) -> TypeEngine[_T]:
         """
         Return an instance of the generic type corresponding to this type
         using heuristic rule. The method may be overridden if this
@@ -719,18 +792,20 @@ class TypeEngine(Visitable, Generic[_T]):
 
         return util.constructor_copy(self, self._generic_type_affinity)
 
-    def dialect_impl(self, dialect):
+    def dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]:
         """Return a dialect-specific implementation for this
         :class:`.TypeEngine`.
 
         """
         try:
-            return dialect._type_memos[self]["impl"]
+            tm = dialect._type_memos[self]
         except KeyError:
             pass
+        else:
+            return tm["impl"]
         return self._dialect_info(dialect)["impl"]
 
-    def _unwrapped_dialect_impl(self, dialect):
+    def _unwrapped_dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]:
         """Return the 'unwrapped' dialect impl for this type.
 
         For a type that applies wrapping logic (e.g. TypeDecorator), give
@@ -744,60 +819,80 @@ class TypeEngine(Visitable, Generic[_T]):
         """
         return self.dialect_impl(dialect)
 
-    def _cached_literal_processor(self, dialect):
+    def _cached_literal_processor(
+        self, dialect: Dialect
+    ) -> Optional[_LiteralProcessorType[_T]]:
         """Return a dialect-specific literal processor for this type."""
+
         try:
             return dialect._type_memos[self]["literal"]
         except KeyError:
             pass
+
         # avoid KeyError context coming into literal_processor() function
         # raises
         d = self._dialect_info(dialect)
         d["literal"] = lp = d["impl"].literal_processor(dialect)
         return lp
 
-    def _cached_bind_processor(self, dialect):
+    def _cached_bind_processor(
+        self, dialect: Dialect
+    ) -> Optional[_BindProcessorType[_T]]:
         """Return a dialect-specific bind processor for this type."""
 
         try:
             return dialect._type_memos[self]["bind"]
         except KeyError:
             pass
+
         # avoid KeyError context coming into bind_processor() function
         # raises
         d = self._dialect_info(dialect)
         d["bind"] = bp = d["impl"].bind_processor(dialect)
         return bp
 
-    def _cached_result_processor(self, dialect, coltype):
+    def _cached_result_processor(
+        self, dialect: Dialect, coltype: Any
+    ) -> Optional[_ResultProcessorType[_T]]:
         """Return a dialect-specific result processor for this type."""
 
         try:
-            return dialect._type_memos[self][coltype]
+            return dialect._type_memos[self]["result"][coltype]
         except KeyError:
             pass
+
         # avoid KeyError context coming into result_processor() function
         # raises
         d = self._dialect_info(dialect)
         # key assumption: DBAPI type codes are
         # constants.  Else this dictionary would
         # grow unbounded.
-        d[coltype] = rp = d["impl"].result_processor(dialect, coltype)
+        rp = d["impl"].result_processor(dialect, coltype)
+        d["result"][coltype] = rp
         return rp
 
-    def _cached_custom_processor(self, dialect, key, fn):
+    def _cached_custom_processor(
+        self, dialect: Dialect, key: str, fn: Callable[[TypeEngine[_T]], _O]
+    ) -> _O:
+        """return a dialect-specific processing object for
+        custom purposes.
+
+        The cx_Oracle dialect uses this at the moment.
+
+        """
         try:
-            return dialect._type_memos[self][key]
+            return cast(_O, dialect._type_memos[self]["custom"][key])
         except KeyError:
             pass
         # avoid KeyError context coming into fn() function
         # raises
         d = self._dialect_info(dialect)
         impl = d["impl"]
-        d[key] = result = fn(impl)
+        custom_dict = d.setdefault("custom", {})
+        custom_dict[key] = result = fn(impl)
         return result
 
-    def _dialect_info(self, dialect):
+    def _dialect_info(self, dialect: Dialect) -> _TypeMemoDict:
         """Return a dialect-specific registry which
         caches a dialect-specific implementation, bind processing
         function, and one or more result processing functions."""
@@ -810,10 +905,11 @@ class TypeEngine(Visitable, Generic[_T]):
                 impl = self.adapt(type(self))
             # this can't be self, else we create a cycle
             assert impl is not self
-            dialect._type_memos[self] = d = {"impl": impl}
+            d: _TypeMemoDict = {"impl": impl, "result": {}}
+            dialect._type_memos[self] = d
             return d
 
-    def _gen_dialect_impl(self, dialect):
+    def _gen_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
         if dialect.name in self._variant_mapping:
             return self._variant_mapping[dialect.name]._gen_dialect_impl(
                 dialect
@@ -822,7 +918,9 @@ class TypeEngine(Visitable, Generic[_T]):
             return dialect.type_descriptor(self)
 
     @util.memoized_property
-    def _static_cache_key(self):
+    def _static_cache_key(
+        self,
+    ) -> Union[CacheConst, Tuple[Any, ...]]:
         names = util.get_cls_kwargs(self.__class__)
         return (self.__class__,) + tuple(
             (
@@ -835,7 +933,17 @@ class TypeEngine(Visitable, Generic[_T]):
             if k in self.__dict__ and not k.startswith("_")
         )
 
-    def adapt(self, cls, **kw):
+    @overload
+    def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
+        ...
+
+    @overload
+    def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
+        ...
+
+    def adapt(
+        self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any
+    ) -> TypeEngine[Any]:
         """Produce an "adapted" form of this type, given an "impl" class
         to work with.
 
@@ -843,9 +951,13 @@ class TypeEngine(Visitable, Generic[_T]):
         types with "implementation" types that are specific to a particular
         dialect.
         """
-        return util.constructor_copy(self, cls, **kw)
+        return util.constructor_copy(
+            self, cast(Type[TypeEngine[Any]], cls), **kw
+        )
 
-    def coerce_compared_value(self, op, value):
+    def coerce_compared_value(
+        self, op: Optional[OperatorType], value: Any
+    ) -> TypeEngine[Any]:
         """Suggest a type for a 'coerced' Python value in an expression.
 
         Given an operator and value, gives the type a chance
@@ -873,10 +985,10 @@ class TypeEngine(Visitable, Generic[_T]):
         else:
             return _coerced_type
 
-    def _compare_type_affinity(self, other):
+    def _compare_type_affinity(self, other: TypeEngine[Any]) -> bool:
         return self._type_affinity is other._type_affinity
 
-    def compile(self, dialect=None):
+    def compile(self, dialect: Optional[Dialect] = None) -> str:
         """Produce a string-compiled form of this :class:`.TypeEngine`.
 
         When called with no arguments, uses a "default" dialect
@@ -888,24 +1000,65 @@ class TypeEngine(Visitable, Generic[_T]):
         # arg, return value is inconsistent with
         # ClauseElement.compile()....this is a mistake.
 
-        if not dialect:
+        if dialect is None:
             dialect = self._default_dialect()
 
-        return dialect.type_compiler.process(self)
+        return dialect.type_compiler_instance.process(self)
 
     @util.preload_module("sqlalchemy.engine.default")
-    def _default_dialect(self):
-        default = util.preloaded.engine_default
-        return default.StrCompileDialect()
+    def _default_dialect(self) -> Dialect:
 
-    def __str__(self):
+        if TYPE_CHECKING:
+            from ..engine import default
+        else:
+            default = util.preloaded.engine_default
+
+        # dmypy / mypy seems to sporadically keep thinking this line is
+        # returning Any, which seems to be caused by the @deprecated_params
+        # decorator on the DefaultDialect constructor
+        return default.StrCompileDialect()  # type: ignore
+
+    def __str__(self) -> str:
         return str(self.compile())
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return util.generic_repr(self)
 
 
-class ExternalType:
+class TypeEngineMixin:
+    """classes which subclass this can act as "mixin" classes for
+    TypeEngine."""
+
+    __slots__ = ()
+
+    if TYPE_CHECKING:
+
+        @util.memoized_property
+        def _static_cache_key(
+            self,
+        ) -> Union[CacheConst, Tuple[Any, ...]]:
+            ...
+
+        @overload
+        def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
+            ...
+
+        @overload
+        def adapt(
+            self, cls: Type[TypeEngineMixin], **kw: Any
+        ) -> TypeEngine[Any]:
+            ...
+
+        def adapt(
+            self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any
+        ) -> TypeEngine[Any]:
+            ...
+
+        def dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
+            ...
+
+
+class ExternalType(TypeEngineMixin):
     """mixin that defines attributes and behaviors specific to third-party
     datatypes.
 
@@ -1057,13 +1210,18 @@ class ExternalType:
 
     """  # noqa: E501
 
-    @property
-    def _static_cache_key(self):
+    @util.non_memoized_property
+    def _static_cache_key(
+        self,
+    ) -> Union[CacheConst, Tuple[Any, ...]]:
         cache_ok = self.__class__.__dict__.get("cache_ok", None)
 
         if cache_ok is None:
-            subtype_idx = self.__class__.__mro__.index(ExternalType)
-            subtype = self.__class__.__mro__[max(subtype_idx - 1, 0)]
+            for subtype in self.__class__.__mro__:
+                if ExternalType in subtype.__bases__:
+                    break
+            else:
+                subtype = self.__class__.__mro__[1]
 
             util.warn(
                 "%s %r will not produce a cache key because "
@@ -1076,12 +1234,14 @@ class ExternalType:
                 code="cprf",
             )
         elif cache_ok is True:
-            return super(ExternalType, self)._static_cache_key
+            return super()._static_cache_key
 
         return NO_CACHE
 
 
-class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg):
+class UserDefinedType(
+    ExternalType, TypeEngineMixin, TypeEngine[_T], util.EnsureKWArg
+):
     """Base for user defined types.
 
     This should be the base of new types.  Note that
@@ -1148,7 +1308,9 @@ class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg):
 
     ensure_kwarg = "get_col_spec"
 
-    def coerce_compared_value(self, op, value):
+    def coerce_compared_value(
+        self, op: Optional[OperatorType], value: Any
+    ) -> TypeEngine[Any]:
         """Suggest a type for a 'coerced' Python value in an expression.
 
         Default behavior for :class:`.UserDefinedType` is the
@@ -1162,7 +1324,7 @@ class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg):
         return self
 
 
-class Emulated:
+class Emulated(TypeEngineMixin):
     """Mixin for base types that emulate the behavior of a DB-native type.
 
     An :class:`.Emulated` type will use an available database type
@@ -1180,7 +1342,13 @@ class Emulated:
 
     """
 
-    def adapt_to_emulated(self, impltype, **kw):
+    native: bool
+
+    def adapt_to_emulated(
+        self,
+        impltype: Type[Union[TypeEngine[Any], TypeEngineMixin]],
+        **kw: Any,
+    ) -> TypeEngine[Any]:
         """Given an impl class, adapt this type to the impl assuming "emulated".
 
         The impl should also be an "emulated" version of this type,
@@ -1189,27 +1357,43 @@ class Emulated:
         e.g.: sqltypes.Enum adapts to the Enum class.
 
         """
-        return super(Emulated, self).adapt(impltype, **kw)
+        return super().adapt(impltype, **kw)
 
-    def adapt(self, impltype, **kw):
-        if hasattr(impltype, "adapt_emulated_to_native"):
+    @overload
+    def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
+        ...
+
+    @overload
+    def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
+        ...
+
+    def adapt(
+        self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any
+    ) -> TypeEngine[Any]:
+        if _is_native_for_emulated(cls):
             if self.native:
                 # native support requested, dialect gave us a native
                 # implementor, pass control over to it
-                return impltype.adapt_emulated_to_native(self, **kw)
+                return cls.adapt_emulated_to_native(self, **kw)
             else:
                 # non-native support, let the native implementor
                 # decide also, at the moment this is just to help debugging
                 # as only the default logic is implemented.
-                return impltype.adapt_native_to_emulated(self, **kw)
+                return cls.adapt_native_to_emulated(self, **kw)
         else:
-            if issubclass(impltype, self.__class__):
-                return self.adapt_to_emulated(impltype, **kw)
+            if issubclass(cls, self.__class__):
+                return self.adapt_to_emulated(cls, **kw)
             else:
-                return super(Emulated, self).adapt(impltype, **kw)
+                return super().adapt(cls, **kw)
+
 
+def _is_native_for_emulated(
+    typ: Type[Union[TypeEngine[Any], TypeEngineMixin]],
+) -> TypeGuard["Type[NativeForEmulated]"]:
+    return hasattr(typ, "adapt_emulated_to_native")
 
-class NativeForEmulated:
+
+class NativeForEmulated(TypeEngineMixin):
     """Indicates DB-native types supported by an :class:`.Emulated` type.
 
     .. versionadded:: 1.2.0b3
@@ -1217,7 +1401,11 @@ class NativeForEmulated:
     """
 
     @classmethod
-    def adapt_native_to_emulated(cls, impl, **kw):
+    def adapt_native_to_emulated(
+        cls,
+        impl: Union[TypeEngine[Any], TypeEngineMixin],
+        **kw: Any,
+    ) -> TypeEngine[Any]:
         """Given an impl, adapt this type's class to the impl assuming
         "emulated".
 
@@ -1227,7 +1415,12 @@ class NativeForEmulated:
         return impl.adapt(impltype, **kw)
 
     @classmethod
-    def adapt_emulated_to_native(cls, impl, **kw):
+    def adapt_emulated_to_native(
+        cls,
+        impl: Union[TypeEngine[Any], TypeEngineMixin],
+        **kw: Any,
+    ) -> TypeEngine[Any]:
+
         """Given an impl, adapt this type's class to the impl assuming "native".
 
         The impl will be an :class:`.Emulated` class but not a
@@ -1236,10 +1429,20 @@ class NativeForEmulated:
         e.g.: postgresql.ENUM produces a type given an Enum instance.
 
         """
-        return cls(**kw)
 
+        # dmypy seems to crash on this
+        return cls(**kw)  # type: ignore
 
-class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
+    # dmypy seems to crash with this, on repeated runs with changes
+    # if TYPE_CHECKING:
+    #    def __init__(self, **kw: Any):
+    #        ...
+
+
+SelfTypeDecorator = TypeVar("SelfTypeDecorator", bound="TypeDecorator[Any]")
+
+
+class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
     """Allows the creation of types which add additional functionality
     to an existing type.
 
@@ -1358,9 +1561,24 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
 
     _is_type_decorator = True
 
+    # this is that pattern I've used in a few places (Dialect.dbapi,
+    # Dialect.type_compiler) where the "cls.attr" is a class to make something,
+    # and "instance.attr" is an instance of that thing.  It's such a nifty,
+    # great pattern, and there is zero chance Python typing tools will ever be
+    # OK with it.  For TypeDecorator.impl, this is a highly public attribute so
+    # we really can't change its behavior without a major deprecation routine.
     impl: Union[TypeEngine[Any], Type[TypeEngine[Any]]]
 
-    def __init__(self, *args, **kwargs):
+    # we are changing its behavior *slightly*, which is that we now consume
+    # the instance level version from this memoized property instead, so you
+    # can't reassign "impl" on an existing TypeDecorator that's already been
+    # used (something one shouldn't do anyway) without also updating
+    # impl_instance.
+    @util.memoized_property
+    def impl_instance(self) -> TypeEngine[Any]:
+        return self.impl  # type: ignore
+
+    def __init__(self, *args: Any, **kwargs: Any):
         """Construct a :class:`.TypeDecorator`.
 
         Arguments sent here are passed to the constructor
@@ -1385,9 +1603,10 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
                 "'impl' which refers to the class of "
                 "type being decorated"
             )
+
         self.impl = to_instance(self.__class__.impl, *args, **kwargs)
 
-    coerce_to_is_types = (util.NoneType,)
+    coerce_to_is_types: Sequence[Type[Any]] = (type(None),)
     """Specify those Python types which should be coerced at the expression
     level to "IS <constant>" when compared using ``==`` (and same for
     ``IS NOT`` in conjunction with ``!=``).
@@ -1416,33 +1635,42 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
 
         __slots__ = ()
 
-        def operate(self, op, *other, **kwargs):
+        def operate(
+            self, op: OperatorType, *other: Any, **kwargs: Any
+        ) -> ColumnElement[_CT]:
+            if TYPE_CHECKING:
+                assert isinstance(self.expr.type, TypeDecorator)
             kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
             return super(TypeDecorator.Comparator, self).operate(
                 op, *other, **kwargs
             )
 
-        def reverse_operate(self, op, other, **kwargs):
+        def reverse_operate(
+            self, op: OperatorType, other: Any, **kwargs: Any
+        ) -> ColumnElement[_CT]:
+            if TYPE_CHECKING:
+                assert isinstance(self.expr.type, TypeDecorator)
             kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
             return super(TypeDecorator.Comparator, self).reverse_operate(
                 op, other, **kwargs
             )
 
     @property
-    def comparator_factory(self) -> Callable[..., TypeEngine.Comparator[_T]]:
-        if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__:
+    def comparator_factory(  # type: ignore  # mypy properties bug
+        self,
+    ) -> _ComparatorFactory[Any]:
+        if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__:  # type: ignore # noqa E501
             return self.impl.comparator_factory
         else:
+            # reconcile the Comparator class on the impl with that
+            # of TypeDecorator
             return type(
                 "TDComparator",
-                (TypeDecorator.Comparator, self.impl.comparator_factory),
+                (TypeDecorator.Comparator, self.impl.comparator_factory),  # type: ignore # noqa E501
                 {},
             )
 
-    def _gen_dialect_impl(self, dialect):
-        """
-        #todo
-        """
+    def _gen_dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]:
         if dialect.name in self._variant_mapping:
             adapted = dialect.type_descriptor(
                 self._variant_mapping[dialect.name]
@@ -1463,35 +1691,34 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
                 "implement the copy() method, it must "
                 "return an object of type %s" % (self, self.__class__)
             )
-        tt.impl = typedesc
+        tt.impl = tt.impl_instance = typedesc
         return tt
 
-    @property
-    def _type_affinity(self):
-        """
-        #todo
-        """
-        return self.impl._type_affinity
+    @util.non_memoized_property
+    def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
+        return self.impl_instance._type_affinity
 
-    def _set_parent(self, column, outer=False, **kw):
+    def _set_parent(
+        self, parent: SchemaEventTarget, outer: bool = False, **kw: Any
+    ) -> None:
         """Support SchemaEventTarget"""
 
-        super(TypeDecorator, self)._set_parent(column)
+        super()._set_parent(parent)
 
-        if not outer and isinstance(self.impl, SchemaEventTarget):
-            self.impl._set_parent(column, outer=False, **kw)
+        if not outer and isinstance(self.impl_instance, SchemaEventTarget):
+            self.impl_instance._set_parent(parent, outer=False, **kw)
 
-    def _set_parent_with_dispatch(self, parent):
+    def _set_parent_with_dispatch(
+        self, parent: SchemaEventTarget, **kw: Any
+    ) -> None:
         """Support SchemaEventTarget"""
 
-        super(TypeDecorator, self)._set_parent_with_dispatch(
-            parent, outer=True
-        )
+        super()._set_parent_with_dispatch(parent, outer=True, **kw)
 
-        if isinstance(self.impl, SchemaEventTarget):
-            self.impl._set_parent_with_dispatch(parent)
+        if isinstance(self.impl_instance, SchemaEventTarget):
+            self.impl_instance._set_parent_with_dispatch(parent)
 
-    def type_engine(self, dialect):
+    def type_engine(self, dialect: Dialect) -> TypeEngine[Any]:
         """Return a dialect-specific :class:`.TypeEngine` instance
         for this :class:`.TypeDecorator`.
 
@@ -1508,7 +1735,7 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
         else:
             return self.load_dialect_impl(dialect)
 
-    def load_dialect_impl(self, dialect):
+    def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
         """Return a :class:`.TypeEngine` object corresponding to a dialect.
 
         This is an end-user override hook that can be used to provide
@@ -1520,9 +1747,9 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
         By default returns ``self.impl``.
 
         """
-        return self.impl
+        return self.impl_instance
 
-    def _unwrapped_dialect_impl(self, dialect):
+    def _unwrapped_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
         """Return the 'unwrapped' dialect impl for this type.
 
         This is used by the :meth:`.DefaultDialect.set_input_sizes`
@@ -1540,12 +1767,14 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
         else:
             return typ
 
-    def __getattr__(self, key):
+    def __getattr__(self, key: str) -> Any:
         """Proxy all other undefined accessors to the underlying
         implementation."""
-        return getattr(self.impl, key)
+        return getattr(self.impl_instance, key)
 
-    def process_literal_param(self, value, dialect):
+    def process_literal_param(
+        self, value: Optional[_T], dialect: Dialect
+    ) -> str:
         """Receive a literal parameter value to be rendered inline within
         a statement.
 
@@ -1568,7 +1797,7 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
         """
         raise NotImplementedError()
 
-    def process_bind_param(self, value, dialect):
+    def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
         """Receive a bound parameter value to be converted.
 
         Custom subclasses of :class:`_types.TypeDecorator` should override
@@ -1595,7 +1824,9 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
 
         raise NotImplementedError()
 
-    def process_result_value(self, value, dialect):
+    def process_result_value(
+        self, value: Optional[Any], dialect: Any
+    ) -> Optional[_T]:
         """Receive a result-row column value to be converted.
 
         Custom subclasses of :class:`_types.TypeDecorator` should override
@@ -1624,7 +1855,7 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
         raise NotImplementedError()
 
     @util.memoized_property
-    def _has_bind_processor(self):
+    def _has_bind_processor(self) -> bool:
         """memoized boolean, check if process_bind_param is implemented.
 
         Allows the base process_bind_param to raise
@@ -1638,14 +1869,16 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
         )
 
     @util.memoized_property
-    def _has_literal_processor(self):
+    def _has_literal_processor(self) -> bool:
         """memoized boolean, check if process_literal_param is implemented."""
 
         return util.method_is_overridden(
             self, TypeDecorator.process_literal_param
         )
 
-    def literal_processor(self, dialect):
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> Optional[_LiteralProcessorType[_T]]:
         """Provide a literal processing function for the given
         :class:`.Dialect`.
 
@@ -1661,34 +1894,59 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
             "inner" processing provided by the implementing type is maintained.
 
         """
+
         if self._has_literal_processor:
-            process_param = self.process_literal_param
+            process_literal_param = self.process_literal_param
+            process_bind_param = None
         elif self._has_bind_processor:
-            # the bind processor should normally be OK
-            # for TypeDecorator since it isn't doing DB-level
-            # handling, the handling here won't be different for bound vs.
-            # literals.
-            process_param = self.process_bind_param
+            # use the bind processor if dont have a literal processor,
+            # but we have an impl literal processor
+            process_literal_param = None
+            process_bind_param = self.process_bind_param
         else:
-            process_param = None
+            process_literal_param = None
+            process_bind_param = None
 
-        if process_param:
-            impl_processor = self.impl.literal_processor(dialect)
+        if process_literal_param is not None:
+            impl_processor = self.impl_instance.literal_processor(dialect)
             if impl_processor:
 
-                def process(value):
-                    return impl_processor(process_param(value, dialect))
+                fixed_impl_processor = impl_processor
+                fixed_process_literal_param = process_literal_param
+
+                def process(value: Any) -> str:
+                    return fixed_impl_processor(
+                        fixed_process_literal_param(value, dialect)
+                    )
 
             else:
+                fixed_process_literal_param = process_literal_param
 
-                def process(value):
-                    return process_param(value, dialect)
+                def process(value: Any) -> str:
+                    return fixed_process_literal_param(value, dialect)
 
             return process
+
+        elif process_bind_param is not None:
+            impl_processor = self.impl_instance.literal_processor(dialect)
+            if not impl_processor:
+                return None
+            else:
+                fixed_impl_processor = impl_processor
+                fixed_process_bind_param = process_bind_param
+
+                def process(value: Any) -> str:
+                    return fixed_impl_processor(
+                        fixed_process_bind_param(value, dialect)
+                    )
+
+                return process
         else:
-            return self.impl.literal_processor(dialect)
+            return self.impl_instance.literal_processor(dialect)
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> Optional[_BindProcessorType[_T]]:
         """Provide a bound value processing function for the
         given :class:`.Dialect`.
 
@@ -1708,23 +1966,28 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
         """
         if self._has_bind_processor:
             process_param = self.process_bind_param
-            impl_processor = self.impl.bind_processor(dialect)
+            impl_processor = self.impl_instance.bind_processor(dialect)
             if impl_processor:
+                fixed_impl_processor = impl_processor
+                fixed_process_param = process_param
 
-                def process(value):
-                    return impl_processor(process_param(value, dialect))
+                def process(value: Optional[_T]) -> Any:
+                    return fixed_impl_processor(
+                        fixed_process_param(value, dialect)
+                    )
 
             else:
+                fixed_process_param = process_param
 
-                def process(value):
-                    return process_param(value, dialect)
+                def process(value: Optional[_T]) -> Any:
+                    return fixed_process_param(value, dialect)
 
             return process
         else:
-            return self.impl.bind_processor(dialect)
+            return self.impl_instance.bind_processor(dialect)
 
     @util.memoized_property
-    def _has_result_processor(self):
+    def _has_result_processor(self) -> bool:
         """memoized boolean, check if process_result_value is implemented.
 
         Allows the base process_result_value to raise
@@ -1737,7 +2000,9 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
             self, TypeDecorator.process_result_value
         )
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: Any
+    ) -> Optional[_ResultProcessorType[_T]]:
         """Provide a result value processing function for the given
         :class:`.Dialect`.
 
@@ -1758,30 +2023,39 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
         """
         if self._has_result_processor:
             process_value = self.process_result_value
-            impl_processor = self.impl.result_processor(dialect, coltype)
+            impl_processor = self.impl_instance.result_processor(
+                dialect, coltype
+            )
             if impl_processor:
+                fixed_process_value = process_value
+                fixed_impl_processor = impl_processor
 
-                def process(value):
-                    return process_value(impl_processor(value), dialect)
+                def process(value: Any) -> Optional[_T]:
+                    return fixed_process_value(
+                        fixed_impl_processor(value), dialect
+                    )
 
             else:
+                fixed_process_value = process_value
 
-                def process(value):
-                    return process_value(value, dialect)
+                def process(value: Any) -> Optional[_T]:
+                    return fixed_process_value(value, dialect)
 
             return process
         else:
-            return self.impl.result_processor(dialect, coltype)
+            return self.impl_instance.result_processor(dialect, coltype)
 
     @util.memoized_property
-    def _has_bind_expression(self):
+    def _has_bind_expression(self) -> bool:
 
         return (
             util.method_is_overridden(self, TypeDecorator.bind_expression)
-            or self.impl._has_bind_expression
+            or self.impl_instance._has_bind_expression
         )
 
-    def bind_expression(self, bindparam):
+    def bind_expression(
+        self, bindparam: BindParameter[_T]
+    ) -> Optional[ColumnElement[_T]]:
         """Given a bind value (i.e. a :class:`.BindParameter` instance),
         return a SQL expression which will typically wrap the given parameter.
 
@@ -1800,10 +2074,10 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
         type.
 
         """
-        return self.impl.bind_expression(bindparam)
+        return self.impl_instance.bind_expression(bindparam)
 
     @util.memoized_property
-    def _has_column_expression(self):
+    def _has_column_expression(self) -> bool:
         """memoized boolean, check if column_expression is implemented.
 
         Allows the method to be skipped for the vast majority of expression
@@ -1813,10 +2087,12 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
 
         return (
             util.method_is_overridden(self, TypeDecorator.column_expression)
-            or self.impl._has_column_expression
+            or self.impl_instance._has_column_expression
         )
 
-    def column_expression(self, column):
+    def column_expression(
+        self, column: ColumnElement[_T]
+    ) -> Optional[ColumnElement[_T]]:
         """Given a SELECT column expression, return a wrapping SQL expression.
 
         .. note::
@@ -1838,9 +2114,11 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
 
         """
 
-        return self.impl.column_expression(column)
+        return self.impl_instance.column_expression(column)
 
-    def coerce_compared_value(self, op, value):
+    def coerce_compared_value(
+        self, op: Optional[OperatorType], value: Any
+    ) -> Any:
         """Suggest a type for a 'coerced' Python value in an expression.
 
         By default, returns self.   This method is called by
@@ -1858,7 +2136,7 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
         """
         return self
 
-    def copy(self, **kw):
+    def copy(self: SelfTypeDecorator, **kw: Any) -> SelfTypeDecorator:
         """Produce a copy of this :class:`.TypeDecorator` instance.
 
         This is a shallow copy and is provided to fulfill part of
@@ -1872,16 +2150,16 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
         instance.__dict__.update(self.__dict__)
         return instance
 
-    def get_dbapi_type(self, dbapi):
+    def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]:
         """Return the DBAPI type object represented by this
         :class:`.TypeDecorator`.
 
         By default this calls upon :meth:`.TypeEngine.get_dbapi_type` of the
         underlying "impl".
         """
-        return self.impl.get_dbapi_type(dbapi)
+        return self.impl_instance.get_dbapi_type(dbapi)
 
-    def compare_values(self, x, y):
+    def compare_values(self, x: Any, y: Any) -> bool:
         """Given two values, compare them for equality.
 
         By default this calls upon :meth:`.TypeEngine.compare_values`
@@ -1894,46 +2172,60 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]):
         has occurred.
 
         """
-        return self.impl.compare_values(x, y)
+        return self.impl_instance.compare_values(x, y)
 
+    # mypy property bug
     @property
-    def sort_key_function(self):
-        return self.impl.sort_key_function
+    def sort_key_function(self) -> Optional[Callable[[Any], Any]]:  # type: ignore # noqa E501
+        return self.impl_instance.sort_key_function
 
-    def __repr__(self):
-        return util.generic_repr(self, to_inspect=self.impl)
+    def __repr__(self) -> str:
+        return util.generic_repr(self, to_inspect=self.impl_instance)
 
 
-class Variant(TypeDecorator):
+class Variant(TypeDecorator[_T]):
     """deprecated.  symbol is present for backwards-compatibility with
     workaround recipes, however this actual type should not be used.
 
     """
 
-    def __init__(self, *arg, **kw):
+    def __init__(self, *arg: Any, **kw: Any):
         raise NotImplementedError(
             "Variant is no longer used in SQLAlchemy; this is a "
             "placeholder symbol for backwards compatibility."
         )
 
 
-def _reconstitute_comparator(expression):
+def _reconstitute_comparator(expression: Any) -> Any:
     return expression.comparator
 
 
+@overload
+def to_instance(typeobj: Union[Type[_TE], _TE], *arg: Any, **kw: Any) -> _TE:
+    ...
+
+
+@overload
+def to_instance(typeobj: None, *arg: Any, **kw: Any) -> TypeEngine[None]:
+    ...
+
+
 def to_instance(
-    typeobj: Union[Type[TypeEngine[_T]], TypeEngine[_T], None], *arg, **kw
-) -> TypeEngine[_T]:
+    typeobj: Union[Type[_TE], _TE, None], *arg: Any, **kw: Any
+) -> Union[_TE, TypeEngine[None]]:
     if typeobj is None:
         return NULLTYPE
 
     if callable(typeobj):
-        return typeobj(*arg, **kw)
+        return typeobj(*arg, **kw)  # type: ignore  # for pyright
     else:
         return typeobj
 
 
-def adapt_type(typeobj, colspecs):
+def adapt_type(
+    typeobj: TypeEngine[Any],
+    colspecs: Mapping[Type[Any], Type[TypeEngine[Any]]],
+) -> TypeEngine[Any]:
     if isinstance(typeobj, type):
         typeobj = typeobj()
     for t in typeobj.__class__.__mro__[0:-1]:
index a8e58a8bfb87b63c79f70c0cfb2da77fabe6ebe1..5c536b675fa40bda4ac96ce755e8d3e5e541229e 100644 (file)
@@ -36,7 +36,7 @@ _T = TypeVar("_T", bound=Any)
 
 
 # https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
-_F = TypeVar("_F", bound=Callable[..., Any])
+_F = TypeVar("_F", bound="Callable[..., Any]")
 
 
 def _warn_with_version(
index ef8520c4f8ad9f944f71327a798f0af6a0cbd25d..35294715cbbf411467d93a5fe90d17567b6242a0 100644 (file)
@@ -178,7 +178,8 @@ def clsname_as_plain_name(cls: Type[Any]) -> str:
 
 
 def method_is_overridden(
-    instance_or_cls: Union[Type[Any], object], against_method: types.MethodType
+    instance_or_cls: Union[Type[Any], object],
+    against_method: Callable[..., Any],
 ) -> bool:
     """Return True if the two class methods don't match."""
 
@@ -815,7 +816,12 @@ def unbound_method_to_callable(func_or_cls):
         return func_or_cls
 
 
-def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
+def generic_repr(
+    obj: Any,
+    additional_kw: Sequence[Tuple[str, Any]] = (),
+    to_inspect: Optional[Union[object, List[object]]] = None,
+    omit_kwarg: Sequence[str] = (),
+) -> str:
     """Produce a __repr__() based on direct association of the __init__()
     specification vs. same-named attributes present.
 
index 6cfa8db4695557b38bf12cce9227dca5bae64cf4..7117c2689f38b459750d3cc0e96b7a30e3df3236 100644 (file)
@@ -86,6 +86,7 @@ module = [
     "sqlalchemy.sql.cache_key",
     "sqlalchemy.sql._elements_constructors",
     "sqlalchemy.sql.operators",
+    "sqlalchemy.sql.type_api",
     "sqlalchemy.sql.roles",
     "sqlalchemy.sql.visitors",
     "sqlalchemy.sql._py_util",
@@ -109,6 +110,7 @@ strict = true
 
 module = [
     #"sqlalchemy.sql.*",
+    "sqlalchemy.sql.sqltypes",
     "sqlalchemy.sql.elements",
     "sqlalchemy.sql.coercions",
     "sqlalchemy.sql.compiler",
index 8b950026f2f5a860ea5c56ae6425bd7ce6879caf..e50d515fb4e8a83d771848028ae4454b23dac3ed 100644 (file)
@@ -355,7 +355,6 @@ class ExecuteTest(fixtures.TablesTest):
                     )
                 )
             )
-
             assert_raises_message(
                 TypeError,
                 "I'm not a DBAPI error",
@@ -379,6 +378,8 @@ class ExecuteTest(fixtures.TablesTest):
         # have any special behaviors
         with patch.object(
             testing.db.dialect, "dbapi", Mock(Error=DBAPIError)
+        ), patch.object(
+            testing.db.dialect, "loaded_dbapi", Mock(Error=DBAPIError)
         ), patch.object(
             testing.db.dialect, "is_disconnect", lambda *arg: False
         ), patch.object(
index c1613069e111cd044ac39b8ac203234402fdcf1a..7f2dd434c35d4fbea05caa6b8dba3f3312654528 100644 (file)
@@ -1682,7 +1682,7 @@ class QueuePoolTest(PoolTestBase):
 
         dialect = Mock()
         dialect.is_disconnect = lambda *arg, **kw: True
-        dialect.dbapi.Error = Error
+        dialect.dbapi.Error = dialect.loaded_dbapi.Error = Error
 
         pools = []
 
index b7bae0185d22192dad50d946b2b2d15c1e04d982..78b0a467cefd175a474d07fda1bc3fdfe073857f 100644 (file)
@@ -25,7 +25,8 @@ expr4 = -c2
 
 expr5 = ~(c2 == 5)
 
-expr6 = ~column("q", Boolean)
+q = column("q", Boolean)
+expr6 = ~q
 
 expr7 = c1 + "x"
 
index 09249416586621b806c69042628b3f620a2e35d6..11c3e83b7d44b5b132cd6b198ea7c9a908e52d7b 100644 (file)
@@ -2100,6 +2100,8 @@ class SchemaTypeTest(fixtures.TestBase):
     # causes collection-mutate-while-iterated errors in the event system
     # since the hooks here call upon the adapted type.  Need to figure out
     # why Enum and Boolean don't have this problem.
+    # NOTE: it's likely the need for the SchemaType.adapt() method,
+    # which Enum / Boolean don't use (and crash if it comes first)
     class MyType(TrackEvents, sqltypes.SchemaType, sqltypes.TypeEngine):
         pass
 
@@ -2141,7 +2143,7 @@ class SchemaTypeTest(fixtures.TestBase):
         # [ticket:3832]
         # this also serves as the test for [ticket:6152]
 
-        class MySchemaType(sqltypes.TypeEngine, sqltypes.SchemaType):
+        class MySchemaType(sqltypes.SchemaType):
             pass
 
         target_typ = MySchemaType()
index 07003abcb58008dd77f69dbc503c115d3263b777..9812f84c154306ff5fc67aba67250869e576512a 100644 (file)
@@ -954,7 +954,7 @@ class QuotedIdentTest(fixtures.TestBase):
         eq_(q2.quote, False)
 
     def test_coerce_none(self):
-        q1 = quoted_name(None, False)
+        q1 = quoted_name.construct(None, False)
         eq_(q1, None)
 
     def test_apply_map_quoted(self):
index acf16565a01dfc76b18d0c901b336f2ce5d70469..d496b323b5e335518eaeda7dff674ca19f93065b 100644 (file)
@@ -72,7 +72,9 @@ from sqlalchemy.sql import null
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import sqltypes
 from sqlalchemy.sql import table
+from sqlalchemy.sql import type_api
 from sqlalchemy.sql import visitors
+from sqlalchemy.sql.compiler import TypeCompiler
 from sqlalchemy.sql.sqltypes import TypeEngine
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
@@ -117,10 +119,15 @@ def _types_for_mod(mod):
 def _all_types(omit_special_types=False):
     seen = set()
     for typ in _types_for_mod(types):
-        if omit_special_types and typ in (
-            types.TypeDecorator,
-            types.TypeEngine,
-            types.Variant,
+        if omit_special_types and (
+            typ
+            in (
+                TypeEngine,
+                type_api.TypeEngineMixin,
+                types.Variant,
+                types.TypeDecorator,
+            )
+            or type_api.TypeEngineMixin in typ.__bases__
         ):
             continue
 
@@ -3553,7 +3560,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
                 return "MYINTEGER %s" % kw["type_expression"].name
 
         dialect = default.DefaultDialect()
-        dialect.type_compiler = SomeTypeCompiler(dialect)
+        dialect.type_compiler_instance = SomeTypeCompiler(dialect)
         self.assert_compile(
             ddl.CreateColumn(Column("bar", VARCHAR(50))),
             "bar MYVARCHAR",
@@ -3565,6 +3572,34 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=dialect,
         )
 
+    def test_legacy_typecompiler_attribute(self):
+        """the .type_compiler attribute was broken into
+        .type_compiler_cls and .type_compiler_instance for 2.0 so that it can
+        be properly typed.  However it is expected that the majority of
+        dialects make use of the .type_compiler attribute both at the class
+        level as well as the instance level, so make sure it still functions
+        in exactly the same way, both as the type compiler class to be
+        used as well as that it's present as an instance on an instance
+        of the dialect.
+
+        """
+
+        dialect = default.DefaultDialect()
+        assert isinstance(
+            dialect.type_compiler_instance, dialect.type_compiler_cls
+        )
+        is_(dialect.type_compiler_instance, dialect.type_compiler)
+
+        class MyTypeCompiler(TypeCompiler):
+            pass
+
+        class MyDialect(default.DefaultDialect):
+            type_compiler = MyTypeCompiler
+
+        dialect = MyDialect()
+        assert isinstance(dialect.type_compiler_instance, MyTypeCompiler)
+        is_(dialect.type_compiler_instance, dialect.type_compiler)
+
 
 class TestKWArgPassThru(AssertsCompiledSQL, fixtures.TestBase):
     __backend__ = True