]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
miscelaneous to type mysql dialect
authorPablo Estevez <pablo22estevez@gmail.com>
Wed, 8 Jan 2025 20:31:05 +0000 (17:31 -0300)
committerPablo Estevez <pablo22estevez@gmail.com>
Wed, 8 Jan 2025 20:31:05 +0000 (17:31 -0300)
19 files changed:
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/ddl.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/typing.py
test/dialect/oracle/test_dialect.py

index e57f7bfdf2186a1d1196d169aae81d2a37b8e9fc..bce08d9cc353e018eb6940560b7a249fc007a938 100644 (file)
@@ -40,7 +40,7 @@ class AsyncIODBAPIConnection(Protocol):
 
     async def commit(self) -> None: ...
 
-    def cursor(self) -> AsyncIODBAPICursor: ...
+    def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ...
 
     async def rollback(self) -> None: ...
 
index 3a32d19c8bbc1b23f3e555ee3b8b6d5c0d0d521c..b26d32fdf7e66b76f9a2a2e990289e7d82752124 100644 (file)
@@ -14,6 +14,7 @@ from typing import Any
 from typing import Dict
 from typing import List
 from typing import Optional
+from typing import Sequence
 from typing import Tuple
 from typing import Union
 
@@ -228,8 +229,8 @@ class PyODBCConnector(Connector):
 
     def get_isolation_level_values(
         self, dbapi_connection: interfaces.DBAPIConnection
-    ) -> List[IsolationLevel]:
-        return super().get_isolation_level_values(dbapi_connection) + [
+    ) -> Sequence[IsolationLevel]:
+        return super().get_isolation_level_values(dbapi_connection) + [  # type: ignore  # NOQA: E501
             "AUTOCOMMIT"
         ]
 
index 56d7ee758855fdda5d1f70ee854885fe27e29b91..dc4f16f5143ace54f7b02a511664a104850caa7c 100644 (file)
@@ -1379,7 +1379,10 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy):
     __slots__ = ("_rowbuffer", "alternate_cursor_description")
 
     def __init__(
-        self, dbapi_cursor, alternate_description=None, initial_buffer=None
+        self,
+        dbapi_cursor: DBAPICursor,
+        alternate_description: _DBAPICursorDescription | None = None,
+        initial_buffer: Any = None,
     ):
         self.alternate_cursor_description = alternate_description
         if initial_buffer is not None:
@@ -1900,7 +1903,7 @@ class CursorResult(Result[Unpack[_Ts]]):
         clone._metadata = clone._metadata._splice_horizontally(other._metadata)
 
         clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
-            None,
+            None,  # type: ignore[arg-type]
             initial_buffer=total_rows,
         )
         clone._reset_memoizations()
@@ -1932,7 +1935,7 @@ class CursorResult(Result[Unpack[_Ts]]):
         )
 
         clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
-            None,
+            None,  # type: ignore[arg-type]
             initial_buffer=total_rows,
         )
         clone._reset_memoizations()
@@ -1961,7 +1964,7 @@ class CursorResult(Result[Unpack[_Ts]]):
         )._remove_processors()
 
         self.cursor_strategy = FullyBufferedCursorFetchStrategy(
-            None,
+            None,  # type: ignore[arg-type]
             # TODO: if these are Row objects, can we save on not having to
             # re-make new Row objects out of them a second time?  is that
             # what's actually happening right now?  maybe look into this
index ba59ac297bca9ddb47e3f92d430b6270eb54fc00..8dce692d7357b1de0e066ab129d27fb47ed3c619 100644 (file)
@@ -80,9 +80,11 @@ if typing.TYPE_CHECKING:
     from .interfaces import _CoreSingleExecuteParams
     from .interfaces import _DBAPICursorDescription
     from .interfaces import _DBAPIMultiExecuteParams
+    from .interfaces import _DBAPISingleExecuteParams
     from .interfaces import _ExecuteOptions
     from .interfaces import _MutableCoreSingleExecuteParams
     from .interfaces import _ParamStyle
+    from .interfaces import ConnectArgsType
     from .interfaces import DBAPIConnection
     from .interfaces import IsolationLevel
     from .row import Row
@@ -101,6 +103,8 @@ if typing.TYPE_CHECKING:
     from ..sql.type_api import _BindProcessorType
     from ..sql.type_api import _ResultProcessorType
     from ..sql.type_api import TypeEngine
+    from ..util.langhelpers import generic_fn_descriptor
+
 
 # When we're handed literal SQL, ensure it's a SELECT query
 SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE)
@@ -245,7 +249,7 @@ class DefaultDialect(Dialect):
 
     supports_is_distinct_from = True
 
-    supports_server_side_cursors = False
+    supports_server_side_cursors: "generic_fn_descriptor[bool] | bool" = False
 
     server_side_cursors = False
 
@@ -440,7 +444,7 @@ class DefaultDialect(Dialect):
     def _bind_typing_render_casts(self):
         return self.bind_typing is interfaces.BindTyping.RENDER_CASTS
 
-    def _ensure_has_table_connection(self, arg):
+    def _ensure_has_table_connection(self, arg: Connection) -> None:
         if not isinstance(arg, Connection):
             raise exc.ArgumentError(
                 "The argument passed to Dialect.has_table() should be a "
@@ -524,7 +528,7 @@ class DefaultDialect(Dialect):
         else:
             return None
 
-    def initialize(self, connection):
+    def initialize(self, connection: Connection) -> None:
         try:
             self.server_version_info = self._get_server_version_info(
                 connection
@@ -560,7 +564,7 @@ class DefaultDialect(Dialect):
                 % (self.label_length, self.max_identifier_length)
             )
 
-    def on_connect(self):
+    def on_connect(self) -> Callable[[Any], None] | None:
         # inherits the docstring from interfaces.Dialect.on_connect
         return None
 
@@ -626,11 +630,11 @@ class DefaultDialect(Dialect):
                 % (ident, self.max_identifier_length)
             )
 
-    def connect(self, *cargs, **cparams):
+    def connect(self, *cargs: Any, **cparams: Any):  # type: ignore[no-untyped-def]  # NOQA: E501
         # inherits the docstring from interfaces.Dialect.connect
         return self.loaded_dbapi.connect(*cargs, **cparams)
 
-    def create_connect_args(self, url):
+    def create_connect_args(self, url: URL) -> "ConnectArgsType":
         # inherits the docstring from interfaces.Dialect.create_connect_args
         opts = url.translate_connect_args()
         opts.update(url.query)
@@ -953,7 +957,14 @@ class DefaultDialect(Dialect):
     def do_execute_no_params(self, cursor, statement, context=None):
         cursor.execute(statement)
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: Exception,
+        connection: (
+            pool.PoolProxiedConnection | interfaces.DBAPIConnection | None
+        ),
+        cursor: interfaces.DBAPICursor | None,
+    ) -> bool:
         return False
 
     @util.memoized_instancemethod
@@ -1669,7 +1680,12 @@ class DefaultExecutionContext(ExecutionContext):
     def no_parameters(self):
         return self.execution_options.get("no_parameters", False)
 
-    def _execute_scalar(self, stmt, type_, parameters=None):
+    def _execute_scalar(
+        self,
+        stmt: str,
+        type_: TypeEngine[Any] | None,
+        parameters: _DBAPISingleExecuteParams | None = None,
+    ) -> Any:
         """Execute a string statement on the current cursor, returning a
         scalar result.
 
@@ -1743,7 +1759,7 @@ class DefaultExecutionContext(ExecutionContext):
 
         return use_server_side
 
-    def create_cursor(self):
+    def create_cursor(self) -> DBAPICursor:
         if (
             # inlining initial preference checks for SS cursors
             self.dialect.supports_server_side_cursors
@@ -1764,10 +1780,10 @@ class DefaultExecutionContext(ExecutionContext):
     def fetchall_for_returning(self, cursor):
         return cursor.fetchall()
 
-    def create_default_cursor(self):
+    def create_default_cursor(self) -> DBAPICursor:
         return self._dbapi_connection.cursor()
 
-    def create_server_side_cursor(self):
+    def create_server_side_cursor(self) -> DBAPICursor:
         raise NotImplementedError()
 
     def pre_exec(self):
index 35c52ae3b942df1a030ee9eb551ca3d625a322dc..74ec272514950e75b078b47dc4aa36507a6e2d34 100644 (file)
@@ -70,6 +70,7 @@ if TYPE_CHECKING:
     from ..sql.sqltypes import Integer
     from ..sql.type_api import _TypeMemoDict
     from ..sql.type_api import TypeEngine
+    from ..util.langhelpers import generic_fn_descriptor
 
 ConnectArgsType = Tuple[Sequence[str], MutableMapping[str, Any]]
 
@@ -122,7 +123,7 @@ class DBAPIConnection(Protocol):
 
     def commit(self) -> None: ...
 
-    def cursor(self) -> DBAPICursor: ...
+    def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ...
 
     def rollback(self) -> None: ...
 
@@ -241,7 +242,7 @@ _AnyMultiExecuteParams = _DBAPIMultiExecuteParams
 _AnyExecuteParams = _DBAPIAnyExecuteParams
 
 CompiledCacheType = MutableMapping[Any, "Compiled"]
-SchemaTranslateMapType = Mapping[Optional[str], Optional[str]]
+SchemaTranslateMapType = MutableMapping[Optional[str], Optional[str]]
 
 _ImmutableExecuteOptions = immutabledict[str, Any]
 
@@ -547,6 +548,8 @@ class ReflectedIndex(TypedDict):
     dialect_options: NotRequired[Dict[str, Any]]
     """Additional dialect-specific options detected for this index"""
 
+    type: NotRequired[str]
+
 
 class ReflectedTableComment(TypedDict):
     """Dictionary representing the reflected comment corresponding to
@@ -781,7 +784,7 @@ class Dialect(EventTarget):
     max_identifier_length: int
     """The maximum length of identifier names."""
 
-    supports_server_side_cursors: bool
+    supports_server_side_cursors: "generic_fn_descriptor[bool] | bool"
     """indicates if the dialect supports server side cursors"""
 
     server_side_cursors: bool
@@ -2483,7 +2486,7 @@ class Dialect(EventTarget):
 
     def get_isolation_level_values(
         self, dbapi_conn: DBAPIConnection
-    ) -> List[IsolationLevel]:
+    ) -> Sequence[IsolationLevel]:
         """return a sequence of string isolation level names that are accepted
         by this dialect.
 
index b91048e387911c990c81a31323e431ce3189b065..1eb8b4c217473e20996c4676fa1beb49c8afd0bb 100644 (file)
@@ -1074,7 +1074,7 @@ class PoolProxiedConnection(ManagesConnection):
 
         def commit(self) -> None: ...
 
-        def cursor(self) -> DBAPICursor: ...
+        def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ...
 
         def rollback(self) -> None: ...
 
index a93ea4e42e8a0452385c8af5e5bb80d72ad3e26d..aa86b5e880ceec6a38f75bd0c02351101a0c091c 100644 (file)
@@ -656,7 +656,9 @@ class CompileState:
     _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
 
     @classmethod
-    def create_for_statement(cls, statement, compiler, **kw):
+    def create_for_statement(
+        cls, statement: Executable, compiler: Compiled, **kw: Any
+    ) -> "CompileState":
         # factory construction.
 
         if statement._propagate_attrs:
index 7119ae1c1f56fd4f7f3dcb9256293b0be53a9c4f..620f53a8b6badb4dade492e7daecb7145ada9e35 100644 (file)
@@ -75,7 +75,7 @@ _StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole)
 _T = TypeVar("_T", bound=Any)
 
 
-def _is_literal(element):
+def _is_literal(element: Any) -> bool:
     """Return whether or not the element is a "literal" in the context
     of a SQL expression construct.
 
index 6010b95862e5d8c6138551f544655d965d7cce7a..a343efbcebe4ad3cb59f273812b20087eec24372 100644 (file)
@@ -28,6 +28,7 @@ from __future__ import annotations
 import collections
 import collections.abc as collections_abc
 import contextlib
+import decimal
 from enum import IntEnum
 import functools
 import itertools
@@ -49,6 +50,7 @@ from typing import MutableMapping
 from typing import NamedTuple
 from typing import NoReturn
 from typing import Optional
+from typing import overload
 from typing import Pattern
 from typing import Protocol
 from typing import Sequence
@@ -79,10 +81,18 @@ from .base import _SentinelDefaultCharacterization
 from .base import Executable
 from .base import NO_ARG
 from .elements import ClauseElement
+from .elements import False_
+from .elements import Null
 from .elements import quoted_name
+from .elements import True_
 from .schema import Column
+from .schema import ForeignKeyConstraint
+from .schema import UniqueConstraint
+from .sqltypes import _UUID_RETURN
 from .sqltypes import TupleType
+from .type_api import TypeDecorator
 from .type_api import TypeEngine
+from .type_api import UserDefinedType
 from .visitors import prefix_anon_map
 from .visitors import Visitable
 from .. import exc
@@ -99,7 +109,9 @@ if typing.TYPE_CHECKING:
     from .cache_key import CacheKey
     from .ddl import ExecutableDDLElement
     from .dml import Insert
+    from .dml import Update
     from .dml import UpdateBase
+    from .dml import UpdateDMLState
     from .dml import ValuesBase
     from .elements import _truncated_label
     from .elements import BindParameter
@@ -107,6 +119,9 @@ if typing.TYPE_CHECKING:
     from .elements import ColumnElement
     from .elements import Label
     from .functions import Function
+    from .schema import Constraint
+    from .schema import Index
+    from .schema import PrimaryKeyConstraint
     from .schema import Table
     from .selectable import AliasedReturnsRows
     from .selectable import CompoundSelectState
@@ -118,6 +133,7 @@ if typing.TYPE_CHECKING:
     from .selectable import SelectState
     from .type_api import _BindProcessorType
     from ..engine.cursor import CursorResultMetaData
+    from ..engine.default import DefaultDialect
     from ..engine.interfaces import _CoreSingleExecuteParams
     from ..engine.interfaces import _DBAPIAnyExecuteParams
     from ..engine.interfaces import _DBAPIMultiExecuteParams
@@ -128,6 +144,7 @@ if typing.TYPE_CHECKING:
     from ..engine.interfaces import Dialect
     from ..engine.interfaces import SchemaTranslateMapType
 
+
 _FromHintsType = Dict["FromClause", str]
 
 RESERVED_WORDS = {
@@ -873,7 +890,7 @@ class Compiled:
 
             if render_schema_translate:
                 self.string = self.preparer._render_schema_translates(
-                    self.string, schema_translate_map
+                    self.string, schema_translate_map  # type: ignore[arg-type]
                 )
 
             self.state = CompilerState.STRING_APPLIED
@@ -2344,7 +2361,7 @@ class SQLCompiler(Compiled):
 
         return get
 
-    def default_from(self):
+    def default_from(self) -> str:
         """Called when a SELECT statement has no froms, and no FROM clause is
         to be appended.
 
@@ -2736,16 +2753,16 @@ class SQLCompiler(Compiled):
 
         return text
 
-    def visit_null(self, expr, **kw):
+    def visit_null(self, expr: Null, **kw: Any) -> str:
         return "NULL"
 
-    def visit_true(self, expr, **kw):
+    def visit_true(self, expr: True_, **kw: Any) -> str:
         if self.dialect.supports_native_boolean:
             return "true"
         else:
             return "1"
 
-    def visit_false(self, expr, **kw):
+    def visit_false(self, expr: False_, **kw: Any) -> str:
         if self.dialect.supports_native_boolean:
             return "false"
         else:
@@ -2973,7 +2990,9 @@ class SQLCompiler(Compiled):
             % self.dialect.name
         )
 
-    def function_argspec(self, func, **kwargs):
+    def function_argspec(
+        self, func: functions.Function[Any], **kwargs: Any
+    ) -> str:
         return func.clause_expr._compiler_dispatch(self, **kwargs)
 
     def visit_compound_select(
@@ -3437,8 +3456,12 @@ class SQLCompiler(Compiled):
         )
 
     def _generate_generic_binary(
-        self, binary, opstring, eager_grouping=False, **kw
-    ):
+        self,
+        binary: elements.BinaryExpression[Any],
+        opstring: str,
+        eager_grouping: bool = False,
+        **kw: Any,
+    ) -> str:
         _in_operator_expression = kw.get("_in_operator_expression", False)
 
         kw["_in_operator_expression"] = True
@@ -3607,19 +3630,25 @@ class SQLCompiler(Compiled):
             **kw,
         )
 
-    def visit_regexp_match_op_binary(self, binary, operator, **kw):
+    def visit_regexp_match_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         raise exc.CompileError(
             "%s dialect does not support regular expressions"
             % self.dialect.name
         )
 
-    def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+    def visit_not_regexp_match_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         raise exc.CompileError(
             "%s dialect does not support regular expressions"
             % self.dialect.name
         )
 
-    def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+    def visit_regexp_replace_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         raise exc.CompileError(
             "%s dialect does not support regular expression replacements"
             % self.dialect.name
@@ -3826,7 +3855,9 @@ class SQLCompiler(Compiled):
         else:
             return self.render_literal_value(value, bindparam.type)
 
-    def render_literal_value(self, value, type_):
+    def render_literal_value(
+        self, value: str | None, type_: sqltypes.String
+    ) -> str:
         """Render the value of a bind parameter as a quoted literal.
 
         This is used for statement sections that do not accept bind parameters
@@ -4587,7 +4618,7 @@ class SQLCompiler(Compiled):
     def get_select_hint_text(self, byfroms):
         return None
 
-    def get_from_hint_text(self, table, text):
+    def get_from_hint_text(self, table: Any, text: str | None) -> str | None:
         return None
 
     def get_crud_hint_text(self, table, text):
@@ -5072,7 +5103,7 @@ class SQLCompiler(Compiled):
         else:
             return "WITH"
 
-    def get_select_precolumns(self, select, **kw):
+    def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str:
         """Called when building a ``SELECT`` statement, position is just
         before column list.
 
@@ -6098,7 +6129,7 @@ class SQLCompiler(Compiled):
 
         return text
 
-    def update_limit_clause(self, update_stmt):
+    def update_limit_clause(self, update_stmt: "Update") -> str | None:
         """Provide a hook for MySQL to add LIMIT to the UPDATE"""
         return None
 
@@ -6130,11 +6161,14 @@ class SQLCompiler(Compiled):
             "criteria within UPDATE"
         )
 
-    def visit_update(self, update_stmt, visiting_cte=None, **kw):
-        compile_state = update_stmt._compile_state_factory(
-            update_stmt, self, **kw
+    def visit_update(
+        self, update_stmt: "Update", visiting_cte: CTE | None = None, **kw: Any
+    ) -> str:
+        compile_state = update_stmt._compile_state_factory(  # type: ignore
+            update_stmt, self, **kw  # type: ignore
         )
-        update_stmt = compile_state.statement
+        compile_state = cast("UpdateDMLState", compile_state)
+        update_stmt = compile_state.statement  # type: ignore[assignment]
 
         if visiting_cte is not None:
             kw["visiting_cte"] = visiting_cte
@@ -6223,7 +6257,7 @@ class SQLCompiler(Compiled):
             if self.returning_precedes_values:
                 text += " " + self.returning_clause(
                     update_stmt,
-                    self.implicit_returning or update_stmt._returning,
+                    self.implicit_returning or update_stmt._returning,  # type: ignore[arg-type]  # NOQA: E501
                     populate_result_map=toplevel,
                 )
 
@@ -6255,7 +6289,7 @@ class SQLCompiler(Compiled):
         ) and not self.returning_precedes_values:
             text += " " + self.returning_clause(
                 update_stmt,
-                self.implicit_returning or update_stmt._returning,
+                self.implicit_returning or update_stmt._returning,  # type: ignore[arg-type] # noqa: E501
                 populate_result_map=toplevel,
             )
 
@@ -6272,7 +6306,7 @@ class SQLCompiler(Compiled):
         return text
 
     def delete_extra_from_clause(
-        self, update_stmt, from_table, extra_froms, from_hints, **kw
+        self, delete_stmt, from_table, extra_froms, from_hints, **kw
     ):
         """Provide a hook to override the generation of an
         DELETE..FROM clause.
@@ -6515,7 +6549,7 @@ class StrSQLCompiler(SQLCompiler):
         )
 
     def delete_extra_from_clause(
-        self, update_stmt, from_table, extra_froms, from_hints, **kw
+        self, delete_stmt, from_table, extra_froms, from_hints, **kw
     ):
         kw["asfrom"] = True
         return ", " + ", ".join(
@@ -6727,7 +6761,7 @@ class DDLCompiler(Compiled):
     def visit_drop_view(self, drop, **kw):
         return "\nDROP VIEW " + self.preparer.format_table(drop.element)
 
-    def _verify_index_table(self, index):
+    def _verify_index_table(self, index: "Index") -> None:
         if index.table is None:
             raise exc.CompileError(
                 "Index '%s' is not associated with any table." % index.name
@@ -6778,7 +6812,9 @@ class DDLCompiler(Compiled):
 
         return text + self._prepared_index_name(index, include_schema=True)
 
-    def _prepared_index_name(self, index, include_schema=False):
+    def _prepared_index_name(
+        self, index: "Index", include_schema: bool = False
+    ) -> str:
         if index.table is not None:
             effective_schema = self.preparer.schema_for_object(index.table)
         else:
@@ -6788,7 +6824,7 @@ class DDLCompiler(Compiled):
         else:
             schema_name = None
 
-        index_name = self.preparer.format_index(index)
+        index_name: str = self.preparer.format_index(index)
 
         if schema_name:
             index_name = schema_name + "." + index_name
@@ -6925,19 +6961,19 @@ class DDLCompiler(Compiled):
     def post_create_table(self, table):
         return ""
 
-    def get_column_default_string(self, column):
+    def get_column_default_string(self, column: Column[Any]) -> str | None:
         if isinstance(column.server_default, schema.DefaultClause):
             return self.render_default_string(column.server_default.arg)
         else:
             return None
 
-    def render_default_string(self, default):
+    def render_default_string(self, default: Visitable | str) -> str:
         if isinstance(default, str):
-            return self.sql_compiler.render_literal_value(
+            return self.sql_compiler.render_literal_value(  # type: ignore[no-any-return]  # NOQA: E501
                 default, sqltypes.STRINGTYPE
             )
         else:
-            return self.sql_compiler.process(default, literal_binds=True)
+            return self.sql_compiler.process(default, literal_binds=True)  # type: ignore[no-any-return]  # NOQA: E501
 
     def visit_table_or_column_check_constraint(self, constraint, **kw):
         if constraint.is_column_level:
@@ -6969,7 +7005,9 @@ class DDLCompiler(Compiled):
         text += self.define_constraint_deferrability(constraint)
         return text
 
-    def visit_primary_key_constraint(self, constraint, **kw):
+    def visit_primary_key_constraint(
+        self, constraint: "PrimaryKeyConstraint", **kw: Any
+    ) -> str:
         if len(constraint) == 0:
             return ""
         text = ""
@@ -7018,7 +7056,9 @@ class DDLCompiler(Compiled):
 
         return preparer.format_table(table)
 
-    def visit_unique_constraint(self, constraint, **kw):
+    def visit_unique_constraint(
+        self, constraint: UniqueConstraint, **kw: Any
+    ) -> str:
         if len(constraint) == 0:
             return ""
         text = ""
@@ -7033,10 +7073,14 @@ class DDLCompiler(Compiled):
         text += self.define_constraint_deferrability(constraint)
         return text
 
-    def define_unique_constraint_distinct(self, constraint, **kw):
+    def define_unique_constraint_distinct(
+        self, constraint: UniqueConstraint, **kw: Any
+    ) -> str:
         return ""
 
-    def define_constraint_cascades(self, constraint):
+    def define_constraint_cascades(
+        self, constraint: ForeignKeyConstraint
+    ) -> str:
         text = ""
         if constraint.ondelete is not None:
             text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
@@ -7048,7 +7092,7 @@ class DDLCompiler(Compiled):
             )
         return text
 
-    def define_constraint_deferrability(self, constraint):
+    def define_constraint_deferrability(self, constraint: Constraint) -> str:
         text = ""
         if constraint.deferrable is not None:
             if constraint.deferrable:
@@ -7088,19 +7132,31 @@ class DDLCompiler(Compiled):
 
 
 class GenericTypeCompiler(TypeCompiler):
-    def visit_FLOAT(self, type_, **kw):
+    def visit_FLOAT(
+        self, type_: "sqltypes.Float[decimal.Decimal| float]", **kw: Any
+    ) -> str:
         return "FLOAT"
 
-    def visit_DOUBLE(self, type_, **kw):
+    def visit_DOUBLE(
+        self, type_: "sqltypes.Double[decimal.Decimal | float]", **kw: Any
+    ) -> str:
         return "DOUBLE"
 
-    def visit_DOUBLE_PRECISION(self, type_, **kw):
+    def visit_DOUBLE_PRECISION(
+        self,
+        type_: "sqltypes.DOUBLE_PRECISION[decimal.Decimal| float]",
+        **kw: Any,
+    ) -> str:
         return "DOUBLE PRECISION"
 
-    def visit_REAL(self, type_, **kw):
+    def visit_REAL(
+        self, type_: "sqltypes.REAL[decimal.Decimal| float]", **kw: Any
+    ) -> str:
         return "REAL"
 
-    def visit_NUMERIC(self, type_, **kw):
+    def visit_NUMERIC(
+        self, type_: "sqltypes.Numeric[decimal.Decimal| float]", **kw: Any
+    ) -> str:
         if type_.precision is None:
             return "NUMERIC"
         elif type_.scale is None:
@@ -7111,7 +7167,9 @@ class GenericTypeCompiler(TypeCompiler):
                 "scale": type_.scale,
             }
 
-    def visit_DECIMAL(self, type_, **kw):
+    def visit_DECIMAL(
+        self, type_: "sqltypes.DECIMAL[decimal.Decimal| float]", **kw: Any
+    ) -> str:
         if type_.precision is None:
             return "DECIMAL"
         elif type_.scale is None:
@@ -7122,128 +7180,153 @@ class GenericTypeCompiler(TypeCompiler):
                 "scale": type_.scale,
             }
 
-    def visit_INTEGER(self, type_, **kw):
+    def visit_INTEGER(self, type_: "sqltypes.Integer", **kw: Any) -> str:
         return "INTEGER"
 
-    def visit_SMALLINT(self, type_, **kw):
+    def visit_SMALLINT(self, type_: "sqltypes.SmallInteger", **kw: Any) -> str:
         return "SMALLINT"
 
-    def visit_BIGINT(self, type_, **kw):
+    def visit_BIGINT(self, type_: "sqltypes.BigInteger", **kw: Any) -> str:
         return "BIGINT"
 
-    def visit_TIMESTAMP(self, type_, **kw):
+    def visit_TIMESTAMP(self, type_: "sqltypes.TIMESTAMP", **kw: Any) -> str:
         return "TIMESTAMP"
 
-    def visit_DATETIME(self, type_, **kw):
+    def visit_DATETIME(self, type_: "sqltypes.DateTime", **kw: Any) -> str:
         return "DATETIME"
 
-    def visit_DATE(self, type_, **kw):
+    def visit_DATE(self, type_: "sqltypes.Date", **kw: Any) -> str:
         return "DATE"
 
-    def visit_TIME(self, type_, **kw):
+    def visit_TIME(self, type_: "sqltypes.Time", **kw: Any) -> str:
         return "TIME"
 
-    def visit_CLOB(self, type_, **kw):
+    def visit_CLOB(self, type_: "sqltypes.Text", **kw: Any) -> str:
         return "CLOB"
 
-    def visit_NCLOB(self, type_, **kw):
+    def visit_NCLOB(self, type_: "sqltypes.Text", **kw: Any) -> str:
         return "NCLOB"
 
-    def _render_string_type(self, type_, name, length_override=None):
+    def _render_string_type(
+        self,
+        type_: "sqltypes.String | sqltypes.Uuid[_UUID_RETURN]",
+        name: str,
+        length_override: int | None = None,
+    ) -> str:
         text = name
         if length_override:
             text += "(%d)" % length_override
-        elif type_.length:
-            text += "(%d)" % type_.length
+        elif type_.length:  # type: ignore[union-attr]
+            text += "(%d)" % type_.length  # type: ignore[union-attr]
         if type_.collation:
             text += ' COLLATE "%s"' % type_.collation
         return text
 
-    def visit_CHAR(self, type_, **kw):
+    def visit_CHAR(self, type_: "sqltypes.CHAR", **kw: Any) -> str:
         return self._render_string_type(type_, "CHAR")
 
-    def visit_NCHAR(self, type_, **kw):
+    def visit_NCHAR(self, type_: "sqltypes.NCHAR", **kw: Any) -> str:
         return self._render_string_type(type_, "NCHAR")
 
-    def visit_VARCHAR(self, type_, **kw):
+    def visit_VARCHAR(self, type_: "sqltypes.String", **kw: Any) -> str:
         return self._render_string_type(type_, "VARCHAR")
 
-    def visit_NVARCHAR(self, type_, **kw):
+    def visit_NVARCHAR(self, type_: "sqltypes.NVARCHAR", **kw: Any) -> str:
         return self._render_string_type(type_, "NVARCHAR")
 
-    def visit_TEXT(self, type_, **kw):
+    def visit_TEXT(self, type_: "sqltypes.Text", **kw: Any) -> str:
         return self._render_string_type(type_, "TEXT")
 
-    def visit_UUID(self, type_, **kw):
+    def visit_UUID(
+        self, type_: "sqltypes.Uuid[_UUID_RETURN]", **kw: Any
+    ) -> str:
         return "UUID"
 
-    def visit_BLOB(self, type_, **kw):
+    def visit_BLOB(self, type_: "sqltypes.LargeBinary", **kw: Any) -> str:
         return "BLOB"
 
-    def visit_BINARY(self, type_, **kw):
+    def visit_BINARY(self, type_: "sqltypes.BINARY", **kw: Any) -> str:
         return "BINARY" + (type_.length and "(%d)" % type_.length or "")
 
-    def visit_VARBINARY(self, type_, **kw):
+    def visit_VARBINARY(self, type_: "sqltypes.VARBINARY", **kw: Any) -> str:
         return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
 
-    def visit_BOOLEAN(self, type_, **kw):
+    def visit_BOOLEAN(self, type_: "sqltypes.Boolean", **kw: Any) -> str:
         return "BOOLEAN"
 
-    def visit_uuid(self, type_, **kw):
+    def visit_uuid(
+        self, type_: "sqltypes.Uuid[_UUID_RETURN]", **kw: Any
+    ) -> str:
         if not type_.native_uuid or not self.dialect.supports_native_uuid:
             return self._render_string_type(type_, "CHAR", length_override=32)
         else:
             return self.visit_UUID(type_, **kw)
 
-    def visit_large_binary(self, type_, **kw):
+    def visit_large_binary(
+        self, type_: "sqltypes.LargeBinary", **kw: Any
+    ) -> str:
         return self.visit_BLOB(type_, **kw)
 
-    def visit_boolean(self, type_, **kw):
+    def visit_boolean(self, type_: "sqltypes.Boolean", **kw: Any) -> str:
         return self.visit_BOOLEAN(type_, **kw)
 
-    def visit_time(self, type_, **kw):
+    def visit_time(self, type_: "sqltypes.Time", **kw: Any) -> str:
         return self.visit_TIME(type_, **kw)
 
-    def visit_datetime(self, type_, **kw):
+    def visit_datetime(self, type_: "sqltypes.DateTime", **kw: Any) -> str:
         return self.visit_DATETIME(type_, **kw)
 
-    def visit_date(self, type_, **kw):
+    def visit_date(self, type_: "sqltypes.Date", **kw: Any) -> str:
         return self.visit_DATE(type_, **kw)
 
-    def visit_big_integer(self, type_, **kw):
+    def visit_big_integer(
+        self, type_: "sqltypes.BigInteger", **kw: Any
+    ) -> str:
         return self.visit_BIGINT(type_, **kw)
 
-    def visit_small_integer(self, type_, **kw):
+    def visit_small_integer(
+        self, type_: "sqltypes.SmallInteger", **kw: Any
+    ) -> str:
         return self.visit_SMALLINT(type_, **kw)
 
-    def visit_integer(self, type_, **kw):
+    def visit_integer(self, type_: "sqltypes.Integer", **kw: Any) -> str:
         return self.visit_INTEGER(type_, **kw)
 
-    def visit_real(self, type_, **kw):
+    def visit_real(
+        self, type_: "sqltypes.REAL[decimal.Decimal| float]", **kw: Any
+    ) -> str:
         return self.visit_REAL(type_, **kw)
 
-    def visit_float(self, type_, **kw):
+    def visit_float(
+        self, type_: "sqltypes.Float[decimal.Decimal| float]", **kw: Any
+    ) -> str:
         return self.visit_FLOAT(type_, **kw)
 
-    def visit_double(self, type_, **kw):
+    def visit_double(
+        self, type_: "sqltypes.Double[decimal.Decimal | float]", **kw: Any
+    ) -> str:
         return self.visit_DOUBLE(type_, **kw)
 
-    def visit_numeric(self, type_, **kw):
+    def visit_numeric(
+        self, type_: "sqltypes.Numeric[decimal.Decimal | float]", **kw: Any
+    ) -> str:
         return self.visit_NUMERIC(type_, **kw)
 
-    def visit_string(self, type_, **kw):
+    def visit_string(self, type_: "sqltypes.String", **kw: Any) -> str:
         return self.visit_VARCHAR(type_, **kw)
 
-    def visit_unicode(self, type_, **kw):
+    def visit_unicode(self, type_: "sqltypes.Unicode", **kw: Any) -> str:
         return self.visit_VARCHAR(type_, **kw)
 
-    def visit_text(self, type_, **kw):
+    def visit_text(self, type_: "sqltypes.Text", **kw: Any) -> str:
         return self.visit_TEXT(type_, **kw)
 
-    def visit_unicode_text(self, type_, **kw):
+    def visit_unicode_text(
+        self, type_: "sqltypes.UnicodeText", **kw: Any
+    ) -> str:
         return self.visit_TEXT(type_, **kw)
 
-    def visit_enum(self, type_, **kw):
+    def visit_enum(self, type_: "sqltypes.Enum", **kw: Any) -> str:
         return self.visit_VARCHAR(type_, **kw)
 
     def visit_null(self, type_, **kw):
@@ -7253,10 +7336,14 @@ class GenericTypeCompiler(TypeCompiler):
             "type on this Column?" % type_
         )
 
-    def visit_type_decorator(self, type_, **kw):
+    def visit_type_decorator(
+        self, type_: TypeDecorator[Any], **kw: Any
+    ) -> str:
         return self.process(type_.type_engine(self.dialect), **kw)
 
-    def visit_user_defined(self, type_, **kw):
+    def visit_user_defined(
+        self, type_: UserDefinedType[Any], **kw: Any
+    ) -> str:
         return type_.get_col_spec(**kw)
 
 
@@ -7331,12 +7418,12 @@ class IdentifierPreparer:
 
     def __init__(
         self,
-        dialect,
-        initial_quote='"',
-        final_quote=None,
-        escape_quote='"',
-        quote_case_sensitive_collations=True,
-        omit_schema=False,
+        dialect: DefaultDialect,
+        initial_quote: str = '"',
+        final_quote: str | None = None,
+        escape_quote: str = '"',
+        quote_case_sensitive_collations: bool = True,
+        omit_schema: bool = False,
     ):
         """Construct a new ``IdentifierPreparer`` object.
 
@@ -7389,7 +7476,9 @@ class IdentifierPreparer:
         prep._includes_none_schema_translate = includes_none
         return prep
 
-    def _render_schema_translates(self, statement, schema_translate_map):
+    def _render_schema_translates(
+        self, statement: str, schema_translate_map: SchemaTranslateMapType
+    ) -> str:
         d = schema_translate_map
         if None in d:
             if not self._includes_none_schema_translate:
@@ -7594,7 +7683,9 @@ class IdentifierPreparer:
         else:
             return collation_name
 
-    def format_sequence(self, sequence, use_schema=True):
+    def format_sequence(
+        self, sequence: schema.Sequence, use_schema: bool = True
+    ) -> str:
         name = self.quote(sequence.name)
 
         effective_schema = self.schema_for_object(sequence)
@@ -7631,7 +7722,9 @@ class IdentifierPreparer:
         return ident
 
     @util.preload_module("sqlalchemy.sql.naming")
-    def format_constraint(self, constraint, _alembic_quote=True):
+    def format_constraint(
+        self, constraint: Constraint, _alembic_quote: bool = True
+    ) -> str | None:
         naming = util.preloaded.sql_naming
 
         if constraint.name is _NONE_NAME:
@@ -7653,7 +7746,9 @@ class IdentifierPreparer:
                 name, _alembic_quote=_alembic_quote
             )
 
-    def truncate_and_render_index_name(self, name, _alembic_quote=True):
+    def truncate_and_render_index_name(
+        self, name: str, _alembic_quote: bool = True
+    ) -> str | None:
         # calculate these at format time so that ad-hoc changes
         # to dialect.max_identifier_length etc. can be reflected
         # as IdentifierPreparer is long lived
@@ -7665,7 +7760,9 @@ class IdentifierPreparer:
             name, max_, _alembic_quote
         )
 
-    def truncate_and_render_constraint_name(self, name, _alembic_quote=True):
+    def truncate_and_render_constraint_name(
+        self, name: str, _alembic_quote: bool = True
+    ) -> str | None:
         # calculate these at format time so that ad-hoc changes
         # to dialect.max_identifier_length etc. can be reflected
         # as IdentifierPreparer is long lived
@@ -7677,7 +7774,9 @@ class IdentifierPreparer:
             name, max_, _alembic_quote
         )
 
-    def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote):
+    def _truncate_and_render_maxlen_name(
+        self, name: str, max_: int, _alembic_quote: bool
+    ) -> str | None:
         if isinstance(name, elements._truncated_label):
             if len(name) > max_:
                 name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
@@ -7692,10 +7791,31 @@ class IdentifierPreparer:
     def format_index(self, index):
         return self.format_constraint(index)
 
-    def format_table(self, table, use_schema=True, name=None):
-        """Prepare a quoted table and schema name."""
+    @overload
+    def format_table(
+        self,
+        table: "Table | None",
+        use_schema: bool,
+        name: str,
+    ) -> str: ...
+
+    @overload
+    def format_table(
+        self,
+        table: "Table",
+        use_schema: bool = True,
+        name: None = None,
+    ) -> str: ...
 
+    def format_table(
+        self,
+        table: "Table | None",
+        use_schema: bool = True,
+        name: str | None = None,
+    ) -> str:
+        """Prepare a quoted table and schema name."""
         if name is None:
+            assert table is not None
             name = table.name
 
         result = self.quote(name)
@@ -7727,13 +7847,13 @@ class IdentifierPreparer:
 
     def format_column(
         self,
-        column,
-        use_table=False,
-        name=None,
-        table_name=None,
-        use_schema=False,
-        anon_map=None,
-    ):
+        column: "Column[Any]",
+        use_table: bool = False,
+        name: str | None = None,
+        table_name: str | None = None,
+        use_schema: bool = False,
+        anon_map: Mapping[str, Any] | None = None,
+    ) -> str:
         """Prepare a quoted column name."""
 
         if name is None:
@@ -7805,7 +7925,7 @@ class IdentifierPreparer:
         )
         return r
 
-    def unformat_identifiers(self, identifiers):
+    def unformat_identifiers(self, identifiers: str) -> list[str]:
         """Unpack 'schema.table.column'-like strings into components."""
 
         r = self._r_identifiers
index 19af40ff08038e375d51e87f4d95378287713963..c3a5b015bd2563f1950968d9c2a24ac0b6ed71a5 100644 (file)
@@ -207,7 +207,7 @@ def _get_crud_params(
             [
                 (
                     c,
-                    compiler.preparer.format_column(c),
+                    compiler.preparer.format_column(c),  # type: ignore[arg-type] # noqa: E501
                     _create_bind_param(compiler, c, None, required=True),
                     (c.key,),
                 )
@@ -369,7 +369,7 @@ def _get_crud_params(
         values = [
             (
                 _as_dml_column(stmt.table.columns[0]),
-                compiler.preparer.format_column(stmt.table.columns[0]),
+                compiler.preparer.format_column(stmt.table.columns[0]),  # type: ignore[arg-type] # noqa: E501
                 compiler.dialect.default_metavalue_token,
                 (),
             )
@@ -1136,7 +1136,7 @@ def _append_param_insert_select_hasdefault(
             values.append(
                 (
                     c,
-                    compiler.preparer.format_column(c),
+                    compiler.preparer.format_column(c),  # type: ignore[arg-type] # noqa: E501
                     c.default.next_value(),
                     (),
                 )
@@ -1145,7 +1145,7 @@ def _append_param_insert_select_hasdefault(
         values.append(
             (
                 c,
-                compiler.preparer.format_column(c),
+                compiler.preparer.format_column(c),  # type: ignore[arg-type]
                 c.default.arg.self_group(),
                 (),
             )
@@ -1154,7 +1154,7 @@ def _append_param_insert_select_hasdefault(
         values.append(
             (
                 c,
-                compiler.preparer.format_column(c),
+                compiler.preparer.format_column(c),  # type: ignore[arg-type]
                 _create_insert_prefetch_bind_param(
                     compiler, c, process=False, **kw
                 ),
index 7210d930a188dd545673afda6212f1029d2b1a47..d397331c14e4257b5b7d892724e71fd9fb4744dd 100644 (file)
@@ -38,8 +38,10 @@ if typing.TYPE_CHECKING:
     from .compiler import Compiled
     from .compiler import DDLCompiler
     from .elements import BindParameter
+    from .schema import Column
     from .schema import Constraint
     from .schema import ForeignKeyConstraint
+    from .schema import Index
     from .schema import SchemaItem
     from .schema import Sequence
     from .schema import Table
@@ -712,8 +714,9 @@ class CreateIndex(_CreateBase):
     """Represent a CREATE INDEX statement."""
 
     __visit_name__ = "create_index"
+    element: "Index"
 
-    def __init__(self, element, if_not_exists=False):
+    def __init__(self, element: "Index", if_not_exists: bool = False):
         """Create a :class:`.Createindex` construct.
 
         :param element: a :class:`_schema.Index` that's the subject
@@ -732,7 +735,9 @@ class DropIndex(_DropBase):
 
     __visit_name__ = "drop_index"
 
-    def __init__(self, element, if_exists=False):
+    element: "Index"
+
+    def __init__(self, element: "Index", if_exists: bool = False):
         """Create a :class:`.DropIndex` construct.
 
         :param element: a :class:`_schema.Index` that's the subject
@@ -791,6 +796,7 @@ class SetColumnComment(_CreateDropBase):
     """Represent a COMMENT ON COLUMN IS statement."""
 
     __visit_name__ = "set_column_comment"
+    element: "Column[Any]"
 
 
 class DropColumnComment(_CreateDropBase):
index 41630261edfde3749e730b2604b565eadd38da39..6e660c0be7f85d2255b3bd35f94486f23d0da6a1 100644 (file)
@@ -76,6 +76,7 @@ from .. import inspection
 from .. import util
 from ..util import HasMemoized_ro_memoized_attribute
 from ..util import TypingOnly
+from ..util._immutabledict_cy import immutabledict
 from ..util.typing import Literal
 from ..util.typing import ParamSpec
 from ..util.typing import Self
@@ -118,6 +119,7 @@ if typing.TYPE_CHECKING:
     from ..engine.interfaces import SchemaTranslateMapType
     from ..engine.result import Result
 
+
 _NUMERIC = Union[float, Decimal]
 _NUMBER = Union[float, int, Decimal]
 
@@ -2216,8 +2218,8 @@ class TypeClause(DQLDMLClauseElement):
         ("type", InternalTraversal.dp_type)
     ]
 
-    def __init__(self, type_):
-        self.type = type_
+    def __init__(self, type_: TypeEngine[Any]):
+        self.type: TypeEngine[Any] = type_
 
 
 class TextClause(
@@ -3875,8 +3877,6 @@ class BinaryExpression(OperatorExpression[_T]):
 
     """
 
-    modifiers: Optional[Mapping[str, Any]]
-
     left: ColumnElement[Any]
     right: ColumnElement[Any]
 
@@ -3907,9 +3907,9 @@ class BinaryExpression(OperatorExpression[_T]):
         self._is_implicitly_boolean = operators.is_boolean(operator)
 
         if modifiers is None:
-            self.modifiers = {}
+            self.modifiers: immutabledict[str, str] = immutabledict({})
         else:
-            self.modifiers = modifiers
+            self.modifiers = immutabledict(modifiers)
 
     @property
     def _flattened_operator_clauses(
index b905913d376edb9b9a3f9fdb419dd7933b549d13..c66482179ec076a0897bb2df80ba673c541070b9 100644 (file)
@@ -62,6 +62,7 @@ from .sqltypes import TableValueType
 from .type_api import TypeEngine
 from .visitors import InternalTraversal
 from .. import util
+from ..util._immutabledict_cy import immutabledict
 
 
 if TYPE_CHECKING:
@@ -788,7 +789,7 @@ class FunctionAsBinary(BinaryExpression[Any]):
         self.type = sqltypes.BOOLEANTYPE
         self.negate = None
         self._is_implicitly_boolean = True
-        self.modifiers = {}
+        self.modifiers = immutabledict({})
 
     @property
     def left_expr(self) -> ColumnElement[Any]:
index 212b86ca8ab6cafee20d0fa3d387e9c7d0ac40f7..9f7fb1330f6b6095ade39f7a4b5dd468c1e640d9 100644 (file)
@@ -246,10 +246,14 @@ class String(Concatenable, TypeEngine[str]):
 
         return process
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: "Dialect"
+    ) -> _BindProcessorType[str] | None:
         return None
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> _ResultProcessorType[str] | None:
         return None
 
     @property
@@ -426,7 +430,11 @@ class NumericCommon(HasExpressionLookup, TypeEngineMixin, Generic[_N]):
     if TYPE_CHECKING:
 
         @util.ro_memoized_property
-        def _type_affinity(self) -> Type[NumericCommon[_N]]: ...
+        def _type_affinity(
+            self,
+        ) -> Type[
+            Numeric[decimal.Decimal | float] | Float[decimal.Decimal | float]
+        ]: ...
 
     def __init__(
         self,
@@ -653,7 +661,7 @@ class Float(NumericCommon[_N], TypeEngine[_N]):
 
     __visit_name__ = "float"
 
-    scale = None
+    scale: int | None = None
 
     @overload
     def __init__(
@@ -1325,6 +1333,8 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
 
     __visit_name__ = "enum"
 
+    enum_class: None | str | type[enum.StrEnum]
+
     def __init__(self, *enums: object, **kw: Any):
         r"""Construct an enum.
 
@@ -1457,7 +1467,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
            .. versionchanged:: 2.0 This parameter now defaults to True.
 
         """
-        self._enum_init(enums, kw)
+        self._enum_init(enums, kw)  # type: ignore[arg-type]
 
     @property
     def _enums_argument(self):
@@ -1466,7 +1476,9 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
         else:
             return self.enums
 
-    def _enum_init(self, enums, kw):
+    def _enum_init(
+        self, enums: Sequence[str | type[enum.StrEnum]], kw: dict[str, Any]
+    ) -> None:
         """internal init for :class:`.Enum` and subclasses.
 
         friendly init helper used by subclasses to remove
@@ -1476,7 +1488,9 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
         """
         self.native_enum = kw.pop("native_enum", True)
         self.create_constraint = kw.pop("create_constraint", False)
-        self.values_callable = kw.pop("values_callable", None)
+        self.values_callable: (
+            Callable[[type[enum.StrEnum]], Sequence[str]] | None
+        ) = kw.pop("values_callable", None)
         self._sort_key_function = kw.pop("sort_key_function", NO_ARG)
         length_arg = kw.pop("length", NO_ARG)
         self._omit_aliases = kw.pop("omit_aliases", True)
@@ -1504,7 +1518,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
                 )
             length = length_arg
 
-        self._valid_lookup[None] = self._object_lookup[None] = None
+        self._valid_lookup[None] = self._object_lookup[None] = None  # type: ignore # noqa: E501
 
         super().__init__(length=length)
 
@@ -1513,7 +1527,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
         # this is a template enum that will be used to generate
         # new Enum classes.
         if self.enum_class and values:
-            kw.setdefault("name", self.enum_class.__name__.lower())
+            kw.setdefault("name", self.enum_class.__name__.lower())  # type: ignore[union-attr] # noqa: E501
         SchemaType.__init__(
             self,
             name=kw.pop("name", None),
@@ -1525,7 +1539,9 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
             _adapted_from=kw.pop("_adapted_from", None),
         )
 
-    def _parse_into_values(self, enums, kw):
+    def _parse_into_values(
+        self, enums: Sequence[str | type[enum.StrEnum]], kw: Any
+    ) -> tuple[Sequence[str], Sequence[enum.StrEnum] | Sequence[str]]:
         if not enums and "_enums" in kw:
             enums = kw.pop("_enums")
 
@@ -1540,16 +1556,16 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
                     (n, v) for n, v in _members.items() if v.name == n
                 )
             else:
-                members = _members
+                members = _members  # type: ignore[assignment]
             if self.values_callable:
-                values = self.values_callable(self.enum_class)
+                values = self.values_callable(self.enum_class)  # type: ignore[arg-type] # noqa: E501
             else:
                 values = list(members)
             objects = [members[k] for k in members]
             return values, objects
         else:
             self.enum_class = None
-            return enums, enums
+            return enums, enums  # type: ignore[return-value]
 
     def _resolve_for_literal(self, value: Any) -> Enum:
         tv = type(value)
@@ -1639,12 +1655,19 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
             self._generic_type_affinity(_enums=enum_args, **kw),  # type: ignore  # noqa: E501
         )
 
-    def _setup_for_values(self, values, objects, kw):
+    def _setup_for_values(
+        self,
+        values: Sequence[str],
+        objects: Sequence[enum.StrEnum] | Sequence[str],
+        kw: Any,
+    ) -> None:
         self.enums = list(values)
 
-        self._valid_lookup = dict(zip(reversed(objects), reversed(values)))
+        self._valid_lookup: dict[str, str] = dict(
+            zip(reversed(objects), reversed(values))
+        )
 
-        self._object_lookup = dict(zip(values, objects))
+        self._object_lookup: dict[str, str] = dict(zip(values, objects))
 
         self._valid_lookup.update(
             [
@@ -1706,7 +1729,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
 
     comparator_factory = Comparator
 
-    def _object_value_for_elem(self, elem):
+    def _object_value_for_elem(self, elem: str) -> str:
         try:
             return self._object_lookup[elem]
         except KeyError as err:
@@ -2183,7 +2206,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]):
 
     def bind_processor(
         self, dialect: Dialect
-    ) -> _BindProcessorType[dt.timedelta]:
+    ) -> "_BindProcessorType[dt.timedelta]":
         if TYPE_CHECKING:
             assert isinstance(self.impl_instance, DateTime)
         impl_processor = self.impl_instance.bind_processor(dialect)
@@ -3490,6 +3513,7 @@ class BINARY(_Binary):
 class VARBINARY(_Binary):
     """The SQL VARBINARY type."""
 
+    length: int
     __visit_name__ = "VARBINARY"
 
 
@@ -3686,7 +3710,9 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]):
         else:
             return super().coerce_compared_value(op, value)
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> "_BindProcessorType[_UUID_RETURN] | None":
         character_based_uuid = (
             not dialect.supports_native_uuid or not self.native_uuid
         )
@@ -3694,18 +3720,18 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]):
         if character_based_uuid:
             if self.as_uuid:
 
-                def process(value):
+                def process(value: Any) -> str:
                     if value is not None:
                         value = value.hex
-                    return value
+                    return value  # type: ignore[no-any-return]
 
                 return process
             else:
 
-                def process(value):
+                def process(value: Any) -> str:
                     if value is not None:
                         value = value.replace("-", "")
-                    return value
+                    return value  # type: ignore[no-any-return]
 
                 return process
         else:
index fb72c825e576b1f414a3fe7e3fa16cd7975d7d1b..33f7bc41a159c4d3facfbd4c707cb82f55c7637a 100644 (file)
@@ -1376,6 +1376,9 @@ class UserDefinedType(
 
         return self
 
+    def get_col_spec(self, **kw: Any) -> str:
+        raise NotImplementedError()
+
 
 class Emulated(TypeEngineMixin):
     """Mixin for base types that emulate the behavior of a DB-native type.
index 98990041784297912fc15402bcb72d2e16b4542e..b3b5d0e8c1e34a936317eadeb050ae7164202385 100644 (file)
@@ -21,6 +21,7 @@ from typing import Callable
 from typing import cast
 from typing import Collection
 from typing import Dict
+from typing import Generator
 from typing import Iterable
 from typing import Iterator
 from typing import List
@@ -481,7 +482,9 @@ def surface_selectables(clause):
             stack.append(elem.element)
 
 
-def surface_selectables_only(clause):
+def surface_selectables_only(
+    clause: ClauseElement,
+) -> Generator[ClauseElement]:
     stack = [clause]
     while stack:
         elem = stack.pop()
index 9ca5e60a202478058f9fcb9e88b884172820c37d..34b6b8a29e35f75c02ce07ee931c7cbc29d37392 100644 (file)
@@ -430,7 +430,9 @@ def to_column_set(x: Any) -> Set[Any]:
         return x
 
 
-def update_copy(d, _new=None, **kw):
+def update_copy(
+    d: dict[Any, Any], _new: dict[Any, Any] | None = None, **kw: dict[Any, Any]
+) -> dict[Any, Any]:
     """Copy the given dict and update with the given values."""
 
     d = d.copy()
index 7809c9fcad7cd9766e3abcbfbb2312bc2f9c453b..d2a78b2287b86f6569340d6f51756a3a46579dab 100644 (file)
@@ -56,6 +56,7 @@ if True:  # zimports removes the tailing comments
     from typing_extensions import TypeAliasType as TypeAliasType  # 3.12
     from typing_extensions import Unpack as Unpack  # 3.11
     from typing_extensions import Never as Never  # 3.11
+    from typing_extensions import LiteralString as LiteralString  # 3.11
 
 
 _T = TypeVar("_T", bound=Any)
index 8ea523fb7e56b7307d1dd377d0e576f941ceb996..1f8a23f70dc945793ac2bb3cff9dba59ba9c2e73 100644 (file)
@@ -681,7 +681,6 @@ class CompatFlagsTest(fixtures.TestBase, AssertsCompiledSQL):
 
         dialect._get_server_version_info = server_version_info
         dialect.get_isolation_level = Mock()
-        dialect._check_unicode_returns = Mock()
         dialect._check_unicode_description = Mock()
         dialect._get_default_schema_name = Mock()
         dialect._detect_decimal_char = Mock()