]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Type mysql dialect
authorPablo Estevez <pablo22estevez@gmail.com>
Tue, 13 May 2025 13:39:19 +0000 (09:39 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 May 2025 02:00:45 +0000 (22:00 -0400)
Closes: #12164
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12164
Pull-request-sha: 545e2c39d5ee4f3938111b26e098fa2aa2b6e800
Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Change-Id: I37bd98049ff1a64d58e9490b0e5e2ea764dd1f73

29 files changed:
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/dialects/__init__.py
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/mysql/asyncmy.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/cymysql.py
lib/sqlalchemy/dialects/mysql/enumerated.py
lib/sqlalchemy/dialects/mysql/expression.py
lib/sqlalchemy/dialects/mysql/json.py
lib/sqlalchemy/dialects/mysql/mariadb.py
lib/sqlalchemy/dialects/mysql/mariadbconnector.py
lib/sqlalchemy/dialects/mysql/mysqlconnector.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/provision.py
lib/sqlalchemy/dialects/mysql/pymysql.py
lib/sqlalchemy/dialects/mysql/pyodbc.py
lib/sqlalchemy/dialects/mysql/reflection.py
lib/sqlalchemy/dialects/mysql/reserved_words.py
lib/sqlalchemy/dialects/mysql/types.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/ddl.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/type_api.py
pyproject.toml

index bce08d9cc353e018eb6940560b7a249fc007a938..2037c248efc4b14ab1058f80900c9275d3636cc8 100644 (file)
@@ -20,13 +20,17 @@ from typing import NoReturn
 from typing import Optional
 from typing import Protocol
 from typing import Sequence
+from typing import TYPE_CHECKING
 
 from ..engine import AdaptedConnection
-from ..engine.interfaces import _DBAPICursorDescription
-from ..engine.interfaces import _DBAPIMultiExecuteParams
-from ..engine.interfaces import _DBAPISingleExecuteParams
 from ..util.concurrency import await_
-from ..util.typing import Self
+
+if TYPE_CHECKING:
+    from ..engine.interfaces import _DBAPICursorDescription
+    from ..engine.interfaces import _DBAPIMultiExecuteParams
+    from ..engine.interfaces import _DBAPISingleExecuteParams
+    from ..engine.interfaces import DBAPIModule
+    from ..util.typing import Self
 
 
 class AsyncIODBAPIConnection(Protocol):
@@ -36,7 +40,8 @@ class AsyncIODBAPIConnection(Protocol):
 
     """
 
-    async def close(self) -> None: ...
+    # note that async DBAPIs dont agree if close() should be awaitable,
+    # so it is omitted here and picked up by the __getattr__ hook below
 
     async def commit(self) -> None: ...
 
@@ -44,6 +49,10 @@ class AsyncIODBAPIConnection(Protocol):
 
     async def rollback(self) -> None: ...
 
+    def __getattr__(self, key: str) -> Any: ...
+
+    def __setattr__(self, key: str, value: Any) -> None: ...
+
 
 class AsyncIODBAPICursor(Protocol):
     """protocol representing an async adapted version
@@ -101,6 +110,16 @@ class AsyncIODBAPICursor(Protocol):
     def __aiter__(self) -> AsyncIterator[Any]: ...
 
 
+class AsyncAdapt_dbapi_module:
+    if TYPE_CHECKING:
+        Error = DBAPIModule.Error
+        OperationalError = DBAPIModule.OperationalError
+        InterfaceError = DBAPIModule.InterfaceError
+        IntegrityError = DBAPIModule.IntegrityError
+
+        def __getattr__(self, key: str) -> Any: ...
+
+
 class AsyncAdapt_dbapi_cursor:
     server_side = False
     __slots__ = (
index 8aaf223d4d9d1aabe2023650ea17dfd0bdc31759..d66836e038ed5e87c87eceb5e02331c7fb9854a7 100644 (file)
@@ -8,7 +8,6 @@
 from __future__ import annotations
 
 import re
-from types import ModuleType
 import typing
 from typing import Any
 from typing import Dict
@@ -28,6 +27,7 @@ from ..engine import URL
 from ..sql.type_api import TypeEngine
 
 if typing.TYPE_CHECKING:
+    from ..engine.interfaces import DBAPIModule
     from ..engine.interfaces import IsolationLevel
 
 
@@ -47,15 +47,13 @@ class PyODBCConnector(Connector):
     # hold the desired driver name
     pyodbc_driver_name: Optional[str] = None
 
-    dbapi: ModuleType
-
     def __init__(self, use_setinputsizes: bool = False, **kw: Any):
         super().__init__(**kw)
         if use_setinputsizes:
             self.bind_typing = interfaces.BindTyping.SETINPUTSIZES
 
     @classmethod
-    def import_dbapi(cls) -> ModuleType:
+    def import_dbapi(cls) -> DBAPIModule:
         return __import__("pyodbc")
 
     def create_connect_args(self, url: URL) -> ConnectArgsType:
@@ -150,7 +148,7 @@ class PyODBCConnector(Connector):
         ],
         cursor: Optional[interfaces.DBAPICursor],
     ) -> bool:
-        if isinstance(e, self.dbapi.ProgrammingError):
+        if isinstance(e, self.loaded_dbapi.ProgrammingError):
             return "The cursor's connection has been closed." in str(
                 e
             ) or "Attempt to use a closed connection." in str(e)
index 31ce6d64b52ec293eb431bfb66d9d566f1a8c894..30928a98455aba407cd7ae9b198c4602b7224cd5 100644 (file)
@@ -7,6 +7,7 @@
 
 from __future__ import annotations
 
+from typing import Any
 from typing import Callable
 from typing import Optional
 from typing import Type
@@ -39,7 +40,7 @@ def _auto_fn(name: str) -> Optional[Callable[[], Type[Dialect]]]:
             # hardcoded.   if mysql / mariadb etc were third party dialects
             # they would just publish all the entrypoints, which would actually
             # look much nicer.
-            module = __import__(
+            module: Any = __import__(
                 "sqlalchemy.dialects.mysql.mariadb"
             ).dialects.mysql.mariadb
             return module.loader(driver)  # type: ignore
index 66dd91110432f31ddf22e9a21e49464f481e97d7..d9828d0a27df5d7086ed9b258e7cdb77aa9ec5a7 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 r"""
 .. dialect:: mysql+aiomysql
@@ -29,17 +28,39 @@ This dialect should normally be used only with the
     )
 
 """  # noqa
+from __future__ import annotations
+
+from types import ModuleType
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
 from .pymysql import MySQLDialect_pymysql
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_module
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
 from ...util.concurrency import await_
 
+if TYPE_CHECKING:
+
+    from ...connectors.asyncio import AsyncIODBAPIConnection
+    from ...connectors.asyncio import AsyncIODBAPICursor
+    from ...engine.interfaces import ConnectArgsType
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.interfaces import PoolProxiedConnection
+    from ...engine.url import URL
+
 
 class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor):
     __slots__ = ()
 
-    def _make_new_cursor(self, connection):
+    def _make_new_cursor(
+        self, connection: AsyncIODBAPIConnection
+    ) -> AsyncIODBAPICursor:
         return connection.cursor(self._adapt_connection.dbapi.Cursor)
 
 
@@ -48,7 +69,9 @@ class AsyncAdapt_aiomysql_ss_cursor(
 ):
     __slots__ = ()
 
-    def _make_new_cursor(self, connection):
+    def _make_new_cursor(
+        self, connection: AsyncIODBAPIConnection
+    ) -> AsyncIODBAPICursor:
         return connection.cursor(
             self._adapt_connection.dbapi.aiomysql.cursors.SSCursor
         )
@@ -60,17 +83,17 @@ class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
     _cursor_cls = AsyncAdapt_aiomysql_cursor
     _ss_cursor_cls = AsyncAdapt_aiomysql_ss_cursor
 
-    def ping(self, reconnect):
+    def ping(self, reconnect: bool) -> None:
         assert not reconnect
-        return await_(self._connection.ping(reconnect))
+        await_(self._connection.ping(reconnect))
 
-    def character_set_name(self):
-        return self._connection.character_set_name()
+    def character_set_name(self) -> Optional[str]:
+        return self._connection.character_set_name()  # type: ignore[no-any-return]  # noqa: E501
 
-    def autocommit(self, value):
+    def autocommit(self, value: Any) -> None:
         await_(self._connection.autocommit(value))
 
-    def terminate(self):
+    def terminate(self) -> None:
         # it's not awaitable.
         self._connection.close()
 
@@ -78,15 +101,15 @@ class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
         await_(self._connection.ensure_closed())
 
 
-class AsyncAdapt_aiomysql_dbapi:
-    def __init__(self, aiomysql, pymysql):
+class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
+    def __init__(self, aiomysql: ModuleType, pymysql: ModuleType):
         self.aiomysql = aiomysql
         self.pymysql = pymysql
         self.paramstyle = "format"
         self._init_dbapi_attributes()
         self.Cursor, self.SSCursor = self._init_cursors_subclasses()
 
-    def _init_dbapi_attributes(self):
+    def _init_dbapi_attributes(self) -> None:
         for name in (
             "Warning",
             "Error",
@@ -112,7 +135,7 @@ class AsyncAdapt_aiomysql_dbapi:
         ):
             setattr(self, name, getattr(self.pymysql, name))
 
-    def connect(self, *arg, **kw):
+    def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiomysql_connection:
         creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
 
         return AsyncAdapt_aiomysql_connection(
@@ -120,57 +143,72 @@ class AsyncAdapt_aiomysql_dbapi:
             await_(creator_fn(*arg, **kw)),
         )
 
-    def _init_cursors_subclasses(self):
+    def _init_cursors_subclasses(
+        self,
+    ) -> tuple[AsyncIODBAPICursor, AsyncIODBAPICursor]:
         # suppress unconditional warning emitted by aiomysql
-        class Cursor(self.aiomysql.Cursor):
-            async def _show_warnings(self, conn):
+        class Cursor(self.aiomysql.Cursor):  # type: ignore[misc, name-defined]
+            async def _show_warnings(
+                self, conn: AsyncIODBAPIConnection
+            ) -> None:
                 pass
 
-        class SSCursor(self.aiomysql.SSCursor):
-            async def _show_warnings(self, conn):
+        class SSCursor(self.aiomysql.SSCursor):  # type: ignore[misc, name-defined]   # noqa: E501
+            async def _show_warnings(
+                self, conn: AsyncIODBAPIConnection
+            ) -> None:
                 pass
 
-        return Cursor, SSCursor
+        return Cursor, SSCursor  # type: ignore[return-value]
 
 
 class MySQLDialect_aiomysql(MySQLDialect_pymysql):
     driver = "aiomysql"
     supports_statement_cache = True
 
-    supports_server_side_cursors = True
+    supports_server_side_cursors = True  # type: ignore[assignment]
     _sscursor = AsyncAdapt_aiomysql_ss_cursor
 
     is_async = True
     has_terminate = True
 
     @classmethod
-    def import_dbapi(cls):
+    def import_dbapi(cls) -> AsyncAdapt_aiomysql_dbapi:
         return AsyncAdapt_aiomysql_dbapi(
             __import__("aiomysql"), __import__("pymysql")
         )
 
-    def do_terminate(self, dbapi_connection) -> None:
+    def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
         dbapi_connection.terminate()
 
-    def create_connect_args(self, url):
+    def create_connect_args(
+        self, url: URL, _translate_args: Optional[dict[str, Any]] = None
+    ) -> ConnectArgsType:
         return super().create_connect_args(
             url, _translate_args=dict(username="user", database="db")
         )
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: DBAPIModule.Error,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
         if super().is_disconnect(e, connection, cursor):
             return True
         else:
             str_e = str(e).lower()
             return "not connected" in str_e
 
-    def _found_rows_client_flag(self):
-        from pymysql.constants import CLIENT
+    def _found_rows_client_flag(self) -> int:
+        from pymysql.constants import CLIENT  # type: ignore
 
-        return CLIENT.FOUND_ROWS
+        return CLIENT.FOUND_ROWS  # type: ignore[no-any-return]
 
-    def get_driver_connection(self, connection):
-        return connection._connection
+    def get_driver_connection(
+        self, connection: DBAPIConnection
+    ) -> AsyncIODBAPIConnection:
+        return connection._connection  # type: ignore[no-any-return]
 
 
 dialect = MySQLDialect_aiomysql
index 86c78d65d5b2c6948cb2c118ac5f03b9b1b6ad98..a2e1fffec696f38793f9d5c2e92682ed0e5745de 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 r"""
 .. dialect:: mysql+asyncmy
@@ -29,13 +28,32 @@ This dialect should normally be used only with the
 """  # noqa
 from __future__ import annotations
 
+from types import ModuleType
+from typing import Any
+from typing import NoReturn
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
 from .pymysql import MySQLDialect_pymysql
 from ... import util
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_module
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
 from ...util.concurrency import await_
 
+if TYPE_CHECKING:
+
+    from ...connectors.asyncio import AsyncIODBAPIConnection
+    from ...connectors.asyncio import AsyncIODBAPICursor
+    from ...engine.interfaces import ConnectArgsType
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.interfaces import PoolProxiedConnection
+    from ...engine.url import URL
+
 
 class AsyncAdapt_asyncmy_cursor(AsyncAdapt_dbapi_cursor):
     __slots__ = ()
@@ -46,7 +64,9 @@ class AsyncAdapt_asyncmy_ss_cursor(
 ):
     __slots__ = ()
 
-    def _make_new_cursor(self, connection):
+    def _make_new_cursor(
+        self, connection: AsyncIODBAPIConnection
+    ) -> AsyncIODBAPICursor:
         return connection.cursor(
             self._adapt_connection.dbapi.asyncmy.cursors.SSCursor
         )
@@ -58,7 +78,7 @@ class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
     _cursor_cls = AsyncAdapt_asyncmy_cursor
     _ss_cursor_cls = AsyncAdapt_asyncmy_ss_cursor
 
-    def _handle_exception(self, error):
+    def _handle_exception(self, error: Exception) -> NoReturn:
         if isinstance(error, AttributeError):
             raise self.dbapi.InternalError(
                 "network operation failed due to asyncmy attribute error"
@@ -66,24 +86,24 @@ class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
 
         raise error
 
-    def ping(self, reconnect):
+    def ping(self, reconnect: bool) -> None:
         assert not reconnect
         return await_(self._do_ping())
 
-    async def _do_ping(self):
+    async def _do_ping(self) -> None:
         try:
             async with self._execute_mutex:
-                return await self._connection.ping(False)
+                await self._connection.ping(False)
         except Exception as error:
             self._handle_exception(error)
 
-    def character_set_name(self):
-        return self._connection.character_set_name()
+    def character_set_name(self) -> Optional[str]:
+        return self._connection.character_set_name()  # type: ignore[no-any-return]  # noqa: E501
 
-    def autocommit(self, value):
+    def autocommit(self, value: Any) -> None:
         await_(self._connection.autocommit(value))
 
-    def terminate(self):
+    def terminate(self) -> None:
         # it's not awaitable.
         self._connection.close()
 
@@ -91,18 +111,13 @@ class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
         await_(self._connection.ensure_closed())
 
 
-def _Binary(x):
-    """Return x as a binary type."""
-    return bytes(x)
-
-
-class AsyncAdapt_asyncmy_dbapi:
-    def __init__(self, asyncmy):
+class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
+    def __init__(self, asyncmy: ModuleType):
         self.asyncmy = asyncmy
         self.paramstyle = "format"
         self._init_dbapi_attributes()
 
-    def _init_dbapi_attributes(self):
+    def _init_dbapi_attributes(self) -> None:
         for name in (
             "Warning",
             "Error",
@@ -123,9 +138,9 @@ class AsyncAdapt_asyncmy_dbapi:
     BINARY = util.symbol("BINARY")
     DATETIME = util.symbol("DATETIME")
     TIMESTAMP = util.symbol("TIMESTAMP")
-    Binary = staticmethod(_Binary)
+    Binary = staticmethod(bytes)
 
-    def connect(self, *arg, **kw):
+    def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_asyncmy_connection:
         creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
 
         return AsyncAdapt_asyncmy_connection(
@@ -138,25 +153,30 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
     driver = "asyncmy"
     supports_statement_cache = True
 
-    supports_server_side_cursors = True
+    supports_server_side_cursors = True  # type: ignore[assignment]
     _sscursor = AsyncAdapt_asyncmy_ss_cursor
 
     is_async = True
     has_terminate = True
 
     @classmethod
-    def import_dbapi(cls):
+    def import_dbapi(cls) -> DBAPIModule:
         return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
 
-    def do_terminate(self, dbapi_connection) -> None:
+    def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
         dbapi_connection.terminate()
 
-    def create_connect_args(self, url):
+    def create_connect_args(self, url: URL) -> ConnectArgsType:  # type: ignore[override]  # noqa: E501
         return super().create_connect_args(
             url, _translate_args=dict(username="user", database="db")
         )
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: DBAPIModule.Error,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
         if super().is_disconnect(e, connection, cursor):
             return True
         else:
@@ -165,13 +185,15 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
                 "not connected" in str_e or "network operation failed" in str_e
             )
 
-    def _found_rows_client_flag(self):
-        from asyncmy.constants import CLIENT
+    def _found_rows_client_flag(self) -> int:
+        from asyncmy.constants import CLIENT  # type: ignore
 
-        return CLIENT.FOUND_ROWS
+        return CLIENT.FOUND_ROWS  # type: ignore[no-any-return]
 
-    def get_driver_connection(self, connection):
-        return connection._connection
+    def get_driver_connection(
+        self, connection: DBAPIConnection
+    ) -> AsyncIODBAPIConnection:
+        return connection._connection  # type: ignore[no-any-return]
 
 
 dialect = MySQLDialect_asyncmy
index 2951b17d3b56b55069b66c44192e2756b4f291b6..ef37ba05652bf632d9a2fb7d6d29bda85779a798 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 
 r"""
@@ -1065,11 +1064,18 @@ output:
 """  # noqa
 from __future__ import annotations
 
-from array import array as _array
 from collections import defaultdict
 from itertools import compress
 import re
+from typing import Any
+from typing import Callable
 from typing import cast
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import TYPE_CHECKING
+from typing import Union
 
 from . import reflection as _reflection
 from .enumerated import ENUM
@@ -1113,7 +1119,6 @@ from .types import VARCHAR
 from .types import YEAR
 from ... import exc
 from ... import literal_column
-from ... import log
 from ... import schema as sa_schema
 from ... import sql
 from ... import util
@@ -1137,10 +1142,50 @@ from ...types import BINARY
 from ...types import BLOB
 from ...types import BOOLEAN
 from ...types import DATE
+from ...types import LargeBinary
 from ...types import UUID
 from ...types import VARBINARY
 from ...util import topological
 
+if TYPE_CHECKING:
+
+    from ...dialects.mysql import expression
+    from ...dialects.mysql.dml import DMLLimitClause
+    from ...dialects.mysql.dml import OnDuplicateClause
+    from ...engine.base import Connection
+    from ...engine.cursor import CursorResult
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.interfaces import IsolationLevel
+    from ...engine.interfaces import PoolProxiedConnection
+    from ...engine.interfaces import ReflectedCheckConstraint
+    from ...engine.interfaces import ReflectedColumn
+    from ...engine.interfaces import ReflectedForeignKeyConstraint
+    from ...engine.interfaces import ReflectedIndex
+    from ...engine.interfaces import ReflectedPrimaryKeyConstraint
+    from ...engine.interfaces import ReflectedTableComment
+    from ...engine.interfaces import ReflectedUniqueConstraint
+    from ...engine.result import _Ts
+    from ...engine.row import Row
+    from ...engine.url import URL
+    from ...schema import Table
+    from ...sql import ddl
+    from ...sql import selectable
+    from ...sql.dml import _DMLTableElement
+    from ...sql.dml import Delete
+    from ...sql.dml import Update
+    from ...sql.dml import ValuesBase
+    from ...sql.functions import aggregate_strings
+    from ...sql.functions import random
+    from ...sql.functions import rollup
+    from ...sql.functions import sysdate
+    from ...sql.schema import Sequence as Sequence_SchemaItem
+    from ...sql.type_api import TypeEngine
+    from ...sql.visitors import ExternallyTraversible
+    from ...util.typing import TupleAny
+    from ...util.typing import Unpack
+
 
 SET_RE = re.compile(
     r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE
@@ -1236,7 +1281,7 @@ ischema_names = {
 
 
 class MySQLExecutionContext(default.DefaultExecutionContext):
-    def post_exec(self):
+    def post_exec(self) -> None:
         if (
             self.isdelete
             and cast(SQLCompiler, self.compiled).effective_returning
@@ -1253,7 +1298,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
                 _cursor.FullyBufferedCursorFetchStrategy(
                     self.cursor,
                     [
-                        (entry.keyname, None)
+                        (entry.keyname, None)  # type: ignore[misc]
                         for entry in cast(
                             SQLCompiler, self.compiled
                         )._result_columns
@@ -1262,14 +1307,18 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
                 )
             )
 
-    def create_server_side_cursor(self):
+    def create_server_side_cursor(self) -> DBAPICursor:
         if self.dialect.supports_server_side_cursors:
-            return self._dbapi_connection.cursor(self.dialect._sscursor)
+            return self._dbapi_connection.cursor(
+                self.dialect._sscursor  # type: ignore[attr-defined]
+            )
         else:
             raise NotImplementedError()
 
-    def fire_sequence(self, seq, type_):
-        return self._execute_scalar(
+    def fire_sequence(
+        self, seq: Sequence_SchemaItem, type_: sqltypes.Integer
+    ) -> int:
+        return self._execute_scalar(  # type: ignore[no-any-return]
             (
                 "select nextval(%s)"
                 % self.identifier_preparer.format_sequence(seq)
@@ -1279,46 +1328,51 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
 
 
 class MySQLCompiler(compiler.SQLCompiler):
+    dialect: MySQLDialect
     render_table_with_column_in_update_from = True
     """Overridden from base SQLCompiler value"""
 
     extract_map = compiler.SQLCompiler.extract_map.copy()
     extract_map.update({"milliseconds": "millisecond"})
 
-    def default_from(self):
+    def default_from(self) -> str:
         """Called when a ``SELECT`` statement has no froms,
         and no ``FROM`` clause is to be appended.
 
         """
         if self.stack:
             stmt = self.stack[-1]["selectable"]
-            if stmt._where_criteria:
+            if stmt._where_criteria:  # type: ignore[attr-defined]
                 return " FROM DUAL"
 
         return ""
 
-    def visit_random_func(self, fn, **kw):
+    def visit_random_func(self, fn: random, **kw: Any) -> str:
         return "rand%s" % self.function_argspec(fn)
 
-    def visit_rollup_func(self, fn, **kw):
+    def visit_rollup_func(self, fn: rollup[Any], **kw: Any) -> str:
         clause = ", ".join(
             elem._compiler_dispatch(self, **kw) for elem in fn.clauses
         )
         return f"{clause} WITH ROLLUP"
 
-    def visit_aggregate_strings_func(self, fn, **kw):
+    def visit_aggregate_strings_func(
+        self, fn: aggregate_strings, **kw: Any
+    ) -> str:
         expr, delimeter = (
             elem._compiler_dispatch(self, **kw) for elem in fn.clauses
         )
         return f"group_concat({expr} SEPARATOR {delimeter})"
 
-    def visit_sequence(self, seq, **kw):
-        return "nextval(%s)" % self.preparer.format_sequence(seq)
+    def visit_sequence(self, sequence: sa_schema.Sequence, **kw: Any) -> str:
+        return "nextval(%s)" % self.preparer.format_sequence(sequence)
 
-    def visit_sysdate_func(self, fn, **kw):
+    def visit_sysdate_func(self, fn: sysdate, **kw: Any) -> str:
         return "SYSDATE()"
 
-    def _render_json_extract_from_binary(self, binary, operator, **kw):
+    def _render_json_extract_from_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         # note we are intentionally calling upon the process() calls in the
         # order in which they appear in the SQL String as this is used
         # by positional parameter rendering
@@ -1345,9 +1399,10 @@ class MySQLCompiler(compiler.SQLCompiler):
                 )
             )
         elif binary.type._type_affinity in (sqltypes.Numeric, sqltypes.Float):
+            binary_type = cast(sqltypes.Numeric[Any], binary.type)
             if (
-                binary.type.scale is not None
-                and binary.type.precision is not None
+                binary_type.scale is not None
+                and binary_type.precision is not None
             ):
                 # using DECIMAL here because MySQL does not recognize NUMERIC
                 type_expression = (
@@ -1355,8 +1410,8 @@ class MySQLCompiler(compiler.SQLCompiler):
                     % (
                         self.process(binary.left, **kw),
                         self.process(binary.right, **kw),
-                        binary.type.precision,
-                        binary.type.scale,
+                        binary_type.precision,
+                        binary_type.scale,
                     )
                 )
             else:
@@ -1390,15 +1445,22 @@ class MySQLCompiler(compiler.SQLCompiler):
 
         return case_expression + " " + type_expression + " END"
 
-    def visit_json_getitem_op_binary(self, binary, operator, **kw):
+    def visit_json_getitem_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         return self._render_json_extract_from_binary(binary, operator, **kw)
 
-    def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+    def visit_json_path_getitem_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         return self._render_json_extract_from_binary(binary, operator, **kw)
 
-    def visit_on_duplicate_key_update(self, on_duplicate, **kw):
-        statement = self.current_executable
+    def visit_on_duplicate_key_update(
+        self, on_duplicate: OnDuplicateClause, **kw: Any
+    ) -> str:
+        statement: ValuesBase = self.current_executable
 
+        cols: list[elements.KeyedColumnElement[Any]]
         if on_duplicate._parameter_ordering:
             parameter_ordering = [
                 coercions.expect(roles.DMLColumnRole, key)
@@ -1411,7 +1473,7 @@ class MySQLCompiler(compiler.SQLCompiler):
                 if key in statement.table.c
             ] + [c for c in statement.table.c if c.key not in ordered_keys]
         else:
-            cols = statement.table.c
+            cols = list(statement.table.c)
 
         clauses = []
 
@@ -1420,7 +1482,7 @@ class MySQLCompiler(compiler.SQLCompiler):
         )
 
         if requires_mysql8_alias:
-            if statement.table.name.lower() == "new":
+            if statement.table.name.lower() == "new":  # type: ignore[union-attr]  # noqa: E501
                 _on_dup_alias_name = "new_1"
             else:
                 _on_dup_alias_name = "new"
@@ -1434,24 +1496,26 @@ class MySQLCompiler(compiler.SQLCompiler):
         for column in (col for col in cols if col.key in on_duplicate_update):
             val = on_duplicate_update[column.key]
 
-            def replace(obj):
+            def replace(
+                element: ExternallyTraversible, **kw: Any
+            ) -> Optional[ExternallyTraversible]:
                 if (
-                    isinstance(obj, elements.BindParameter)
-                    and obj.type._isnull
+                    isinstance(element, elements.BindParameter)
+                    and element.type._isnull
                 ):
-                    return obj._with_binary_element_type(column.type)
+                    return element._with_binary_element_type(column.type)
                 elif (
-                    isinstance(obj, elements.ColumnClause)
-                    and obj.table is on_duplicate.inserted_alias
+                    isinstance(element, elements.ColumnClause)
+                    and element.table is on_duplicate.inserted_alias
                 ):
                     if requires_mysql8_alias:
                         column_literal_clause = (
                             f"{_on_dup_alias_name}."
-                            f"{self.preparer.quote(obj.name)}"
+                            f"{self.preparer.quote(element.name)}"
                         )
                     else:
                         column_literal_clause = (
-                            f"VALUES({self.preparer.quote(obj.name)})"
+                            f"VALUES({self.preparer.quote(element.name)})"
                         )
                     return literal_column(column_literal_clause)
                 else:
@@ -1470,7 +1534,7 @@ class MySQLCompiler(compiler.SQLCompiler):
                 "Additional column names not matching "
                 "any column keys in table '%s': %s"
                 % (
-                    self.statement.table.name,
+                    self.statement.table.name,  # type: ignore[union-attr]
                     (", ".join("'%s'" % c for c in non_matching)),
                 )
             )
@@ -1484,13 +1548,15 @@ class MySQLCompiler(compiler.SQLCompiler):
             return f"ON DUPLICATE KEY UPDATE {', '.join(clauses)}"
 
     def visit_concat_op_expression_clauselist(
-        self, clauselist, operator, **kw
-    ):
+        self, clauselist: elements.ClauseList, operator: Any, **kw: Any
+    ) -> str:
         return "concat(%s)" % (
             ", ".join(self.process(elem, **kw) for elem in clauselist.clauses)
         )
 
-    def visit_concat_op_binary(self, binary, operator, **kw):
+    def visit_concat_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         return "concat(%s, %s)" % (
             self.process(binary.left, **kw),
             self.process(binary.right, **kw),
@@ -1513,10 +1579,12 @@ class MySQLCompiler(compiler.SQLCompiler):
         "WITH QUERY EXPANSION",
     )
 
-    def visit_mysql_match(self, element, **kw):
+    def visit_mysql_match(self, element: expression.match, **kw: Any) -> str:
         return self.visit_match_op_binary(element, element.operator, **kw)
 
-    def visit_match_op_binary(self, binary, operator, **kw):
+    def visit_match_op_binary(
+        self, binary: expression.match, operator: Any, **kw: Any
+    ) -> str:
         """
         Note that `mysql_boolean_mode` is enabled by default because of
         backward compatibility
@@ -1537,12 +1605,11 @@ class MySQLCompiler(compiler.SQLCompiler):
                 "with_query_expansion=%s" % query_expansion,
             )
 
-            flags = ", ".join(flags)
+            flags_str = ", ".join(flags)
 
-            raise exc.CompileError("Invalid MySQL match flags: %s" % flags)
+            raise exc.CompileError("Invalid MySQL match flags: %s" % flags_str)
 
-        match_clause = binary.left
-        match_clause = self.process(match_clause, **kw)
+        match_clause = self.process(binary.left, **kw)
         against_clause = self.process(binary.right, **kw)
 
         if any(flag_combination):
@@ -1551,21 +1618,25 @@ class MySQLCompiler(compiler.SQLCompiler):
                 flag_combination,
             )
 
-            against_clause = [against_clause]
-            against_clause.extend(flag_expressions)
-
-            against_clause = " ".join(against_clause)
+            against_clause = " ".join([against_clause, *flag_expressions])
 
         return "MATCH (%s) AGAINST (%s)" % (match_clause, against_clause)
 
-    def get_from_hint_text(self, table, text):
+    def get_from_hint_text(
+        self, table: selectable.FromClause, text: Optional[str]
+    ) -> Optional[str]:
         return text
 
-    def visit_typeclause(self, typeclause, type_=None, **kw):
+    def visit_typeclause(
+        self,
+        typeclause: elements.TypeClause,
+        type_: Optional[TypeEngine[Any]] = None,
+        **kw: Any,
+    ) -> Optional[str]:
         if type_ is None:
             type_ = typeclause.type.dialect_impl(self.dialect)
         if isinstance(type_, sqltypes.TypeDecorator):
-            return self.visit_typeclause(typeclause, type_.impl, **kw)
+            return self.visit_typeclause(typeclause, type_.impl, **kw)  # type: ignore[arg-type]  # noqa: E501
         elif isinstance(type_, sqltypes.Integer):
             if getattr(type_, "unsigned", False):
                 return "UNSIGNED INTEGER"
@@ -1604,7 +1675,7 @@ class MySQLCompiler(compiler.SQLCompiler):
         else:
             return None
 
-    def visit_cast(self, cast, **kw):
+    def visit_cast(self, cast: elements.Cast[Any], **kw: Any) -> str:
         type_ = self.process(cast.typeclause)
         if type_ is None:
             util.warn(
@@ -1618,7 +1689,9 @@ class MySQLCompiler(compiler.SQLCompiler):
 
         return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_)
 
-    def render_literal_value(self, value, type_):
+    def render_literal_value(
+        self, value: Optional[str], type_: TypeEngine[Any]
+    ) -> str:
         value = super().render_literal_value(value, type_)
         if self.dialect._backslash_escapes:
             value = value.replace("\\", "\\\\")
@@ -1626,13 +1699,15 @@ class MySQLCompiler(compiler.SQLCompiler):
 
     # override native_boolean=False behavior here, as
     # MySQL still supports native boolean
-    def visit_true(self, element, **kw):
+    def visit_true(self, expr: elements.True_, **kw: Any) -> str:
         return "true"
 
-    def visit_false(self, element, **kw):
+    def visit_false(self, expr: elements.False_, **kw: Any) -> str:
         return "false"
 
-    def get_select_precolumns(self, select, **kw):
+    def get_select_precolumns(
+        self, select: selectable.Select[Any], **kw: Any
+    ) -> str:
         """Add special MySQL keywords in place of DISTINCT.
 
         .. deprecated:: 1.4 This usage is deprecated.
@@ -1652,7 +1727,13 @@ class MySQLCompiler(compiler.SQLCompiler):
 
         return super().get_select_precolumns(select, **kw)
 
-    def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
+    def visit_join(
+        self,
+        join: selectable.Join,
+        asfrom: bool = False,
+        from_linter: Optional[compiler.FromLinter] = None,
+        **kwargs: Any,
+    ) -> str:
         if from_linter:
             from_linter.edges.add((join.left, join.right))
 
@@ -1673,18 +1754,21 @@ class MySQLCompiler(compiler.SQLCompiler):
                     join.right, asfrom=True, from_linter=from_linter, **kwargs
                 ),
                 " ON ",
-                self.process(join.onclause, from_linter=from_linter, **kwargs),
+                self.process(join.onclause, from_linter=from_linter, **kwargs),  # type: ignore[arg-type]  # noqa: E501
             )
         )
 
-    def for_update_clause(self, select, **kw):
+    def for_update_clause(
+        self, select: selectable.GenerativeSelect, **kw: Any
+    ) -> str:
+        assert select._for_update_arg is not None
         if select._for_update_arg.read:
             tmp = " LOCK IN SHARE MODE"
         else:
             tmp = " FOR UPDATE"
 
         if select._for_update_arg.of and self.dialect.supports_for_update_of:
-            tables = util.OrderedSet()
+            tables: util.OrderedSet[elements.ClauseElement] = util.OrderedSet()
             for c in select._for_update_arg.of:
                 tables.update(sql_util.surface_selectables_only(c))
 
@@ -1701,7 +1785,9 @@ class MySQLCompiler(compiler.SQLCompiler):
 
         return tmp
 
-    def limit_clause(self, select, **kw):
+    def limit_clause(
+        self, select: selectable.GenerativeSelect, **kw: Any
+    ) -> str:
         # MySQL supports:
         #   LIMIT <limit>
         #   LIMIT <offset>, <limit>
@@ -1737,10 +1823,13 @@ class MySQLCompiler(compiler.SQLCompiler):
                     self.process(limit_clause, **kw),
                 )
         else:
+            assert limit_clause is not None
             # No offset provided, so just use the limit
             return " \n LIMIT %s" % (self.process(limit_clause, **kw),)
 
-    def update_post_criteria_clause(self, update_stmt, **kw):
+    def update_post_criteria_clause(
+        self, update_stmt: Update, **kw: Any
+    ) -> Optional[str]:
         limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None)
         supertext = super().update_post_criteria_clause(update_stmt, **kw)
 
@@ -1753,7 +1842,9 @@ class MySQLCompiler(compiler.SQLCompiler):
         else:
             return supertext
 
-    def delete_post_criteria_clause(self, delete_stmt, **kw):
+    def delete_post_criteria_clause(
+        self, delete_stmt: Delete, **kw: Any
+    ) -> Optional[str]:
         limit = delete_stmt.kwargs.get("%s_limit" % self.dialect.name, None)
         supertext = super().delete_post_criteria_clause(delete_stmt, **kw)
 
@@ -1766,11 +1857,19 @@ class MySQLCompiler(compiler.SQLCompiler):
         else:
             return supertext
 
-    def visit_mysql_dml_limit_clause(self, element, **kw):
+    def visit_mysql_dml_limit_clause(
+        self, element: DMLLimitClause, **kw: Any
+    ) -> str:
         kw["literal_execute"] = True
         return f"LIMIT {self.process(element._limit_clause, **kw)}"
 
-    def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+    def update_tables_clause(
+        self,
+        update_stmt: Update,
+        from_table: _DMLTableElement,
+        extra_froms: list[selectable.FromClause],
+        **kw: Any,
+    ) -> str:
         kw["asfrom"] = True
         return ", ".join(
             t._compiler_dispatch(self, **kw)
@@ -1778,11 +1877,22 @@ class MySQLCompiler(compiler.SQLCompiler):
         )
 
     def update_from_clause(
-        self, update_stmt, from_table, extra_froms, from_hints, **kw
-    ):
+        self,
+        update_stmt: Update,
+        from_table: _DMLTableElement,
+        extra_froms: list[selectable.FromClause],
+        from_hints: Any,
+        **kw: Any,
+    ) -> None:
         return None
 
-    def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
+    def delete_table_clause(
+        self,
+        delete_stmt: Delete,
+        from_table: _DMLTableElement,
+        extra_froms: list[selectable.FromClause],
+        **kw: Any,
+    ) -> str:
         """If we have extra froms make sure we render any alias as hint."""
         ashint = False
         if extra_froms:
@@ -1792,8 +1902,13 @@ class MySQLCompiler(compiler.SQLCompiler):
         )
 
     def delete_extra_from_clause(
-        self, delete_stmt, from_table, extra_froms, from_hints, **kw
-    ):
+        self,
+        delete_stmt: Delete,
+        from_table: _DMLTableElement,
+        extra_froms: list[selectable.FromClause],
+        from_hints: Any,
+        **kw: Any,
+    ) -> str:
         """Render the DELETE .. USING clause specific to MySQL."""
         kw["asfrom"] = True
         return "USING " + ", ".join(
@@ -1801,7 +1916,9 @@ class MySQLCompiler(compiler.SQLCompiler):
             for t in [from_table] + extra_froms
         )
 
-    def visit_empty_set_expr(self, element_types, **kw):
+    def visit_empty_set_expr(
+        self, element_types: list[TypeEngine[Any]], **kw: Any
+    ) -> str:
         return (
             "SELECT %(outer)s FROM (SELECT %(inner)s) "
             "as _empty_set WHERE 1!=1"
@@ -1816,25 +1933,38 @@ class MySQLCompiler(compiler.SQLCompiler):
             }
         )
 
-    def visit_is_distinct_from_binary(self, binary, operator, **kw):
+    def visit_is_distinct_from_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         return "NOT (%s <=> %s)" % (
             self.process(binary.left),
             self.process(binary.right),
         )
 
-    def visit_is_not_distinct_from_binary(self, binary, operator, **kw):
+    def visit_is_not_distinct_from_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         return "%s <=> %s" % (
             self.process(binary.left),
             self.process(binary.right),
         )
 
-    def _mariadb_regexp_flags(self, flags, pattern, **kw):
+    def _mariadb_regexp_flags(
+        self, flags: str, pattern: elements.ColumnElement[Any], **kw: Any
+    ) -> str:
         return "CONCAT('(?', %s, ')', %s)" % (
             self.render_literal_value(flags, sqltypes.STRINGTYPE),
             self.process(pattern, **kw),
         )
 
-    def _regexp_match(self, op_string, binary, operator, **kw):
+    def _regexp_match(
+        self,
+        op_string: str,
+        binary: elements.BinaryExpression[Any],
+        operator: Any,
+        **kw: Any,
+    ) -> str:
+        assert binary.modifiers is not None
         flags = binary.modifiers["flags"]
         if flags is None:
             return self._generate_generic_binary(binary, op_string, **kw)
@@ -1855,13 +1985,20 @@ class MySQLCompiler(compiler.SQLCompiler):
             else:
                 return text
 
-    def visit_regexp_match_op_binary(self, binary, operator, **kw):
+    def visit_regexp_match_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         return self._regexp_match(" REGEXP ", binary, operator, **kw)
 
-    def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+    def visit_not_regexp_match_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         return self._regexp_match(" NOT REGEXP ", binary, operator, **kw)
 
-    def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+    def visit_regexp_replace_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
+        assert binary.modifiers is not None
         flags = binary.modifiers["flags"]
         if flags is None:
             return "REGEXP_REPLACE(%s, %s)" % (
@@ -1883,7 +2020,11 @@ class MySQLCompiler(compiler.SQLCompiler):
 
 
 class MySQLDDLCompiler(compiler.DDLCompiler):
-    def get_column_specification(self, column, **kw):
+    dialect: MySQLDialect
+
+    def get_column_specification(
+        self, column: sa_schema.Column[Any], **kw: Any
+    ) -> str:
         """Builds column DDL."""
         if (
             self.dialect.is_mariadb is True
@@ -1949,7 +2090,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
                     colspec.append("DEFAULT " + default)
         return " ".join(colspec)
 
-    def post_create_table(self, table):
+    def post_create_table(self, table: sa_schema.Table) -> str:
         """Build table-level CREATE options like ENGINE and COLLATE."""
 
         table_opts = []
@@ -2033,16 +2174,16 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
 
         return " ".join(table_opts)
 
-    def visit_create_index(self, create, **kw):
+    def visit_create_index(self, create: ddl.CreateIndex, **kw: Any) -> str:  # type: ignore[override]  # noqa: E501
         index = create.element
         self._verify_index_table(index)
         preparer = self.preparer
-        table = preparer.format_table(index.table)
+        table = preparer.format_table(index.table)  # type: ignore[arg-type]
 
         columns = [
             self.sql_compiler.process(
                 (
-                    elements.Grouping(expr)
+                    elements.Grouping(expr)  # type: ignore[arg-type]
                     if (
                         isinstance(expr, elements.BinaryExpression)
                         or (
@@ -2081,10 +2222,10 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
                 # length value can be a (column_name --> integer value)
                 # mapping specifying the prefix length for each column of the
                 # index
-                columns = ", ".join(
+                columns_str = ", ".join(
                     (
-                        "%s(%d)" % (expr, length[col.name])
-                        if col.name in length
+                        "%s(%d)" % (expr, length[col.name])  # type: ignore[union-attr]  # noqa: E501
+                        if col.name in length  # type: ignore[union-attr]
                         else (
                             "%s(%d)" % (expr, length[expr])
                             if expr in length
@@ -2096,12 +2237,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             else:
                 # or can be an integer value specifying the same
                 # prefix length for all columns of the index
-                columns = ", ".join(
+                columns_str = ", ".join(
                     "%s(%d)" % (col, length) for col in columns
                 )
         else:
-            columns = ", ".join(columns)
-        text += "(%s)" % columns
+            columns_str = ", ".join(columns)
+        text += "(%s)" % columns_str
 
         parser = index.dialect_options["mysql"]["with_parser"]
         if parser is not None:
@@ -2113,14 +2254,16 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
 
         return text
 
-    def visit_primary_key_constraint(self, constraint, **kw):
+    def visit_primary_key_constraint(
+        self, constraint: sa_schema.PrimaryKeyConstraint, **kw: Any
+    ) -> str:
         text = super().visit_primary_key_constraint(constraint)
         using = constraint.dialect_options["mysql"]["using"]
         if using:
             text += " USING %s" % (self.preparer.quote(using))
         return text
 
-    def visit_drop_index(self, drop, **kw):
+    def visit_drop_index(self, drop: ddl.DropIndex, **kw: Any) -> str:
         index = drop.element
         text = "\nDROP INDEX "
         if drop.if_exists:
@@ -2128,10 +2271,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
 
         return text + "%s ON %s" % (
             self._prepared_index_name(index, include_schema=False),
-            self.preparer.format_table(index.table),
+            self.preparer.format_table(index.table),  # type: ignore[arg-type]
         )
 
-    def visit_drop_constraint(self, drop, **kw):
+    def visit_drop_constraint(
+        self, drop: ddl.DropConstraint, **kw: Any
+    ) -> str:
         constraint = drop.element
         if isinstance(constraint, sa_schema.ForeignKeyConstraint):
             qual = "FOREIGN KEY "
@@ -2157,7 +2302,9 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             const,
         )
 
-    def define_constraint_match(self, constraint):
+    def define_constraint_match(
+        self, constraint: sa_schema.ForeignKeyConstraint
+    ) -> str:
         if constraint.match is not None:
             raise exc.CompileError(
                 "MySQL ignores the 'MATCH' keyword while at the same time "
@@ -2165,7 +2312,9 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             )
         return ""
 
-    def visit_set_table_comment(self, create, **kw):
+    def visit_set_table_comment(
+        self, create: ddl.SetTableComment, **kw: Any
+    ) -> str:
         return "ALTER TABLE %s COMMENT %s" % (
             self.preparer.format_table(create.element),
             self.sql_compiler.render_literal_value(
@@ -2173,12 +2322,16 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             ),
         )
 
-    def visit_drop_table_comment(self, create, **kw):
+    def visit_drop_table_comment(
+        self, drop: ddl.DropTableComment, **kw: Any
+    ) -> str:
         return "ALTER TABLE %s COMMENT ''" % (
-            self.preparer.format_table(create.element)
+            self.preparer.format_table(drop.element)
         )
 
-    def visit_set_column_comment(self, create, **kw):
+    def visit_set_column_comment(
+        self, create: ddl.SetColumnComment, **kw: Any
+    ) -> str:
         return "ALTER TABLE %s CHANGE %s %s" % (
             self.preparer.format_table(create.element.table),
             self.preparer.format_column(create.element),
@@ -2187,7 +2340,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
 
 
 class MySQLTypeCompiler(compiler.GenericTypeCompiler):
-    def _extend_numeric(self, type_, spec):
+    def _extend_numeric(self, type_: _NumericCommonType, spec: str) -> str:
         "Extend a numeric-type declaration with MySQL specific extensions."
 
         if not self._mysql_type(type_):
@@ -2199,13 +2352,15 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
             spec += " ZEROFILL"
         return spec
 
-    def _extend_string(self, type_, defaults, spec):
+    def _extend_string(
+        self, type_: _StringType, defaults: dict[str, Any], spec: str
+    ) -> str:
         """Extend a string-type declaration with standard SQL CHARACTER SET /
         COLLATE annotations and MySQL specific extensions.
 
         """
 
-        def attr(name):
+        def attr(name: str) -> Any:
             return getattr(type_, name, defaults.get(name))
 
         if attr("charset"):
@@ -2215,6 +2370,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         elif attr("unicode"):
             charset = "UNICODE"
         else:
+
             charset = None
 
         if attr("collation"):
@@ -2233,10 +2389,10 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
             [c for c in (spec, charset, collation) if c is not None]
         )
 
-    def _mysql_type(self, type_):
+    def _mysql_type(self, type_: Any) -> bool:
         return isinstance(type_, (_StringType, _NumericCommonType))
 
-    def visit_NUMERIC(self, type_, **kw):
+    def visit_NUMERIC(self, type_: NUMERIC, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if type_.precision is None:
             return self._extend_numeric(type_, "NUMERIC")
         elif type_.scale is None:
@@ -2251,7 +2407,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
                 % {"precision": type_.precision, "scale": type_.scale},
             )
 
-    def visit_DECIMAL(self, type_, **kw):
+    def visit_DECIMAL(self, type_: DECIMAL, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if type_.precision is None:
             return self._extend_numeric(type_, "DECIMAL")
         elif type_.scale is None:
@@ -2266,7 +2422,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
                 % {"precision": type_.precision, "scale": type_.scale},
             )
 
-    def visit_DOUBLE(self, type_, **kw):
+    def visit_DOUBLE(self, type_: DOUBLE, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if type_.precision is not None and type_.scale is not None:
             return self._extend_numeric(
                 type_,
@@ -2276,7 +2432,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "DOUBLE")
 
-    def visit_REAL(self, type_, **kw):
+    def visit_REAL(self, type_: REAL, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if type_.precision is not None and type_.scale is not None:
             return self._extend_numeric(
                 type_,
@@ -2286,7 +2442,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "REAL")
 
-    def visit_FLOAT(self, type_, **kw):
+    def visit_FLOAT(self, type_: FLOAT, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if (
             self._mysql_type(type_)
             and type_.scale is not None
@@ -2302,7 +2458,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "FLOAT")
 
-    def visit_INTEGER(self, type_, **kw):
+    def visit_INTEGER(self, type_: INTEGER, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if self._mysql_type(type_) and type_.display_width is not None:
             return self._extend_numeric(
                 type_,
@@ -2312,7 +2468,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "INTEGER")
 
-    def visit_BIGINT(self, type_, **kw):
+    def visit_BIGINT(self, type_: BIGINT, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if self._mysql_type(type_) and type_.display_width is not None:
             return self._extend_numeric(
                 type_,
@@ -2322,7 +2478,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "BIGINT")
 
-    def visit_MEDIUMINT(self, type_, **kw):
+    def visit_MEDIUMINT(self, type_: MEDIUMINT, **kw: Any) -> str:
         if self._mysql_type(type_) and type_.display_width is not None:
             return self._extend_numeric(
                 type_,
@@ -2332,7 +2488,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "MEDIUMINT")
 
-    def visit_TINYINT(self, type_, **kw):
+    def visit_TINYINT(self, type_: TINYINT, **kw: Any) -> str:
         if self._mysql_type(type_) and type_.display_width is not None:
             return self._extend_numeric(
                 type_, "TINYINT(%s)" % type_.display_width
@@ -2340,7 +2496,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "TINYINT")
 
-    def visit_SMALLINT(self, type_, **kw):
+    def visit_SMALLINT(self, type_: SMALLINT, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if self._mysql_type(type_) and type_.display_width is not None:
             return self._extend_numeric(
                 type_,
@@ -2350,55 +2506,55 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "SMALLINT")
 
-    def visit_BIT(self, type_, **kw):
+    def visit_BIT(self, type_: BIT, **kw: Any) -> str:
         if type_.length is not None:
             return "BIT(%s)" % type_.length
         else:
             return "BIT"
 
-    def visit_DATETIME(self, type_, **kw):
+    def visit_DATETIME(self, type_: DATETIME, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if getattr(type_, "fsp", None):
-            return "DATETIME(%d)" % type_.fsp
+            return "DATETIME(%d)" % type_.fsp  # type: ignore[str-format]
         else:
             return "DATETIME"
 
-    def visit_DATE(self, type_, **kw):
+    def visit_DATE(self, type_: DATE, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         return "DATE"
 
-    def visit_TIME(self, type_, **kw):
+    def visit_TIME(self, type_: TIME, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if getattr(type_, "fsp", None):
-            return "TIME(%d)" % type_.fsp
+            return "TIME(%d)" % type_.fsp  # type: ignore[str-format]
         else:
             return "TIME"
 
-    def visit_TIMESTAMP(self, type_, **kw):
+    def visit_TIMESTAMP(self, type_: TIMESTAMP, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if getattr(type_, "fsp", None):
-            return "TIMESTAMP(%d)" % type_.fsp
+            return "TIMESTAMP(%d)" % type_.fsp  # type: ignore[str-format]
         else:
             return "TIMESTAMP"
 
-    def visit_YEAR(self, type_, **kw):
+    def visit_YEAR(self, type_: YEAR, **kw: Any) -> str:
         if type_.display_width is None:
             return "YEAR"
         else:
             return "YEAR(%s)" % type_.display_width
 
-    def visit_TEXT(self, type_, **kw):
+    def visit_TEXT(self, type_: TEXT, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if type_.length is not None:
             return self._extend_string(type_, {}, "TEXT(%d)" % type_.length)
         else:
             return self._extend_string(type_, {}, "TEXT")
 
-    def visit_TINYTEXT(self, type_, **kw):
+    def visit_TINYTEXT(self, type_: TINYTEXT, **kw: Any) -> str:
         return self._extend_string(type_, {}, "TINYTEXT")
 
-    def visit_MEDIUMTEXT(self, type_, **kw):
+    def visit_MEDIUMTEXT(self, type_: MEDIUMTEXT, **kw: Any) -> str:
         return self._extend_string(type_, {}, "MEDIUMTEXT")
 
-    def visit_LONGTEXT(self, type_, **kw):
+    def visit_LONGTEXT(self, type_: LONGTEXT, **kw: Any) -> str:
         return self._extend_string(type_, {}, "LONGTEXT")
 
-    def visit_VARCHAR(self, type_, **kw):
+    def visit_VARCHAR(self, type_: VARCHAR, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if type_.length is not None:
             return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length)
         else:
@@ -2406,7 +2562,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
                 "VARCHAR requires a length on dialect %s" % self.dialect.name
             )
 
-    def visit_CHAR(self, type_, **kw):
+    def visit_CHAR(self, type_: CHAR, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if type_.length is not None:
             return self._extend_string(
                 type_, {}, "CHAR(%(length)s)" % {"length": type_.length}
@@ -2414,7 +2570,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_string(type_, {}, "CHAR")
 
-    def visit_NVARCHAR(self, type_, **kw):
+    def visit_NVARCHAR(self, type_: NVARCHAR, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         # We'll actually generate the equiv. "NATIONAL VARCHAR" instead
         # of "NVARCHAR".
         if type_.length is not None:
@@ -2428,7 +2584,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
                 "NVARCHAR requires a length on dialect %s" % self.dialect.name
             )
 
-    def visit_NCHAR(self, type_, **kw):
+    def visit_NCHAR(self, type_: NCHAR, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         # We'll actually generate the equiv.
         # "NATIONAL CHAR" instead of "NCHAR".
         if type_.length is not None:
@@ -2440,40 +2596,42 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_string(type_, {"national": True}, "CHAR")
 
-    def visit_UUID(self, type_, **kw):
+    def visit_UUID(self, type_: UUID[Any], **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         return "UUID"
 
-    def visit_VARBINARY(self, type_, **kw):
-        return "VARBINARY(%d)" % type_.length
+    def visit_VARBINARY(self, type_: VARBINARY, **kw: Any) -> str:
+        return "VARBINARY(%d)" % type_.length  # type: ignore[str-format]
 
-    def visit_JSON(self, type_, **kw):
+    def visit_JSON(self, type_: JSON, **kw: Any) -> str:
         return "JSON"
 
-    def visit_large_binary(self, type_, **kw):
+    def visit_large_binary(self, type_: LargeBinary, **kw: Any) -> str:
         return self.visit_BLOB(type_)
 
-    def visit_enum(self, type_, **kw):
+    def visit_enum(self, type_: ENUM, **kw: Any) -> str:  # type: ignore[override]  # NOQA: E501
         if not type_.native_enum:
             return super().visit_enum(type_)
         else:
             return self._visit_enumerated_values("ENUM", type_, type_.enums)
 
-    def visit_BLOB(self, type_, **kw):
+    def visit_BLOB(self, type_: LargeBinary, **kw: Any) -> str:
         if type_.length is not None:
             return "BLOB(%d)" % type_.length
         else:
             return "BLOB"
 
-    def visit_TINYBLOB(self, type_, **kw):
+    def visit_TINYBLOB(self, type_: TINYBLOB, **kw: Any) -> str:
         return "TINYBLOB"
 
-    def visit_MEDIUMBLOB(self, type_, **kw):
+    def visit_MEDIUMBLOB(self, type_: MEDIUMBLOB, **kw: Any) -> str:
         return "MEDIUMBLOB"
 
-    def visit_LONGBLOB(self, type_, **kw):
+    def visit_LONGBLOB(self, type_: LONGBLOB, **kw: Any) -> str:
         return "LONGBLOB"
 
-    def _visit_enumerated_values(self, name, type_, enumerated_values):
+    def _visit_enumerated_values(
+        self, name: str, type_: _StringType, enumerated_values: Sequence[str]
+    ) -> str:
         quoted_enums = []
         for e in enumerated_values:
             if self.dialect.identifier_preparer._double_percents:
@@ -2483,20 +2641,25 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
             type_, {}, "%s(%s)" % (name, ",".join(quoted_enums))
         )
 
-    def visit_ENUM(self, type_, **kw):
+    def visit_ENUM(self, type_: ENUM, **kw: Any) -> str:
         return self._visit_enumerated_values("ENUM", type_, type_.enums)
 
-    def visit_SET(self, type_, **kw):
+    def visit_SET(self, type_: SET, **kw: Any) -> str:
         return self._visit_enumerated_values("SET", type_, type_.values)
 
-    def visit_BOOLEAN(self, type_, **kw):
+    def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str:
         return "BOOL"
 
 
 class MySQLIdentifierPreparer(compiler.IdentifierPreparer):
     reserved_words = RESERVED_WORDS_MYSQL
 
-    def __init__(self, dialect, server_ansiquotes=False, **kw):
+    def __init__(
+        self,
+        dialect: default.DefaultDialect,
+        server_ansiquotes: bool = False,
+        **kw: Any,
+    ):
         if not server_ansiquotes:
             quote = "`"
         else:
@@ -2504,7 +2667,7 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer):
 
         super().__init__(dialect, initial_quote=quote, escape_quote=quote)
 
-    def _quote_free_identifiers(self, *ids):
+    def _quote_free_identifiers(self, *ids: Optional[str]) -> tuple[str, ...]:
         """Unilaterally identifier-quote any number of strings."""
 
         return tuple([self.quote_identifier(i) for i in ids if i is not None])
@@ -2514,7 +2677,6 @@ class MariaDBIdentifierPreparer(MySQLIdentifierPreparer):
     reserved_words = RESERVED_WORDS_MARIADB
 
 
-@log.class_logger
 class MySQLDialect(default.DefaultDialect):
     """Details of the MySQL dialect.
     Not used directly in application code.
@@ -2581,9 +2743,9 @@ class MySQLDialect(default.DefaultDialect):
     ddl_compiler = MySQLDDLCompiler
     type_compiler_cls = MySQLTypeCompiler
     ischema_names = ischema_names
-    preparer = MySQLIdentifierPreparer
+    preparer: type[MySQLIdentifierPreparer] = MySQLIdentifierPreparer
 
-    is_mariadb = False
+    is_mariadb: bool = False
     _mariadb_normalized_version_info = None
 
     # default SQL compilation settings -
@@ -2592,6 +2754,9 @@ class MySQLDialect(default.DefaultDialect):
     _backslash_escapes = True
     _server_ansiquotes = False
 
+    server_version_info: tuple[int, ...]
+    identifier_preparer: MySQLIdentifierPreparer
+
     construct_arguments = [
         (sa_schema.Table, {"*": None}),
         (sql.Update, {"limit": None}),
@@ -2610,18 +2775,20 @@ class MySQLDialect(default.DefaultDialect):
 
     def __init__(
         self,
-        json_serializer=None,
-        json_deserializer=None,
-        is_mariadb=None,
-        **kwargs,
-    ):
+        json_serializer: Optional[Callable[..., Any]] = None,
+        json_deserializer: Optional[Callable[..., Any]] = None,
+        is_mariadb: Optional[bool] = None,
+        **kwargs: Any,
+    ) -> None:
         kwargs.pop("use_ansiquotes", None)  # legacy
         default.DefaultDialect.__init__(self, **kwargs)
         self._json_serializer = json_serializer
         self._json_deserializer = json_deserializer
-        self._set_mariadb(is_mariadb, None)
+        self._set_mariadb(is_mariadb, ())
 
-    def get_isolation_level_values(self, dbapi_conn):
+    def get_isolation_level_values(
+        self, dbapi_conn: DBAPIConnection
+    ) -> Sequence[IsolationLevel]:
         return (
             "SERIALIZABLE",
             "READ UNCOMMITTED",
@@ -2629,13 +2796,17 @@ class MySQLDialect(default.DefaultDialect):
             "REPEATABLE READ",
         )
 
-    def set_isolation_level(self, dbapi_connection, level):
+    def set_isolation_level(
+        self, dbapi_connection: DBAPIConnection, level: IsolationLevel
+    ) -> None:
         cursor = dbapi_connection.cursor()
         cursor.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {level}")
         cursor.execute("COMMIT")
         cursor.close()
 
-    def get_isolation_level(self, dbapi_connection):
+    def get_isolation_level(
+        self, dbapi_connection: DBAPIConnection
+    ) -> IsolationLevel:
         cursor = dbapi_connection.cursor()
         if self._is_mysql and self.server_version_info >= (5, 7, 20):
             cursor.execute("SELECT @@transaction_isolation")
@@ -2652,10 +2823,10 @@ class MySQLDialect(default.DefaultDialect):
         cursor.close()
         if isinstance(val, bytes):
             val = val.decode()
-        return val.upper().replace("-", " ")
+        return val.upper().replace("-", " ")  # type: ignore[no-any-return]
 
     @classmethod
-    def _is_mariadb_from_url(cls, url):
+    def _is_mariadb_from_url(cls, url: URL) -> bool:
         dbapi = cls.import_dbapi()
         dialect = cls(dbapi=dbapi)
 
@@ -2664,7 +2835,7 @@ class MySQLDialect(default.DefaultDialect):
         try:
             cursor = conn.cursor()
             cursor.execute("SELECT VERSION() LIKE '%MariaDB%'")
-            val = cursor.fetchone()[0]
+            val = cursor.fetchone()[0]  # type: ignore[index]
         except:
             raise
         else:
@@ -2672,22 +2843,25 @@ class MySQLDialect(default.DefaultDialect):
         finally:
             conn.close()
 
-    def _get_server_version_info(self, connection):
+    def _get_server_version_info(
+        self, connection: Connection
+    ) -> tuple[int, ...]:
         # get database server version info explicitly over the wire
         # to avoid proxy servers like MaxScale getting in the
         # way with their own values, see #4205
         dbapi_con = connection.connection
         cursor = dbapi_con.cursor()
         cursor.execute("SELECT VERSION()")
-        val = cursor.fetchone()[0]
+
+        val = cursor.fetchone()[0]  # type: ignore[index]
         cursor.close()
         if isinstance(val, bytes):
             val = val.decode()
 
         return self._parse_server_version(val)
 
-    def _parse_server_version(self, val):
-        version = []
+    def _parse_server_version(self, val: str) -> tuple[int, ...]:
+        version: list[int] = []
         is_mariadb = False
 
         r = re.compile(r"[.\-+]")
@@ -2708,7 +2882,7 @@ class MySQLDialect(default.DefaultDialect):
         server_version_info = tuple(version)
 
         self._set_mariadb(
-            server_version_info and is_mariadb, server_version_info
+            bool(server_version_info and is_mariadb), server_version_info
         )
 
         if not is_mariadb:
@@ -2724,7 +2898,9 @@ class MySQLDialect(default.DefaultDialect):
         self.server_version_info = server_version_info
         return server_version_info
 
-    def _set_mariadb(self, is_mariadb, server_version_info):
+    def _set_mariadb(
+        self, is_mariadb: Optional[bool], server_version_info: tuple[int, ...]
+    ) -> None:
         if is_mariadb is None:
             return
 
@@ -2748,38 +2924,54 @@ class MySQLDialect(default.DefaultDialect):
 
         self.is_mariadb = is_mariadb
 
-    def do_begin_twophase(self, connection, xid):
+    def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
         connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid))
 
-    def do_prepare_twophase(self, connection, xid):
+    def do_prepare_twophase(self, connection: Connection, xid: Any) -> None:
         connection.execute(sql.text("XA END :xid"), dict(xid=xid))
         connection.execute(sql.text("XA PREPARE :xid"), dict(xid=xid))
 
     def do_rollback_twophase(
-        self, connection, xid, is_prepared=True, recover=False
-    ):
+        self,
+        connection: Connection,
+        xid: Any,
+        is_prepared: bool = True,
+        recover: bool = False,
+    ) -> None:
         if not is_prepared:
             connection.execute(sql.text("XA END :xid"), dict(xid=xid))
         connection.execute(sql.text("XA ROLLBACK :xid"), dict(xid=xid))
 
     def do_commit_twophase(
-        self, connection, xid, is_prepared=True, recover=False
-    ):
+        self,
+        connection: Connection,
+        xid: Any,
+        is_prepared: bool = True,
+        recover: bool = False,
+    ) -> None:
         if not is_prepared:
             self.do_prepare_twophase(connection, xid)
         connection.execute(sql.text("XA COMMIT :xid"), dict(xid=xid))
 
-    def do_recover_twophase(self, connection):
+    def do_recover_twophase(self, connection: Connection) -> list[Any]:
         resultset = connection.exec_driver_sql("XA RECOVER")
-        return [row["data"][0 : row["gtrid_length"]] for row in resultset]
+        return [
+            row["data"][0 : row["gtrid_length"]]
+            for row in resultset.mappings()
+        ]
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: DBAPIModule.Error,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
         if isinstance(
             e,
             (
-                self.dbapi.OperationalError,
-                self.dbapi.ProgrammingError,
-                self.dbapi.InterfaceError,
+                self.dbapi.OperationalError,  # type: ignore
+                self.dbapi.ProgrammingError,  # type: ignore
+                self.dbapi.InterfaceError,  # type: ignore
             ),
         ) and self._extract_error_code(e) in (
             1927,
@@ -2792,7 +2984,7 @@ class MySQLDialect(default.DefaultDialect):
         ):
             return True
         elif isinstance(
-            e, (self.dbapi.InterfaceError, self.dbapi.InternalError)
+            e, (self.dbapi.InterfaceError, self.dbapi.InternalError)  # type: ignore  # noqa: E501
         ):
             # if underlying connection is closed,
             # this is the error you get
@@ -2800,13 +2992,17 @@ class MySQLDialect(default.DefaultDialect):
         else:
             return False
 
-    def _compat_fetchall(self, rp, charset=None):
+    def _compat_fetchall(
+        self, rp: CursorResult[Unpack[TupleAny]], charset: Optional[str] = None
+    ) -> Union[Sequence[Row[Unpack[TupleAny]]], Sequence[_DecodingRow]]:
         """Proxy result rows to smooth over MySQL-Python driver
         inconsistencies."""
 
         return [_DecodingRow(row, charset) for row in rp.fetchall()]
 
-    def _compat_fetchone(self, rp, charset=None):
+    def _compat_fetchone(
+        self, rp: CursorResult[Unpack[TupleAny]], charset: Optional[str] = None
+    ) -> Union[Row[Unpack[TupleAny]], None, _DecodingRow]:
         """Proxy a result row to smooth over MySQL-Python driver
         inconsistencies."""
 
@@ -2816,7 +3012,9 @@ class MySQLDialect(default.DefaultDialect):
         else:
             return None
 
-    def _compat_first(self, rp, charset=None):
+    def _compat_first(
+        self, rp: CursorResult[Unpack[TupleAny]], charset: Optional[str] = None
+    ) -> Optional[_DecodingRow]:
         """Proxy a result row to smooth over MySQL-Python driver
         inconsistencies."""
 
@@ -2826,14 +3024,22 @@ class MySQLDialect(default.DefaultDialect):
         else:
             return None
 
-    def _extract_error_code(self, exception):
+    def _extract_error_code(
+        self, exception: DBAPIModule.Error
+    ) -> Optional[int]:
         raise NotImplementedError()
 
-    def _get_default_schema_name(self, connection):
-        return connection.exec_driver_sql("SELECT DATABASE()").scalar()
+    def _get_default_schema_name(self, connection: Connection) -> str:
+        return connection.exec_driver_sql("SELECT DATABASE()").scalar()  # type: ignore[return-value]  # noqa: E501
 
     @reflection.cache
-    def has_table(self, connection, table_name, schema=None, **kw):
+    def has_table(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> bool:
         self._ensure_has_table_connection(connection)
 
         if schema is None:
@@ -2874,12 +3080,18 @@ class MySQLDialect(default.DefaultDialect):
             #
             # there's more "doesn't exist" kinds of messages but they are
             # less clear if mysql 8 would suddenly start using one of those
-            if self._extract_error_code(e.orig) in (1146, 1049, 1051):
+            if self._extract_error_code(e.orig) in (1146, 1049, 1051):  # type: ignore  # noqa: E501
                 return False
             raise
 
     @reflection.cache
-    def has_sequence(self, connection, sequence_name, schema=None, **kw):
+    def has_sequence(
+        self,
+        connection: Connection,
+        sequence_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> bool:
         if not self.supports_sequences:
             self._sequences_not_supported()
         if not schema:
@@ -2899,14 +3111,16 @@ class MySQLDialect(default.DefaultDialect):
         )
         return cursor.first() is not None
 
-    def _sequences_not_supported(self):
+    def _sequences_not_supported(self) -> NoReturn:
         raise NotImplementedError(
             "Sequences are supported only by the "
             "MariaDB series 10.3 or greater"
         )
 
     @reflection.cache
-    def get_sequence_names(self, connection, schema=None, **kw):
+    def get_sequence_names(
+        self, connection: Connection, schema: Optional[str] = None, **kw: Any
+    ) -> list[str]:
         if not self.supports_sequences:
             self._sequences_not_supported()
         if not schema:
@@ -2926,10 +3140,12 @@ class MySQLDialect(default.DefaultDialect):
             )
         ]
 
-    def initialize(self, connection):
+    def initialize(self, connection: Connection) -> None:
         # this is driver-based, does not need server version info
         # and is fairly critical for even basic SQL operations
-        self._connection_charset = self._detect_charset(connection)
+        self._connection_charset: Optional[str] = self._detect_charset(
+            connection
+        )
 
         # call super().initialize() because we need to have
         # server_version_info set up.  in 1.4 under python 2 only this does the
@@ -2973,9 +3189,10 @@ class MySQLDialect(default.DefaultDialect):
 
         self._warn_for_known_db_issues()
 
-    def _warn_for_known_db_issues(self):
+    def _warn_for_known_db_issues(self) -> None:
         if self.is_mariadb:
             mdb_version = self._mariadb_normalized_version_info
+            assert mdb_version is not None
             if mdb_version > (10, 2) and mdb_version < (10, 2, 9):
                 util.warn(
                     "MariaDB %r before 10.2.9 has known issues regarding "
@@ -2988,7 +3205,7 @@ class MySQLDialect(default.DefaultDialect):
                 )
 
     @property
-    def _support_float_cast(self):
+    def _support_float_cast(self) -> bool:
         if not self.server_version_info:
             return False
         elif self.is_mariadb:
@@ -2999,7 +3216,7 @@ class MySQLDialect(default.DefaultDialect):
             return self.server_version_info >= (8, 0, 17)
 
     @property
-    def _support_default_function(self):
+    def _support_default_function(self) -> bool:
         if not self.server_version_info:
             return False
         elif self.is_mariadb:
@@ -3010,32 +3227,38 @@ class MySQLDialect(default.DefaultDialect):
             return self.server_version_info >= (8, 0, 13)
 
     @property
-    def _is_mariadb(self):
+    def _is_mariadb(self) -> bool:
         return self.is_mariadb
 
     @property
-    def _is_mysql(self):
+    def _is_mysql(self) -> bool:
         return not self.is_mariadb
 
     @property
-    def _is_mariadb_102(self):
-        return self.is_mariadb and self._mariadb_normalized_version_info > (
-            10,
-            2,
+    def _is_mariadb_102(self) -> bool:
+        return (
+            self.is_mariadb
+            and self._mariadb_normalized_version_info  # type:ignore[operator]
+            > (
+                10,
+                2,
+            )
         )
 
     @reflection.cache
-    def get_schema_names(self, connection, **kw):
+    def get_schema_names(self, connection: Connection, **kw: Any) -> list[str]:
         rp = connection.exec_driver_sql("SHOW schemas")
         return [r[0] for r in rp]
 
     @reflection.cache
-    def get_table_names(self, connection, schema=None, **kw):
+    def get_table_names(
+        self, connection: Connection, schema: Optional[str] = None, **kw: Any
+    ) -> list[str]:
         """Return a Unicode SHOW TABLES from a given schema."""
         if schema is not None:
-            current_schema = schema
+            current_schema: str = schema
         else:
-            current_schema = self.default_schema_name
+            current_schema = self.default_schema_name  # type: ignore
 
         charset = self._connection_charset
 
@@ -3051,9 +3274,12 @@ class MySQLDialect(default.DefaultDialect):
         ]
 
     @reflection.cache
-    def get_view_names(self, connection, schema=None, **kw):
+    def get_view_names(
+        self, connection: Connection, schema: Optional[str] = None, **kw: Any
+    ) -> list[str]:
         if schema is None:
             schema = self.default_schema_name
+        assert schema is not None
         charset = self._connection_charset
         rp = connection.exec_driver_sql(
             "SHOW FULL TABLES FROM %s"
@@ -3066,7 +3292,13 @@ class MySQLDialect(default.DefaultDialect):
         ]
 
     @reflection.cache
-    def get_table_options(self, connection, table_name, schema=None, **kw):
+    def get_table_options(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> dict[str, Any]:
         parsed_state = self._parsed_state_or_create(
             connection, table_name, schema, **kw
         )
@@ -3076,7 +3308,13 @@ class MySQLDialect(default.DefaultDialect):
             return ReflectionDefaults.table_options()
 
     @reflection.cache
-    def get_columns(self, connection, table_name, schema=None, **kw):
+    def get_columns(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> list[ReflectedColumn]:
         parsed_state = self._parsed_state_or_create(
             connection, table_name, schema, **kw
         )
@@ -3086,7 +3324,13 @@ class MySQLDialect(default.DefaultDialect):
             return ReflectionDefaults.columns()
 
     @reflection.cache
-    def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+    def get_pk_constraint(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> ReflectedPrimaryKeyConstraint:
         parsed_state = self._parsed_state_or_create(
             connection, table_name, schema, **kw
         )
@@ -3098,13 +3342,19 @@ class MySQLDialect(default.DefaultDialect):
         return ReflectionDefaults.pk_constraint()
 
     @reflection.cache
-    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+    def get_foreign_keys(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> list[ReflectedForeignKeyConstraint]:
         parsed_state = self._parsed_state_or_create(
             connection, table_name, schema, **kw
         )
         default_schema = None
 
-        fkeys = []
+        fkeys: list[ReflectedForeignKeyConstraint] = []
 
         for spec in parsed_state.fk_constraints:
             ref_name = spec["table"][-1]
@@ -3124,7 +3374,7 @@ class MySQLDialect(default.DefaultDialect):
                 if spec.get(opt, False) not in ("NO ACTION", None):
                     con_kw[opt] = spec[opt]
 
-            fkey_d = {
+            fkey_d: ReflectedForeignKeyConstraint = {
                 "name": spec["name"],
                 "constrained_columns": loc_names,
                 "referred_schema": ref_schema,
@@ -3139,7 +3389,11 @@ class MySQLDialect(default.DefaultDialect):
 
         return fkeys if fkeys else ReflectionDefaults.foreign_keys()
 
-    def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection):
+    def _correct_for_mysql_bugs_88718_96365(
+        self,
+        fkeys: list[ReflectedForeignKeyConstraint],
+        connection: Connection,
+    ) -> None:
         # Foreign key is always in lower case (MySQL 8.0)
         # https://bugs.mysql.com/bug.php?id=88718
         # issue #4344 for SQLAlchemy
@@ -3155,22 +3409,24 @@ class MySQLDialect(default.DefaultDialect):
 
         if self._casing in (1, 2):
 
-            def lower(s):
+            def lower(s: str) -> str:
                 return s.lower()
 
         else:
             # if on case sensitive, there can be two tables referenced
             # with the same name different casing, so we need to use
             # case-sensitive matching.
-            def lower(s):
+            def lower(s: str) -> str:
                 return s
 
-        default_schema_name = connection.dialect.default_schema_name
+        default_schema_name: str = connection.dialect.default_schema_name  # type: ignore  # noqa: E501
 
         # NOTE: using (table_schema, table_name, lower(column_name)) in (...)
         # is very slow since mysql does not seem able to properly use indexse.
         # Unpack the where condition instead.
-        schema_by_table_by_column = defaultdict(lambda: defaultdict(list))
+        schema_by_table_by_column: defaultdict[
+            str, defaultdict[str, list[str]]
+        ] = defaultdict(lambda: defaultdict(list))
         for rec in fkeys:
             sch = lower(rec["referred_schema"] or default_schema_name)
             tbl = lower(rec["referred_table"])
@@ -3205,7 +3461,9 @@ class MySQLDialect(default.DefaultDialect):
                 _info_columns.c.column_name,
             ).where(condition)
 
-            correct_for_wrong_fk_case = connection.execute(select)
+            correct_for_wrong_fk_case: CursorResult[str, str, str] = (
+                connection.execute(select)
+            )
 
             # in casing=0, table name and schema name come back in their
             # exact case.
@@ -3217,35 +3475,41 @@ class MySQLDialect(default.DefaultDialect):
             # SHOW CREATE TABLE converts them to *lower case*, therefore
             # not matching.  So for this case, case-insensitive lookup
             # is necessary
-            d = defaultdict(dict)
+            d: defaultdict[tuple[str, str], dict[str, str]] = defaultdict(dict)
             for schema, tname, cname in correct_for_wrong_fk_case:
                 d[(lower(schema), lower(tname))]["SCHEMANAME"] = schema
                 d[(lower(schema), lower(tname))]["TABLENAME"] = tname
                 d[(lower(schema), lower(tname))][cname.lower()] = cname
 
             for fkey in fkeys:
-                rec = d[
+                rec_b = d[
                     (
                         lower(fkey["referred_schema"] or default_schema_name),
                         lower(fkey["referred_table"]),
                     )
                 ]
 
-                fkey["referred_table"] = rec["TABLENAME"]
+                fkey["referred_table"] = rec_b["TABLENAME"]
                 if fkey["referred_schema"] is not None:
-                    fkey["referred_schema"] = rec["SCHEMANAME"]
+                    fkey["referred_schema"] = rec_b["SCHEMANAME"]
 
                 fkey["referred_columns"] = [
-                    rec[col.lower()] for col in fkey["referred_columns"]
+                    rec_b[col.lower()] for col in fkey["referred_columns"]
                 ]
 
     @reflection.cache
-    def get_check_constraints(self, connection, table_name, schema=None, **kw):
+    def get_check_constraints(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> list[ReflectedCheckConstraint]:
         parsed_state = self._parsed_state_or_create(
             connection, table_name, schema, **kw
         )
 
-        cks = [
+        cks: list[ReflectedCheckConstraint] = [
             {"name": spec["name"], "sqltext": spec["sqltext"]}
             for spec in parsed_state.ck_constraints
         ]
@@ -3253,7 +3517,13 @@ class MySQLDialect(default.DefaultDialect):
         return cks if cks else ReflectionDefaults.check_constraints()
 
     @reflection.cache
-    def get_table_comment(self, connection, table_name, schema=None, **kw):
+    def get_table_comment(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> ReflectedTableComment:
         parsed_state = self._parsed_state_or_create(
             connection, table_name, schema, **kw
         )
@@ -3264,12 +3534,18 @@ class MySQLDialect(default.DefaultDialect):
             return ReflectionDefaults.table_comment()
 
     @reflection.cache
-    def get_indexes(self, connection, table_name, schema=None, **kw):
+    def get_indexes(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> list[ReflectedIndex]:
         parsed_state = self._parsed_state_or_create(
             connection, table_name, schema, **kw
         )
 
-        indexes = []
+        indexes: list[ReflectedIndex] = []
 
         for spec in parsed_state.keys:
             dialect_options = {}
@@ -3281,32 +3557,30 @@ class MySQLDialect(default.DefaultDialect):
                 unique = True
             elif flavor in ("FULLTEXT", "SPATIAL"):
                 dialect_options["%s_prefix" % self.name] = flavor
-            elif flavor is None:
-                pass
-            else:
-                self.logger.info(
+            elif flavor is not None:
+                util.warn(
                     "Converting unknown KEY type %s to a plain KEY", flavor
                 )
-                pass
 
             if spec["parser"]:
                 dialect_options["%s_with_parser" % (self.name)] = spec[
                     "parser"
                 ]
 
-            index_d = {}
+            index_d: ReflectedIndex = {
+                "name": spec["name"],
+                "column_names": [s[0] for s in spec["columns"]],
+                "unique": unique,
+            }
 
-            index_d["name"] = spec["name"]
-            index_d["column_names"] = [s[0] for s in spec["columns"]]
             mysql_length = {
                 s[0]: s[1] for s in spec["columns"] if s[1] is not None
             }
             if mysql_length:
                 dialect_options["%s_length" % self.name] = mysql_length
 
-            index_d["unique"] = unique
             if flavor:
-                index_d["type"] = flavor
+                index_d["type"] = flavor  # type: ignore[typeddict-unknown-key]
 
             if dialect_options:
                 index_d["dialect_options"] = dialect_options
@@ -3317,13 +3591,17 @@ class MySQLDialect(default.DefaultDialect):
 
     @reflection.cache
     def get_unique_constraints(
-        self, connection, table_name, schema=None, **kw
-    ):
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> list[ReflectedUniqueConstraint]:
         parsed_state = self._parsed_state_or_create(
             connection, table_name, schema, **kw
         )
 
-        ucs = [
+        ucs: list[ReflectedUniqueConstraint] = [
             {
                 "name": key["name"],
                 "column_names": [col[0] for col in key["columns"]],
@@ -3339,7 +3617,13 @@ class MySQLDialect(default.DefaultDialect):
             return ReflectionDefaults.unique_constraints()
 
     @reflection.cache
-    def get_view_definition(self, connection, view_name, schema=None, **kw):
+    def get_view_definition(
+        self,
+        connection: Connection,
+        view_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> str:
         charset = self._connection_charset
         full_name = ".".join(
             self.identifier_preparer._quote_free_identifiers(schema, view_name)
@@ -3353,8 +3637,12 @@ class MySQLDialect(default.DefaultDialect):
         return sql
 
     def _parsed_state_or_create(
-        self, connection, table_name, schema=None, **kw
-    ):
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> _reflection.ReflectedState:
         return self._setup_parser(
             connection,
             table_name,
@@ -3363,7 +3651,7 @@ class MySQLDialect(default.DefaultDialect):
         )
 
     @util.memoized_property
-    def _tabledef_parser(self):
+    def _tabledef_parser(self) -> _reflection.MySQLTableDefinitionParser:
         """return the MySQLTableDefinitionParser, generate if needed.
 
         The deferred creation ensures that the dialect has
@@ -3374,7 +3662,13 @@ class MySQLDialect(default.DefaultDialect):
         return _reflection.MySQLTableDefinitionParser(self, preparer)
 
     @reflection.cache
-    def _setup_parser(self, connection, table_name, schema=None, **kw):
+    def _setup_parser(
+        self,
+        connection: Connection,
+        table_name: str,
+        schema: Optional[str] = None,
+        **kw: Any,
+    ) -> _reflection.ReflectedState:
         charset = self._connection_charset
         parser = self._tabledef_parser
         full_name = ".".join(
@@ -3390,10 +3684,14 @@ class MySQLDialect(default.DefaultDialect):
             columns = self._describe_table(
                 connection, None, charset, full_name=full_name
             )
-            sql = parser._describe_to_create(table_name, columns)
+            sql = parser._describe_to_create(
+                table_name, columns  # type: ignore[arg-type]
+            )
         return parser.parse(sql, charset)
 
-    def _fetch_setting(self, connection, setting_name):
+    def _fetch_setting(
+        self, connection: Connection, setting_name: str
+    ) -> Optional[str]:
         charset = self._connection_charset
 
         if self.server_version_info and self.server_version_info < (5, 6):
@@ -3408,12 +3706,12 @@ class MySQLDialect(default.DefaultDialect):
         if not row:
             return None
         else:
-            return row[fetch_col]
+            return cast("Optional[str]", row[fetch_col])
 
-    def _detect_charset(self, connection):
+    def _detect_charset(self, connection: Connection) -> str:
         raise NotImplementedError()
 
-    def _detect_casing(self, connection):
+    def _detect_casing(self, connection: Connection) -> int:
         """Sniff out identifier case sensitivity.
 
         Cached per-connection. This value can not change without a server
@@ -3437,7 +3735,7 @@ class MySQLDialect(default.DefaultDialect):
         self._casing = cs
         return cs
 
-    def _detect_collations(self, connection):
+    def _detect_collations(self, connection: Connection) -> dict[str, str]:
         """Pull the active COLLATIONS list from the server.
 
         Cached per-connection.
@@ -3450,7 +3748,7 @@ class MySQLDialect(default.DefaultDialect):
             collations[row[0]] = row[1]
         return collations
 
-    def _detect_sql_mode(self, connection):
+    def _detect_sql_mode(self, connection: Connection) -> None:
         setting = self._fetch_setting(connection, "sql_mode")
 
         if setting is None:
@@ -3462,7 +3760,7 @@ class MySQLDialect(default.DefaultDialect):
         else:
             self._sql_mode = setting or ""
 
-    def _detect_ansiquotes(self, connection):
+    def _detect_ansiquotes(self, connection: Connection) -> None:
         """Detect and adjust for the ANSI_QUOTES sql mode."""
 
         mode = self._sql_mode
@@ -3477,12 +3775,35 @@ class MySQLDialect(default.DefaultDialect):
         # as of MySQL 5.0.1
         self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode
 
+    @overload
     def _show_create_table(
-        self, connection, table, charset=None, full_name=None
-    ):
+        self,
+        connection: Connection,
+        table: Optional[Table],
+        charset: Optional[str],
+        full_name: str,
+    ) -> str: ...
+
+    @overload
+    def _show_create_table(
+        self,
+        connection: Connection,
+        table: Table,
+        charset: Optional[str] = None,
+        full_name: None = None,
+    ) -> str: ...
+
+    def _show_create_table(
+        self,
+        connection: Connection,
+        table: Optional[Table],
+        charset: Optional[str] = None,
+        full_name: Optional[str] = None,
+    ) -> str:
         """Run SHOW CREATE TABLE for a ``Table``."""
 
         if full_name is None:
+            assert table is not None
             full_name = self.identifier_preparer.format_table(table)
         st = "SHOW CREATE TABLE %s" % full_name
 
@@ -3491,19 +3812,44 @@ class MySQLDialect(default.DefaultDialect):
                 skip_user_error_events=True
             ).exec_driver_sql(st)
         except exc.DBAPIError as e:
-            if self._extract_error_code(e.orig) == 1146:
+            if self._extract_error_code(e.orig) == 1146:  # type: ignore[arg-type] # noqa: E501
                 raise exc.NoSuchTableError(full_name) from e
             else:
                 raise
         row = self._compat_first(rp, charset=charset)
         if not row:
             raise exc.NoSuchTableError(full_name)
-        return row[1].strip()
+        return cast("str", row[1]).strip()
+
+    @overload
+    def _describe_table(
+        self,
+        connection: Connection,
+        table: Optional[Table],
+        charset: Optional[str],
+        full_name: str,
+    ) -> Union[Sequence[Row[Unpack[TupleAny]]], Sequence[_DecodingRow]]: ...
+
+    @overload
+    def _describe_table(
+        self,
+        connection: Connection,
+        table: Table,
+        charset: Optional[str] = None,
+        full_name: None = None,
+    ) -> Union[Sequence[Row[Unpack[TupleAny]]], Sequence[_DecodingRow]]: ...
 
-    def _describe_table(self, connection, table, charset=None, full_name=None):
+    def _describe_table(
+        self,
+        connection: Connection,
+        table: Optional[Table],
+        charset: Optional[str] = None,
+        full_name: Optional[str] = None,
+    ) -> Union[Sequence[Row[Unpack[TupleAny]]], Sequence[_DecodingRow]]:
         """Run DESCRIBE for a ``Table`` and return processed rows."""
 
         if full_name is None:
+            assert table is not None
             full_name = self.identifier_preparer.format_table(table)
         st = "DESCRIBE %s" % full_name
 
@@ -3514,7 +3860,7 @@ class MySQLDialect(default.DefaultDialect):
                     skip_user_error_events=True
                 ).exec_driver_sql(st)
             except exc.DBAPIError as e:
-                code = self._extract_error_code(e.orig)
+                code = self._extract_error_code(e.orig)  # type: ignore[arg-type] # noqa: E501
                 if code == 1146:
                     raise exc.NoSuchTableError(full_name) from e
 
@@ -3546,7 +3892,7 @@ class _DecodingRow:
     # sets.Set(['value']) (seriously) but thankfully that doesn't
     # seem to come up in DDL queries.
 
-    _encoding_compat = {
+    _encoding_compat: dict[str, str] = {
         "koi8r": "koi8_r",
         "koi8u": "koi8_u",
         "utf16": "utf-16-be",  # MySQL's uft16 is always bigendian
@@ -3556,24 +3902,23 @@ class _DecodingRow:
         "eucjpms": "ujis",
     }
 
-    def __init__(self, rowproxy, charset):
+    def __init__(self, rowproxy: Row[Unpack[_Ts]], charset: Optional[str]):
         self.rowproxy = rowproxy
-        self.charset = self._encoding_compat.get(charset, charset)
+        self.charset = (
+            self._encoding_compat.get(charset, charset)
+            if charset is not None
+            else None
+        )
 
-    def __getitem__(self, index):
+    def __getitem__(self, index: int) -> Any:
         item = self.rowproxy[index]
-        if isinstance(item, _array):
-            item = item.tostring()
-
         if self.charset and isinstance(item, bytes):
             return item.decode(self.charset)
         else:
             return item
 
-    def __getattr__(self, attr):
+    def __getattr__(self, attr: str) -> Any:
         item = getattr(self.rowproxy, attr)
-        if isinstance(item, _array):
-            item = item.tostring()
         if self.charset and isinstance(item, bytes):
             return item.decode(self.charset)
         else:
index 5c00ada9f9400a42e220be0ad96c8bb3623615fa..1d48c4e88bc80c3e5b6d5ed8b6e9ca3a0c9a86db 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 r"""
 
@@ -21,18 +20,36 @@ r"""
     dialects are mysqlclient and PyMySQL.
 
 """  # noqa
+from __future__ import annotations
+
+from typing import Any
+from typing import Iterable
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
-from .base import BIT
 from .base import MySQLDialect
 from .mysqldb import MySQLDialect_mysqldb
+from .types import BIT
 from ... import util
 
+if TYPE_CHECKING:
+    from ...engine.base import Connection
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.interfaces import Dialect
+    from ...engine.interfaces import PoolProxiedConnection
+    from ...sql.type_api import _ResultProcessorType
+
 
 class _cymysqlBIT(BIT):
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> Optional[_ResultProcessorType[Any]]:
         """Convert MySQL's 64 bit, variable length binary string to a long."""
 
-        def process(value):
+        def process(value: Optional[Iterable[int]]) -> Optional[int]:
             if value is not None:
                 v = 0
                 for i in iter(value):
@@ -55,17 +72,22 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
     colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
 
     @classmethod
-    def import_dbapi(cls):
+    def import_dbapi(cls) -> DBAPIModule:
         return __import__("cymysql")
 
-    def _detect_charset(self, connection):
-        return connection.connection.charset
+    def _detect_charset(self, connection: Connection) -> str:
+        return connection.connection.charset  # type: ignore[no-any-return]
 
-    def _extract_error_code(self, exception):
-        return exception.errno
+    def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
+        return exception.errno  # type: ignore[no-any-return]
 
-    def is_disconnect(self, e, connection, cursor):
-        if isinstance(e, self.dbapi.OperationalError):
+    def is_disconnect(
+        self,
+        e: DBAPIModule.Error,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
+        if isinstance(e, self.loaded_dbapi.OperationalError):
             return self._extract_error_code(e) in (
                 2006,
                 2013,
@@ -73,7 +95,7 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
                 2045,
                 2055,
             )
-        elif isinstance(e, self.dbapi.InterfaceError):
+        elif isinstance(e, self.loaded_dbapi.InterfaceError):
             # if underlying connection is closed,
             # this is the error you get
             return True
index f0917f07fa3382424198c65312f4ab0ff06ece8c..c32364507df71a5ac3b4835d1e32405e8377c8ae 100644 (file)
@@ -4,26 +4,41 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
+from __future__ import annotations
 
+import enum
 import re
+from typing import Any
+from typing import Optional
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
 
 from .types import _StringType
 from ... import exc
 from ... import sql
 from ... import util
 from ...sql import sqltypes
+from ...sql import type_api
 
+if TYPE_CHECKING:
+    from ...engine.interfaces import Dialect
+    from ...sql.elements import ColumnElement
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _ResultProcessorType
+    from ...sql.type_api import TypeEngine
+    from ...sql.type_api import TypeEngineMixin
 
-class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
+
+class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
     """MySQL ENUM type."""
 
     __visit_name__ = "ENUM"
 
     native_enum = True
 
-    def __init__(self, *enums, **kw):
+    def __init__(self, *enums: Union[str, Type[enum.Enum]], **kw: Any) -> None:
         """Construct an ENUM.
 
         E.g.::
@@ -59,21 +74,27 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
 
         """
         kw.pop("strict", None)
-        self._enum_init(enums, kw)
+        self._enum_init(enums, kw)  # type: ignore[arg-type]
         _StringType.__init__(self, length=self.length, **kw)
 
     @classmethod
-    def adapt_emulated_to_native(cls, impl, **kw):
+    def adapt_emulated_to_native(
+        cls,
+        impl: Union[TypeEngine[Any], TypeEngineMixin],
+        **kw: Any,
+    ) -> ENUM:
         """Produce a MySQL native :class:`.mysql.ENUM` from plain
         :class:`.Enum`.
 
         """
+        if TYPE_CHECKING:
+            assert isinstance(impl, ENUM)
         kw.setdefault("validate_strings", impl.validate_strings)
         kw.setdefault("values_callable", impl.values_callable)
         kw.setdefault("omit_aliases", impl._omit_aliases)
         return cls(**kw)
 
-    def _object_value_for_elem(self, elem):
+    def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]:
         # mysql sends back a blank string for any value that
         # was persisted that was not in the enums; that is, it does no
         # validation on the incoming data, it "truncates" it to be
@@ -83,18 +104,22 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
         else:
             return super()._object_value_for_elem(elem)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return util.generic_repr(
             self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
         )
 
 
+# TODO: SET is a string as far as configuration but does not act like
+# a string at the python level.  We either need to make a py-type agnostic
+# version of String as a base to be used for this, make this some kind of
+# TypeDecorator, or just vendor it out as its own type.
 class SET(_StringType):
     """MySQL SET type."""
 
     __visit_name__ = "SET"
 
-    def __init__(self, *values, **kw):
+    def __init__(self, *values: str, **kw: Any):
         """Construct a SET.
 
         E.g.::
@@ -147,17 +172,19 @@ class SET(_StringType):
                 "setting retrieve_as_bitwise=True"
             )
         if self.retrieve_as_bitwise:
-            self._bitmap = {
+            self._inversed_bitmap: dict[str, int] = {
                 value: 2**idx for idx, value in enumerate(self.values)
             }
-            self._bitmap.update(
-                (2**idx, value) for idx, value in enumerate(self.values)
-            )
+            self._bitmap: dict[int, str] = {
+                2**idx: value for idx, value in enumerate(self.values)
+            }
         length = max([len(v) for v in values] + [0])
         kw.setdefault("length", length)
         super().__init__(**kw)
 
-    def column_expression(self, colexpr):
+    def column_expression(
+        self, colexpr: ColumnElement[Any]
+    ) -> ColumnElement[Any]:
         if self.retrieve_as_bitwise:
             return sql.type_coerce(
                 sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
@@ -165,10 +192,12 @@ class SET(_StringType):
         else:
             return colexpr
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: Any
+    ) -> Optional[_ResultProcessorType[Any]]:
         if self.retrieve_as_bitwise:
 
-            def process(value):
+            def process(value: Union[str, int, None]) -> Optional[set[str]]:
                 if value is not None:
                     value = int(value)
 
@@ -179,11 +208,14 @@ class SET(_StringType):
         else:
             super_convert = super().result_processor(dialect, coltype)
 
-            def process(value):
+            def process(value: Union[str, set[str], None]) -> Optional[set[str]]:  # type: ignore[misc]  # noqa: E501
                 if isinstance(value, str):
                     # MySQLdb returns a string, let's parse
                     if super_convert:
                         value = super_convert(value)
+                        assert value is not None
+                    if TYPE_CHECKING:
+                        assert isinstance(value, str)
                     return set(re.findall(r"[^,]+", value))
                 else:
                     # mysql-connector-python does a naive
@@ -194,43 +226,48 @@ class SET(_StringType):
 
         return process
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> _BindProcessorType[Union[str, int]]:
         super_convert = super().bind_processor(dialect)
         if self.retrieve_as_bitwise:
 
-            def process(value):
+            def process(
+                value: Union[str, int, set[str], None],
+            ) -> Union[str, int, None]:
                 if value is None:
                     return None
                 elif isinstance(value, (int, str)):
                     if super_convert:
-                        return super_convert(value)
+                        return super_convert(value)  # type: ignore[arg-type, no-any-return]  # noqa: E501
                     else:
                         return value
                 else:
                     int_value = 0
                     for v in value:
-                        int_value |= self._bitmap[v]
+                        int_value |= self._inversed_bitmap[v]
                     return int_value
 
         else:
 
-            def process(value):
+            def process(
+                value: Union[str, int, set[str], None],
+            ) -> Union[str, int, None]:
                 # accept strings and int (actually bitflag) values directly
                 if value is not None and not isinstance(value, (int, str)):
                     value = ",".join(value)
-
                 if super_convert:
-                    return super_convert(value)
+                    return super_convert(value)  # type: ignore
                 else:
                     return value
 
         return process
 
-    def adapt(self, impltype, **kw):
+    def adapt(self, cls: type, **kw: Any) -> Any:
         kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
-        return util.constructor_copy(self, impltype, *self.values, **kw)
+        return util.constructor_copy(self, cls, *self.values, **kw)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return util.generic_repr(
             self,
             to_inspect=[SET, _StringType],
index b60a0888517fd3aab0a75a9bd5ba3f03c386af4f..9d19d52de5e1eafd3d41adbf097ff84f6f15d97a 100644 (file)
@@ -4,8 +4,10 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
+from __future__ import annotations
+
+from typing import Any
 
 from ... import exc
 from ... import util
@@ -18,7 +20,7 @@ from ...sql.base import Generative
 from ...util.typing import Self
 
 
-class match(Generative, elements.BinaryExpression):
+class match(Generative, elements.BinaryExpression[Any]):
     """Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
 
     E.g.::
@@ -73,8 +75,9 @@ class match(Generative, elements.BinaryExpression):
     __visit_name__ = "mysql_match"
 
     inherit_cache = True
+    modifiers: util.immutabledict[str, Any]
 
-    def __init__(self, *cols, **kw):
+    def __init__(self, *cols: elements.ColumnElement[Any], **kw: Any):
         if not cols:
             raise exc.ArgumentError("columns are required")
 
index 8912af36631f93470d7882493cd2b85d59c148f8..e654a61941dfd3ca95610a2b3aa4317ee23e0f0d 100644 (file)
@@ -4,10 +4,18 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
+
+from typing import Any
+from typing import TYPE_CHECKING
 
 from ... import types as sqltypes
 
+if TYPE_CHECKING:
+    from ...engine.interfaces import Dialect
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _LiteralProcessorType
+
 
 class JSON(sqltypes.JSON):
     """MySQL JSON type.
@@ -34,13 +42,13 @@ class JSON(sqltypes.JSON):
 
 
 class _FormatTypeMixin:
-    def _format_value(self, value):
+    def _format_value(self, value: Any) -> str:
         raise NotImplementedError()
 
-    def bind_processor(self, dialect):
-        super_proc = self.string_bind_processor(dialect)
+    def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
+        super_proc = self.string_bind_processor(dialect)  # type: ignore[attr-defined]  # noqa: E501
 
-        def process(value):
+        def process(value: Any) -> Any:
             value = self._format_value(value)
             if super_proc:
                 value = super_proc(value)
@@ -48,29 +56,31 @@ class _FormatTypeMixin:
 
         return process
 
-    def literal_processor(self, dialect):
-        super_proc = self.string_literal_processor(dialect)
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> _LiteralProcessorType[Any]:
+        super_proc = self.string_literal_processor(dialect)  # type: ignore[attr-defined]  # noqa: E501
 
-        def process(value):
+        def process(value: Any) -> str:
             value = self._format_value(value)
             if super_proc:
                 value = super_proc(value)
-            return value
+            return value  # type: ignore[no-any-return]
 
         return process
 
 
 class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
-    def _format_value(self, value):
+    def _format_value(self, value: Any) -> str:
         if isinstance(value, int):
-            value = "$[%s]" % value
+            formatted_value = "$[%s]" % value
         else:
-            value = '$."%s"' % value
-        return value
+            formatted_value = '$."%s"' % value
+        return formatted_value
 
 
 class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
-    def _format_value(self, value):
+    def _format_value(self, value: Any) -> str:
         return "$%s" % (
             "".join(
                 [
index ff5214798f2a0704a9904d10cd897effa8d44557..8b66531131c902f372218f2a53ff6e61ac93f062 100644 (file)
@@ -4,15 +4,28 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+
+from __future__ import annotations
+
+from typing import Any
+from typing import Callable
+from typing import Optional
+from typing import TYPE_CHECKING
+
 from .base import MariaDBIdentifierPreparer
 from .base import MySQLDialect
+from .base import MySQLIdentifierPreparer
 from .base import MySQLTypeCompiler
 from ... import util
 from ...sql import sqltypes
+from ...sql.sqltypes import _UUID_RETURN
 from ...sql.sqltypes import UUID
 from ...sql.sqltypes import Uuid
 
+if TYPE_CHECKING:
+    from ...engine.base import Connection
+    from ...sql.type_api import _BindProcessorType
+
 
 class INET4(sqltypes.TypeEngine[str]):
     """INET4 column type for MariaDB
@@ -32,7 +45,7 @@ class INET6(sqltypes.TypeEngine[str]):
     __visit_name__ = "INET6"
 
 
-class _MariaDBUUID(UUID):
+class _MariaDBUUID(UUID[_UUID_RETURN]):
     def __init__(self, as_uuid: bool = True, native_uuid: bool = True):
         self.as_uuid = as_uuid
 
@@ -46,23 +59,23 @@ class _MariaDBUUID(UUID):
         self.native_uuid = False
 
     @property
-    def native(self):
+    def native(self) -> bool:  # type: ignore[override]
         # override to return True, this is a native type, just turning
         # off native_uuid for internal data handling
         return True
 
-    def bind_processor(self, dialect):
+    def bind_processor(self, dialect: MariaDBDialect) -> Optional[_BindProcessorType[_UUID_RETURN]]:  # type: ignore[override] # noqa: E501
         if not dialect.supports_native_uuid or not dialect._allows_uuid_binds:
-            return super().bind_processor(dialect)
+            return super().bind_processor(dialect)  # type: ignore[return-value] # noqa: E501
         else:
             return None
 
 
 class MariaDBTypeCompiler(MySQLTypeCompiler):
-    def visit_INET4(self, type_, **kwargs) -> str:
+    def visit_INET4(self, type_: INET4, **kwargs: Any) -> str:
         return "INET4"
 
-    def visit_INET6(self, type_, **kwargs) -> str:
+    def visit_INET6(self, type_: INET6, **kwargs: Any) -> str:
         return "INET6"
 
 
@@ -74,12 +87,12 @@ class MariaDBDialect(MySQLDialect):
     _allows_uuid_binds = True
 
     name = "mariadb"
-    preparer = MariaDBIdentifierPreparer
+    preparer: type[MySQLIdentifierPreparer] = MariaDBIdentifierPreparer
     type_compiler_cls = MariaDBTypeCompiler
 
     colspecs = util.update_copy(MySQLDialect.colspecs, {Uuid: _MariaDBUUID})
 
-    def initialize(self, connection):
+    def initialize(self, connection: Connection) -> None:
         super().initialize(connection)
 
         self.supports_native_uuid = (
@@ -88,7 +101,7 @@ class MariaDBDialect(MySQLDialect):
         )
 
 
-def loader(driver):
+def loader(driver: str) -> Callable[[], type[MariaDBDialect]]:
     dialect_mod = __import__(
         "sqlalchemy.dialects.mysql.%s" % driver
     ).dialects.mysql
@@ -96,7 +109,7 @@ def loader(driver):
     driver_mod = getattr(dialect_mod, driver)
     if hasattr(driver_mod, "mariadb_dialect"):
         driver_cls = driver_mod.mariadb_dialect
-        return driver_cls
+        return driver_cls  # type: ignore[no-any-return]
     else:
         driver_cls = driver_mod.dialect
 
index fbc60037971de8df2255529de2e47be3eb5e947c..944549f9a5ea26bb821bdb04c66bec3566d989ba 100644 (file)
@@ -4,8 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
 
 """
 
@@ -29,7 +27,14 @@ be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
 .. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
 
 """  # noqa
+from __future__ import annotations
+
 import re
+from typing import Any
+from typing import Optional
+from typing import Sequence
+from typing import TYPE_CHECKING
+from typing import Union
 from uuid import UUID as _python_UUID
 
 from .base import MySQLCompiler
@@ -40,6 +45,19 @@ from ... import sql
 from ... import util
 from ...sql import sqltypes
 
+if TYPE_CHECKING:
+    from ...engine.base import Connection
+    from ...engine.interfaces import ConnectArgsType
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.interfaces import Dialect
+    from ...engine.interfaces import IsolationLevel
+    from ...engine.interfaces import PoolProxiedConnection
+    from ...engine.url import URL
+    from ...sql.compiler import SQLCompiler
+    from ...sql.type_api import _ResultProcessorType
+
 
 mariadb_cpy_minimum_version = (1, 0, 1)
 
@@ -48,10 +66,12 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
     # work around JIRA issue
     # https://jira.mariadb.org/browse/CONPY-270.  When that issue is fixed,
     # this type can be removed.
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> Optional[_ResultProcessorType[Any]]:
         if self.as_uuid:
 
-            def process(value):
+            def process(value: Any) -> Any:
                 if value is not None:
                     if hasattr(value, "decode"):
                         value = value.decode("ascii")
@@ -61,7 +81,7 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
             return process
         else:
 
-            def process(value):
+            def process(value: Any) -> Any:
                 if value is not None:
                     if hasattr(value, "decode"):
                         value = value.decode("ascii")
@@ -72,23 +92,27 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
 
 
 class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
-    _lastrowid = None
+    _lastrowid: Optional[int] = None
 
-    def create_server_side_cursor(self):
+    def create_server_side_cursor(self) -> DBAPICursor:
         return self._dbapi_connection.cursor(buffered=False)
 
-    def create_default_cursor(self):
+    def create_default_cursor(self) -> DBAPICursor:
         return self._dbapi_connection.cursor(buffered=True)
 
-    def post_exec(self):
+    def post_exec(self) -> None:
         super().post_exec()
 
         self._rowcount = self.cursor.rowcount
 
+        if TYPE_CHECKING:
+            assert isinstance(self.compiled, SQLCompiler)
         if self.isinsert and self.compiled.postfetch_lastrowid:
             self._lastrowid = self.cursor.lastrowid
 
-    def get_lastrowid(self):
+    def get_lastrowid(self) -> int:
+        if TYPE_CHECKING:
+            assert self._lastrowid is not None
         return self._lastrowid
 
 
@@ -127,7 +151,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
     )
 
     @util.memoized_property
-    def _dbapi_version(self):
+    def _dbapi_version(self) -> tuple[int, ...]:
         if self.dbapi and hasattr(self.dbapi, "__version__"):
             return tuple(
                 [
@@ -140,7 +164,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
         else:
             return (99, 99, 99)
 
-    def __init__(self, **kwargs):
+    def __init__(self, **kwargs: Any) -> None:
         super().__init__(**kwargs)
         self.paramstyle = "qmark"
         if self.dbapi is not None:
@@ -152,19 +176,24 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
                 )
 
     @classmethod
-    def import_dbapi(cls):
+    def import_dbapi(cls) -> DBAPIModule:
         return __import__("mariadb")
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: DBAPIModule.Error,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
         if super().is_disconnect(e, connection, cursor):
             return True
-        elif isinstance(e, self.dbapi.Error):
+        elif isinstance(e, self.loaded_dbapi.Error):
             str_e = str(e).lower()
             return "not connected" in str_e or "isn't valid" in str_e
         else:
             return False
 
-    def create_connect_args(self, url):
+    def create_connect_args(self, url: URL) -> ConnectArgsType:
         opts = url.translate_connect_args()
         opts.update(url.query)
 
@@ -201,19 +230,21 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
             except (AttributeError, ImportError):
                 self.supports_sane_rowcount = False
             opts["client_flag"] = client_flag
-        return [[], opts]
+        return [], opts
 
-    def _extract_error_code(self, exception):
+    def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
         try:
-            rc = exception.errno
+            rc: int = exception.errno
         except:
             rc = -1
         return rc
 
-    def _detect_charset(self, connection):
+    def _detect_charset(self, connection: Connection) -> str:
         return "utf8mb4"
 
-    def get_isolation_level_values(self, dbapi_connection):
+    def get_isolation_level_values(
+        self, dbapi_conn: DBAPIConnection
+    ) -> Sequence[IsolationLevel]:
         return (
             "SERIALIZABLE",
             "READ UNCOMMITTED",
@@ -222,21 +253,23 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
             "AUTOCOMMIT",
         )
 
-    def set_isolation_level(self, connection, level):
+    def set_isolation_level(
+        self, dbapi_connection: DBAPIConnection, level: IsolationLevel
+    ) -> None:
         if level == "AUTOCOMMIT":
-            connection.autocommit = True
+            dbapi_connection.autocommit = True
         else:
-            connection.autocommit = False
-            super().set_isolation_level(connection, level)
+            dbapi_connection.autocommit = False
+            super().set_isolation_level(dbapi_connection, level)
 
-    def do_begin_twophase(self, connection, xid):
+    def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
         connection.execute(
             sql.text("XA BEGIN :xid").bindparams(
                 sql.bindparam("xid", xid, literal_execute=True)
             )
         )
 
-    def do_prepare_twophase(self, connection, xid):
+    def do_prepare_twophase(self, connection: Connection, xid: Any) -> None:
         connection.execute(
             sql.text("XA END :xid").bindparams(
                 sql.bindparam("xid", xid, literal_execute=True)
@@ -249,8 +282,12 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
         )
 
     def do_rollback_twophase(
-        self, connection, xid, is_prepared=True, recover=False
-    ):
+        self,
+        connection: Connection,
+        xid: Any,
+        is_prepared: bool = True,
+        recover: bool = False,
+    ) -> None:
         if not is_prepared:
             connection.execute(
                 sql.text("XA END :xid").bindparams(
@@ -264,8 +301,12 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
         )
 
     def do_commit_twophase(
-        self, connection, xid, is_prepared=True, recover=False
-    ):
+        self,
+        connection: Connection,
+        xid: Any,
+        is_prepared: bool = True,
+        recover: bool = False,
+    ) -> None:
         if not is_prepared:
             self.do_prepare_twophase(connection, xid)
         connection.execute(
index faeae16abd58b915504f2ab8920ac2da5f8d3764..b36248cb35ae6eef2c0f566337de0772640b6778 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 
 r"""
@@ -46,29 +45,54 @@ charset/collation will allow connectivity.
 
 
 """  # noqa
+from __future__ import annotations
 
 import re
+from typing import Any
+from typing import cast
+from typing import Optional
+from typing import Sequence
+from typing import TYPE_CHECKING
+from typing import Union
 
-from .base import BIT
 from .base import MariaDBIdentifierPreparer
 from .base import MySQLCompiler
 from .base import MySQLDialect
 from .base import MySQLExecutionContext
 from .base import MySQLIdentifierPreparer
 from .mariadb import MariaDBDialect
+from .types import BIT
 from ... import util
 
+if TYPE_CHECKING:
+
+    from ...engine.base import Connection
+    from ...engine.cursor import CursorResult
+    from ...engine.interfaces import ConnectArgsType
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.interfaces import IsolationLevel
+    from ...engine.interfaces import PoolProxiedConnection
+    from ...engine.row import Row
+    from ...engine.url import URL
+    from ...sql.elements import BinaryExpression
+    from ...util.typing import TupleAny
+    from ...util.typing import Unpack
+
 
 class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
-    def create_server_side_cursor(self):
+    def create_server_side_cursor(self) -> DBAPICursor:
         return self._dbapi_connection.cursor(buffered=False)
 
-    def create_default_cursor(self):
+    def create_default_cursor(self) -> DBAPICursor:
         return self._dbapi_connection.cursor(buffered=True)
 
 
 class MySQLCompiler_mysqlconnector(MySQLCompiler):
-    def visit_mod_binary(self, binary, operator, **kw):
+    def visit_mod_binary(
+        self, binary: BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         return (
             self.process(binary.left, **kw)
             + " % "
@@ -78,32 +102,35 @@ class MySQLCompiler_mysqlconnector(MySQLCompiler):
 
 class IdentifierPreparerCommon_mysqlconnector:
     @property
-    def _double_percents(self):
+    def _double_percents(self) -> bool:
         return False
 
     @_double_percents.setter
-    def _double_percents(self, value):
+    def _double_percents(self, value: Any) -> None:
         pass
 
-    def _escape_identifier(self, value):
-        value = value.replace(self.escape_quote, self.escape_to_quote)
+    def _escape_identifier(self, value: str) -> str:
+        value = value.replace(
+            self.escape_quote,  # type:ignore[attr-defined]
+            self.escape_to_quote,  # type:ignore[attr-defined]
+        )
         return value
 
 
-class MySQLIdentifierPreparer_mysqlconnector(
+class MySQLIdentifierPreparer_mysqlconnector(  # type:ignore[misc]
     IdentifierPreparerCommon_mysqlconnector, MySQLIdentifierPreparer
 ):
     pass
 
 
-class MariaDBIdentifierPreparer_mysqlconnector(
+class MariaDBIdentifierPreparer_mysqlconnector(  # type:ignore[misc]
     IdentifierPreparerCommon_mysqlconnector, MariaDBIdentifierPreparer
 ):
     pass
 
 
 class _myconnpyBIT(BIT):
-    def result_processor(self, dialect, coltype):
+    def result_processor(self, dialect: Any, coltype: Any) -> None:
         """MySQL-connector already converts mysql bits, so."""
 
         return None
@@ -128,21 +155,21 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
 
     execution_ctx_cls = MySQLExecutionContext_mysqlconnector
 
-    preparer = MySQLIdentifierPreparer_mysqlconnector
+    preparer: type[MySQLIdentifierPreparer] = (
+        MySQLIdentifierPreparer_mysqlconnector
+    )
 
     colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
 
     @classmethod
-    def import_dbapi(cls):
-        from mysql import connector
+    def import_dbapi(cls) -> DBAPIModule:
+        return cast(DBAPIModule, __import__("mysql.connector").connector)
 
-        return connector
-
-    def do_ping(self, dbapi_connection):
+    def do_ping(self, dbapi_connection: DBAPIConnection) -> bool:
         dbapi_connection.ping(False)
         return True
 
-    def create_connect_args(self, url):
+    def create_connect_args(self, url: URL) -> ConnectArgsType:
         opts = url.translate_connect_args(username="user")
 
         opts.update(url.query)
@@ -177,7 +204,9 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
         # supports_sane_rowcount.
         if self.dbapi is not None:
             try:
-                from mysql.connector.constants import ClientFlag
+                from mysql.connector import constants  # type: ignore
+
+                ClientFlag = constants.ClientFlag
 
                 client_flags = opts.get(
                     "client_flags", ClientFlag.get_default()
@@ -187,27 +216,33 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
             except Exception:
                 pass
 
-        return [[], opts]
+        return [], opts
 
     @util.memoized_property
-    def _mysqlconnector_version_info(self):
+    def _mysqlconnector_version_info(self) -> Optional[tuple[int, ...]]:
         if self.dbapi and hasattr(self.dbapi, "__version__"):
             m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
             if m:
                 return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
+        return None
 
-    def _detect_charset(self, connection):
-        return connection.connection.charset
+    def _detect_charset(self, connection: Connection) -> str:
+        return connection.connection.charset  # type: ignore
 
-    def _extract_error_code(self, exception):
-        return exception.errno
+    def _extract_error_code(self, exception: BaseException) -> int:
+        return exception.errno  # type: ignore
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: Exception,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
         errnos = (2006, 2013, 2014, 2045, 2055, 2048)
         exceptions = (
-            self.dbapi.OperationalError,
-            self.dbapi.InterfaceError,
-            self.dbapi.ProgrammingError,
+            self.loaded_dbapi.OperationalError,  #
+            self.loaded_dbapi.InterfaceError,
+            self.loaded_dbapi.ProgrammingError,
         )
         if isinstance(e, exceptions):
             return (
@@ -218,13 +253,23 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
         else:
             return False
 
-    def _compat_fetchall(self, rp, charset=None):
+    def _compat_fetchall(
+        self,
+        rp: CursorResult[Unpack[TupleAny]],
+        charset: Optional[str] = None,
+    ) -> Sequence[Row[Unpack[TupleAny]]]:
         return rp.fetchall()
 
-    def _compat_fetchone(self, rp, charset=None):
+    def _compat_fetchone(
+        self,
+        rp: CursorResult[Unpack[TupleAny]],
+        charset: Optional[str] = None,
+    ) -> Optional[Row[Unpack[TupleAny]]]:
         return rp.fetchone()
 
-    def get_isolation_level_values(self, dbapi_connection):
+    def get_isolation_level_values(
+        self, dbapi_conn: DBAPIConnection
+    ) -> Sequence[IsolationLevel]:
         return (
             "SERIALIZABLE",
             "READ UNCOMMITTED",
@@ -233,12 +278,14 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
             "AUTOCOMMIT",
         )
 
-    def set_isolation_level(self, connection, level):
+    def set_isolation_level(
+        self, dbapi_connection: DBAPIConnection, level: IsolationLevel
+    ) -> None:
         if level == "AUTOCOMMIT":
-            connection.autocommit = True
+            dbapi_connection.autocommit = True
         else:
-            connection.autocommit = False
-            super().set_isolation_level(connection, level)
+            dbapi_connection.autocommit = False
+            super().set_isolation_level(dbapi_connection, level)
 
 
 class MariaDBDialect_mysqlconnector(
index 3cf56c1fd0942122193202550ffe8167b308630f..14a4c00e4c0e03219e33f58bd312eb409814264b 100644 (file)
@@ -4,8 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
 
 """
 
@@ -86,17 +84,34 @@ Server Side Cursors
 The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
 
 """
+from __future__ import annotations
 
 import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Literal
+from typing import Optional
+from typing import TYPE_CHECKING
 
 from .base import MySQLCompiler
 from .base import MySQLDialect
 from .base import MySQLExecutionContext
 from .base import MySQLIdentifierPreparer
-from .base import TEXT
-from ... import sql
 from ... import util
 
+if TYPE_CHECKING:
+
+    from ...engine.base import Connection
+    from ...engine.interfaces import _DBAPIMultiExecuteParams
+    from ...engine.interfaces import ConnectArgsType
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.interfaces import ExecutionContext
+    from ...engine.interfaces import IsolationLevel
+    from ...engine.url import URL
+
 
 class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
     pass
@@ -119,8 +134,9 @@ class MySQLDialect_mysqldb(MySQLDialect):
     execution_ctx_cls = MySQLExecutionContext_mysqldb
     statement_compiler = MySQLCompiler_mysqldb
     preparer = MySQLIdentifierPreparer
+    server_version_info: tuple[int, ...]
 
-    def __init__(self, **kwargs):
+    def __init__(self, **kwargs: Any):
         super().__init__(**kwargs)
         self._mysql_dbapi_version = (
             self._parse_dbapi_version(self.dbapi.__version__)
@@ -128,7 +144,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
             else (0, 0, 0)
         )
 
-    def _parse_dbapi_version(self, version):
+    def _parse_dbapi_version(self, version: str) -> tuple[int, ...]:
         m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
         if m:
             return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
@@ -136,7 +152,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
             return (0, 0, 0)
 
     @util.langhelpers.memoized_property
-    def supports_server_side_cursors(self):
+    def supports_server_side_cursors(self) -> bool:  # type: ignore[override]
         try:
             cursors = __import__("MySQLdb.cursors").cursors
             self._sscursor = cursors.SSCursor
@@ -145,13 +161,13 @@ class MySQLDialect_mysqldb(MySQLDialect):
             return False
 
     @classmethod
-    def import_dbapi(cls):
+    def import_dbapi(cls) -> DBAPIModule:
         return __import__("MySQLdb")
 
-    def on_connect(self):
+    def on_connect(self) -> Callable[[DBAPIConnection], None]:
         super_ = super().on_connect()
 
-        def on_connect(conn):
+        def on_connect(conn: DBAPIConnection) -> None:
             if super_ is not None:
                 super_(conn)
 
@@ -164,43 +180,24 @@ class MySQLDialect_mysqldb(MySQLDialect):
 
         return on_connect
 
-    def do_ping(self, dbapi_connection):
+    def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]:
         dbapi_connection.ping()
         return True
 
-    def do_executemany(self, cursor, statement, parameters, context=None):
+    def do_executemany(
+        self,
+        cursor: DBAPICursor,
+        statement: str,
+        parameters: _DBAPIMultiExecuteParams,
+        context: Optional[ExecutionContext] = None,
+    ) -> None:
         rowcount = cursor.executemany(statement, parameters)
         if context is not None:
-            context._rowcount = rowcount
-
-    def _check_unicode_returns(self, connection):
-        # work around issue fixed in
-        # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
-        # specific issue w/ the utf8mb4_bin collation and unicode returns
-
-        collation = connection.exec_driver_sql(
-            "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
-            % (
-                self.identifier_preparer.quote("Charset"),
-                self.identifier_preparer.quote("Collation"),
-            )
-        ).scalar()
-        has_utf8mb4_bin = self.server_version_info > (5,) and collation
-        if has_utf8mb4_bin:
-            additional_tests = [
-                sql.collate(
-                    sql.cast(
-                        sql.literal_column("'test collated returns'"),
-                        TEXT(charset="utf8mb4"),
-                    ),
-                    "utf8mb4_bin",
-                )
-            ]
-        else:
-            additional_tests = []
-        return super()._check_unicode_returns(connection, additional_tests)
+            cast(MySQLExecutionContext, context)._rowcount = rowcount
 
-    def create_connect_args(self, url, _translate_args=None):
+    def create_connect_args(
+        self, url: URL, _translate_args: Optional[dict[str, Any]] = None
+    ) -> ConnectArgsType:
         if _translate_args is None:
             _translate_args = dict(
                 database="db", username="user", password="passwd"
@@ -249,9 +246,9 @@ class MySQLDialect_mysqldb(MySQLDialect):
         if client_flag_found_rows is not None:
             client_flag |= client_flag_found_rows
             opts["client_flag"] = client_flag
-        return [[], opts]
+        return [], opts
 
-    def _found_rows_client_flag(self):
+    def _found_rows_client_flag(self) -> Optional[int]:
         if self.dbapi is not None:
             try:
                 CLIENT_FLAGS = __import__(
@@ -260,20 +257,23 @@ class MySQLDialect_mysqldb(MySQLDialect):
             except (AttributeError, ImportError):
                 return None
             else:
-                return CLIENT_FLAGS.FOUND_ROWS
+                return CLIENT_FLAGS.FOUND_ROWS  # type: ignore
         else:
             return None
 
-    def _extract_error_code(self, exception):
-        return exception.args[0]
+    def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
+        return exception.args[0]  # type: ignore[no-any-return]
 
-    def _detect_charset(self, connection):
+    def _detect_charset(self, connection: Connection) -> str:
         """Sniff out the character set in use for connection results."""
 
         try:
             # note: the SQL here would be
             # "SHOW VARIABLES LIKE 'character_set%%'"
-            cset_name = connection.connection.character_set_name
+
+            cset_name: Callable[[], str] = (
+                connection.connection.character_set_name
+            )
         except AttributeError:
             util.warn(
                 "No 'character_set_name' can be detected with "
@@ -285,7 +285,9 @@ class MySQLDialect_mysqldb(MySQLDialect):
         else:
             return cset_name()
 
-    def get_isolation_level_values(self, dbapi_connection):
+    def get_isolation_level_values(
+        self, dbapi_conn: DBAPIConnection
+    ) -> tuple[IsolationLevel, ...]:
         return (
             "SERIALIZABLE",
             "READ UNCOMMITTED",
@@ -294,7 +296,9 @@ class MySQLDialect_mysqldb(MySQLDialect):
             "AUTOCOMMIT",
         )
 
-    def set_isolation_level(self, dbapi_connection, level):
+    def set_isolation_level(
+        self, dbapi_connection: DBAPIConnection, level: IsolationLevel
+    ) -> None:
         if level == "AUTOCOMMIT":
             dbapi_connection.autocommit(True)
         else:
index 46070848cb11743e5874366e2498ff8afa869b42..fe97672ad85791ffeb88d35dc9fe7fa53623fd7e 100644 (file)
@@ -5,7 +5,6 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 # mypy: ignore-errors
-
 from ... import exc
 from ...testing.provision import configure_follower
 from ...testing.provision import create_db
index 67cb4cdd766c40661d89d4532a23c17c7126bf3d..e754bb6fcfca59a5baa4f174926898640467be83 100644 (file)
@@ -4,8 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
 
 r"""
 
@@ -49,10 +47,26 @@ and targets 100% compatibility.   Most behavioral notes for MySQL-python apply
 to the pymysql driver as well.
 
 """  # noqa
+from __future__ import annotations
+
+from typing import Any
+from typing import Literal
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from .mysqldb import MySQLDialect_mysqldb
 from ...util import langhelpers
 
+if TYPE_CHECKING:
+
+    from ...engine.interfaces import ConnectArgsType
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import DBAPICursor
+    from ...engine.interfaces import DBAPIModule
+    from ...engine.interfaces import PoolProxiedConnection
+    from ...engine.url import URL
+
 
 class MySQLDialect_pymysql(MySQLDialect_mysqldb):
     driver = "pymysql"
@@ -61,7 +75,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
     description_encoding = None
 
     @langhelpers.memoized_property
-    def supports_server_side_cursors(self):
+    def supports_server_side_cursors(self) -> bool:  # type: ignore[override]
         try:
             cursors = __import__("pymysql.cursors").cursors
             self._sscursor = cursors.SSCursor
@@ -70,11 +84,11 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
             return False
 
     @classmethod
-    def import_dbapi(cls):
+    def import_dbapi(cls) -> DBAPIModule:
         return __import__("pymysql")
 
     @langhelpers.memoized_property
-    def _send_false_to_ping(self):
+    def _send_false_to_ping(self) -> bool:
         """determine if pymysql has deprecated, changed the default of,
         or removed the 'reconnect' argument of connection.ping().
 
@@ -101,7 +115,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
                     not insp.defaults or insp.defaults[0] is not False
                 )
 
-    def do_ping(self, dbapi_connection):
+    def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]:  # type: ignore # noqa: E501
         if self._send_false_to_ping:
             dbapi_connection.ping(False)
         else:
@@ -109,17 +123,24 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
 
         return True
 
-    def create_connect_args(self, url, _translate_args=None):
+    def create_connect_args(
+        self, url: URL, _translate_args: Optional[dict[str, Any]] = None
+    ) -> ConnectArgsType:
         if _translate_args is None:
             _translate_args = dict(username="user")
         return super().create_connect_args(
             url, _translate_args=_translate_args
         )
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: DBAPIModule.Error,
+        connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+        cursor: Optional[DBAPICursor],
+    ) -> bool:
         if super().is_disconnect(e, connection, cursor):
             return True
-        elif isinstance(e, self.dbapi.Error):
+        elif isinstance(e, self.loaded_dbapi.Error):
             str_e = str(e).lower()
             return (
                 "already closed" in str_e or "connection was killed" in str_e
@@ -127,7 +148,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
         else:
             return False
 
-    def _extract_error_code(self, exception):
+    def _extract_error_code(self, exception: BaseException) -> Any:
         if isinstance(exception.args[0], Exception):
             exception = exception.args[0]
         return exception.args[0]
index 6d44bd3837067715424398a5ebcbf7bcb0d55569..86b19bd84de85c4be4839c6e49b71502987e44a9 100644 (file)
@@ -4,12 +4,10 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 
 r"""
 
-
 .. dialect:: mysql+pyodbc
     :name: PyODBC
     :dbapi: pyodbc
@@ -44,8 +42,15 @@ Pass through exact pyodbc connection string::
     connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params
 
 """  # noqa
+from __future__ import annotations
 
+import datetime
 import re
+from typing import Any
+from typing import Callable
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from .base import MySQLDialect
 from .base import MySQLExecutionContext
@@ -55,23 +60,31 @@ from ... import util
 from ...connectors.pyodbc import PyODBCConnector
 from ...sql.sqltypes import Time
 
+if TYPE_CHECKING:
+    from ...engine import Connection
+    from ...engine.interfaces import DBAPIConnection
+    from ...engine.interfaces import Dialect
+    from ...sql.type_api import _ResultProcessorType
+
 
 class _pyodbcTIME(TIME):
-    def result_processor(self, dialect, coltype):
-        def process(value):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> _ResultProcessorType[datetime.time]:
+        def process(value: Any) -> Union[datetime.time, None]:
             # pyodbc returns a datetime.time object; no need to convert
-            return value
+            return value  # type: ignore[no-any-return]
 
         return process
 
 
 class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
-    def get_lastrowid(self):
+    def get_lastrowid(self) -> int:
         cursor = self.create_cursor()
         cursor.execute("SELECT LAST_INSERT_ID()")
-        lastrowid = cursor.fetchone()[0]
+        lastrowid = cursor.fetchone()[0]  # type: ignore[index]
         cursor.close()
-        return lastrowid
+        return lastrowid  # type: ignore[no-any-return]
 
 
 class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
@@ -82,7 +95,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
 
     pyodbc_driver_name = "MySQL"
 
-    def _detect_charset(self, connection):
+    def _detect_charset(self, connection: Connection) -> str:
         """Sniff out the character set in use for connection results."""
 
         # Prefer 'character_set_results' for the current connection over the
@@ -107,21 +120,25 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
         )
         return "latin1"
 
-    def _get_server_version_info(self, connection):
+    def _get_server_version_info(
+        self, connection: Connection
+    ) -> tuple[int, ...]:
         return MySQLDialect._get_server_version_info(self, connection)
 
-    def _extract_error_code(self, exception):
+    def _extract_error_code(self, exception: BaseException) -> Optional[int]:
         m = re.compile(r"\((\d+)\)").search(str(exception.args))
-        c = m.group(1)
+        if m is None:
+            return None
+        c: Optional[str] = m.group(1)
         if c:
             return int(c)
         else:
             return None
 
-    def on_connect(self):
+    def on_connect(self) -> Callable[[DBAPIConnection], None]:
         super_ = super().on_connect()
 
-        def on_connect(conn):
+        def on_connect(conn: DBAPIConnection) -> None:
             if super_ is not None:
                 super_(conn)
 
index d62390bb8457d63da4eb3e02fff96bc1e499001b..127667aae9ca11d432a09e570d35ed3d9123097c 100644 (file)
@@ -4,43 +4,59 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
+from __future__ import annotations
 
 import re
+from typing import Any
+from typing import Callable
+from typing import Literal
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import TYPE_CHECKING
+from typing import Union
 
 from .enumerated import ENUM
 from .enumerated import SET
 from .types import DATETIME
 from .types import TIME
 from .types import TIMESTAMP
-from ... import log
 from ... import types as sqltypes
 from ... import util
 
+if TYPE_CHECKING:
+    from .base import MySQLDialect
+    from .base import MySQLIdentifierPreparer
+    from ...engine.interfaces import ReflectedColumn
+
 
 class ReflectedState:
     """Stores raw information about a SHOW CREATE TABLE statement."""
 
-    def __init__(self):
-        self.columns = []
-        self.table_options = {}
-        self.table_name = None
-        self.keys = []
-        self.fk_constraints = []
-        self.ck_constraints = []
+    charset: Optional[str]
+
+    def __init__(self) -> None:
+        self.columns: list[ReflectedColumn] = []
+        self.table_options: dict[str, str] = {}
+        self.table_name: Optional[str] = None
+        self.keys: list[dict[str, Any]] = []
+        self.fk_constraints: list[dict[str, Any]] = []
+        self.ck_constraints: list[dict[str, Any]] = []
 
 
-@log.class_logger
 class MySQLTableDefinitionParser:
     """Parses the results of a SHOW CREATE TABLE statement."""
 
-    def __init__(self, dialect, preparer):
+    def __init__(
+        self, dialect: MySQLDialect, preparer: MySQLIdentifierPreparer
+    ):
         self.dialect = dialect
         self.preparer = preparer
         self._prep_regexes()
 
-    def parse(self, show_create, charset):
+    def parse(
+        self, show_create: str, charset: Optional[str]
+    ) -> ReflectedState:
         state = ReflectedState()
         state.charset = charset
         for line in re.split(r"\r?\n", show_create):
@@ -65,11 +81,11 @@ class MySQLTableDefinitionParser:
                 if type_ is None:
                     util.warn("Unknown schema content: %r" % line)
                 elif type_ == "key":
-                    state.keys.append(spec)
+                    state.keys.append(spec)  # type: ignore[arg-type]
                 elif type_ == "fk_constraint":
-                    state.fk_constraints.append(spec)
+                    state.fk_constraints.append(spec)  # type: ignore[arg-type]
                 elif type_ == "ck_constraint":
-                    state.ck_constraints.append(spec)
+                    state.ck_constraints.append(spec)  # type: ignore[arg-type]
                 else:
                     pass
         return state
@@ -77,7 +93,13 @@ class MySQLTableDefinitionParser:
     def _check_view(self, sql: str) -> bool:
         return bool(self._re_is_view.match(sql))
 
-    def _parse_constraints(self, line):
+    def _parse_constraints(self, line: str) -> Union[
+        tuple[None, str],
+        tuple[Literal["partition"], str],
+        tuple[
+            Literal["ck_constraint", "fk_constraint", "key"], dict[str, str]
+        ],
+    ]:
         """Parse a KEY or CONSTRAINT line.
 
         :param line: A line of SHOW CREATE TABLE output
@@ -127,7 +149,7 @@ class MySQLTableDefinitionParser:
         # No match.
         return (None, line)
 
-    def _parse_table_name(self, line, state):
+    def _parse_table_name(self, line: str, state: ReflectedState) -> None:
         """Extract the table name.
 
         :param line: The first line of SHOW CREATE TABLE
@@ -138,7 +160,7 @@ class MySQLTableDefinitionParser:
         if m:
             state.table_name = cleanup(m.group("name"))
 
-    def _parse_table_options(self, line, state):
+    def _parse_table_options(self, line: str, state: ReflectedState) -> None:
         """Build a dictionary of all reflected table-level options.
 
         :param line: The final line of SHOW CREATE TABLE output.
@@ -164,7 +186,9 @@ class MySQLTableDefinitionParser:
         for opt, val in options.items():
             state.table_options["%s_%s" % (self.dialect.name, opt)] = val
 
-    def _parse_partition_options(self, line, state):
+    def _parse_partition_options(
+        self, line: str, state: ReflectedState
+    ) -> None:
         options = {}
         new_line = line[:]
 
@@ -220,7 +244,7 @@ class MySQLTableDefinitionParser:
             else:
                 state.table_options["%s_%s" % (self.dialect.name, opt)] = val
 
-    def _parse_column(self, line, state):
+    def _parse_column(self, line: str, state: ReflectedState) -> None:
         """Extract column details.
 
         Falls back to a 'minimal support' variant if full parse fails.
@@ -283,7 +307,7 @@ class MySQLTableDefinitionParser:
 
         type_instance = col_type(*type_args, **type_kw)
 
-        col_kw = {}
+        col_kw: dict[str, Any] = {}
 
         # NOT NULL
         col_kw["nullable"] = True
@@ -324,9 +348,13 @@ class MySQLTableDefinitionParser:
             name=name, type=type_instance, default=default, comment=comment
         )
         col_d.update(col_kw)
-        state.columns.append(col_d)
+        state.columns.append(col_d)  # type: ignore[arg-type]
 
-    def _describe_to_create(self, table_name, columns):
+    def _describe_to_create(
+        self,
+        table_name: str,
+        columns: Sequence[tuple[str, str, str, str, str, str]],
+    ) -> str:
         """Re-format DESCRIBE output as a SHOW CREATE TABLE string.
 
         DESCRIBE is a much simpler reflection and is sufficient for
@@ -379,7 +407,9 @@ class MySQLTableDefinitionParser:
             ]
         )
 
-    def _parse_keyexprs(self, identifiers):
+    def _parse_keyexprs(
+        self, identifiers: str
+    ) -> list[tuple[str, Optional[int], str]]:
         """Unpack '"col"(2),"col" ASC'-ish strings into components."""
 
         return [
@@ -389,11 +419,12 @@ class MySQLTableDefinitionParser:
             )
         ]
 
-    def _prep_regexes(self):
+    def _prep_regexes(self) -> None:
         """Pre-compile regular expressions."""
 
-        self._re_columns = []
-        self._pr_options = []
+        self._pr_options: list[
+            tuple[re.Pattern[Any], Optional[Callable[[str], str]]]
+        ] = []
 
         _final = self.preparer.final_quote
 
@@ -582,21 +613,21 @@ class MySQLTableDefinitionParser:
 
     _optional_equals = r"(?:\s*(?:=\s*)|\s+)"
 
-    def _add_option_string(self, directive):
+    def _add_option_string(self, directive: str) -> None:
         regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % (
             re.escape(directive),
             self._optional_equals,
         )
         self._pr_options.append(_pr_compile(regex, cleanup_text))
 
-    def _add_option_word(self, directive):
+    def _add_option_word(self, directive: str) -> None:
         regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % (
             re.escape(directive),
             self._optional_equals,
         )
         self._pr_options.append(_pr_compile(regex))
 
-    def _add_partition_option_word(self, directive):
+    def _add_partition_option_word(self, directive: str) -> None:
         if directive == "PARTITION BY" or directive == "SUBPARTITION BY":
             regex = r"(?<!\S)(?P<directive>%s)%s" r"(?P<val>\w+.*)" % (
                 re.escape(directive),
@@ -611,7 +642,7 @@ class MySQLTableDefinitionParser:
             regex = r"(?<!\S)(?P<directive>%s)(?!\S)" % (re.escape(directive),)
         self._pr_options.append(_pr_compile(regex))
 
-    def _add_option_regex(self, directive, regex):
+    def _add_option_regex(self, directive: str, regex: str) -> None:
         regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % (
             re.escape(directive),
             self._optional_equals,
@@ -629,21 +660,35 @@ _options_of_type_string = (
 )
 
 
-def _pr_compile(regex, cleanup=None):
+@overload
+def _pr_compile(
+    regex: str, cleanup: Callable[[str], str]
+) -> tuple[re.Pattern[Any], Callable[[str], str]]: ...
+
+
+@overload
+def _pr_compile(
+    regex: str, cleanup: None = None
+) -> tuple[re.Pattern[Any], None]: ...
+
+
+def _pr_compile(
+    regex: str, cleanup: Optional[Callable[[str], str]] = None
+) -> tuple[re.Pattern[Any], Optional[Callable[[str], str]]]:
     """Prepare a 2-tuple of compiled regex and callable."""
 
     return (_re_compile(regex), cleanup)
 
 
-def _re_compile(regex):
+def _re_compile(regex: str) -> re.Pattern[Any]:
     """Compile a string to regex, I and UNICODE."""
 
     return re.compile(regex, re.I | re.UNICODE)
 
 
-def _strip_values(values):
+def _strip_values(values: Sequence[str]) -> list[str]:
     "Strip reflected values quotes"
-    strip_values = []
+    strip_values: list[str] = []
     for a in values:
         if a[0:1] == '"' or a[0:1] == "'":
             # strip enclosing quotes and unquote interior
@@ -655,7 +700,9 @@ def _strip_values(values):
 def cleanup_text(raw_text: str) -> str:
     if "\\" in raw_text:
         raw_text = re.sub(
-            _control_char_regexp, lambda s: _control_char_map[s[0]], raw_text
+            _control_char_regexp,
+            lambda s: _control_char_map[s[0]],  # type: ignore[index]
+            raw_text,
         )
     return raw_text.replace("''", "'")
 
index 34fecf42724bdbe796425aed50db7aec2e2ba635..ff526394a695a57cc98840f4334b085031165587 100644 (file)
@@ -11,7 +11,6 @@
 # https://mariadb.com/kb/en/reserved-words/
 # includes: Reserved Words, Oracle Mode (separate set unioned)
 # excludes: Exceptions, Function Names
-# mypy: ignore-errors
 
 RESERVED_WORDS_MARIADB = {
     "accessible",
index 015d51a105816ff0e26250029830194a8e7cde92..8621f5b9864dd4dd62d2cf1c7caf178367483fb0 100644 (file)
@@ -4,15 +4,26 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
+from __future__ import annotations
 
 import datetime
+import decimal
+from typing import Any
+from typing import Iterable
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from ... import exc
 from ... import util
 from ...sql import sqltypes
 
+if TYPE_CHECKING:
+    from .base import MySQLDialect
+    from ...engine.interfaces import Dialect
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _ResultProcessorType
+
 
 class _NumericCommonType:
     """Base for MySQL numeric types.
@@ -22,24 +33,36 @@ class _NumericCommonType:
 
     """
 
-    def __init__(self, unsigned=False, zerofill=False, **kw):
+    def __init__(
+        self, unsigned: bool = False, zerofill: bool = False, **kw: Any
+    ):
         self.unsigned = unsigned
         self.zerofill = zerofill
         super().__init__(**kw)
 
 
-class _NumericType(_NumericCommonType, sqltypes.Numeric):
+class _NumericType(
+    _NumericCommonType, sqltypes.Numeric[Union[decimal.Decimal, float]]
+):
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return util.generic_repr(
             self,
             to_inspect=[_NumericType, _NumericCommonType, sqltypes.Numeric],
         )
 
 
-class _FloatType(_NumericCommonType, sqltypes.Float):
+class _FloatType(
+    _NumericCommonType, sqltypes.Float[Union[decimal.Decimal, float]]
+):
 
-    def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+    def __init__(
+        self,
+        precision: Optional[int] = None,
+        scale: Optional[int] = None,
+        asdecimal: bool = True,
+        **kw: Any,
+    ):
         if isinstance(self, (REAL, DOUBLE)) and (
             (precision is None and scale is not None)
             or (precision is not None and scale is None)
@@ -51,18 +74,18 @@ class _FloatType(_NumericCommonType, sqltypes.Float):
         super().__init__(precision=precision, asdecimal=asdecimal, **kw)
         self.scale = scale
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return util.generic_repr(
             self, to_inspect=[_FloatType, _NumericCommonType, sqltypes.Float]
         )
 
 
 class _IntegerType(_NumericCommonType, sqltypes.Integer):
-    def __init__(self, display_width=None, **kw):
+    def __init__(self, display_width: Optional[int] = None, **kw: Any):
         self.display_width = display_width
         super().__init__(**kw)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return util.generic_repr(
             self,
             to_inspect=[_IntegerType, _NumericCommonType, sqltypes.Integer],
@@ -74,13 +97,13 @@ class _StringType(sqltypes.String):
 
     def __init__(
         self,
-        charset=None,
-        collation=None,
-        ascii=False,  # noqa
-        binary=False,
-        unicode=False,
-        national=False,
-        **kw,
+        charset: Optional[str] = None,
+        collation: Optional[str] = None,
+        ascii: bool = False,  # noqa
+        binary: bool = False,
+        unicode: bool = False,
+        national: bool = False,
+        **kw: Any,
     ):
         self.charset = charset
 
@@ -93,25 +116,33 @@ class _StringType(sqltypes.String):
         self.national = national
         super().__init__(**kw)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return util.generic_repr(
             self, to_inspect=[_StringType, sqltypes.String]
         )
 
 
-class _MatchType(sqltypes.Float, sqltypes.MatchType):
-    def __init__(self, **kw):
+class _MatchType(
+    sqltypes.Float[Union[decimal.Decimal, float]], sqltypes.MatchType
+):
+    def __init__(self, **kw: Any):
         # TODO: float arguments?
-        sqltypes.Float.__init__(self)
+        sqltypes.Float.__init__(self)  # type: ignore[arg-type]
         sqltypes.MatchType.__init__(self)
 
 
-class NUMERIC(_NumericType, sqltypes.NUMERIC):
+class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]):
     """MySQL NUMERIC type."""
 
     __visit_name__ = "NUMERIC"
 
-    def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+    def __init__(
+        self,
+        precision: Optional[int] = None,
+        scale: Optional[int] = None,
+        asdecimal: bool = True,
+        **kw: Any,
+    ):
         """Construct a NUMERIC.
 
         :param precision: Total digits in this number.  If scale and precision
@@ -132,12 +163,18 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC):
         )
 
 
-class DECIMAL(_NumericType, sqltypes.DECIMAL):
+class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]):
     """MySQL DECIMAL type."""
 
     __visit_name__ = "DECIMAL"
 
-    def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+    def __init__(
+        self,
+        precision: Optional[int] = None,
+        scale: Optional[int] = None,
+        asdecimal: bool = True,
+        **kw: Any,
+    ):
         """Construct a DECIMAL.
 
         :param precision: Total digits in this number.  If scale and precision
@@ -158,12 +195,18 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL):
         )
 
 
-class DOUBLE(_FloatType, sqltypes.DOUBLE):
+class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]):
     """MySQL DOUBLE type."""
 
     __visit_name__ = "DOUBLE"
 
-    def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+    def __init__(
+        self,
+        precision: Optional[int] = None,
+        scale: Optional[int] = None,
+        asdecimal: bool = True,
+        **kw: Any,
+    ):
         """Construct a DOUBLE.
 
         .. note::
@@ -192,12 +235,18 @@ class DOUBLE(_FloatType, sqltypes.DOUBLE):
         )
 
 
-class REAL(_FloatType, sqltypes.REAL):
+class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]):
     """MySQL REAL type."""
 
     __visit_name__ = "REAL"
 
-    def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+    def __init__(
+        self,
+        precision: Optional[int] = None,
+        scale: Optional[int] = None,
+        asdecimal: bool = True,
+        **kw: Any,
+    ):
         """Construct a REAL.
 
         .. note::
@@ -226,12 +275,18 @@ class REAL(_FloatType, sqltypes.REAL):
         )
 
 
-class FLOAT(_FloatType, sqltypes.FLOAT):
+class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]):
     """MySQL FLOAT type."""
 
     __visit_name__ = "FLOAT"
 
-    def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
+    def __init__(
+        self,
+        precision: Optional[int] = None,
+        scale: Optional[int] = None,
+        asdecimal: bool = False,
+        **kw: Any,
+    ):
         """Construct a FLOAT.
 
         :param precision: Total digits in this number.  If scale and precision
@@ -251,7 +306,9 @@ class FLOAT(_FloatType, sqltypes.FLOAT):
             precision=precision, scale=scale, asdecimal=asdecimal, **kw
         )
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> Optional[_BindProcessorType[Union[decimal.Decimal, float]]]:
         return None
 
 
@@ -260,7 +317,7 @@ class INTEGER(_IntegerType, sqltypes.INTEGER):
 
     __visit_name__ = "INTEGER"
 
-    def __init__(self, display_width=None, **kw):
+    def __init__(self, display_width: Optional[int] = None, **kw: Any):
         """Construct an INTEGER.
 
         :param display_width: Optional, maximum display width for this number.
@@ -281,7 +338,7 @@ class BIGINT(_IntegerType, sqltypes.BIGINT):
 
     __visit_name__ = "BIGINT"
 
-    def __init__(self, display_width=None, **kw):
+    def __init__(self, display_width: Optional[int] = None, **kw: Any):
         """Construct a BIGINTEGER.
 
         :param display_width: Optional, maximum display width for this number.
@@ -302,7 +359,7 @@ class MEDIUMINT(_IntegerType):
 
     __visit_name__ = "MEDIUMINT"
 
-    def __init__(self, display_width=None, **kw):
+    def __init__(self, display_width: Optional[int] = None, **kw: Any):
         """Construct a MEDIUMINTEGER
 
         :param display_width: Optional, maximum display width for this number.
@@ -323,7 +380,7 @@ class TINYINT(_IntegerType):
 
     __visit_name__ = "TINYINT"
 
-    def __init__(self, display_width=None, **kw):
+    def __init__(self, display_width: Optional[int] = None, **kw: Any):
         """Construct a TINYINT.
 
         :param display_width: Optional, maximum display width for this number.
@@ -344,7 +401,7 @@ class SMALLINT(_IntegerType, sqltypes.SMALLINT):
 
     __visit_name__ = "SMALLINT"
 
-    def __init__(self, display_width=None, **kw):
+    def __init__(self, display_width: Optional[int] = None, **kw: Any):
         """Construct a SMALLINTEGER.
 
         :param display_width: Optional, maximum display width for this number.
@@ -360,7 +417,7 @@ class SMALLINT(_IntegerType, sqltypes.SMALLINT):
         super().__init__(display_width=display_width, **kw)
 
 
-class BIT(sqltypes.TypeEngine):
+class BIT(sqltypes.TypeEngine[Any]):
     """MySQL BIT type.
 
     This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
@@ -371,7 +428,7 @@ class BIT(sqltypes.TypeEngine):
 
     __visit_name__ = "BIT"
 
-    def __init__(self, length=None):
+    def __init__(self, length: Optional[int] = None):
         """Construct a BIT.
 
         :param length: Optional, number of bits.
@@ -379,19 +436,19 @@ class BIT(sqltypes.TypeEngine):
         """
         self.length = length
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: MySQLDialect, coltype: object  # type: ignore[override]
+    ) -> Optional[_ResultProcessorType[Any]]:
         """Convert a MySQL's 64 bit, variable length binary string to a
         long."""
 
         if dialect.supports_native_bit:
             return None
 
-        def process(value):
+        def process(value: Optional[Iterable[int]]) -> Optional[int]:
             if value is not None:
                 v = 0
                 for i in value:
-                    if not isinstance(i, int):
-                        i = ord(i)  # convert byte to int on Python 2
                     v = v << 8 | i
                 return v
             return value
@@ -404,7 +461,7 @@ class TIME(sqltypes.TIME):
 
     __visit_name__ = "TIME"
 
-    def __init__(self, timezone=False, fsp=None):
+    def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
         """Construct a MySQL TIME type.
 
         :param timezone: not used by the MySQL dialect.
@@ -423,10 +480,12 @@ class TIME(sqltypes.TIME):
         super().__init__(timezone=timezone)
         self.fsp = fsp
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> _ResultProcessorType[datetime.time]:
         time = datetime.time
 
-        def process(value):
+        def process(value: Any) -> Optional[datetime.time]:
             # convert from a timedelta value
             if value is not None:
                 microseconds = value.microseconds
@@ -449,7 +508,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
 
     __visit_name__ = "TIMESTAMP"
 
-    def __init__(self, timezone=False, fsp=None):
+    def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
         """Construct a MySQL TIMESTAMP type.
 
         :param timezone: not used by the MySQL dialect.
@@ -474,7 +533,7 @@ class DATETIME(sqltypes.DATETIME):
 
     __visit_name__ = "DATETIME"
 
-    def __init__(self, timezone=False, fsp=None):
+    def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
         """Construct a MySQL DATETIME type.
 
         :param timezone: not used by the MySQL dialect.
@@ -494,12 +553,12 @@ class DATETIME(sqltypes.DATETIME):
         self.fsp = fsp
 
 
-class YEAR(sqltypes.TypeEngine):
+class YEAR(sqltypes.TypeEngine[Any]):
     """MySQL YEAR type, for single byte storage of years 1901-2155."""
 
     __visit_name__ = "YEAR"
 
-    def __init__(self, display_width=None):
+    def __init__(self, display_width: Optional[int] = None):
         self.display_width = display_width
 
 
@@ -508,7 +567,7 @@ class TEXT(_StringType, sqltypes.TEXT):
 
     __visit_name__ = "TEXT"
 
-    def __init__(self, length=None, **kw):
+    def __init__(self, length: Optional[int] = None, **kw: Any):
         """Construct a TEXT.
 
         :param length: Optional, if provided the server may optimize storage
@@ -544,7 +603,7 @@ class TINYTEXT(_StringType):
 
     __visit_name__ = "TINYTEXT"
 
-    def __init__(self, **kwargs):
+    def __init__(self, **kwargs: Any):
         """Construct a TINYTEXT.
 
         :param charset: Optional, a column-level character set for this string
@@ -577,7 +636,7 @@ class MEDIUMTEXT(_StringType):
 
     __visit_name__ = "MEDIUMTEXT"
 
-    def __init__(self, **kwargs):
+    def __init__(self, **kwargs: Any):
         """Construct a MEDIUMTEXT.
 
         :param charset: Optional, a column-level character set for this string
@@ -609,7 +668,7 @@ class LONGTEXT(_StringType):
 
     __visit_name__ = "LONGTEXT"
 
-    def __init__(self, **kwargs):
+    def __init__(self, **kwargs: Any):
         """Construct a LONGTEXT.
 
         :param charset: Optional, a column-level character set for this string
@@ -641,7 +700,7 @@ class VARCHAR(_StringType, sqltypes.VARCHAR):
 
     __visit_name__ = "VARCHAR"
 
-    def __init__(self, length=None, **kwargs):
+    def __init__(self, length: Optional[int] = None, **kwargs: Any) -> None:
         """Construct a VARCHAR.
 
         :param charset: Optional, a column-level character set for this string
@@ -673,7 +732,7 @@ class CHAR(_StringType, sqltypes.CHAR):
 
     __visit_name__ = "CHAR"
 
-    def __init__(self, length=None, **kwargs):
+    def __init__(self, length: Optional[int] = None, **kwargs: Any):
         """Construct a CHAR.
 
         :param length: Maximum data length, in characters.
@@ -689,7 +748,7 @@ class CHAR(_StringType, sqltypes.CHAR):
         super().__init__(length=length, **kwargs)
 
     @classmethod
-    def _adapt_string_for_cast(cls, type_):
+    def _adapt_string_for_cast(cls, type_: sqltypes.String) -> sqltypes.CHAR:
         # copy the given string type into a CHAR
         # for the purposes of rendering a CAST expression
         type_ = sqltypes.to_instance(type_)
@@ -718,7 +777,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR):
 
     __visit_name__ = "NVARCHAR"
 
-    def __init__(self, length=None, **kwargs):
+    def __init__(self, length: Optional[int] = None, **kwargs: Any):
         """Construct an NVARCHAR.
 
         :param length: Maximum data length, in characters.
@@ -744,7 +803,7 @@ class NCHAR(_StringType, sqltypes.NCHAR):
 
     __visit_name__ = "NCHAR"
 
-    def __init__(self, length=None, **kwargs):
+    def __init__(self, length: Optional[int] = None, **kwargs: Any):
         """Construct an NCHAR.
 
         :param length: Maximum data length, in characters.
index 8b704d2a1b749856b961a57b8637a1e958893f14..af087a9eb86909254f63dce89b4aa9310786e3a0 100644 (file)
@@ -86,6 +86,7 @@ if typing.TYPE_CHECKING:
     from .interfaces import _ParamStyle
     from .interfaces import ConnectArgsType
     from .interfaces import DBAPIConnection
+    from .interfaces import DBAPIModule
     from .interfaces import IsolationLevel
     from .row import Row
     from .url import URL
@@ -431,7 +432,7 @@ class DefaultDialect(Dialect):
     delete_executemany_returning = False
 
     @util.memoized_property
-    def loaded_dbapi(self) -> ModuleType:
+    def loaded_dbapi(self) -> DBAPIModule:
         if self.dbapi is None:
             raise exc.InvalidRequestError(
                 f"Dialect {self} does not have a Python DBAPI established "
@@ -563,7 +564,7 @@ class DefaultDialect(Dialect):
                 % (self.label_length, self.max_identifier_length)
             )
 
-    def on_connect(self) -> Optional[Callable[[Any], Any]]:
+    def on_connect(self) -> Optional[Callable[[Any], None]]:
         # inherits the docstring from interfaces.Dialect.on_connect
         return None
 
@@ -952,7 +953,7 @@ class DefaultDialect(Dialect):
 
     def is_disconnect(
         self,
-        e: Exception,
+        e: DBAPIModule.Error,
         connection: Union[
             pool.PoolProxiedConnection, interfaces.DBAPIConnection, None
         ],
@@ -1057,7 +1058,7 @@ class DefaultDialect(Dialect):
             name = name_upper
         return name
 
-    def get_driver_connection(self, connection):
+    def get_driver_connection(self, connection: DBAPIConnection) -> Any:
         return connection
 
     def _overrides_default(self, method):
index 3a949dbbad289538060fca268b0de10a1a0ac14d..966904ba5e551a81f451ec7c0735653270adfaf4 100644 (file)
@@ -10,7 +10,6 @@
 from __future__ import annotations
 
 from enum import Enum
-from types import ModuleType
 from typing import Any
 from typing import Awaitable
 from typing import Callable
@@ -36,7 +35,7 @@ from typing import Union
 from .. import util
 from ..event import EventTarget
 from ..pool import Pool
-from ..pool import PoolProxiedConnection
+from ..pool import PoolProxiedConnection as PoolProxiedConnection
 from ..sql.compiler import Compiled as Compiled
 from ..sql.compiler import Compiled  # noqa
 from ..sql.compiler import TypeCompiler as TypeCompiler
@@ -51,6 +50,7 @@ if TYPE_CHECKING:
     from .base import Engine
     from .cursor import CursorResult
     from .url import URL
+    from ..connectors.asyncio import AsyncIODBAPIConnection
     from ..event import _ListenerFnType
     from ..event import dispatcher
     from ..exc import StatementError
@@ -70,6 +70,7 @@ if TYPE_CHECKING:
     from ..sql.sqltypes import Integer
     from ..sql.type_api import _TypeMemoDict
     from ..sql.type_api import TypeEngine
+    from ..util.langhelpers import generic_fn_descriptor
 
 ConnectArgsType = Tuple[Sequence[str], MutableMapping[str, Any]]
 
@@ -106,6 +107,22 @@ class ExecuteStyle(Enum):
     """
 
 
+class DBAPIModule(Protocol):
+    class Error(Exception):
+        def __getattr__(self, key: str) -> Any: ...
+
+    class OperationalError(Error):
+        pass
+
+    class InterfaceError(Error):
+        pass
+
+    class IntegrityError(Error):
+        pass
+
+    def __getattr__(self, key: str) -> Any: ...
+
+
 class DBAPIConnection(Protocol):
     """protocol representing a :pep:`249` database connection.
 
@@ -126,7 +143,9 @@ class DBAPIConnection(Protocol):
 
     def rollback(self) -> None: ...
 
-    autocommit: bool
+    def __getattr__(self, key: str) -> Any: ...
+
+    def __setattr__(self, key: str, value: Any) -> None: ...
 
 
 class DBAPIType(Protocol):
@@ -653,7 +672,7 @@ class Dialect(EventTarget):
 
     dialect_description: str
 
-    dbapi: Optional[ModuleType]
+    dbapi: Optional[DBAPIModule]
     """A reference to the DBAPI module object itself.
 
     SQLAlchemy dialects import DBAPI modules using the classmethod
@@ -677,7 +696,7 @@ class Dialect(EventTarget):
     """
 
     @util.non_memoized_property
-    def loaded_dbapi(self) -> ModuleType:
+    def loaded_dbapi(self) -> DBAPIModule:
         """same as .dbapi, but is never None; will raise an error if no
         DBAPI was set up.
 
@@ -781,7 +800,7 @@ class Dialect(EventTarget):
     """The maximum length of constraint names if different from
     ``max_identifier_length``."""
 
-    supports_server_side_cursors: bool
+    supports_server_side_cursors: Union[generic_fn_descriptor[bool], bool]
     """indicates if the dialect supports server side cursors"""
 
     server_side_cursors: bool
@@ -1234,7 +1253,7 @@ class Dialect(EventTarget):
         raise NotImplementedError()
 
     @classmethod
-    def import_dbapi(cls) -> ModuleType:
+    def import_dbapi(cls) -> DBAPIModule:
         """Import the DBAPI module that is used by this dialect.
 
         The Python module object returned here will be assigned as an
@@ -2202,7 +2221,7 @@ class Dialect(EventTarget):
 
     def is_disconnect(
         self,
-        e: Exception,
+        e: DBAPIModule.Error,
         connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
         cursor: Optional[DBAPICursor],
     ) -> bool:
@@ -2306,7 +2325,7 @@ class Dialect(EventTarget):
         """
         return self.on_connect()
 
-    def on_connect(self) -> Optional[Callable[[Any], Any]]:
+    def on_connect(self) -> Optional[Callable[[Any], None]]:
         """return a callable which sets up a newly created DBAPI connection.
 
         The callable should accept a single argument "conn" which is the
@@ -3356,7 +3375,7 @@ class AdaptedConnection:
 
     __slots__ = ("_connection",)
 
-    _connection: Any
+    _connection: AsyncIODBAPIConnection
 
     @property
     def driver_connection(self) -> Any:
index 39194dbad9fb2e4b83fbac3254d6b014bac018c3..7c051f12afc9e151d0e55f370c5289394d28e4d0 100644 (file)
@@ -1077,6 +1077,8 @@ class PoolProxiedConnection(ManagesConnection):
 
         def rollback(self) -> None: ...
 
+        def __getattr__(self, key: str) -> Any: ...
+
     @property
     def is_valid(self) -> bool:
         """Return True if this :class:`.PoolProxiedConnection` still refers
index b123acbff1461a48f42bb26437e6d8fd3095d407..1961623ab55144b63a9b53a03bf4c87ce2af5d57 100644 (file)
@@ -95,6 +95,7 @@ if typing.TYPE_CHECKING:
     from .base import Executable
     from .cache_key import CacheKey
     from .ddl import ExecutableDDLElement
+    from .dml import Delete
     from .dml import Insert
     from .dml import Update
     from .dml import UpdateBase
@@ -6180,7 +6181,9 @@ class SQLCompiler(Compiled):
             "criteria within UPDATE"
         )
 
-    def update_post_criteria_clause(self, update_stmt, **kw):
+    def update_post_criteria_clause(
+        self, update_stmt: Update, **kw: Any
+    ) -> Optional[str]:
         """provide a hook to override generation after the WHERE criteria
         in an UPDATE statement
 
@@ -6195,7 +6198,9 @@ class SQLCompiler(Compiled):
         else:
             return None
 
-    def delete_post_criteria_clause(self, delete_stmt, **kw):
+    def delete_post_criteria_clause(
+        self, delete_stmt: Delete, **kw: Any
+    ) -> Optional[str]:
         """provide a hook to override generation after the WHERE criteria
         in a DELETE statement
 
@@ -6881,7 +6886,7 @@ class DDLCompiler(Compiled):
         else:
             schema_name = None
 
-        index_name = self.preparer.format_index(index)
+        index_name: str = self.preparer.format_index(index)
 
         if schema_name:
             index_name = schema_name + "." + index_name
index 8748c7c7be818bd157fac7b0176dbba3358add30..5487a170eae33888de23009679b2974eb05f208c 100644 (file)
@@ -432,6 +432,8 @@ class _CreateDropBase(ExecutableDDLElement, Generic[_SI]):
 
     """
 
+    element: _SI
+
     def __init__(self, element: _SI) -> None:
         self.element = self.target = element
         self._ddl_if = getattr(element, "_ddl_if", None)
index 42dfe6110647f1636bbf3649a9ddc6429c69d422..1907845fc20ca733d60a1062486cb94bcc2dc05f 100644 (file)
@@ -82,6 +82,7 @@ from ..util.typing import Self
 from ..util.typing import TupleAny
 from ..util.typing import Unpack
 
+
 if typing.TYPE_CHECKING:
     from ._typing import _ByArgument
     from ._typing import _ColumnExpressionArgument
@@ -119,6 +120,7 @@ if typing.TYPE_CHECKING:
     from ..engine.interfaces import SchemaTranslateMapType
     from ..engine.result import Result
 
+
 _NUMERIC = Union[float, Decimal]
 _NUMBER = Union[float, int, Decimal]
 
@@ -2127,8 +2129,8 @@ class BindParameter(roles.InElementRole, KeyedColumnElement[_T]):
         else:
             return self
 
-    def _with_binary_element_type(self, type_):
-        c: Self = ClauseElement._clone(self)  # type: ignore[assignment]
+    def _with_binary_element_type(self, type_: TypeEngine[Any]) -> Self:
+        c: Self = ClauseElement._clone(self)
         c.type = type_
         return c
 
index 050f94fd8087685aed5f7b0329bd3d2619ea6b88..375cb26f13f6d05a7aef1948c90be3625c6a7d1c 100644 (file)
@@ -787,7 +787,7 @@ class FunctionAsBinary(BinaryExpression[Any]):
         self.type = sqltypes.BOOLEANTYPE
         self.negate = None
         self._is_implicitly_boolean = True
-        self.modifiers = {}
+        self.modifiers = util.immutabledict({})
 
     @property
     def left_expr(self) -> ColumnElement[Any]:
index 5692ddba3c7568cc3818dc41ea5e700d03e55daa..becd500d5d4971081e7ab8b4941133797abbb3e3 100644 (file)
@@ -12,7 +12,6 @@
 from __future__ import annotations
 
 from enum import Enum
-from types import ModuleType
 import typing
 from typing import Any
 from typing import Callable
@@ -58,6 +57,7 @@ if typing.TYPE_CHECKING:
     from .sqltypes import NUMERICTYPE as NUMERICTYPE  # noqa: F401
     from .sqltypes import STRINGTYPE as STRINGTYPE  # noqa: F401
     from .sqltypes import TABLEVALUE as TABLEVALUE  # noqa: F401
+    from ..engine.interfaces import DBAPIModule
     from ..engine.interfaces import Dialect
     from ..util.typing import GenericProtocol
 
@@ -612,7 +612,7 @@ class TypeEngine(Visitable, Generic[_T]):
 
         return x == y  # type: ignore[no-any-return]
 
-    def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]:
+    def get_dbapi_type(self, dbapi: DBAPIModule) -> Optional[Any]:
         """Return the corresponding type object from the underlying DB-API, if
         any.
 
@@ -2263,7 +2263,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
         instance.__dict__.update(self.__dict__)
         return instance
 
-    def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]:
+    def get_dbapi_type(self, dbapi: DBAPIModule) -> Optional[Any]:
         """Return the DBAPI type object represented by this
         :class:`.TypeDecorator`.
 
index 4365a9a7f08d49dc19dccc7db0ace6b7abe0e3fe..a5bafbe65d57da5f0156efd9d145e8d572c05f97 100644 (file)
@@ -46,7 +46,7 @@ Discussions = "https://github.com/sqlalchemy/sqlalchemy/discussions"
 asyncio = ["greenlet>=1"]
 mypy = [
     "mypy >= 1.7",
-    "types-greenlet >= 2"
+    "types-greenlet >= 2",
 ]
 mssql = ["pyodbc"]
 mssql-pymssql = ["pymssql"]
@@ -67,6 +67,7 @@ postgresql-psycopg2cffi = ["psycopg2cffi"]
 postgresql-psycopg = ["psycopg>=3.0.7,!=3.1.15"]
 postgresql-psycopgbinary = ["psycopg[binary]>=3.0.7,!=3.1.15"]
 pymysql = ["pymysql"]
+cymysql = ["cymysql"]
 aiomysql = [
     "greenlet>=1",  # same as ".[asyncio]" if this syntax were supported
     "aiomysql",