from typing import Optional
from typing import Protocol
from typing import Sequence
+from typing import TYPE_CHECKING
from ..engine import AdaptedConnection
-from ..engine.interfaces import _DBAPICursorDescription
-from ..engine.interfaces import _DBAPIMultiExecuteParams
-from ..engine.interfaces import _DBAPISingleExecuteParams
from ..util.concurrency import await_
-from ..util.typing import Self
+
+if TYPE_CHECKING:
+ from ..engine.interfaces import _DBAPICursorDescription
+ from ..engine.interfaces import _DBAPIMultiExecuteParams
+ from ..engine.interfaces import _DBAPISingleExecuteParams
+ from ..engine.interfaces import DBAPIModule
+ from ..util.typing import Self
class AsyncIODBAPIConnection(Protocol):
"""
- async def close(self) -> None: ...
+ # note that async DBAPIs dont agree if close() should be awaitable,
+ # so it is omitted here and picked up by the __getattr__ hook below
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
+ def __getattr__(self, key: str) -> Any: ...
+
+ def __setattr__(self, key: str, value: Any) -> None: ...
+
class AsyncIODBAPICursor(Protocol):
"""protocol representing an async adapted version
def __aiter__(self) -> AsyncIterator[Any]: ...
+class AsyncAdapt_dbapi_module:
+ if TYPE_CHECKING:
+ Error = DBAPIModule.Error
+ OperationalError = DBAPIModule.OperationalError
+ InterfaceError = DBAPIModule.InterfaceError
+ IntegrityError = DBAPIModule.IntegrityError
+
+ def __getattr__(self, key: str) -> Any: ...
+
+
class AsyncAdapt_dbapi_cursor:
server_side = False
__slots__ = (
from __future__ import annotations
import re
-from types import ModuleType
import typing
from typing import Any
from typing import Dict
from ..sql.type_api import TypeEngine
if typing.TYPE_CHECKING:
+ from ..engine.interfaces import DBAPIModule
from ..engine.interfaces import IsolationLevel
# hold the desired driver name
pyodbc_driver_name: Optional[str] = None
- dbapi: ModuleType
-
def __init__(self, use_setinputsizes: bool = False, **kw: Any):
super().__init__(**kw)
if use_setinputsizes:
self.bind_typing = interfaces.BindTyping.SETINPUTSIZES
@classmethod
- def import_dbapi(cls) -> ModuleType:
+ def import_dbapi(cls) -> DBAPIModule:
return __import__("pyodbc")
def create_connect_args(self, url: URL) -> ConnectArgsType:
],
cursor: Optional[interfaces.DBAPICursor],
) -> bool:
- if isinstance(e, self.dbapi.ProgrammingError):
+ if isinstance(e, self.loaded_dbapi.ProgrammingError):
return "The cursor's connection has been closed." in str(
e
) or "Attempt to use a closed connection." in str(e)
from __future__ import annotations
+from typing import Any
from typing import Callable
from typing import Optional
from typing import Type
# hardcoded. if mysql / mariadb etc were third party dialects
# they would just publish all the entrypoints, which would actually
# look much nicer.
- module = __import__(
+ module: Any = __import__(
"sqlalchemy.dialects.mysql.mariadb"
).dialects.mysql.mariadb
return module.loader(driver) # type: ignore
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
r"""
.. dialect:: mysql+aiomysql
)
""" # noqa
+from __future__ import annotations
+
+from types import ModuleType
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
from .pymysql import MySQLDialect_pymysql
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_module
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
from ...util.concurrency import await_
+if TYPE_CHECKING:
+
+ from ...connectors.asyncio import AsyncIODBAPIConnection
+ from ...connectors.asyncio import AsyncIODBAPICursor
+ from ...engine.interfaces import ConnectArgsType
+ from ...engine.interfaces import DBAPIConnection
+ from ...engine.interfaces import DBAPICursor
+ from ...engine.interfaces import DBAPIModule
+ from ...engine.interfaces import PoolProxiedConnection
+ from ...engine.url import URL
+
class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor):
__slots__ = ()
- def _make_new_cursor(self, connection):
+ def _make_new_cursor(
+ self, connection: AsyncIODBAPIConnection
+ ) -> AsyncIODBAPICursor:
return connection.cursor(self._adapt_connection.dbapi.Cursor)
):
__slots__ = ()
- def _make_new_cursor(self, connection):
+ def _make_new_cursor(
+ self, connection: AsyncIODBAPIConnection
+ ) -> AsyncIODBAPICursor:
return connection.cursor(
self._adapt_connection.dbapi.aiomysql.cursors.SSCursor
)
_cursor_cls = AsyncAdapt_aiomysql_cursor
_ss_cursor_cls = AsyncAdapt_aiomysql_ss_cursor
- def ping(self, reconnect):
+ def ping(self, reconnect: bool) -> None:
assert not reconnect
- return await_(self._connection.ping(reconnect))
+ await_(self._connection.ping(reconnect))
- def character_set_name(self):
- return self._connection.character_set_name()
+ def character_set_name(self) -> Optional[str]:
+ return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501
- def autocommit(self, value):
+ def autocommit(self, value: Any) -> None:
await_(self._connection.autocommit(value))
- def terminate(self):
+ def terminate(self) -> None:
# it's not awaitable.
self._connection.close()
await_(self._connection.ensure_closed())
-class AsyncAdapt_aiomysql_dbapi:
- def __init__(self, aiomysql, pymysql):
+class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
+ def __init__(self, aiomysql: ModuleType, pymysql: ModuleType):
self.aiomysql = aiomysql
self.pymysql = pymysql
self.paramstyle = "format"
self._init_dbapi_attributes()
self.Cursor, self.SSCursor = self._init_cursors_subclasses()
- def _init_dbapi_attributes(self):
+ def _init_dbapi_attributes(self) -> None:
for name in (
"Warning",
"Error",
):
setattr(self, name, getattr(self.pymysql, name))
- def connect(self, *arg, **kw):
+ def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiomysql_connection:
creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
return AsyncAdapt_aiomysql_connection(
await_(creator_fn(*arg, **kw)),
)
- def _init_cursors_subclasses(self):
+ def _init_cursors_subclasses(
+ self,
+ ) -> tuple[AsyncIODBAPICursor, AsyncIODBAPICursor]:
# suppress unconditional warning emitted by aiomysql
- class Cursor(self.aiomysql.Cursor):
- async def _show_warnings(self, conn):
+ class Cursor(self.aiomysql.Cursor): # type: ignore[misc, name-defined]
+ async def _show_warnings(
+ self, conn: AsyncIODBAPIConnection
+ ) -> None:
pass
- class SSCursor(self.aiomysql.SSCursor):
- async def _show_warnings(self, conn):
+ class SSCursor(self.aiomysql.SSCursor): # type: ignore[misc, name-defined] # noqa: E501
+ async def _show_warnings(
+ self, conn: AsyncIODBAPIConnection
+ ) -> None:
pass
- return Cursor, SSCursor
+ return Cursor, SSCursor # type: ignore[return-value]
class MySQLDialect_aiomysql(MySQLDialect_pymysql):
driver = "aiomysql"
supports_statement_cache = True
- supports_server_side_cursors = True
+ supports_server_side_cursors = True # type: ignore[assignment]
_sscursor = AsyncAdapt_aiomysql_ss_cursor
is_async = True
has_terminate = True
@classmethod
- def import_dbapi(cls):
+ def import_dbapi(cls) -> AsyncAdapt_aiomysql_dbapi:
return AsyncAdapt_aiomysql_dbapi(
__import__("aiomysql"), __import__("pymysql")
)
- def do_terminate(self, dbapi_connection) -> None:
+ def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
dbapi_connection.terminate()
- def create_connect_args(self, url):
+ def create_connect_args(
+ self, url: URL, _translate_args: Optional[dict[str, Any]] = None
+ ) -> ConnectArgsType:
return super().create_connect_args(
url, _translate_args=dict(username="user", database="db")
)
- def is_disconnect(self, e, connection, cursor):
+ def is_disconnect(
+ self,
+ e: DBAPIModule.Error,
+ connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+ cursor: Optional[DBAPICursor],
+ ) -> bool:
if super().is_disconnect(e, connection, cursor):
return True
else:
str_e = str(e).lower()
return "not connected" in str_e
- def _found_rows_client_flag(self):
- from pymysql.constants import CLIENT
+ def _found_rows_client_flag(self) -> int:
+ from pymysql.constants import CLIENT # type: ignore
- return CLIENT.FOUND_ROWS
+ return CLIENT.FOUND_ROWS # type: ignore[no-any-return]
- def get_driver_connection(self, connection):
- return connection._connection
+ def get_driver_connection(
+ self, connection: DBAPIConnection
+ ) -> AsyncIODBAPIConnection:
+ return connection._connection # type: ignore[no-any-return]
dialect = MySQLDialect_aiomysql
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
r"""
.. dialect:: mysql+asyncmy
""" # noqa
from __future__ import annotations
+from types import ModuleType
+from typing import Any
+from typing import NoReturn
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
from .pymysql import MySQLDialect_pymysql
from ... import util
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_module
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
from ...util.concurrency import await_
+if TYPE_CHECKING:
+
+ from ...connectors.asyncio import AsyncIODBAPIConnection
+ from ...connectors.asyncio import AsyncIODBAPICursor
+ from ...engine.interfaces import ConnectArgsType
+ from ...engine.interfaces import DBAPIConnection
+ from ...engine.interfaces import DBAPICursor
+ from ...engine.interfaces import DBAPIModule
+ from ...engine.interfaces import PoolProxiedConnection
+ from ...engine.url import URL
+
class AsyncAdapt_asyncmy_cursor(AsyncAdapt_dbapi_cursor):
__slots__ = ()
):
__slots__ = ()
- def _make_new_cursor(self, connection):
+ def _make_new_cursor(
+ self, connection: AsyncIODBAPIConnection
+ ) -> AsyncIODBAPICursor:
return connection.cursor(
self._adapt_connection.dbapi.asyncmy.cursors.SSCursor
)
_cursor_cls = AsyncAdapt_asyncmy_cursor
_ss_cursor_cls = AsyncAdapt_asyncmy_ss_cursor
- def _handle_exception(self, error):
+ def _handle_exception(self, error: Exception) -> NoReturn:
if isinstance(error, AttributeError):
raise self.dbapi.InternalError(
"network operation failed due to asyncmy attribute error"
raise error
- def ping(self, reconnect):
+ def ping(self, reconnect: bool) -> None:
assert not reconnect
return await_(self._do_ping())
- async def _do_ping(self):
+ async def _do_ping(self) -> None:
try:
async with self._execute_mutex:
- return await self._connection.ping(False)
+ await self._connection.ping(False)
except Exception as error:
self._handle_exception(error)
- def character_set_name(self):
- return self._connection.character_set_name()
+ def character_set_name(self) -> Optional[str]:
+ return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501
- def autocommit(self, value):
+ def autocommit(self, value: Any) -> None:
await_(self._connection.autocommit(value))
- def terminate(self):
+ def terminate(self) -> None:
# it's not awaitable.
self._connection.close()
await_(self._connection.ensure_closed())
-def _Binary(x):
- """Return x as a binary type."""
- return bytes(x)
-
-
-class AsyncAdapt_asyncmy_dbapi:
- def __init__(self, asyncmy):
+class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
+ def __init__(self, asyncmy: ModuleType):
self.asyncmy = asyncmy
self.paramstyle = "format"
self._init_dbapi_attributes()
- def _init_dbapi_attributes(self):
+ def _init_dbapi_attributes(self) -> None:
for name in (
"Warning",
"Error",
BINARY = util.symbol("BINARY")
DATETIME = util.symbol("DATETIME")
TIMESTAMP = util.symbol("TIMESTAMP")
- Binary = staticmethod(_Binary)
+ Binary = staticmethod(bytes)
- def connect(self, *arg, **kw):
+ def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_asyncmy_connection:
creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
return AsyncAdapt_asyncmy_connection(
driver = "asyncmy"
supports_statement_cache = True
- supports_server_side_cursors = True
+ supports_server_side_cursors = True # type: ignore[assignment]
_sscursor = AsyncAdapt_asyncmy_ss_cursor
is_async = True
has_terminate = True
@classmethod
- def import_dbapi(cls):
+ def import_dbapi(cls) -> DBAPIModule:
return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
- def do_terminate(self, dbapi_connection) -> None:
+ def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
dbapi_connection.terminate()
- def create_connect_args(self, url):
+ def create_connect_args(self, url: URL) -> ConnectArgsType: # type: ignore[override] # noqa: E501
return super().create_connect_args(
url, _translate_args=dict(username="user", database="db")
)
- def is_disconnect(self, e, connection, cursor):
+ def is_disconnect(
+ self,
+ e: DBAPIModule.Error,
+ connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+ cursor: Optional[DBAPICursor],
+ ) -> bool:
if super().is_disconnect(e, connection, cursor):
return True
else:
"not connected" in str_e or "network operation failed" in str_e
)
- def _found_rows_client_flag(self):
- from asyncmy.constants import CLIENT
+ def _found_rows_client_flag(self) -> int:
+ from asyncmy.constants import CLIENT # type: ignore
- return CLIENT.FOUND_ROWS
+ return CLIENT.FOUND_ROWS # type: ignore[no-any-return]
- def get_driver_connection(self, connection):
- return connection._connection
+ def get_driver_connection(
+ self, connection: DBAPIConnection
+ ) -> AsyncIODBAPIConnection:
+ return connection._connection # type: ignore[no-any-return]
dialect = MySQLDialect_asyncmy
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
r"""
""" # noqa
from __future__ import annotations
-from array import array as _array
from collections import defaultdict
from itertools import compress
import re
+from typing import Any
+from typing import Callable
from typing import cast
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import TYPE_CHECKING
+from typing import Union
from . import reflection as _reflection
from .enumerated import ENUM
from .types import YEAR
from ... import exc
from ... import literal_column
-from ... import log
from ... import schema as sa_schema
from ... import sql
from ... import util
from ...types import BLOB
from ...types import BOOLEAN
from ...types import DATE
+from ...types import LargeBinary
from ...types import UUID
from ...types import VARBINARY
from ...util import topological
+if TYPE_CHECKING:
+
+ from ...dialects.mysql import expression
+ from ...dialects.mysql.dml import DMLLimitClause
+ from ...dialects.mysql.dml import OnDuplicateClause
+ from ...engine.base import Connection
+ from ...engine.cursor import CursorResult
+ from ...engine.interfaces import DBAPIConnection
+ from ...engine.interfaces import DBAPICursor
+ from ...engine.interfaces import DBAPIModule
+ from ...engine.interfaces import IsolationLevel
+ from ...engine.interfaces import PoolProxiedConnection
+ from ...engine.interfaces import ReflectedCheckConstraint
+ from ...engine.interfaces import ReflectedColumn
+ from ...engine.interfaces import ReflectedForeignKeyConstraint
+ from ...engine.interfaces import ReflectedIndex
+ from ...engine.interfaces import ReflectedPrimaryKeyConstraint
+ from ...engine.interfaces import ReflectedTableComment
+ from ...engine.interfaces import ReflectedUniqueConstraint
+ from ...engine.result import _Ts
+ from ...engine.row import Row
+ from ...engine.url import URL
+ from ...schema import Table
+ from ...sql import ddl
+ from ...sql import selectable
+ from ...sql.dml import _DMLTableElement
+ from ...sql.dml import Delete
+ from ...sql.dml import Update
+ from ...sql.dml import ValuesBase
+ from ...sql.functions import aggregate_strings
+ from ...sql.functions import random
+ from ...sql.functions import rollup
+ from ...sql.functions import sysdate
+ from ...sql.schema import Sequence as Sequence_SchemaItem
+ from ...sql.type_api import TypeEngine
+ from ...sql.visitors import ExternallyTraversible
+ from ...util.typing import TupleAny
+ from ...util.typing import Unpack
+
SET_RE = re.compile(
r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE
class MySQLExecutionContext(default.DefaultExecutionContext):
- def post_exec(self):
+ def post_exec(self) -> None:
if (
self.isdelete
and cast(SQLCompiler, self.compiled).effective_returning
_cursor.FullyBufferedCursorFetchStrategy(
self.cursor,
[
- (entry.keyname, None)
+ (entry.keyname, None) # type: ignore[misc]
for entry in cast(
SQLCompiler, self.compiled
)._result_columns
)
)
- def create_server_side_cursor(self):
+ def create_server_side_cursor(self) -> DBAPICursor:
if self.dialect.supports_server_side_cursors:
- return self._dbapi_connection.cursor(self.dialect._sscursor)
+ return self._dbapi_connection.cursor(
+ self.dialect._sscursor # type: ignore[attr-defined]
+ )
else:
raise NotImplementedError()
- def fire_sequence(self, seq, type_):
- return self._execute_scalar(
+ def fire_sequence(
+ self, seq: Sequence_SchemaItem, type_: sqltypes.Integer
+ ) -> int:
+ return self._execute_scalar( # type: ignore[no-any-return]
(
"select nextval(%s)"
% self.identifier_preparer.format_sequence(seq)
class MySQLCompiler(compiler.SQLCompiler):
+ dialect: MySQLDialect
render_table_with_column_in_update_from = True
"""Overridden from base SQLCompiler value"""
extract_map = compiler.SQLCompiler.extract_map.copy()
extract_map.update({"milliseconds": "millisecond"})
- def default_from(self):
+ def default_from(self) -> str:
"""Called when a ``SELECT`` statement has no froms,
and no ``FROM`` clause is to be appended.
"""
if self.stack:
stmt = self.stack[-1]["selectable"]
- if stmt._where_criteria:
+ if stmt._where_criteria: # type: ignore[attr-defined]
return " FROM DUAL"
return ""
- def visit_random_func(self, fn, **kw):
+ def visit_random_func(self, fn: random, **kw: Any) -> str:
return "rand%s" % self.function_argspec(fn)
- def visit_rollup_func(self, fn, **kw):
+ def visit_rollup_func(self, fn: rollup[Any], **kw: Any) -> str:
clause = ", ".join(
elem._compiler_dispatch(self, **kw) for elem in fn.clauses
)
return f"{clause} WITH ROLLUP"
- def visit_aggregate_strings_func(self, fn, **kw):
+ def visit_aggregate_strings_func(
+ self, fn: aggregate_strings, **kw: Any
+ ) -> str:
expr, delimeter = (
elem._compiler_dispatch(self, **kw) for elem in fn.clauses
)
return f"group_concat({expr} SEPARATOR {delimeter})"
- def visit_sequence(self, seq, **kw):
- return "nextval(%s)" % self.preparer.format_sequence(seq)
+ def visit_sequence(self, sequence: sa_schema.Sequence, **kw: Any) -> str:
+ return "nextval(%s)" % self.preparer.format_sequence(sequence)
- def visit_sysdate_func(self, fn, **kw):
+ def visit_sysdate_func(self, fn: sysdate, **kw: Any) -> str:
return "SYSDATE()"
- def _render_json_extract_from_binary(self, binary, operator, **kw):
+ def _render_json_extract_from_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
# note we are intentionally calling upon the process() calls in the
# order in which they appear in the SQL String as this is used
# by positional parameter rendering
)
)
elif binary.type._type_affinity in (sqltypes.Numeric, sqltypes.Float):
+ binary_type = cast(sqltypes.Numeric[Any], binary.type)
if (
- binary.type.scale is not None
- and binary.type.precision is not None
+ binary_type.scale is not None
+ and binary_type.precision is not None
):
# using DECIMAL here because MySQL does not recognize NUMERIC
type_expression = (
% (
self.process(binary.left, **kw),
self.process(binary.right, **kw),
- binary.type.precision,
- binary.type.scale,
+ binary_type.precision,
+ binary_type.scale,
)
)
else:
return case_expression + " " + type_expression + " END"
- def visit_json_getitem_op_binary(self, binary, operator, **kw):
+ def visit_json_getitem_op_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
return self._render_json_extract_from_binary(binary, operator, **kw)
- def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+ def visit_json_path_getitem_op_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
return self._render_json_extract_from_binary(binary, operator, **kw)
- def visit_on_duplicate_key_update(self, on_duplicate, **kw):
- statement = self.current_executable
+ def visit_on_duplicate_key_update(
+ self, on_duplicate: OnDuplicateClause, **kw: Any
+ ) -> str:
+ statement: ValuesBase = self.current_executable
+ cols: list[elements.KeyedColumnElement[Any]]
if on_duplicate._parameter_ordering:
parameter_ordering = [
coercions.expect(roles.DMLColumnRole, key)
if key in statement.table.c
] + [c for c in statement.table.c if c.key not in ordered_keys]
else:
- cols = statement.table.c
+ cols = list(statement.table.c)
clauses = []
)
if requires_mysql8_alias:
- if statement.table.name.lower() == "new":
+ if statement.table.name.lower() == "new": # type: ignore[union-attr] # noqa: E501
_on_dup_alias_name = "new_1"
else:
_on_dup_alias_name = "new"
for column in (col for col in cols if col.key in on_duplicate_update):
val = on_duplicate_update[column.key]
- def replace(obj):
+ def replace(
+ element: ExternallyTraversible, **kw: Any
+ ) -> Optional[ExternallyTraversible]:
if (
- isinstance(obj, elements.BindParameter)
- and obj.type._isnull
+ isinstance(element, elements.BindParameter)
+ and element.type._isnull
):
- return obj._with_binary_element_type(column.type)
+ return element._with_binary_element_type(column.type)
elif (
- isinstance(obj, elements.ColumnClause)
- and obj.table is on_duplicate.inserted_alias
+ isinstance(element, elements.ColumnClause)
+ and element.table is on_duplicate.inserted_alias
):
if requires_mysql8_alias:
column_literal_clause = (
f"{_on_dup_alias_name}."
- f"{self.preparer.quote(obj.name)}"
+ f"{self.preparer.quote(element.name)}"
)
else:
column_literal_clause = (
- f"VALUES({self.preparer.quote(obj.name)})"
+ f"VALUES({self.preparer.quote(element.name)})"
)
return literal_column(column_literal_clause)
else:
"Additional column names not matching "
"any column keys in table '%s': %s"
% (
- self.statement.table.name,
+ self.statement.table.name, # type: ignore[union-attr]
(", ".join("'%s'" % c for c in non_matching)),
)
)
return f"ON DUPLICATE KEY UPDATE {', '.join(clauses)}"
def visit_concat_op_expression_clauselist(
- self, clauselist, operator, **kw
- ):
+ self, clauselist: elements.ClauseList, operator: Any, **kw: Any
+ ) -> str:
return "concat(%s)" % (
", ".join(self.process(elem, **kw) for elem in clauselist.clauses)
)
- def visit_concat_op_binary(self, binary, operator, **kw):
+ def visit_concat_op_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
return "concat(%s, %s)" % (
self.process(binary.left, **kw),
self.process(binary.right, **kw),
"WITH QUERY EXPANSION",
)
- def visit_mysql_match(self, element, **kw):
+ def visit_mysql_match(self, element: expression.match, **kw: Any) -> str:
return self.visit_match_op_binary(element, element.operator, **kw)
- def visit_match_op_binary(self, binary, operator, **kw):
+ def visit_match_op_binary(
+ self, binary: expression.match, operator: Any, **kw: Any
+ ) -> str:
"""
Note that `mysql_boolean_mode` is enabled by default because of
backward compatibility
"with_query_expansion=%s" % query_expansion,
)
- flags = ", ".join(flags)
+ flags_str = ", ".join(flags)
- raise exc.CompileError("Invalid MySQL match flags: %s" % flags)
+ raise exc.CompileError("Invalid MySQL match flags: %s" % flags_str)
- match_clause = binary.left
- match_clause = self.process(match_clause, **kw)
+ match_clause = self.process(binary.left, **kw)
against_clause = self.process(binary.right, **kw)
if any(flag_combination):
flag_combination,
)
- against_clause = [against_clause]
- against_clause.extend(flag_expressions)
-
- against_clause = " ".join(against_clause)
+ against_clause = " ".join([against_clause, *flag_expressions])
return "MATCH (%s) AGAINST (%s)" % (match_clause, against_clause)
- def get_from_hint_text(self, table, text):
+ def get_from_hint_text(
+ self, table: selectable.FromClause, text: Optional[str]
+ ) -> Optional[str]:
return text
- def visit_typeclause(self, typeclause, type_=None, **kw):
+ def visit_typeclause(
+ self,
+ typeclause: elements.TypeClause,
+ type_: Optional[TypeEngine[Any]] = None,
+ **kw: Any,
+ ) -> Optional[str]:
if type_ is None:
type_ = typeclause.type.dialect_impl(self.dialect)
if isinstance(type_, sqltypes.TypeDecorator):
- return self.visit_typeclause(typeclause, type_.impl, **kw)
+ return self.visit_typeclause(typeclause, type_.impl, **kw) # type: ignore[arg-type] # noqa: E501
elif isinstance(type_, sqltypes.Integer):
if getattr(type_, "unsigned", False):
return "UNSIGNED INTEGER"
else:
return None
- def visit_cast(self, cast, **kw):
+ def visit_cast(self, cast: elements.Cast[Any], **kw: Any) -> str:
type_ = self.process(cast.typeclause)
if type_ is None:
util.warn(
return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_)
- def render_literal_value(self, value, type_):
+ def render_literal_value(
+ self, value: Optional[str], type_: TypeEngine[Any]
+ ) -> str:
value = super().render_literal_value(value, type_)
if self.dialect._backslash_escapes:
value = value.replace("\\", "\\\\")
# override native_boolean=False behavior here, as
# MySQL still supports native boolean
- def visit_true(self, element, **kw):
+ def visit_true(self, expr: elements.True_, **kw: Any) -> str:
return "true"
- def visit_false(self, element, **kw):
+ def visit_false(self, expr: elements.False_, **kw: Any) -> str:
return "false"
- def get_select_precolumns(self, select, **kw):
+ def get_select_precolumns(
+ self, select: selectable.Select[Any], **kw: Any
+ ) -> str:
"""Add special MySQL keywords in place of DISTINCT.
.. deprecated:: 1.4 This usage is deprecated.
return super().get_select_precolumns(select, **kw)
- def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
+ def visit_join(
+ self,
+ join: selectable.Join,
+ asfrom: bool = False,
+ from_linter: Optional[compiler.FromLinter] = None,
+ **kwargs: Any,
+ ) -> str:
if from_linter:
from_linter.edges.add((join.left, join.right))
join.right, asfrom=True, from_linter=from_linter, **kwargs
),
" ON ",
- self.process(join.onclause, from_linter=from_linter, **kwargs),
+ self.process(join.onclause, from_linter=from_linter, **kwargs), # type: ignore[arg-type] # noqa: E501
)
)
- def for_update_clause(self, select, **kw):
+ def for_update_clause(
+ self, select: selectable.GenerativeSelect, **kw: Any
+ ) -> str:
+ assert select._for_update_arg is not None
if select._for_update_arg.read:
tmp = " LOCK IN SHARE MODE"
else:
tmp = " FOR UPDATE"
if select._for_update_arg.of and self.dialect.supports_for_update_of:
- tables = util.OrderedSet()
+ tables: util.OrderedSet[elements.ClauseElement] = util.OrderedSet()
for c in select._for_update_arg.of:
tables.update(sql_util.surface_selectables_only(c))
return tmp
- def limit_clause(self, select, **kw):
+ def limit_clause(
+ self, select: selectable.GenerativeSelect, **kw: Any
+ ) -> str:
# MySQL supports:
# LIMIT <limit>
# LIMIT <offset>, <limit>
self.process(limit_clause, **kw),
)
else:
+ assert limit_clause is not None
# No offset provided, so just use the limit
return " \n LIMIT %s" % (self.process(limit_clause, **kw),)
- def update_post_criteria_clause(self, update_stmt, **kw):
+ def update_post_criteria_clause(
+ self, update_stmt: Update, **kw: Any
+ ) -> Optional[str]:
limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None)
supertext = super().update_post_criteria_clause(update_stmt, **kw)
else:
return supertext
- def delete_post_criteria_clause(self, delete_stmt, **kw):
+ def delete_post_criteria_clause(
+ self, delete_stmt: Delete, **kw: Any
+ ) -> Optional[str]:
limit = delete_stmt.kwargs.get("%s_limit" % self.dialect.name, None)
supertext = super().delete_post_criteria_clause(delete_stmt, **kw)
else:
return supertext
- def visit_mysql_dml_limit_clause(self, element, **kw):
+ def visit_mysql_dml_limit_clause(
+ self, element: DMLLimitClause, **kw: Any
+ ) -> str:
kw["literal_execute"] = True
return f"LIMIT {self.process(element._limit_clause, **kw)}"
- def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+ def update_tables_clause(
+ self,
+ update_stmt: Update,
+ from_table: _DMLTableElement,
+ extra_froms: list[selectable.FromClause],
+ **kw: Any,
+ ) -> str:
kw["asfrom"] = True
return ", ".join(
t._compiler_dispatch(self, **kw)
)
def update_from_clause(
- self, update_stmt, from_table, extra_froms, from_hints, **kw
- ):
+ self,
+ update_stmt: Update,
+ from_table: _DMLTableElement,
+ extra_froms: list[selectable.FromClause],
+ from_hints: Any,
+ **kw: Any,
+ ) -> None:
return None
- def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
+ def delete_table_clause(
+ self,
+ delete_stmt: Delete,
+ from_table: _DMLTableElement,
+ extra_froms: list[selectable.FromClause],
+ **kw: Any,
+ ) -> str:
"""If we have extra froms make sure we render any alias as hint."""
ashint = False
if extra_froms:
)
def delete_extra_from_clause(
- self, delete_stmt, from_table, extra_froms, from_hints, **kw
- ):
+ self,
+ delete_stmt: Delete,
+ from_table: _DMLTableElement,
+ extra_froms: list[selectable.FromClause],
+ from_hints: Any,
+ **kw: Any,
+ ) -> str:
"""Render the DELETE .. USING clause specific to MySQL."""
kw["asfrom"] = True
return "USING " + ", ".join(
for t in [from_table] + extra_froms
)
- def visit_empty_set_expr(self, element_types, **kw):
+ def visit_empty_set_expr(
+ self, element_types: list[TypeEngine[Any]], **kw: Any
+ ) -> str:
return (
"SELECT %(outer)s FROM (SELECT %(inner)s) "
"as _empty_set WHERE 1!=1"
}
)
- def visit_is_distinct_from_binary(self, binary, operator, **kw):
+ def visit_is_distinct_from_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
return "NOT (%s <=> %s)" % (
self.process(binary.left),
self.process(binary.right),
)
- def visit_is_not_distinct_from_binary(self, binary, operator, **kw):
+ def visit_is_not_distinct_from_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
return "%s <=> %s" % (
self.process(binary.left),
self.process(binary.right),
)
- def _mariadb_regexp_flags(self, flags, pattern, **kw):
+ def _mariadb_regexp_flags(
+ self, flags: str, pattern: elements.ColumnElement[Any], **kw: Any
+ ) -> str:
return "CONCAT('(?', %s, ')', %s)" % (
self.render_literal_value(flags, sqltypes.STRINGTYPE),
self.process(pattern, **kw),
)
- def _regexp_match(self, op_string, binary, operator, **kw):
+ def _regexp_match(
+ self,
+ op_string: str,
+ binary: elements.BinaryExpression[Any],
+ operator: Any,
+ **kw: Any,
+ ) -> str:
+ assert binary.modifiers is not None
flags = binary.modifiers["flags"]
if flags is None:
return self._generate_generic_binary(binary, op_string, **kw)
else:
return text
- def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ def visit_regexp_match_op_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
return self._regexp_match(" REGEXP ", binary, operator, **kw)
- def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ def visit_not_regexp_match_op_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
return self._regexp_match(" NOT REGEXP ", binary, operator, **kw)
- def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ def visit_regexp_replace_op_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
+ assert binary.modifiers is not None
flags = binary.modifiers["flags"]
if flags is None:
return "REGEXP_REPLACE(%s, %s)" % (
class MySQLDDLCompiler(compiler.DDLCompiler):
- def get_column_specification(self, column, **kw):
+ dialect: MySQLDialect
+
+ def get_column_specification(
+ self, column: sa_schema.Column[Any], **kw: Any
+ ) -> str:
"""Builds column DDL."""
if (
self.dialect.is_mariadb is True
colspec.append("DEFAULT " + default)
return " ".join(colspec)
- def post_create_table(self, table):
+ def post_create_table(self, table: sa_schema.Table) -> str:
"""Build table-level CREATE options like ENGINE and COLLATE."""
table_opts = []
return " ".join(table_opts)
- def visit_create_index(self, create, **kw):
+ def visit_create_index(self, create: ddl.CreateIndex, **kw: Any) -> str: # type: ignore[override] # noqa: E501
index = create.element
self._verify_index_table(index)
preparer = self.preparer
- table = preparer.format_table(index.table)
+ table = preparer.format_table(index.table) # type: ignore[arg-type]
columns = [
self.sql_compiler.process(
(
- elements.Grouping(expr)
+ elements.Grouping(expr) # type: ignore[arg-type]
if (
isinstance(expr, elements.BinaryExpression)
or (
# length value can be a (column_name --> integer value)
# mapping specifying the prefix length for each column of the
# index
- columns = ", ".join(
+ columns_str = ", ".join(
(
- "%s(%d)" % (expr, length[col.name])
- if col.name in length
+ "%s(%d)" % (expr, length[col.name]) # type: ignore[union-attr] # noqa: E501
+ if col.name in length # type: ignore[union-attr]
else (
"%s(%d)" % (expr, length[expr])
if expr in length
else:
# or can be an integer value specifying the same
# prefix length for all columns of the index
- columns = ", ".join(
+ columns_str = ", ".join(
"%s(%d)" % (col, length) for col in columns
)
else:
- columns = ", ".join(columns)
- text += "(%s)" % columns
+ columns_str = ", ".join(columns)
+ text += "(%s)" % columns_str
parser = index.dialect_options["mysql"]["with_parser"]
if parser is not None:
return text
- def visit_primary_key_constraint(self, constraint, **kw):
+ def visit_primary_key_constraint(
+ self, constraint: sa_schema.PrimaryKeyConstraint, **kw: Any
+ ) -> str:
text = super().visit_primary_key_constraint(constraint)
using = constraint.dialect_options["mysql"]["using"]
if using:
text += " USING %s" % (self.preparer.quote(using))
return text
- def visit_drop_index(self, drop, **kw):
+ def visit_drop_index(self, drop: ddl.DropIndex, **kw: Any) -> str:
index = drop.element
text = "\nDROP INDEX "
if drop.if_exists:
return text + "%s ON %s" % (
self._prepared_index_name(index, include_schema=False),
- self.preparer.format_table(index.table),
+ self.preparer.format_table(index.table), # type: ignore[arg-type]
)
- def visit_drop_constraint(self, drop, **kw):
+ def visit_drop_constraint(
+ self, drop: ddl.DropConstraint, **kw: Any
+ ) -> str:
constraint = drop.element
if isinstance(constraint, sa_schema.ForeignKeyConstraint):
qual = "FOREIGN KEY "
const,
)
- def define_constraint_match(self, constraint):
+ def define_constraint_match(
+ self, constraint: sa_schema.ForeignKeyConstraint
+ ) -> str:
if constraint.match is not None:
raise exc.CompileError(
"MySQL ignores the 'MATCH' keyword while at the same time "
)
return ""
- def visit_set_table_comment(self, create, **kw):
+ def visit_set_table_comment(
+ self, create: ddl.SetTableComment, **kw: Any
+ ) -> str:
return "ALTER TABLE %s COMMENT %s" % (
self.preparer.format_table(create.element),
self.sql_compiler.render_literal_value(
),
)
- def visit_drop_table_comment(self, create, **kw):
+ def visit_drop_table_comment(
+ self, drop: ddl.DropTableComment, **kw: Any
+ ) -> str:
return "ALTER TABLE %s COMMENT ''" % (
- self.preparer.format_table(create.element)
+ self.preparer.format_table(drop.element)
)
- def visit_set_column_comment(self, create, **kw):
+ def visit_set_column_comment(
+ self, create: ddl.SetColumnComment, **kw: Any
+ ) -> str:
return "ALTER TABLE %s CHANGE %s %s" % (
self.preparer.format_table(create.element.table),
self.preparer.format_column(create.element),
class MySQLTypeCompiler(compiler.GenericTypeCompiler):
- def _extend_numeric(self, type_, spec):
+ def _extend_numeric(self, type_: _NumericCommonType, spec: str) -> str:
"Extend a numeric-type declaration with MySQL specific extensions."
if not self._mysql_type(type_):
spec += " ZEROFILL"
return spec
- def _extend_string(self, type_, defaults, spec):
+ def _extend_string(
+ self, type_: _StringType, defaults: dict[str, Any], spec: str
+ ) -> str:
"""Extend a string-type declaration with standard SQL CHARACTER SET /
COLLATE annotations and MySQL specific extensions.
"""
- def attr(name):
+ def attr(name: str) -> Any:
return getattr(type_, name, defaults.get(name))
if attr("charset"):
elif attr("unicode"):
charset = "UNICODE"
else:
+
charset = None
if attr("collation"):
[c for c in (spec, charset, collation) if c is not None]
)
- def _mysql_type(self, type_):
+ def _mysql_type(self, type_: Any) -> bool:
return isinstance(type_, (_StringType, _NumericCommonType))
- def visit_NUMERIC(self, type_, **kw):
+ def visit_NUMERIC(self, type_: NUMERIC, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if type_.precision is None:
return self._extend_numeric(type_, "NUMERIC")
elif type_.scale is None:
% {"precision": type_.precision, "scale": type_.scale},
)
- def visit_DECIMAL(self, type_, **kw):
+ def visit_DECIMAL(self, type_: DECIMAL, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if type_.precision is None:
return self._extend_numeric(type_, "DECIMAL")
elif type_.scale is None:
% {"precision": type_.precision, "scale": type_.scale},
)
- def visit_DOUBLE(self, type_, **kw):
+ def visit_DOUBLE(self, type_: DOUBLE, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if type_.precision is not None and type_.scale is not None:
return self._extend_numeric(
type_,
else:
return self._extend_numeric(type_, "DOUBLE")
- def visit_REAL(self, type_, **kw):
+ def visit_REAL(self, type_: REAL, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if type_.precision is not None and type_.scale is not None:
return self._extend_numeric(
type_,
else:
return self._extend_numeric(type_, "REAL")
- def visit_FLOAT(self, type_, **kw):
+ def visit_FLOAT(self, type_: FLOAT, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if (
self._mysql_type(type_)
and type_.scale is not None
else:
return self._extend_numeric(type_, "FLOAT")
- def visit_INTEGER(self, type_, **kw):
+ def visit_INTEGER(self, type_: INTEGER, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_,
else:
return self._extend_numeric(type_, "INTEGER")
- def visit_BIGINT(self, type_, **kw):
+ def visit_BIGINT(self, type_: BIGINT, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_,
else:
return self._extend_numeric(type_, "BIGINT")
- def visit_MEDIUMINT(self, type_, **kw):
+ def visit_MEDIUMINT(self, type_: MEDIUMINT, **kw: Any) -> str:
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_,
else:
return self._extend_numeric(type_, "MEDIUMINT")
- def visit_TINYINT(self, type_, **kw):
+ def visit_TINYINT(self, type_: TINYINT, **kw: Any) -> str:
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_, "TINYINT(%s)" % type_.display_width
else:
return self._extend_numeric(type_, "TINYINT")
- def visit_SMALLINT(self, type_, **kw):
+ def visit_SMALLINT(self, type_: SMALLINT, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if self._mysql_type(type_) and type_.display_width is not None:
return self._extend_numeric(
type_,
else:
return self._extend_numeric(type_, "SMALLINT")
- def visit_BIT(self, type_, **kw):
+ def visit_BIT(self, type_: BIT, **kw: Any) -> str:
if type_.length is not None:
return "BIT(%s)" % type_.length
else:
return "BIT"
- def visit_DATETIME(self, type_, **kw):
+ def visit_DATETIME(self, type_: DATETIME, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if getattr(type_, "fsp", None):
- return "DATETIME(%d)" % type_.fsp
+ return "DATETIME(%d)" % type_.fsp # type: ignore[str-format]
else:
return "DATETIME"
- def visit_DATE(self, type_, **kw):
+ def visit_DATE(self, type_: DATE, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
return "DATE"
- def visit_TIME(self, type_, **kw):
+ def visit_TIME(self, type_: TIME, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if getattr(type_, "fsp", None):
- return "TIME(%d)" % type_.fsp
+ return "TIME(%d)" % type_.fsp # type: ignore[str-format]
else:
return "TIME"
- def visit_TIMESTAMP(self, type_, **kw):
+ def visit_TIMESTAMP(self, type_: TIMESTAMP, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if getattr(type_, "fsp", None):
- return "TIMESTAMP(%d)" % type_.fsp
+ return "TIMESTAMP(%d)" % type_.fsp # type: ignore[str-format]
else:
return "TIMESTAMP"
- def visit_YEAR(self, type_, **kw):
+ def visit_YEAR(self, type_: YEAR, **kw: Any) -> str:
if type_.display_width is None:
return "YEAR"
else:
return "YEAR(%s)" % type_.display_width
- def visit_TEXT(self, type_, **kw):
+ def visit_TEXT(self, type_: TEXT, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if type_.length is not None:
return self._extend_string(type_, {}, "TEXT(%d)" % type_.length)
else:
return self._extend_string(type_, {}, "TEXT")
- def visit_TINYTEXT(self, type_, **kw):
+ def visit_TINYTEXT(self, type_: TINYTEXT, **kw: Any) -> str:
return self._extend_string(type_, {}, "TINYTEXT")
- def visit_MEDIUMTEXT(self, type_, **kw):
+ def visit_MEDIUMTEXT(self, type_: MEDIUMTEXT, **kw: Any) -> str:
return self._extend_string(type_, {}, "MEDIUMTEXT")
- def visit_LONGTEXT(self, type_, **kw):
+ def visit_LONGTEXT(self, type_: LONGTEXT, **kw: Any) -> str:
return self._extend_string(type_, {}, "LONGTEXT")
- def visit_VARCHAR(self, type_, **kw):
+ def visit_VARCHAR(self, type_: VARCHAR, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if type_.length is not None:
return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length)
else:
"VARCHAR requires a length on dialect %s" % self.dialect.name
)
- def visit_CHAR(self, type_, **kw):
+ def visit_CHAR(self, type_: CHAR, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if type_.length is not None:
return self._extend_string(
type_, {}, "CHAR(%(length)s)" % {"length": type_.length}
else:
return self._extend_string(type_, {}, "CHAR")
- def visit_NVARCHAR(self, type_, **kw):
+ def visit_NVARCHAR(self, type_: NVARCHAR, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
# We'll actually generate the equiv. "NATIONAL VARCHAR" instead
# of "NVARCHAR".
if type_.length is not None:
"NVARCHAR requires a length on dialect %s" % self.dialect.name
)
- def visit_NCHAR(self, type_, **kw):
+ def visit_NCHAR(self, type_: NCHAR, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
# We'll actually generate the equiv.
# "NATIONAL CHAR" instead of "NCHAR".
if type_.length is not None:
else:
return self._extend_string(type_, {"national": True}, "CHAR")
- def visit_UUID(self, type_, **kw):
+ def visit_UUID(self, type_: UUID[Any], **kw: Any) -> str: # type: ignore[override] # NOQA: E501
return "UUID"
- def visit_VARBINARY(self, type_, **kw):
- return "VARBINARY(%d)" % type_.length
+ def visit_VARBINARY(self, type_: VARBINARY, **kw: Any) -> str:
+ return "VARBINARY(%d)" % type_.length # type: ignore[str-format]
- def visit_JSON(self, type_, **kw):
+ def visit_JSON(self, type_: JSON, **kw: Any) -> str:
return "JSON"
- def visit_large_binary(self, type_, **kw):
+ def visit_large_binary(self, type_: LargeBinary, **kw: Any) -> str:
return self.visit_BLOB(type_)
- def visit_enum(self, type_, **kw):
+ def visit_enum(self, type_: ENUM, **kw: Any) -> str: # type: ignore[override] # NOQA: E501
if not type_.native_enum:
return super().visit_enum(type_)
else:
return self._visit_enumerated_values("ENUM", type_, type_.enums)
- def visit_BLOB(self, type_, **kw):
+ def visit_BLOB(self, type_: LargeBinary, **kw: Any) -> str:
if type_.length is not None:
return "BLOB(%d)" % type_.length
else:
return "BLOB"
- def visit_TINYBLOB(self, type_, **kw):
+ def visit_TINYBLOB(self, type_: TINYBLOB, **kw: Any) -> str:
return "TINYBLOB"
- def visit_MEDIUMBLOB(self, type_, **kw):
+ def visit_MEDIUMBLOB(self, type_: MEDIUMBLOB, **kw: Any) -> str:
return "MEDIUMBLOB"
- def visit_LONGBLOB(self, type_, **kw):
+ def visit_LONGBLOB(self, type_: LONGBLOB, **kw: Any) -> str:
return "LONGBLOB"
- def _visit_enumerated_values(self, name, type_, enumerated_values):
+ def _visit_enumerated_values(
+ self, name: str, type_: _StringType, enumerated_values: Sequence[str]
+ ) -> str:
quoted_enums = []
for e in enumerated_values:
if self.dialect.identifier_preparer._double_percents:
type_, {}, "%s(%s)" % (name, ",".join(quoted_enums))
)
- def visit_ENUM(self, type_, **kw):
+ def visit_ENUM(self, type_: ENUM, **kw: Any) -> str:
return self._visit_enumerated_values("ENUM", type_, type_.enums)
- def visit_SET(self, type_, **kw):
+ def visit_SET(self, type_: SET, **kw: Any) -> str:
return self._visit_enumerated_values("SET", type_, type_.values)
- def visit_BOOLEAN(self, type_, **kw):
+ def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str:
return "BOOL"
class MySQLIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS_MYSQL
- def __init__(self, dialect, server_ansiquotes=False, **kw):
+ def __init__(
+ self,
+ dialect: default.DefaultDialect,
+ server_ansiquotes: bool = False,
+ **kw: Any,
+ ):
if not server_ansiquotes:
quote = "`"
else:
super().__init__(dialect, initial_quote=quote, escape_quote=quote)
- def _quote_free_identifiers(self, *ids):
+ def _quote_free_identifiers(self, *ids: Optional[str]) -> tuple[str, ...]:
"""Unilaterally identifier-quote any number of strings."""
return tuple([self.quote_identifier(i) for i in ids if i is not None])
reserved_words = RESERVED_WORDS_MARIADB
-@log.class_logger
class MySQLDialect(default.DefaultDialect):
"""Details of the MySQL dialect.
Not used directly in application code.
ddl_compiler = MySQLDDLCompiler
type_compiler_cls = MySQLTypeCompiler
ischema_names = ischema_names
- preparer = MySQLIdentifierPreparer
+ preparer: type[MySQLIdentifierPreparer] = MySQLIdentifierPreparer
- is_mariadb = False
+ is_mariadb: bool = False
_mariadb_normalized_version_info = None
# default SQL compilation settings -
_backslash_escapes = True
_server_ansiquotes = False
+ server_version_info: tuple[int, ...]
+ identifier_preparer: MySQLIdentifierPreparer
+
construct_arguments = [
(sa_schema.Table, {"*": None}),
(sql.Update, {"limit": None}),
def __init__(
self,
- json_serializer=None,
- json_deserializer=None,
- is_mariadb=None,
- **kwargs,
- ):
+ json_serializer: Optional[Callable[..., Any]] = None,
+ json_deserializer: Optional[Callable[..., Any]] = None,
+ is_mariadb: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> None:
kwargs.pop("use_ansiquotes", None) # legacy
default.DefaultDialect.__init__(self, **kwargs)
self._json_serializer = json_serializer
self._json_deserializer = json_deserializer
- self._set_mariadb(is_mariadb, None)
+ self._set_mariadb(is_mariadb, ())
- def get_isolation_level_values(self, dbapi_conn):
+ def get_isolation_level_values(
+ self, dbapi_conn: DBAPIConnection
+ ) -> Sequence[IsolationLevel]:
return (
"SERIALIZABLE",
"READ UNCOMMITTED",
"REPEATABLE READ",
)
- def set_isolation_level(self, dbapi_connection, level):
+ def set_isolation_level(
+ self, dbapi_connection: DBAPIConnection, level: IsolationLevel
+ ) -> None:
cursor = dbapi_connection.cursor()
cursor.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {level}")
cursor.execute("COMMIT")
cursor.close()
- def get_isolation_level(self, dbapi_connection):
+ def get_isolation_level(
+ self, dbapi_connection: DBAPIConnection
+ ) -> IsolationLevel:
cursor = dbapi_connection.cursor()
if self._is_mysql and self.server_version_info >= (5, 7, 20):
cursor.execute("SELECT @@transaction_isolation")
cursor.close()
if isinstance(val, bytes):
val = val.decode()
- return val.upper().replace("-", " ")
+ return val.upper().replace("-", " ") # type: ignore[no-any-return]
@classmethod
- def _is_mariadb_from_url(cls, url):
+ def _is_mariadb_from_url(cls, url: URL) -> bool:
dbapi = cls.import_dbapi()
dialect = cls(dbapi=dbapi)
try:
cursor = conn.cursor()
cursor.execute("SELECT VERSION() LIKE '%MariaDB%'")
- val = cursor.fetchone()[0]
+ val = cursor.fetchone()[0] # type: ignore[index]
except:
raise
else:
finally:
conn.close()
- def _get_server_version_info(self, connection):
+ def _get_server_version_info(
+ self, connection: Connection
+ ) -> tuple[int, ...]:
# get database server version info explicitly over the wire
# to avoid proxy servers like MaxScale getting in the
# way with their own values, see #4205
dbapi_con = connection.connection
cursor = dbapi_con.cursor()
cursor.execute("SELECT VERSION()")
- val = cursor.fetchone()[0]
+
+ val = cursor.fetchone()[0] # type: ignore[index]
cursor.close()
if isinstance(val, bytes):
val = val.decode()
return self._parse_server_version(val)
- def _parse_server_version(self, val):
- version = []
+ def _parse_server_version(self, val: str) -> tuple[int, ...]:
+ version: list[int] = []
is_mariadb = False
r = re.compile(r"[.\-+]")
server_version_info = tuple(version)
self._set_mariadb(
- server_version_info and is_mariadb, server_version_info
+ bool(server_version_info and is_mariadb), server_version_info
)
if not is_mariadb:
self.server_version_info = server_version_info
return server_version_info
- def _set_mariadb(self, is_mariadb, server_version_info):
+ def _set_mariadb(
+ self, is_mariadb: Optional[bool], server_version_info: tuple[int, ...]
+ ) -> None:
if is_mariadb is None:
return
self.is_mariadb = is_mariadb
- def do_begin_twophase(self, connection, xid):
+ def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid))
- def do_prepare_twophase(self, connection, xid):
+ def do_prepare_twophase(self, connection: Connection, xid: Any) -> None:
connection.execute(sql.text("XA END :xid"), dict(xid=xid))
connection.execute(sql.text("XA PREPARE :xid"), dict(xid=xid))
def do_rollback_twophase(
- self, connection, xid, is_prepared=True, recover=False
- ):
+ self,
+ connection: Connection,
+ xid: Any,
+ is_prepared: bool = True,
+ recover: bool = False,
+ ) -> None:
if not is_prepared:
connection.execute(sql.text("XA END :xid"), dict(xid=xid))
connection.execute(sql.text("XA ROLLBACK :xid"), dict(xid=xid))
def do_commit_twophase(
- self, connection, xid, is_prepared=True, recover=False
- ):
+ self,
+ connection: Connection,
+ xid: Any,
+ is_prepared: bool = True,
+ recover: bool = False,
+ ) -> None:
if not is_prepared:
self.do_prepare_twophase(connection, xid)
connection.execute(sql.text("XA COMMIT :xid"), dict(xid=xid))
- def do_recover_twophase(self, connection):
+ def do_recover_twophase(self, connection: Connection) -> list[Any]:
resultset = connection.exec_driver_sql("XA RECOVER")
- return [row["data"][0 : row["gtrid_length"]] for row in resultset]
+ return [
+ row["data"][0 : row["gtrid_length"]]
+ for row in resultset.mappings()
+ ]
- def is_disconnect(self, e, connection, cursor):
+ def is_disconnect(
+ self,
+ e: DBAPIModule.Error,
+ connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+ cursor: Optional[DBAPICursor],
+ ) -> bool:
if isinstance(
e,
(
- self.dbapi.OperationalError,
- self.dbapi.ProgrammingError,
- self.dbapi.InterfaceError,
+ self.dbapi.OperationalError, # type: ignore
+ self.dbapi.ProgrammingError, # type: ignore
+ self.dbapi.InterfaceError, # type: ignore
),
) and self._extract_error_code(e) in (
1927,
):
return True
elif isinstance(
- e, (self.dbapi.InterfaceError, self.dbapi.InternalError)
+ e, (self.dbapi.InterfaceError, self.dbapi.InternalError) # type: ignore # noqa: E501
):
# if underlying connection is closed,
# this is the error you get
else:
return False
- def _compat_fetchall(self, rp, charset=None):
+ def _compat_fetchall(
+ self, rp: CursorResult[Unpack[TupleAny]], charset: Optional[str] = None
+ ) -> Union[Sequence[Row[Unpack[TupleAny]]], Sequence[_DecodingRow]]:
"""Proxy result rows to smooth over MySQL-Python driver
inconsistencies."""
return [_DecodingRow(row, charset) for row in rp.fetchall()]
- def _compat_fetchone(self, rp, charset=None):
+ def _compat_fetchone(
+ self, rp: CursorResult[Unpack[TupleAny]], charset: Optional[str] = None
+ ) -> Union[Row[Unpack[TupleAny]], None, _DecodingRow]:
"""Proxy a result row to smooth over MySQL-Python driver
inconsistencies."""
else:
return None
- def _compat_first(self, rp, charset=None):
+ def _compat_first(
+ self, rp: CursorResult[Unpack[TupleAny]], charset: Optional[str] = None
+ ) -> Optional[_DecodingRow]:
"""Proxy a result row to smooth over MySQL-Python driver
inconsistencies."""
else:
return None
- def _extract_error_code(self, exception):
+ def _extract_error_code(
+ self, exception: DBAPIModule.Error
+ ) -> Optional[int]:
raise NotImplementedError()
- def _get_default_schema_name(self, connection):
- return connection.exec_driver_sql("SELECT DATABASE()").scalar()
+ def _get_default_schema_name(self, connection: Connection) -> str:
+ return connection.exec_driver_sql("SELECT DATABASE()").scalar() # type: ignore[return-value] # noqa: E501
@reflection.cache
- def has_table(self, connection, table_name, schema=None, **kw):
+ def has_table(
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> bool:
self._ensure_has_table_connection(connection)
if schema is None:
#
# there's more "doesn't exist" kinds of messages but they are
# less clear if mysql 8 would suddenly start using one of those
- if self._extract_error_code(e.orig) in (1146, 1049, 1051):
+ if self._extract_error_code(e.orig) in (1146, 1049, 1051): # type: ignore # noqa: E501
return False
raise
@reflection.cache
- def has_sequence(self, connection, sequence_name, schema=None, **kw):
+ def has_sequence(
+ self,
+ connection: Connection,
+ sequence_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> bool:
if not self.supports_sequences:
self._sequences_not_supported()
if not schema:
)
return cursor.first() is not None
- def _sequences_not_supported(self):
+ def _sequences_not_supported(self) -> NoReturn:
raise NotImplementedError(
"Sequences are supported only by the "
"MariaDB series 10.3 or greater"
)
@reflection.cache
- def get_sequence_names(self, connection, schema=None, **kw):
+ def get_sequence_names(
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
+ ) -> list[str]:
if not self.supports_sequences:
self._sequences_not_supported()
if not schema:
)
]
- def initialize(self, connection):
+ def initialize(self, connection: Connection) -> None:
# this is driver-based, does not need server version info
# and is fairly critical for even basic SQL operations
- self._connection_charset = self._detect_charset(connection)
+ self._connection_charset: Optional[str] = self._detect_charset(
+ connection
+ )
# call super().initialize() because we need to have
# server_version_info set up. in 1.4 under python 2 only this does the
self._warn_for_known_db_issues()
- def _warn_for_known_db_issues(self):
+ def _warn_for_known_db_issues(self) -> None:
if self.is_mariadb:
mdb_version = self._mariadb_normalized_version_info
+ assert mdb_version is not None
if mdb_version > (10, 2) and mdb_version < (10, 2, 9):
util.warn(
"MariaDB %r before 10.2.9 has known issues regarding "
)
@property
- def _support_float_cast(self):
+ def _support_float_cast(self) -> bool:
if not self.server_version_info:
return False
elif self.is_mariadb:
return self.server_version_info >= (8, 0, 17)
@property
- def _support_default_function(self):
+ def _support_default_function(self) -> bool:
if not self.server_version_info:
return False
elif self.is_mariadb:
return self.server_version_info >= (8, 0, 13)
@property
- def _is_mariadb(self):
+ def _is_mariadb(self) -> bool:
return self.is_mariadb
@property
- def _is_mysql(self):
+ def _is_mysql(self) -> bool:
return not self.is_mariadb
@property
- def _is_mariadb_102(self):
- return self.is_mariadb and self._mariadb_normalized_version_info > (
- 10,
- 2,
+ def _is_mariadb_102(self) -> bool:
+ return (
+ self.is_mariadb
+ and self._mariadb_normalized_version_info # type:ignore[operator]
+ > (
+ 10,
+ 2,
+ )
)
@reflection.cache
- def get_schema_names(self, connection, **kw):
+ def get_schema_names(self, connection: Connection, **kw: Any) -> list[str]:
rp = connection.exec_driver_sql("SHOW schemas")
return [r[0] for r in rp]
@reflection.cache
- def get_table_names(self, connection, schema=None, **kw):
+ def get_table_names(
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
+ ) -> list[str]:
"""Return a Unicode SHOW TABLES from a given schema."""
if schema is not None:
- current_schema = schema
+ current_schema: str = schema
else:
- current_schema = self.default_schema_name
+ current_schema = self.default_schema_name # type: ignore
charset = self._connection_charset
]
@reflection.cache
- def get_view_names(self, connection, schema=None, **kw):
+ def get_view_names(
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
+ ) -> list[str]:
if schema is None:
schema = self.default_schema_name
+ assert schema is not None
charset = self._connection_charset
rp = connection.exec_driver_sql(
"SHOW FULL TABLES FROM %s"
]
@reflection.cache
- def get_table_options(self, connection, table_name, schema=None, **kw):
+ def get_table_options(
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> dict[str, Any]:
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
return ReflectionDefaults.table_options()
@reflection.cache
- def get_columns(self, connection, table_name, schema=None, **kw):
+ def get_columns(
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> list[ReflectedColumn]:
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
return ReflectionDefaults.columns()
@reflection.cache
- def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ def get_pk_constraint(
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> ReflectedPrimaryKeyConstraint:
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
return ReflectionDefaults.pk_constraint()
@reflection.cache
- def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ def get_foreign_keys(
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> list[ReflectedForeignKeyConstraint]:
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
default_schema = None
- fkeys = []
+ fkeys: list[ReflectedForeignKeyConstraint] = []
for spec in parsed_state.fk_constraints:
ref_name = spec["table"][-1]
if spec.get(opt, False) not in ("NO ACTION", None):
con_kw[opt] = spec[opt]
- fkey_d = {
+ fkey_d: ReflectedForeignKeyConstraint = {
"name": spec["name"],
"constrained_columns": loc_names,
"referred_schema": ref_schema,
return fkeys if fkeys else ReflectionDefaults.foreign_keys()
- def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection):
+ def _correct_for_mysql_bugs_88718_96365(
+ self,
+ fkeys: list[ReflectedForeignKeyConstraint],
+ connection: Connection,
+ ) -> None:
# Foreign key is always in lower case (MySQL 8.0)
# https://bugs.mysql.com/bug.php?id=88718
# issue #4344 for SQLAlchemy
if self._casing in (1, 2):
- def lower(s):
+ def lower(s: str) -> str:
return s.lower()
else:
# if on case sensitive, there can be two tables referenced
# with the same name different casing, so we need to use
# case-sensitive matching.
- def lower(s):
+ def lower(s: str) -> str:
return s
- default_schema_name = connection.dialect.default_schema_name
+ default_schema_name: str = connection.dialect.default_schema_name # type: ignore # noqa: E501
# NOTE: using (table_schema, table_name, lower(column_name)) in (...)
# is very slow since mysql does not seem able to properly use indexse.
# Unpack the where condition instead.
- schema_by_table_by_column = defaultdict(lambda: defaultdict(list))
+ schema_by_table_by_column: defaultdict[
+ str, defaultdict[str, list[str]]
+ ] = defaultdict(lambda: defaultdict(list))
for rec in fkeys:
sch = lower(rec["referred_schema"] or default_schema_name)
tbl = lower(rec["referred_table"])
_info_columns.c.column_name,
).where(condition)
- correct_for_wrong_fk_case = connection.execute(select)
+ correct_for_wrong_fk_case: CursorResult[str, str, str] = (
+ connection.execute(select)
+ )
# in casing=0, table name and schema name come back in their
# exact case.
# SHOW CREATE TABLE converts them to *lower case*, therefore
# not matching. So for this case, case-insensitive lookup
# is necessary
- d = defaultdict(dict)
+ d: defaultdict[tuple[str, str], dict[str, str]] = defaultdict(dict)
for schema, tname, cname in correct_for_wrong_fk_case:
d[(lower(schema), lower(tname))]["SCHEMANAME"] = schema
d[(lower(schema), lower(tname))]["TABLENAME"] = tname
d[(lower(schema), lower(tname))][cname.lower()] = cname
for fkey in fkeys:
- rec = d[
+ rec_b = d[
(
lower(fkey["referred_schema"] or default_schema_name),
lower(fkey["referred_table"]),
)
]
- fkey["referred_table"] = rec["TABLENAME"]
+ fkey["referred_table"] = rec_b["TABLENAME"]
if fkey["referred_schema"] is not None:
- fkey["referred_schema"] = rec["SCHEMANAME"]
+ fkey["referred_schema"] = rec_b["SCHEMANAME"]
fkey["referred_columns"] = [
- rec[col.lower()] for col in fkey["referred_columns"]
+ rec_b[col.lower()] for col in fkey["referred_columns"]
]
@reflection.cache
- def get_check_constraints(self, connection, table_name, schema=None, **kw):
+ def get_check_constraints(
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> list[ReflectedCheckConstraint]:
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
- cks = [
+ cks: list[ReflectedCheckConstraint] = [
{"name": spec["name"], "sqltext": spec["sqltext"]}
for spec in parsed_state.ck_constraints
]
return cks if cks else ReflectionDefaults.check_constraints()
@reflection.cache
- def get_table_comment(self, connection, table_name, schema=None, **kw):
+ def get_table_comment(
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> ReflectedTableComment:
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
return ReflectionDefaults.table_comment()
@reflection.cache
- def get_indexes(self, connection, table_name, schema=None, **kw):
+ def get_indexes(
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> list[ReflectedIndex]:
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
- indexes = []
+ indexes: list[ReflectedIndex] = []
for spec in parsed_state.keys:
dialect_options = {}
unique = True
elif flavor in ("FULLTEXT", "SPATIAL"):
dialect_options["%s_prefix" % self.name] = flavor
- elif flavor is None:
- pass
- else:
- self.logger.info(
+ elif flavor is not None:
+ util.warn(
"Converting unknown KEY type %s to a plain KEY", flavor
)
- pass
if spec["parser"]:
dialect_options["%s_with_parser" % (self.name)] = spec[
"parser"
]
- index_d = {}
+ index_d: ReflectedIndex = {
+ "name": spec["name"],
+ "column_names": [s[0] for s in spec["columns"]],
+ "unique": unique,
+ }
- index_d["name"] = spec["name"]
- index_d["column_names"] = [s[0] for s in spec["columns"]]
mysql_length = {
s[0]: s[1] for s in spec["columns"] if s[1] is not None
}
if mysql_length:
dialect_options["%s_length" % self.name] = mysql_length
- index_d["unique"] = unique
if flavor:
- index_d["type"] = flavor
+ index_d["type"] = flavor # type: ignore[typeddict-unknown-key]
if dialect_options:
index_d["dialect_options"] = dialect_options
@reflection.cache
def get_unique_constraints(
- self, connection, table_name, schema=None, **kw
- ):
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> list[ReflectedUniqueConstraint]:
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
- ucs = [
+ ucs: list[ReflectedUniqueConstraint] = [
{
"name": key["name"],
"column_names": [col[0] for col in key["columns"]],
return ReflectionDefaults.unique_constraints()
@reflection.cache
- def get_view_definition(self, connection, view_name, schema=None, **kw):
+ def get_view_definition(
+ self,
+ connection: Connection,
+ view_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> str:
charset = self._connection_charset
full_name = ".".join(
self.identifier_preparer._quote_free_identifiers(schema, view_name)
return sql
def _parsed_state_or_create(
- self, connection, table_name, schema=None, **kw
- ):
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> _reflection.ReflectedState:
return self._setup_parser(
connection,
table_name,
)
@util.memoized_property
- def _tabledef_parser(self):
+ def _tabledef_parser(self) -> _reflection.MySQLTableDefinitionParser:
"""return the MySQLTableDefinitionParser, generate if needed.
The deferred creation ensures that the dialect has
return _reflection.MySQLTableDefinitionParser(self, preparer)
@reflection.cache
- def _setup_parser(self, connection, table_name, schema=None, **kw):
+ def _setup_parser(
+ self,
+ connection: Connection,
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> _reflection.ReflectedState:
charset = self._connection_charset
parser = self._tabledef_parser
full_name = ".".join(
columns = self._describe_table(
connection, None, charset, full_name=full_name
)
- sql = parser._describe_to_create(table_name, columns)
+ sql = parser._describe_to_create(
+ table_name, columns # type: ignore[arg-type]
+ )
return parser.parse(sql, charset)
- def _fetch_setting(self, connection, setting_name):
+ def _fetch_setting(
+ self, connection: Connection, setting_name: str
+ ) -> Optional[str]:
charset = self._connection_charset
if self.server_version_info and self.server_version_info < (5, 6):
if not row:
return None
else:
- return row[fetch_col]
+ return cast("Optional[str]", row[fetch_col])
- def _detect_charset(self, connection):
+ def _detect_charset(self, connection: Connection) -> str:
raise NotImplementedError()
- def _detect_casing(self, connection):
+ def _detect_casing(self, connection: Connection) -> int:
"""Sniff out identifier case sensitivity.
Cached per-connection. This value can not change without a server
self._casing = cs
return cs
- def _detect_collations(self, connection):
+ def _detect_collations(self, connection: Connection) -> dict[str, str]:
"""Pull the active COLLATIONS list from the server.
Cached per-connection.
collations[row[0]] = row[1]
return collations
- def _detect_sql_mode(self, connection):
+ def _detect_sql_mode(self, connection: Connection) -> None:
setting = self._fetch_setting(connection, "sql_mode")
if setting is None:
else:
self._sql_mode = setting or ""
- def _detect_ansiquotes(self, connection):
+ def _detect_ansiquotes(self, connection: Connection) -> None:
"""Detect and adjust for the ANSI_QUOTES sql mode."""
mode = self._sql_mode
# as of MySQL 5.0.1
self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode
+ @overload
def _show_create_table(
- self, connection, table, charset=None, full_name=None
- ):
+ self,
+ connection: Connection,
+ table: Optional[Table],
+ charset: Optional[str],
+ full_name: str,
+ ) -> str: ...
+
+ @overload
+ def _show_create_table(
+ self,
+ connection: Connection,
+ table: Table,
+ charset: Optional[str] = None,
+ full_name: None = None,
+ ) -> str: ...
+
+ def _show_create_table(
+ self,
+ connection: Connection,
+ table: Optional[Table],
+ charset: Optional[str] = None,
+ full_name: Optional[str] = None,
+ ) -> str:
"""Run SHOW CREATE TABLE for a ``Table``."""
if full_name is None:
+ assert table is not None
full_name = self.identifier_preparer.format_table(table)
st = "SHOW CREATE TABLE %s" % full_name
skip_user_error_events=True
).exec_driver_sql(st)
except exc.DBAPIError as e:
- if self._extract_error_code(e.orig) == 1146:
+ if self._extract_error_code(e.orig) == 1146: # type: ignore[arg-type] # noqa: E501
raise exc.NoSuchTableError(full_name) from e
else:
raise
row = self._compat_first(rp, charset=charset)
if not row:
raise exc.NoSuchTableError(full_name)
- return row[1].strip()
+ return cast("str", row[1]).strip()
+
+ @overload
+ def _describe_table(
+ self,
+ connection: Connection,
+ table: Optional[Table],
+ charset: Optional[str],
+ full_name: str,
+ ) -> Union[Sequence[Row[Unpack[TupleAny]]], Sequence[_DecodingRow]]: ...
+
+ @overload
+ def _describe_table(
+ self,
+ connection: Connection,
+ table: Table,
+ charset: Optional[str] = None,
+ full_name: None = None,
+ ) -> Union[Sequence[Row[Unpack[TupleAny]]], Sequence[_DecodingRow]]: ...
- def _describe_table(self, connection, table, charset=None, full_name=None):
+ def _describe_table(
+ self,
+ connection: Connection,
+ table: Optional[Table],
+ charset: Optional[str] = None,
+ full_name: Optional[str] = None,
+ ) -> Union[Sequence[Row[Unpack[TupleAny]]], Sequence[_DecodingRow]]:
"""Run DESCRIBE for a ``Table`` and return processed rows."""
if full_name is None:
+ assert table is not None
full_name = self.identifier_preparer.format_table(table)
st = "DESCRIBE %s" % full_name
skip_user_error_events=True
).exec_driver_sql(st)
except exc.DBAPIError as e:
- code = self._extract_error_code(e.orig)
+ code = self._extract_error_code(e.orig) # type: ignore[arg-type] # noqa: E501
if code == 1146:
raise exc.NoSuchTableError(full_name) from e
# sets.Set(['value']) (seriously) but thankfully that doesn't
# seem to come up in DDL queries.
- _encoding_compat = {
+ _encoding_compat: dict[str, str] = {
"koi8r": "koi8_r",
"koi8u": "koi8_u",
"utf16": "utf-16-be", # MySQL's uft16 is always bigendian
"eucjpms": "ujis",
}
- def __init__(self, rowproxy, charset):
+ def __init__(self, rowproxy: Row[Unpack[_Ts]], charset: Optional[str]):
self.rowproxy = rowproxy
- self.charset = self._encoding_compat.get(charset, charset)
+ self.charset = (
+ self._encoding_compat.get(charset, charset)
+ if charset is not None
+ else None
+ )
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> Any:
item = self.rowproxy[index]
- if isinstance(item, _array):
- item = item.tostring()
-
if self.charset and isinstance(item, bytes):
return item.decode(self.charset)
else:
return item
- def __getattr__(self, attr):
+ def __getattr__(self, attr: str) -> Any:
item = getattr(self.rowproxy, attr)
- if isinstance(item, _array):
- item = item.tostring()
if self.charset and isinstance(item, bytes):
return item.decode(self.charset)
else:
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
r"""
dialects are mysqlclient and PyMySQL.
""" # noqa
+from __future__ import annotations
+
+from typing import Any
+from typing import Iterable
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
-from .base import BIT
from .base import MySQLDialect
from .mysqldb import MySQLDialect_mysqldb
+from .types import BIT
from ... import util
+if TYPE_CHECKING:
+ from ...engine.base import Connection
+ from ...engine.interfaces import DBAPIConnection
+ from ...engine.interfaces import DBAPICursor
+ from ...engine.interfaces import DBAPIModule
+ from ...engine.interfaces import Dialect
+ from ...engine.interfaces import PoolProxiedConnection
+ from ...sql.type_api import _ResultProcessorType
+
class _cymysqlBIT(BIT):
- def result_processor(self, dialect, coltype):
+ def result_processor(
+ self, dialect: Dialect, coltype: object
+ ) -> Optional[_ResultProcessorType[Any]]:
"""Convert MySQL's 64 bit, variable length binary string to a long."""
- def process(value):
+ def process(value: Optional[Iterable[int]]) -> Optional[int]:
if value is not None:
v = 0
for i in iter(value):
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
@classmethod
- def import_dbapi(cls):
+ def import_dbapi(cls) -> DBAPIModule:
return __import__("cymysql")
- def _detect_charset(self, connection):
- return connection.connection.charset
+ def _detect_charset(self, connection: Connection) -> str:
+ return connection.connection.charset # type: ignore[no-any-return]
- def _extract_error_code(self, exception):
- return exception.errno
+ def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
+ return exception.errno # type: ignore[no-any-return]
- def is_disconnect(self, e, connection, cursor):
- if isinstance(e, self.dbapi.OperationalError):
+ def is_disconnect(
+ self,
+ e: DBAPIModule.Error,
+ connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+ cursor: Optional[DBAPICursor],
+ ) -> bool:
+ if isinstance(e, self.loaded_dbapi.OperationalError):
return self._extract_error_code(e) in (
2006,
2013,
2045,
2055,
)
- elif isinstance(e, self.dbapi.InterfaceError):
+ elif isinstance(e, self.loaded_dbapi.InterfaceError):
# if underlying connection is closed,
# this is the error you get
return True
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
+import enum
import re
+from typing import Any
+from typing import Optional
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from .types import _StringType
from ... import exc
from ... import sql
from ... import util
from ...sql import sqltypes
+from ...sql import type_api
+if TYPE_CHECKING:
+ from ...engine.interfaces import Dialect
+ from ...sql.elements import ColumnElement
+ from ...sql.type_api import _BindProcessorType
+ from ...sql.type_api import _ResultProcessorType
+ from ...sql.type_api import TypeEngine
+ from ...sql.type_api import TypeEngineMixin
-class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
+
+class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
"""MySQL ENUM type."""
__visit_name__ = "ENUM"
native_enum = True
- def __init__(self, *enums, **kw):
+ def __init__(self, *enums: Union[str, Type[enum.Enum]], **kw: Any) -> None:
"""Construct an ENUM.
E.g.::
"""
kw.pop("strict", None)
- self._enum_init(enums, kw)
+ self._enum_init(enums, kw) # type: ignore[arg-type]
_StringType.__init__(self, length=self.length, **kw)
@classmethod
- def adapt_emulated_to_native(cls, impl, **kw):
+ def adapt_emulated_to_native(
+ cls,
+ impl: Union[TypeEngine[Any], TypeEngineMixin],
+ **kw: Any,
+ ) -> ENUM:
"""Produce a MySQL native :class:`.mysql.ENUM` from plain
:class:`.Enum`.
"""
+ if TYPE_CHECKING:
+ assert isinstance(impl, ENUM)
kw.setdefault("validate_strings", impl.validate_strings)
kw.setdefault("values_callable", impl.values_callable)
kw.setdefault("omit_aliases", impl._omit_aliases)
return cls(**kw)
- def _object_value_for_elem(self, elem):
+ def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]:
# mysql sends back a blank string for any value that
# was persisted that was not in the enums; that is, it does no
# validation on the incoming data, it "truncates" it to be
else:
return super()._object_value_for_elem(elem)
- def __repr__(self):
+ def __repr__(self) -> str:
return util.generic_repr(
self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
)
+# TODO: SET is a string as far as configuration but does not act like
+# a string at the python level. We either need to make a py-type agnostic
+# version of String as a base to be used for this, make this some kind of
+# TypeDecorator, or just vendor it out as its own type.
class SET(_StringType):
"""MySQL SET type."""
__visit_name__ = "SET"
- def __init__(self, *values, **kw):
+ def __init__(self, *values: str, **kw: Any):
"""Construct a SET.
E.g.::
"setting retrieve_as_bitwise=True"
)
if self.retrieve_as_bitwise:
- self._bitmap = {
+ self._inversed_bitmap: dict[str, int] = {
value: 2**idx for idx, value in enumerate(self.values)
}
- self._bitmap.update(
- (2**idx, value) for idx, value in enumerate(self.values)
- )
+ self._bitmap: dict[int, str] = {
+ 2**idx: value for idx, value in enumerate(self.values)
+ }
length = max([len(v) for v in values] + [0])
kw.setdefault("length", length)
super().__init__(**kw)
- def column_expression(self, colexpr):
+ def column_expression(
+ self, colexpr: ColumnElement[Any]
+ ) -> ColumnElement[Any]:
if self.retrieve_as_bitwise:
return sql.type_coerce(
sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
else:
return colexpr
- def result_processor(self, dialect, coltype):
+ def result_processor(
+ self, dialect: Dialect, coltype: Any
+ ) -> Optional[_ResultProcessorType[Any]]:
if self.retrieve_as_bitwise:
- def process(value):
+ def process(value: Union[str, int, None]) -> Optional[set[str]]:
if value is not None:
value = int(value)
else:
super_convert = super().result_processor(dialect, coltype)
- def process(value):
+ def process(value: Union[str, set[str], None]) -> Optional[set[str]]: # type: ignore[misc] # noqa: E501
if isinstance(value, str):
# MySQLdb returns a string, let's parse
if super_convert:
value = super_convert(value)
+ assert value is not None
+ if TYPE_CHECKING:
+ assert isinstance(value, str)
return set(re.findall(r"[^,]+", value))
else:
# mysql-connector-python does a naive
return process
- def bind_processor(self, dialect):
+ def bind_processor(
+ self, dialect: Dialect
+ ) -> _BindProcessorType[Union[str, int]]:
super_convert = super().bind_processor(dialect)
if self.retrieve_as_bitwise:
- def process(value):
+ def process(
+ value: Union[str, int, set[str], None],
+ ) -> Union[str, int, None]:
if value is None:
return None
elif isinstance(value, (int, str)):
if super_convert:
- return super_convert(value)
+ return super_convert(value) # type: ignore[arg-type, no-any-return] # noqa: E501
else:
return value
else:
int_value = 0
for v in value:
- int_value |= self._bitmap[v]
+ int_value |= self._inversed_bitmap[v]
return int_value
else:
- def process(value):
+ def process(
+ value: Union[str, int, set[str], None],
+ ) -> Union[str, int, None]:
# accept strings and int (actually bitflag) values directly
if value is not None and not isinstance(value, (int, str)):
value = ",".join(value)
-
if super_convert:
- return super_convert(value)
+ return super_convert(value) # type: ignore
else:
return value
return process
- def adapt(self, impltype, **kw):
+ def adapt(self, cls: type, **kw: Any) -> Any:
kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
- return util.constructor_copy(self, impltype, *self.values, **kw)
+ return util.constructor_copy(self, cls, *self.values, **kw)
- def __repr__(self):
+ def __repr__(self) -> str:
return util.generic_repr(
self,
to_inspect=[SET, _StringType],
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
+
+from typing import Any
from ... import exc
from ... import util
from ...util.typing import Self
-class match(Generative, elements.BinaryExpression):
+class match(Generative, elements.BinaryExpression[Any]):
"""Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
E.g.::
__visit_name__ = "mysql_match"
inherit_cache = True
+ modifiers: util.immutabledict[str, Any]
- def __init__(self, *cols, **kw):
+ def __init__(self, *cols: elements.ColumnElement[Any], **kw: Any):
if not cols:
raise exc.ArgumentError("columns are required")
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
+
+from typing import Any
+from typing import TYPE_CHECKING
from ... import types as sqltypes
+if TYPE_CHECKING:
+ from ...engine.interfaces import Dialect
+ from ...sql.type_api import _BindProcessorType
+ from ...sql.type_api import _LiteralProcessorType
+
class JSON(sqltypes.JSON):
"""MySQL JSON type.
class _FormatTypeMixin:
- def _format_value(self, value):
+ def _format_value(self, value: Any) -> str:
raise NotImplementedError()
- def bind_processor(self, dialect):
- super_proc = self.string_bind_processor(dialect)
+ def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
+ super_proc = self.string_bind_processor(dialect) # type: ignore[attr-defined] # noqa: E501
- def process(value):
+ def process(value: Any) -> Any:
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return process
- def literal_processor(self, dialect):
- super_proc = self.string_literal_processor(dialect)
+ def literal_processor(
+ self, dialect: Dialect
+ ) -> _LiteralProcessorType[Any]:
+ super_proc = self.string_literal_processor(dialect) # type: ignore[attr-defined] # noqa: E501
- def process(value):
+ def process(value: Any) -> str:
value = self._format_value(value)
if super_proc:
value = super_proc(value)
- return value
+ return value # type: ignore[no-any-return]
return process
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
- def _format_value(self, value):
+ def _format_value(self, value: Any) -> str:
if isinstance(value, int):
- value = "$[%s]" % value
+ formatted_value = "$[%s]" % value
else:
- value = '$."%s"' % value
- return value
+ formatted_value = '$."%s"' % value
+ return formatted_value
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
- def _format_value(self, value):
+ def _format_value(self, value: Any) -> str:
return "$%s" % (
"".join(
[
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+
+from __future__ import annotations
+
+from typing import Any
+from typing import Callable
+from typing import Optional
+from typing import TYPE_CHECKING
+
from .base import MariaDBIdentifierPreparer
from .base import MySQLDialect
+from .base import MySQLIdentifierPreparer
from .base import MySQLTypeCompiler
from ... import util
from ...sql import sqltypes
+from ...sql.sqltypes import _UUID_RETURN
from ...sql.sqltypes import UUID
from ...sql.sqltypes import Uuid
+if TYPE_CHECKING:
+ from ...engine.base import Connection
+ from ...sql.type_api import _BindProcessorType
+
class INET4(sqltypes.TypeEngine[str]):
"""INET4 column type for MariaDB
__visit_name__ = "INET6"
-class _MariaDBUUID(UUID):
+class _MariaDBUUID(UUID[_UUID_RETURN]):
def __init__(self, as_uuid: bool = True, native_uuid: bool = True):
self.as_uuid = as_uuid
self.native_uuid = False
@property
- def native(self):
+ def native(self) -> bool: # type: ignore[override]
# override to return True, this is a native type, just turning
# off native_uuid for internal data handling
return True
- def bind_processor(self, dialect):
+ def bind_processor(self, dialect: MariaDBDialect) -> Optional[_BindProcessorType[_UUID_RETURN]]: # type: ignore[override] # noqa: E501
if not dialect.supports_native_uuid or not dialect._allows_uuid_binds:
- return super().bind_processor(dialect)
+ return super().bind_processor(dialect) # type: ignore[return-value] # noqa: E501
else:
return None
class MariaDBTypeCompiler(MySQLTypeCompiler):
- def visit_INET4(self, type_, **kwargs) -> str:
+ def visit_INET4(self, type_: INET4, **kwargs: Any) -> str:
return "INET4"
- def visit_INET6(self, type_, **kwargs) -> str:
+ def visit_INET6(self, type_: INET6, **kwargs: Any) -> str:
return "INET6"
_allows_uuid_binds = True
name = "mariadb"
- preparer = MariaDBIdentifierPreparer
+ preparer: type[MySQLIdentifierPreparer] = MariaDBIdentifierPreparer
type_compiler_cls = MariaDBTypeCompiler
colspecs = util.update_copy(MySQLDialect.colspecs, {Uuid: _MariaDBUUID})
- def initialize(self, connection):
+ def initialize(self, connection: Connection) -> None:
super().initialize(connection)
self.supports_native_uuid = (
)
-def loader(driver):
+def loader(driver: str) -> Callable[[], type[MariaDBDialect]]:
dialect_mod = __import__(
"sqlalchemy.dialects.mysql.%s" % driver
).dialects.mysql
driver_mod = getattr(dialect_mod, driver)
if hasattr(driver_mod, "mariadb_dialect"):
driver_cls = driver_mod.mariadb_dialect
- return driver_cls
+ return driver_cls # type: ignore[no-any-return]
else:
driver_cls = driver_mod.dialect
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
"""
.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
""" # noqa
+from __future__ import annotations
+
import re
+from typing import Any
+from typing import Optional
+from typing import Sequence
+from typing import TYPE_CHECKING
+from typing import Union
from uuid import UUID as _python_UUID
from .base import MySQLCompiler
from ... import util
from ...sql import sqltypes
+if TYPE_CHECKING:
+ from ...engine.base import Connection
+ from ...engine.interfaces import ConnectArgsType
+ from ...engine.interfaces import DBAPIConnection
+ from ...engine.interfaces import DBAPICursor
+ from ...engine.interfaces import DBAPIModule
+ from ...engine.interfaces import Dialect
+ from ...engine.interfaces import IsolationLevel
+ from ...engine.interfaces import PoolProxiedConnection
+ from ...engine.url import URL
+ from ...sql.compiler import SQLCompiler
+ from ...sql.type_api import _ResultProcessorType
+
mariadb_cpy_minimum_version = (1, 0, 1)
# work around JIRA issue
# https://jira.mariadb.org/browse/CONPY-270. When that issue is fixed,
# this type can be removed.
- def result_processor(self, dialect, coltype):
+ def result_processor(
+ self, dialect: Dialect, coltype: object
+ ) -> Optional[_ResultProcessorType[Any]]:
if self.as_uuid:
- def process(value):
+ def process(value: Any) -> Any:
if value is not None:
if hasattr(value, "decode"):
value = value.decode("ascii")
return process
else:
- def process(value):
+ def process(value: Any) -> Any:
if value is not None:
if hasattr(value, "decode"):
value = value.decode("ascii")
class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
- _lastrowid = None
+ _lastrowid: Optional[int] = None
- def create_server_side_cursor(self):
+ def create_server_side_cursor(self) -> DBAPICursor:
return self._dbapi_connection.cursor(buffered=False)
- def create_default_cursor(self):
+ def create_default_cursor(self) -> DBAPICursor:
return self._dbapi_connection.cursor(buffered=True)
- def post_exec(self):
+ def post_exec(self) -> None:
super().post_exec()
self._rowcount = self.cursor.rowcount
+ if TYPE_CHECKING:
+ assert isinstance(self.compiled, SQLCompiler)
if self.isinsert and self.compiled.postfetch_lastrowid:
self._lastrowid = self.cursor.lastrowid
- def get_lastrowid(self):
+ def get_lastrowid(self) -> int:
+ if TYPE_CHECKING:
+ assert self._lastrowid is not None
return self._lastrowid
)
@util.memoized_property
- def _dbapi_version(self):
+ def _dbapi_version(self) -> tuple[int, ...]:
if self.dbapi and hasattr(self.dbapi, "__version__"):
return tuple(
[
else:
return (99, 99, 99)
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.paramstyle = "qmark"
if self.dbapi is not None:
)
@classmethod
- def import_dbapi(cls):
+ def import_dbapi(cls) -> DBAPIModule:
return __import__("mariadb")
- def is_disconnect(self, e, connection, cursor):
+ def is_disconnect(
+ self,
+ e: DBAPIModule.Error,
+ connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+ cursor: Optional[DBAPICursor],
+ ) -> bool:
if super().is_disconnect(e, connection, cursor):
return True
- elif isinstance(e, self.dbapi.Error):
+ elif isinstance(e, self.loaded_dbapi.Error):
str_e = str(e).lower()
return "not connected" in str_e or "isn't valid" in str_e
else:
return False
- def create_connect_args(self, url):
+ def create_connect_args(self, url: URL) -> ConnectArgsType:
opts = url.translate_connect_args()
opts.update(url.query)
except (AttributeError, ImportError):
self.supports_sane_rowcount = False
opts["client_flag"] = client_flag
- return [[], opts]
+ return [], opts
- def _extract_error_code(self, exception):
+ def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
try:
- rc = exception.errno
+ rc: int = exception.errno
except:
rc = -1
return rc
- def _detect_charset(self, connection):
+ def _detect_charset(self, connection: Connection) -> str:
return "utf8mb4"
- def get_isolation_level_values(self, dbapi_connection):
+ def get_isolation_level_values(
+ self, dbapi_conn: DBAPIConnection
+ ) -> Sequence[IsolationLevel]:
return (
"SERIALIZABLE",
"READ UNCOMMITTED",
"AUTOCOMMIT",
)
- def set_isolation_level(self, connection, level):
+ def set_isolation_level(
+ self, dbapi_connection: DBAPIConnection, level: IsolationLevel
+ ) -> None:
if level == "AUTOCOMMIT":
- connection.autocommit = True
+ dbapi_connection.autocommit = True
else:
- connection.autocommit = False
- super().set_isolation_level(connection, level)
+ dbapi_connection.autocommit = False
+ super().set_isolation_level(dbapi_connection, level)
- def do_begin_twophase(self, connection, xid):
+ def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
connection.execute(
sql.text("XA BEGIN :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
- def do_prepare_twophase(self, connection, xid):
+ def do_prepare_twophase(self, connection: Connection, xid: Any) -> None:
connection.execute(
sql.text("XA END :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
def do_rollback_twophase(
- self, connection, xid, is_prepared=True, recover=False
- ):
+ self,
+ connection: Connection,
+ xid: Any,
+ is_prepared: bool = True,
+ recover: bool = False,
+ ) -> None:
if not is_prepared:
connection.execute(
sql.text("XA END :xid").bindparams(
)
def do_commit_twophase(
- self, connection, xid, is_prepared=True, recover=False
- ):
+ self,
+ connection: Connection,
+ xid: Any,
+ is_prepared: bool = True,
+ recover: bool = False,
+ ) -> None:
if not is_prepared:
self.do_prepare_twophase(connection, xid)
connection.execute(
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
r"""
""" # noqa
+from __future__ import annotations
import re
+from typing import Any
+from typing import cast
+from typing import Optional
+from typing import Sequence
+from typing import TYPE_CHECKING
+from typing import Union
-from .base import BIT
from .base import MariaDBIdentifierPreparer
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLExecutionContext
from .base import MySQLIdentifierPreparer
from .mariadb import MariaDBDialect
+from .types import BIT
from ... import util
+if TYPE_CHECKING:
+
+ from ...engine.base import Connection
+ from ...engine.cursor import CursorResult
+ from ...engine.interfaces import ConnectArgsType
+ from ...engine.interfaces import DBAPIConnection
+ from ...engine.interfaces import DBAPICursor
+ from ...engine.interfaces import DBAPIModule
+ from ...engine.interfaces import IsolationLevel
+ from ...engine.interfaces import PoolProxiedConnection
+ from ...engine.row import Row
+ from ...engine.url import URL
+ from ...sql.elements import BinaryExpression
+ from ...util.typing import TupleAny
+ from ...util.typing import Unpack
+
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
- def create_server_side_cursor(self):
+ def create_server_side_cursor(self) -> DBAPICursor:
return self._dbapi_connection.cursor(buffered=False)
- def create_default_cursor(self):
+ def create_default_cursor(self) -> DBAPICursor:
return self._dbapi_connection.cursor(buffered=True)
class MySQLCompiler_mysqlconnector(MySQLCompiler):
- def visit_mod_binary(self, binary, operator, **kw):
+ def visit_mod_binary(
+ self, binary: BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
return (
self.process(binary.left, **kw)
+ " % "
class IdentifierPreparerCommon_mysqlconnector:
@property
- def _double_percents(self):
+ def _double_percents(self) -> bool:
return False
@_double_percents.setter
- def _double_percents(self, value):
+ def _double_percents(self, value: Any) -> None:
pass
- def _escape_identifier(self, value):
- value = value.replace(self.escape_quote, self.escape_to_quote)
+ def _escape_identifier(self, value: str) -> str:
+ value = value.replace(
+ self.escape_quote, # type:ignore[attr-defined]
+ self.escape_to_quote, # type:ignore[attr-defined]
+ )
return value
-class MySQLIdentifierPreparer_mysqlconnector(
+class MySQLIdentifierPreparer_mysqlconnector( # type:ignore[misc]
IdentifierPreparerCommon_mysqlconnector, MySQLIdentifierPreparer
):
pass
-class MariaDBIdentifierPreparer_mysqlconnector(
+class MariaDBIdentifierPreparer_mysqlconnector( # type:ignore[misc]
IdentifierPreparerCommon_mysqlconnector, MariaDBIdentifierPreparer
):
pass
class _myconnpyBIT(BIT):
- def result_processor(self, dialect, coltype):
+ def result_processor(self, dialect: Any, coltype: Any) -> None:
"""MySQL-connector already converts mysql bits, so."""
return None
execution_ctx_cls = MySQLExecutionContext_mysqlconnector
- preparer = MySQLIdentifierPreparer_mysqlconnector
+ preparer: type[MySQLIdentifierPreparer] = (
+ MySQLIdentifierPreparer_mysqlconnector
+ )
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
@classmethod
- def import_dbapi(cls):
- from mysql import connector
+ def import_dbapi(cls) -> DBAPIModule:
+ return cast(DBAPIModule, __import__("mysql.connector").connector)
- return connector
-
- def do_ping(self, dbapi_connection):
+ def do_ping(self, dbapi_connection: DBAPIConnection) -> bool:
dbapi_connection.ping(False)
return True
- def create_connect_args(self, url):
+ def create_connect_args(self, url: URL) -> ConnectArgsType:
opts = url.translate_connect_args(username="user")
opts.update(url.query)
# supports_sane_rowcount.
if self.dbapi is not None:
try:
- from mysql.connector.constants import ClientFlag
+ from mysql.connector import constants # type: ignore
+
+ ClientFlag = constants.ClientFlag
client_flags = opts.get(
"client_flags", ClientFlag.get_default()
except Exception:
pass
- return [[], opts]
+ return [], opts
@util.memoized_property
- def _mysqlconnector_version_info(self):
+ def _mysqlconnector_version_info(self) -> Optional[tuple[int, ...]]:
if self.dbapi and hasattr(self.dbapi, "__version__"):
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
if m:
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
+ return None
- def _detect_charset(self, connection):
- return connection.connection.charset
+ def _detect_charset(self, connection: Connection) -> str:
+ return connection.connection.charset # type: ignore
- def _extract_error_code(self, exception):
- return exception.errno
+ def _extract_error_code(self, exception: BaseException) -> int:
+ return exception.errno # type: ignore
- def is_disconnect(self, e, connection, cursor):
+ def is_disconnect(
+ self,
+ e: Exception,
+ connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+ cursor: Optional[DBAPICursor],
+ ) -> bool:
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
exceptions = (
- self.dbapi.OperationalError,
- self.dbapi.InterfaceError,
- self.dbapi.ProgrammingError,
+ self.loaded_dbapi.OperationalError, #
+ self.loaded_dbapi.InterfaceError,
+ self.loaded_dbapi.ProgrammingError,
)
if isinstance(e, exceptions):
return (
else:
return False
- def _compat_fetchall(self, rp, charset=None):
+ def _compat_fetchall(
+ self,
+ rp: CursorResult[Unpack[TupleAny]],
+ charset: Optional[str] = None,
+ ) -> Sequence[Row[Unpack[TupleAny]]]:
return rp.fetchall()
- def _compat_fetchone(self, rp, charset=None):
+ def _compat_fetchone(
+ self,
+ rp: CursorResult[Unpack[TupleAny]],
+ charset: Optional[str] = None,
+ ) -> Optional[Row[Unpack[TupleAny]]]:
return rp.fetchone()
- def get_isolation_level_values(self, dbapi_connection):
+ def get_isolation_level_values(
+ self, dbapi_conn: DBAPIConnection
+ ) -> Sequence[IsolationLevel]:
return (
"SERIALIZABLE",
"READ UNCOMMITTED",
"AUTOCOMMIT",
)
- def set_isolation_level(self, connection, level):
+ def set_isolation_level(
+ self, dbapi_connection: DBAPIConnection, level: IsolationLevel
+ ) -> None:
if level == "AUTOCOMMIT":
- connection.autocommit = True
+ dbapi_connection.autocommit = True
else:
- connection.autocommit = False
- super().set_isolation_level(connection, level)
+ dbapi_connection.autocommit = False
+ super().set_isolation_level(dbapi_connection, level)
class MariaDBDialect_mysqlconnector(
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
"""
The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
"""
+from __future__ import annotations
import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Literal
+from typing import Optional
+from typing import TYPE_CHECKING
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLExecutionContext
from .base import MySQLIdentifierPreparer
-from .base import TEXT
-from ... import sql
from ... import util
+if TYPE_CHECKING:
+
+ from ...engine.base import Connection
+ from ...engine.interfaces import _DBAPIMultiExecuteParams
+ from ...engine.interfaces import ConnectArgsType
+ from ...engine.interfaces import DBAPIConnection
+ from ...engine.interfaces import DBAPICursor
+ from ...engine.interfaces import DBAPIModule
+ from ...engine.interfaces import ExecutionContext
+ from ...engine.interfaces import IsolationLevel
+ from ...engine.url import URL
+
class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
pass
execution_ctx_cls = MySQLExecutionContext_mysqldb
statement_compiler = MySQLCompiler_mysqldb
preparer = MySQLIdentifierPreparer
+ server_version_info: tuple[int, ...]
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._mysql_dbapi_version = (
self._parse_dbapi_version(self.dbapi.__version__)
else (0, 0, 0)
)
- def _parse_dbapi_version(self, version):
+ def _parse_dbapi_version(self, version: str) -> tuple[int, ...]:
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
if m:
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
return (0, 0, 0)
@util.langhelpers.memoized_property
- def supports_server_side_cursors(self):
+ def supports_server_side_cursors(self) -> bool: # type: ignore[override]
try:
cursors = __import__("MySQLdb.cursors").cursors
self._sscursor = cursors.SSCursor
return False
@classmethod
- def import_dbapi(cls):
+ def import_dbapi(cls) -> DBAPIModule:
return __import__("MySQLdb")
- def on_connect(self):
+ def on_connect(self) -> Callable[[DBAPIConnection], None]:
super_ = super().on_connect()
- def on_connect(conn):
+ def on_connect(conn: DBAPIConnection) -> None:
if super_ is not None:
super_(conn)
return on_connect
- def do_ping(self, dbapi_connection):
+ def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]:
dbapi_connection.ping()
return True
- def do_executemany(self, cursor, statement, parameters, context=None):
+ def do_executemany(
+ self,
+ cursor: DBAPICursor,
+ statement: str,
+ parameters: _DBAPIMultiExecuteParams,
+ context: Optional[ExecutionContext] = None,
+ ) -> None:
rowcount = cursor.executemany(statement, parameters)
if context is not None:
- context._rowcount = rowcount
-
- def _check_unicode_returns(self, connection):
- # work around issue fixed in
- # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
- # specific issue w/ the utf8mb4_bin collation and unicode returns
-
- collation = connection.exec_driver_sql(
- "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
- % (
- self.identifier_preparer.quote("Charset"),
- self.identifier_preparer.quote("Collation"),
- )
- ).scalar()
- has_utf8mb4_bin = self.server_version_info > (5,) and collation
- if has_utf8mb4_bin:
- additional_tests = [
- sql.collate(
- sql.cast(
- sql.literal_column("'test collated returns'"),
- TEXT(charset="utf8mb4"),
- ),
- "utf8mb4_bin",
- )
- ]
- else:
- additional_tests = []
- return super()._check_unicode_returns(connection, additional_tests)
+ cast(MySQLExecutionContext, context)._rowcount = rowcount
- def create_connect_args(self, url, _translate_args=None):
+ def create_connect_args(
+ self, url: URL, _translate_args: Optional[dict[str, Any]] = None
+ ) -> ConnectArgsType:
if _translate_args is None:
_translate_args = dict(
database="db", username="user", password="passwd"
if client_flag_found_rows is not None:
client_flag |= client_flag_found_rows
opts["client_flag"] = client_flag
- return [[], opts]
+ return [], opts
- def _found_rows_client_flag(self):
+ def _found_rows_client_flag(self) -> Optional[int]:
if self.dbapi is not None:
try:
CLIENT_FLAGS = __import__(
except (AttributeError, ImportError):
return None
else:
- return CLIENT_FLAGS.FOUND_ROWS
+ return CLIENT_FLAGS.FOUND_ROWS # type: ignore
else:
return None
- def _extract_error_code(self, exception):
- return exception.args[0]
+ def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
+ return exception.args[0] # type: ignore[no-any-return]
- def _detect_charset(self, connection):
+ def _detect_charset(self, connection: Connection) -> str:
"""Sniff out the character set in use for connection results."""
try:
# note: the SQL here would be
# "SHOW VARIABLES LIKE 'character_set%%'"
- cset_name = connection.connection.character_set_name
+
+ cset_name: Callable[[], str] = (
+ connection.connection.character_set_name
+ )
except AttributeError:
util.warn(
"No 'character_set_name' can be detected with "
else:
return cset_name()
- def get_isolation_level_values(self, dbapi_connection):
+ def get_isolation_level_values(
+ self, dbapi_conn: DBAPIConnection
+ ) -> tuple[IsolationLevel, ...]:
return (
"SERIALIZABLE",
"READ UNCOMMITTED",
"AUTOCOMMIT",
)
- def set_isolation_level(self, dbapi_connection, level):
+ def set_isolation_level(
+ self, dbapi_connection: DBAPIConnection, level: IsolationLevel
+ ) -> None:
if level == "AUTOCOMMIT":
dbapi_connection.autocommit(True)
else:
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
-
from ... import exc
from ...testing.provision import configure_follower
from ...testing.provision import create_db
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
r"""
to the pymysql driver as well.
""" # noqa
+from __future__ import annotations
+
+from typing import Any
+from typing import Literal
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from .mysqldb import MySQLDialect_mysqldb
from ...util import langhelpers
+if TYPE_CHECKING:
+
+ from ...engine.interfaces import ConnectArgsType
+ from ...engine.interfaces import DBAPIConnection
+ from ...engine.interfaces import DBAPICursor
+ from ...engine.interfaces import DBAPIModule
+ from ...engine.interfaces import PoolProxiedConnection
+ from ...engine.url import URL
+
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
driver = "pymysql"
description_encoding = None
@langhelpers.memoized_property
- def supports_server_side_cursors(self):
+ def supports_server_side_cursors(self) -> bool: # type: ignore[override]
try:
cursors = __import__("pymysql.cursors").cursors
self._sscursor = cursors.SSCursor
return False
@classmethod
- def import_dbapi(cls):
+ def import_dbapi(cls) -> DBAPIModule:
return __import__("pymysql")
@langhelpers.memoized_property
- def _send_false_to_ping(self):
+ def _send_false_to_ping(self) -> bool:
"""determine if pymysql has deprecated, changed the default of,
or removed the 'reconnect' argument of connection.ping().
not insp.defaults or insp.defaults[0] is not False
)
- def do_ping(self, dbapi_connection):
+ def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]: # type: ignore # noqa: E501
if self._send_false_to_ping:
dbapi_connection.ping(False)
else:
return True
- def create_connect_args(self, url, _translate_args=None):
+ def create_connect_args(
+ self, url: URL, _translate_args: Optional[dict[str, Any]] = None
+ ) -> ConnectArgsType:
if _translate_args is None:
_translate_args = dict(username="user")
return super().create_connect_args(
url, _translate_args=_translate_args
)
- def is_disconnect(self, e, connection, cursor):
+ def is_disconnect(
+ self,
+ e: DBAPIModule.Error,
+ connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
+ cursor: Optional[DBAPICursor],
+ ) -> bool:
if super().is_disconnect(e, connection, cursor):
return True
- elif isinstance(e, self.dbapi.Error):
+ elif isinstance(e, self.loaded_dbapi.Error):
str_e = str(e).lower()
return (
"already closed" in str_e or "connection was killed" in str_e
else:
return False
- def _extract_error_code(self, exception):
+ def _extract_error_code(self, exception: BaseException) -> Any:
if isinstance(exception.args[0], Exception):
exception = exception.args[0]
return exception.args[0]
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
r"""
-
.. dialect:: mysql+pyodbc
:name: PyODBC
:dbapi: pyodbc
connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params
""" # noqa
+from __future__ import annotations
+import datetime
import re
+from typing import Any
+from typing import Callable
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from .base import MySQLDialect
from .base import MySQLExecutionContext
from ...connectors.pyodbc import PyODBCConnector
from ...sql.sqltypes import Time
+if TYPE_CHECKING:
+ from ...engine import Connection
+ from ...engine.interfaces import DBAPIConnection
+ from ...engine.interfaces import Dialect
+ from ...sql.type_api import _ResultProcessorType
+
class _pyodbcTIME(TIME):
- def result_processor(self, dialect, coltype):
- def process(value):
+ def result_processor(
+ self, dialect: Dialect, coltype: object
+ ) -> _ResultProcessorType[datetime.time]:
+ def process(value: Any) -> Union[datetime.time, None]:
# pyodbc returns a datetime.time object; no need to convert
- return value
+ return value # type: ignore[no-any-return]
return process
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
- def get_lastrowid(self):
+ def get_lastrowid(self) -> int:
cursor = self.create_cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
- lastrowid = cursor.fetchone()[0]
+ lastrowid = cursor.fetchone()[0] # type: ignore[index]
cursor.close()
- return lastrowid
+ return lastrowid # type: ignore[no-any-return]
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
pyodbc_driver_name = "MySQL"
- def _detect_charset(self, connection):
+ def _detect_charset(self, connection: Connection) -> str:
"""Sniff out the character set in use for connection results."""
# Prefer 'character_set_results' for the current connection over the
)
return "latin1"
- def _get_server_version_info(self, connection):
+ def _get_server_version_info(
+ self, connection: Connection
+ ) -> tuple[int, ...]:
return MySQLDialect._get_server_version_info(self, connection)
- def _extract_error_code(self, exception):
+ def _extract_error_code(self, exception: BaseException) -> Optional[int]:
m = re.compile(r"\((\d+)\)").search(str(exception.args))
- c = m.group(1)
+ if m is None:
+ return None
+ c: Optional[str] = m.group(1)
if c:
return int(c)
else:
return None
- def on_connect(self):
+ def on_connect(self) -> Callable[[DBAPIConnection], None]:
super_ = super().on_connect()
- def on_connect(conn):
+ def on_connect(conn: DBAPIConnection) -> None:
if super_ is not None:
super_(conn)
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
+from __future__ import annotations
import re
+from typing import Any
+from typing import Callable
+from typing import Literal
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import TYPE_CHECKING
+from typing import Union
from .enumerated import ENUM
from .enumerated import SET
from .types import DATETIME
from .types import TIME
from .types import TIMESTAMP
-from ... import log
from ... import types as sqltypes
from ... import util
+if TYPE_CHECKING:
+ from .base import MySQLDialect
+ from .base import MySQLIdentifierPreparer
+ from ...engine.interfaces import ReflectedColumn
+
class ReflectedState:
"""Stores raw information about a SHOW CREATE TABLE statement."""
- def __init__(self):
- self.columns = []
- self.table_options = {}
- self.table_name = None
- self.keys = []
- self.fk_constraints = []
- self.ck_constraints = []
+ charset: Optional[str]
+
+ def __init__(self) -> None:
+ self.columns: list[ReflectedColumn] = []
+ self.table_options: dict[str, str] = {}
+ self.table_name: Optional[str] = None
+ self.keys: list[dict[str, Any]] = []
+ self.fk_constraints: list[dict[str, Any]] = []
+ self.ck_constraints: list[dict[str, Any]] = []
-@log.class_logger
class MySQLTableDefinitionParser:
"""Parses the results of a SHOW CREATE TABLE statement."""
- def __init__(self, dialect, preparer):
+ def __init__(
+ self, dialect: MySQLDialect, preparer: MySQLIdentifierPreparer
+ ):
self.dialect = dialect
self.preparer = preparer
self._prep_regexes()
- def parse(self, show_create, charset):
+ def parse(
+ self, show_create: str, charset: Optional[str]
+ ) -> ReflectedState:
state = ReflectedState()
state.charset = charset
for line in re.split(r"\r?\n", show_create):
if type_ is None:
util.warn("Unknown schema content: %r" % line)
elif type_ == "key":
- state.keys.append(spec)
+ state.keys.append(spec) # type: ignore[arg-type]
elif type_ == "fk_constraint":
- state.fk_constraints.append(spec)
+ state.fk_constraints.append(spec) # type: ignore[arg-type]
elif type_ == "ck_constraint":
- state.ck_constraints.append(spec)
+ state.ck_constraints.append(spec) # type: ignore[arg-type]
else:
pass
return state
def _check_view(self, sql: str) -> bool:
return bool(self._re_is_view.match(sql))
- def _parse_constraints(self, line):
+ def _parse_constraints(self, line: str) -> Union[
+ tuple[None, str],
+ tuple[Literal["partition"], str],
+ tuple[
+ Literal["ck_constraint", "fk_constraint", "key"], dict[str, str]
+ ],
+ ]:
"""Parse a KEY or CONSTRAINT line.
:param line: A line of SHOW CREATE TABLE output
# No match.
return (None, line)
- def _parse_table_name(self, line, state):
+ def _parse_table_name(self, line: str, state: ReflectedState) -> None:
"""Extract the table name.
:param line: The first line of SHOW CREATE TABLE
if m:
state.table_name = cleanup(m.group("name"))
- def _parse_table_options(self, line, state):
+ def _parse_table_options(self, line: str, state: ReflectedState) -> None:
"""Build a dictionary of all reflected table-level options.
:param line: The final line of SHOW CREATE TABLE output.
for opt, val in options.items():
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
- def _parse_partition_options(self, line, state):
+ def _parse_partition_options(
+ self, line: str, state: ReflectedState
+ ) -> None:
options = {}
new_line = line[:]
else:
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
- def _parse_column(self, line, state):
+ def _parse_column(self, line: str, state: ReflectedState) -> None:
"""Extract column details.
Falls back to a 'minimal support' variant if full parse fails.
type_instance = col_type(*type_args, **type_kw)
- col_kw = {}
+ col_kw: dict[str, Any] = {}
# NOT NULL
col_kw["nullable"] = True
name=name, type=type_instance, default=default, comment=comment
)
col_d.update(col_kw)
- state.columns.append(col_d)
+ state.columns.append(col_d) # type: ignore[arg-type]
- def _describe_to_create(self, table_name, columns):
+ def _describe_to_create(
+ self,
+ table_name: str,
+ columns: Sequence[tuple[str, str, str, str, str, str]],
+ ) -> str:
"""Re-format DESCRIBE output as a SHOW CREATE TABLE string.
DESCRIBE is a much simpler reflection and is sufficient for
]
)
- def _parse_keyexprs(self, identifiers):
+ def _parse_keyexprs(
+ self, identifiers: str
+ ) -> list[tuple[str, Optional[int], str]]:
"""Unpack '"col"(2),"col" ASC'-ish strings into components."""
return [
)
]
- def _prep_regexes(self):
+ def _prep_regexes(self) -> None:
"""Pre-compile regular expressions."""
- self._re_columns = []
- self._pr_options = []
+ self._pr_options: list[
+ tuple[re.Pattern[Any], Optional[Callable[[str], str]]]
+ ] = []
_final = self.preparer.final_quote
_optional_equals = r"(?:\s*(?:=\s*)|\s+)"
- def _add_option_string(self, directive):
+ def _add_option_string(self, directive: str) -> None:
regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % (
re.escape(directive),
self._optional_equals,
)
self._pr_options.append(_pr_compile(regex, cleanup_text))
- def _add_option_word(self, directive):
+ def _add_option_word(self, directive: str) -> None:
regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % (
re.escape(directive),
self._optional_equals,
)
self._pr_options.append(_pr_compile(regex))
- def _add_partition_option_word(self, directive):
+ def _add_partition_option_word(self, directive: str) -> None:
if directive == "PARTITION BY" or directive == "SUBPARTITION BY":
regex = r"(?<!\S)(?P<directive>%s)%s" r"(?P<val>\w+.*)" % (
re.escape(directive),
regex = r"(?<!\S)(?P<directive>%s)(?!\S)" % (re.escape(directive),)
self._pr_options.append(_pr_compile(regex))
- def _add_option_regex(self, directive, regex):
+ def _add_option_regex(self, directive: str, regex: str) -> None:
regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % (
re.escape(directive),
self._optional_equals,
)
-def _pr_compile(regex, cleanup=None):
+@overload
+def _pr_compile(
+ regex: str, cleanup: Callable[[str], str]
+) -> tuple[re.Pattern[Any], Callable[[str], str]]: ...
+
+
+@overload
+def _pr_compile(
+ regex: str, cleanup: None = None
+) -> tuple[re.Pattern[Any], None]: ...
+
+
+def _pr_compile(
+ regex: str, cleanup: Optional[Callable[[str], str]] = None
+) -> tuple[re.Pattern[Any], Optional[Callable[[str], str]]]:
"""Prepare a 2-tuple of compiled regex and callable."""
return (_re_compile(regex), cleanup)
-def _re_compile(regex):
+def _re_compile(regex: str) -> re.Pattern[Any]:
"""Compile a string to regex, I and UNICODE."""
return re.compile(regex, re.I | re.UNICODE)
-def _strip_values(values):
+def _strip_values(values: Sequence[str]) -> list[str]:
"Strip reflected values quotes"
- strip_values = []
+ strip_values: list[str] = []
for a in values:
if a[0:1] == '"' or a[0:1] == "'":
# strip enclosing quotes and unquote interior
def cleanup_text(raw_text: str) -> str:
if "\\" in raw_text:
raw_text = re.sub(
- _control_char_regexp, lambda s: _control_char_map[s[0]], raw_text
+ _control_char_regexp,
+ lambda s: _control_char_map[s[0]], # type: ignore[index]
+ raw_text,
)
return raw_text.replace("''", "'")
# https://mariadb.com/kb/en/reserved-words/
# includes: Reserved Words, Oracle Mode (separate set unioned)
# excludes: Exceptions, Function Names
-# mypy: ignore-errors
RESERVED_WORDS_MARIADB = {
"accessible",
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
+from __future__ import annotations
import datetime
+import decimal
+from typing import Any
+from typing import Iterable
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from ... import exc
from ... import util
from ...sql import sqltypes
+if TYPE_CHECKING:
+ from .base import MySQLDialect
+ from ...engine.interfaces import Dialect
+ from ...sql.type_api import _BindProcessorType
+ from ...sql.type_api import _ResultProcessorType
+
class _NumericCommonType:
"""Base for MySQL numeric types.
"""
- def __init__(self, unsigned=False, zerofill=False, **kw):
+ def __init__(
+ self, unsigned: bool = False, zerofill: bool = False, **kw: Any
+ ):
self.unsigned = unsigned
self.zerofill = zerofill
super().__init__(**kw)
-class _NumericType(_NumericCommonType, sqltypes.Numeric):
+class _NumericType(
+ _NumericCommonType, sqltypes.Numeric[Union[decimal.Decimal, float]]
+):
- def __repr__(self):
+ def __repr__(self) -> str:
return util.generic_repr(
self,
to_inspect=[_NumericType, _NumericCommonType, sqltypes.Numeric],
)
-class _FloatType(_NumericCommonType, sqltypes.Float):
+class _FloatType(
+ _NumericCommonType, sqltypes.Float[Union[decimal.Decimal, float]]
+):
- def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ def __init__(
+ self,
+ precision: Optional[int] = None,
+ scale: Optional[int] = None,
+ asdecimal: bool = True,
+ **kw: Any,
+ ):
if isinstance(self, (REAL, DOUBLE)) and (
(precision is None and scale is not None)
or (precision is not None and scale is None)
super().__init__(precision=precision, asdecimal=asdecimal, **kw)
self.scale = scale
- def __repr__(self):
+ def __repr__(self) -> str:
return util.generic_repr(
self, to_inspect=[_FloatType, _NumericCommonType, sqltypes.Float]
)
class _IntegerType(_NumericCommonType, sqltypes.Integer):
- def __init__(self, display_width=None, **kw):
+ def __init__(self, display_width: Optional[int] = None, **kw: Any):
self.display_width = display_width
super().__init__(**kw)
- def __repr__(self):
+ def __repr__(self) -> str:
return util.generic_repr(
self,
to_inspect=[_IntegerType, _NumericCommonType, sqltypes.Integer],
def __init__(
self,
- charset=None,
- collation=None,
- ascii=False, # noqa
- binary=False,
- unicode=False,
- national=False,
- **kw,
+ charset: Optional[str] = None,
+ collation: Optional[str] = None,
+ ascii: bool = False, # noqa
+ binary: bool = False,
+ unicode: bool = False,
+ national: bool = False,
+ **kw: Any,
):
self.charset = charset
self.national = national
super().__init__(**kw)
- def __repr__(self):
+ def __repr__(self) -> str:
return util.generic_repr(
self, to_inspect=[_StringType, sqltypes.String]
)
-class _MatchType(sqltypes.Float, sqltypes.MatchType):
- def __init__(self, **kw):
+class _MatchType(
+ sqltypes.Float[Union[decimal.Decimal, float]], sqltypes.MatchType
+):
+ def __init__(self, **kw: Any):
# TODO: float arguments?
- sqltypes.Float.__init__(self)
+ sqltypes.Float.__init__(self) # type: ignore[arg-type]
sqltypes.MatchType.__init__(self)
-class NUMERIC(_NumericType, sqltypes.NUMERIC):
+class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]):
"""MySQL NUMERIC type."""
__visit_name__ = "NUMERIC"
- def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ def __init__(
+ self,
+ precision: Optional[int] = None,
+ scale: Optional[int] = None,
+ asdecimal: bool = True,
+ **kw: Any,
+ ):
"""Construct a NUMERIC.
:param precision: Total digits in this number. If scale and precision
)
-class DECIMAL(_NumericType, sqltypes.DECIMAL):
+class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]):
"""MySQL DECIMAL type."""
__visit_name__ = "DECIMAL"
- def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ def __init__(
+ self,
+ precision: Optional[int] = None,
+ scale: Optional[int] = None,
+ asdecimal: bool = True,
+ **kw: Any,
+ ):
"""Construct a DECIMAL.
:param precision: Total digits in this number. If scale and precision
)
-class DOUBLE(_FloatType, sqltypes.DOUBLE):
+class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]):
"""MySQL DOUBLE type."""
__visit_name__ = "DOUBLE"
- def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ def __init__(
+ self,
+ precision: Optional[int] = None,
+ scale: Optional[int] = None,
+ asdecimal: bool = True,
+ **kw: Any,
+ ):
"""Construct a DOUBLE.
.. note::
)
-class REAL(_FloatType, sqltypes.REAL):
+class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]):
"""MySQL REAL type."""
__visit_name__ = "REAL"
- def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ def __init__(
+ self,
+ precision: Optional[int] = None,
+ scale: Optional[int] = None,
+ asdecimal: bool = True,
+ **kw: Any,
+ ):
"""Construct a REAL.
.. note::
)
-class FLOAT(_FloatType, sqltypes.FLOAT):
+class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]):
"""MySQL FLOAT type."""
__visit_name__ = "FLOAT"
- def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
+ def __init__(
+ self,
+ precision: Optional[int] = None,
+ scale: Optional[int] = None,
+ asdecimal: bool = False,
+ **kw: Any,
+ ):
"""Construct a FLOAT.
:param precision: Total digits in this number. If scale and precision
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
- def bind_processor(self, dialect):
+ def bind_processor(
+ self, dialect: Dialect
+ ) -> Optional[_BindProcessorType[Union[decimal.Decimal, float]]]:
return None
__visit_name__ = "INTEGER"
- def __init__(self, display_width=None, **kw):
+ def __init__(self, display_width: Optional[int] = None, **kw: Any):
"""Construct an INTEGER.
:param display_width: Optional, maximum display width for this number.
__visit_name__ = "BIGINT"
- def __init__(self, display_width=None, **kw):
+ def __init__(self, display_width: Optional[int] = None, **kw: Any):
"""Construct a BIGINTEGER.
:param display_width: Optional, maximum display width for this number.
__visit_name__ = "MEDIUMINT"
- def __init__(self, display_width=None, **kw):
+ def __init__(self, display_width: Optional[int] = None, **kw: Any):
"""Construct a MEDIUMINTEGER
:param display_width: Optional, maximum display width for this number.
__visit_name__ = "TINYINT"
- def __init__(self, display_width=None, **kw):
+ def __init__(self, display_width: Optional[int] = None, **kw: Any):
"""Construct a TINYINT.
:param display_width: Optional, maximum display width for this number.
__visit_name__ = "SMALLINT"
- def __init__(self, display_width=None, **kw):
+ def __init__(self, display_width: Optional[int] = None, **kw: Any):
"""Construct a SMALLINTEGER.
:param display_width: Optional, maximum display width for this number.
super().__init__(display_width=display_width, **kw)
-class BIT(sqltypes.TypeEngine):
+class BIT(sqltypes.TypeEngine[Any]):
"""MySQL BIT type.
This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
__visit_name__ = "BIT"
- def __init__(self, length=None):
+ def __init__(self, length: Optional[int] = None):
"""Construct a BIT.
:param length: Optional, number of bits.
"""
self.length = length
- def result_processor(self, dialect, coltype):
+ def result_processor(
+ self, dialect: MySQLDialect, coltype: object # type: ignore[override]
+ ) -> Optional[_ResultProcessorType[Any]]:
"""Convert a MySQL's 64 bit, variable length binary string to a
long."""
if dialect.supports_native_bit:
return None
- def process(value):
+ def process(value: Optional[Iterable[int]]) -> Optional[int]:
if value is not None:
v = 0
for i in value:
- if not isinstance(i, int):
- i = ord(i) # convert byte to int on Python 2
v = v << 8 | i
return v
return value
__visit_name__ = "TIME"
- def __init__(self, timezone=False, fsp=None):
+ def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
"""Construct a MySQL TIME type.
:param timezone: not used by the MySQL dialect.
super().__init__(timezone=timezone)
self.fsp = fsp
- def result_processor(self, dialect, coltype):
+ def result_processor(
+ self, dialect: Dialect, coltype: object
+ ) -> _ResultProcessorType[datetime.time]:
time = datetime.time
- def process(value):
+ def process(value: Any) -> Optional[datetime.time]:
# convert from a timedelta value
if value is not None:
microseconds = value.microseconds
__visit_name__ = "TIMESTAMP"
- def __init__(self, timezone=False, fsp=None):
+ def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
"""Construct a MySQL TIMESTAMP type.
:param timezone: not used by the MySQL dialect.
__visit_name__ = "DATETIME"
- def __init__(self, timezone=False, fsp=None):
+ def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
"""Construct a MySQL DATETIME type.
:param timezone: not used by the MySQL dialect.
self.fsp = fsp
-class YEAR(sqltypes.TypeEngine):
+class YEAR(sqltypes.TypeEngine[Any]):
"""MySQL YEAR type, for single byte storage of years 1901-2155."""
__visit_name__ = "YEAR"
- def __init__(self, display_width=None):
+ def __init__(self, display_width: Optional[int] = None):
self.display_width = display_width
__visit_name__ = "TEXT"
- def __init__(self, length=None, **kw):
+ def __init__(self, length: Optional[int] = None, **kw: Any):
"""Construct a TEXT.
:param length: Optional, if provided the server may optimize storage
__visit_name__ = "TINYTEXT"
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: Any):
"""Construct a TINYTEXT.
:param charset: Optional, a column-level character set for this string
__visit_name__ = "MEDIUMTEXT"
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: Any):
"""Construct a MEDIUMTEXT.
:param charset: Optional, a column-level character set for this string
__visit_name__ = "LONGTEXT"
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: Any):
"""Construct a LONGTEXT.
:param charset: Optional, a column-level character set for this string
__visit_name__ = "VARCHAR"
- def __init__(self, length=None, **kwargs):
+ def __init__(self, length: Optional[int] = None, **kwargs: Any) -> None:
"""Construct a VARCHAR.
:param charset: Optional, a column-level character set for this string
__visit_name__ = "CHAR"
- def __init__(self, length=None, **kwargs):
+ def __init__(self, length: Optional[int] = None, **kwargs: Any):
"""Construct a CHAR.
:param length: Maximum data length, in characters.
super().__init__(length=length, **kwargs)
@classmethod
- def _adapt_string_for_cast(cls, type_):
+ def _adapt_string_for_cast(cls, type_: sqltypes.String) -> sqltypes.CHAR:
# copy the given string type into a CHAR
# for the purposes of rendering a CAST expression
type_ = sqltypes.to_instance(type_)
__visit_name__ = "NVARCHAR"
- def __init__(self, length=None, **kwargs):
+ def __init__(self, length: Optional[int] = None, **kwargs: Any):
"""Construct an NVARCHAR.
:param length: Maximum data length, in characters.
__visit_name__ = "NCHAR"
- def __init__(self, length=None, **kwargs):
+ def __init__(self, length: Optional[int] = None, **kwargs: Any):
"""Construct an NCHAR.
:param length: Maximum data length, in characters.
from .interfaces import _ParamStyle
from .interfaces import ConnectArgsType
from .interfaces import DBAPIConnection
+ from .interfaces import DBAPIModule
from .interfaces import IsolationLevel
from .row import Row
from .url import URL
delete_executemany_returning = False
@util.memoized_property
- def loaded_dbapi(self) -> ModuleType:
+ def loaded_dbapi(self) -> DBAPIModule:
if self.dbapi is None:
raise exc.InvalidRequestError(
f"Dialect {self} does not have a Python DBAPI established "
% (self.label_length, self.max_identifier_length)
)
- def on_connect(self) -> Optional[Callable[[Any], Any]]:
+ def on_connect(self) -> Optional[Callable[[Any], None]]:
# inherits the docstring from interfaces.Dialect.on_connect
return None
def is_disconnect(
self,
- e: Exception,
+ e: DBAPIModule.Error,
connection: Union[
pool.PoolProxiedConnection, interfaces.DBAPIConnection, None
],
name = name_upper
return name
- def get_driver_connection(self, connection):
+ def get_driver_connection(self, connection: DBAPIConnection) -> Any:
return connection
def _overrides_default(self, method):
from __future__ import annotations
from enum import Enum
-from types import ModuleType
from typing import Any
from typing import Awaitable
from typing import Callable
from .. import util
from ..event import EventTarget
from ..pool import Pool
-from ..pool import PoolProxiedConnection
+from ..pool import PoolProxiedConnection as PoolProxiedConnection
from ..sql.compiler import Compiled as Compiled
from ..sql.compiler import Compiled # noqa
from ..sql.compiler import TypeCompiler as TypeCompiler
from .base import Engine
from .cursor import CursorResult
from .url import URL
+ from ..connectors.asyncio import AsyncIODBAPIConnection
from ..event import _ListenerFnType
from ..event import dispatcher
from ..exc import StatementError
from ..sql.sqltypes import Integer
from ..sql.type_api import _TypeMemoDict
from ..sql.type_api import TypeEngine
+ from ..util.langhelpers import generic_fn_descriptor
ConnectArgsType = Tuple[Sequence[str], MutableMapping[str, Any]]
"""
+class DBAPIModule(Protocol):
+ class Error(Exception):
+ def __getattr__(self, key: str) -> Any: ...
+
+ class OperationalError(Error):
+ pass
+
+ class InterfaceError(Error):
+ pass
+
+ class IntegrityError(Error):
+ pass
+
+ def __getattr__(self, key: str) -> Any: ...
+
+
class DBAPIConnection(Protocol):
"""protocol representing a :pep:`249` database connection.
def rollback(self) -> None: ...
- autocommit: bool
+ def __getattr__(self, key: str) -> Any: ...
+
+ def __setattr__(self, key: str, value: Any) -> None: ...
class DBAPIType(Protocol):
dialect_description: str
- dbapi: Optional[ModuleType]
+ dbapi: Optional[DBAPIModule]
"""A reference to the DBAPI module object itself.
SQLAlchemy dialects import DBAPI modules using the classmethod
"""
@util.non_memoized_property
- def loaded_dbapi(self) -> ModuleType:
+ def loaded_dbapi(self) -> DBAPIModule:
"""same as .dbapi, but is never None; will raise an error if no
DBAPI was set up.
"""The maximum length of constraint names if different from
``max_identifier_length``."""
- supports_server_side_cursors: bool
+ supports_server_side_cursors: Union[generic_fn_descriptor[bool], bool]
"""indicates if the dialect supports server side cursors"""
server_side_cursors: bool
raise NotImplementedError()
@classmethod
- def import_dbapi(cls) -> ModuleType:
+ def import_dbapi(cls) -> DBAPIModule:
"""Import the DBAPI module that is used by this dialect.
The Python module object returned here will be assigned as an
def is_disconnect(
self,
- e: Exception,
+ e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
"""
return self.on_connect()
- def on_connect(self) -> Optional[Callable[[Any], Any]]:
+ def on_connect(self) -> Optional[Callable[[Any], None]]:
"""return a callable which sets up a newly created DBAPI connection.
The callable should accept a single argument "conn" which is the
__slots__ = ("_connection",)
- _connection: Any
+ _connection: AsyncIODBAPIConnection
@property
def driver_connection(self) -> Any:
def rollback(self) -> None: ...
+ def __getattr__(self, key: str) -> Any: ...
+
@property
def is_valid(self) -> bool:
"""Return True if this :class:`.PoolProxiedConnection` still refers
from .base import Executable
from .cache_key import CacheKey
from .ddl import ExecutableDDLElement
+ from .dml import Delete
from .dml import Insert
from .dml import Update
from .dml import UpdateBase
"criteria within UPDATE"
)
- def update_post_criteria_clause(self, update_stmt, **kw):
+ def update_post_criteria_clause(
+ self, update_stmt: Update, **kw: Any
+ ) -> Optional[str]:
"""provide a hook to override generation after the WHERE criteria
in an UPDATE statement
else:
return None
- def delete_post_criteria_clause(self, delete_stmt, **kw):
+ def delete_post_criteria_clause(
+ self, delete_stmt: Delete, **kw: Any
+ ) -> Optional[str]:
"""provide a hook to override generation after the WHERE criteria
in a DELETE statement
else:
schema_name = None
- index_name = self.preparer.format_index(index)
+ index_name: str = self.preparer.format_index(index)
if schema_name:
index_name = schema_name + "." + index_name
"""
+ element: _SI
+
def __init__(self, element: _SI) -> None:
self.element = self.target = element
self._ddl_if = getattr(element, "_ddl_if", None)
from ..util.typing import TupleAny
from ..util.typing import Unpack
+
if typing.TYPE_CHECKING:
from ._typing import _ByArgument
from ._typing import _ColumnExpressionArgument
from ..engine.interfaces import SchemaTranslateMapType
from ..engine.result import Result
+
_NUMERIC = Union[float, Decimal]
_NUMBER = Union[float, int, Decimal]
else:
return self
- def _with_binary_element_type(self, type_):
- c: Self = ClauseElement._clone(self) # type: ignore[assignment]
+ def _with_binary_element_type(self, type_: TypeEngine[Any]) -> Self:
+ c: Self = ClauseElement._clone(self)
c.type = type_
return c
self.type = sqltypes.BOOLEANTYPE
self.negate = None
self._is_implicitly_boolean = True
- self.modifiers = {}
+ self.modifiers = util.immutabledict({})
@property
def left_expr(self) -> ColumnElement[Any]:
from __future__ import annotations
from enum import Enum
-from types import ModuleType
import typing
from typing import Any
from typing import Callable
from .sqltypes import NUMERICTYPE as NUMERICTYPE # noqa: F401
from .sqltypes import STRINGTYPE as STRINGTYPE # noqa: F401
from .sqltypes import TABLEVALUE as TABLEVALUE # noqa: F401
+ from ..engine.interfaces import DBAPIModule
from ..engine.interfaces import Dialect
from ..util.typing import GenericProtocol
return x == y # type: ignore[no-any-return]
- def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]:
+ def get_dbapi_type(self, dbapi: DBAPIModule) -> Optional[Any]:
"""Return the corresponding type object from the underlying DB-API, if
any.
instance.__dict__.update(self.__dict__)
return instance
- def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]:
+ def get_dbapi_type(self, dbapi: DBAPIModule) -> Optional[Any]:
"""Return the DBAPI type object represented by this
:class:`.TypeDecorator`.
asyncio = ["greenlet>=1"]
mypy = [
"mypy >= 1.7",
- "types-greenlet >= 2"
+ "types-greenlet >= 2",
]
mssql = ["pyodbc"]
mssql-pymssql = ["pymssql"]
postgresql-psycopg = ["psycopg>=3.0.7,!=3.1.15"]
postgresql-psycopgbinary = ["psycopg[binary]>=3.0.7,!=3.1.15"]
pymysql = ["pymysql"]
+cymysql = ["cymysql"]
aiomysql = [
"greenlet>=1", # same as ".[asyncio]" if this syntax were supported
"aiomysql",