]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
miscellaneous to type dialects
authorPablo Estevez <pablo22estevez@gmail.com>
Sat, 8 Feb 2025 15:46:24 +0000 (10:46 -0500)
committerFederico Caselli <cfederico87@gmail.com>
Sat, 15 Mar 2025 14:12:27 +0000 (15:12 +0100)
Type of certain methods that are called by dialect, so typing dialects is easier.

Related to https://github.com/sqlalchemy/sqlalchemy/pull/12164

breaking changes:

- Change modifiers from TextClause to InmutableDict, from Mapping, as is in the other classes

Closes: #12231
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12231
Pull-request-sha: 514fe4751c7b1ceefffed2a4ef9c8df339bd9c25

Change-Id: I29314045b2c7eb5428f8d6fec8911c4b6d5ae73e

18 files changed:
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/ddl.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/typing.py
pyproject.toml
test/dialect/oracle/test_dialect.py

index e57f7bfdf2186a1d1196d169aae81d2a37b8e9fc..bce08d9cc353e018eb6940560b7a249fc007a938 100644 (file)
@@ -40,7 +40,7 @@ class AsyncIODBAPIConnection(Protocol):
 
     async def commit(self) -> None: ...
 
-    def cursor(self) -> AsyncIODBAPICursor: ...
+    def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ...
 
     async def rollback(self) -> None: ...
 
index 3a32d19c8bbc1b23f3e555ee3b8b6d5c0d0d521c..8aaf223d4d9d1aabe2023650ea17dfd0bdc31759 100644 (file)
@@ -227,11 +227,9 @@ class PyODBCConnector(Connector):
         )
 
     def get_isolation_level_values(
-        self, dbapi_connection: interfaces.DBAPIConnection
+        self, dbapi_conn: interfaces.DBAPIConnection
     ) -> List[IsolationLevel]:
-        return super().get_isolation_level_values(dbapi_connection) + [
-            "AUTOCOMMIT"
-        ]
+        return [*super().get_isolation_level_values(dbapi_conn), "AUTOCOMMIT"]
 
     def set_isolation_level(
         self,
index 1f00127bfa63e2ca16751b5aeb2eea2c81271680..d25ad83552e8b7c74527650af6aea2317eeceb46 100644 (file)
@@ -1482,6 +1482,7 @@ from functools import lru_cache
 import re
 from typing import Any
 from typing import cast
+from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Tuple
@@ -3738,8 +3739,8 @@ class PGDialect(default.DefaultDialect):
     def _reflect_type(
         self,
         format_type: Optional[str],
-        domains: dict[str, ReflectedDomain],
-        enums: dict[str, ReflectedEnum],
+        domains: Dict[str, ReflectedDomain],
+        enums: Dict[str, ReflectedEnum],
         type_description: str,
     ) -> sqltypes.TypeEngine[Any]:
         """
index 56d7ee758855fdda5d1f70ee854885fe27e29b91..bff473ac5a9ad1d1ff93f510d61db650147bbab8 100644 (file)
@@ -20,6 +20,7 @@ from typing import Any
 from typing import cast
 from typing import ClassVar
 from typing import Dict
+from typing import Iterable
 from typing import Iterator
 from typing import List
 from typing import Mapping
@@ -1379,12 +1380,16 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy):
     __slots__ = ("_rowbuffer", "alternate_cursor_description")
 
     def __init__(
-        self, dbapi_cursor, alternate_description=None, initial_buffer=None
+        self,
+        dbapi_cursor: Optional[DBAPICursor],
+        alternate_description: Optional[_DBAPICursorDescription] = None,
+        initial_buffer: Optional[Iterable[Any]] = None,
     ):
         self.alternate_cursor_description = alternate_description
         if initial_buffer is not None:
             self._rowbuffer = collections.deque(initial_buffer)
         else:
+            assert dbapi_cursor is not None
             self._rowbuffer = collections.deque(dbapi_cursor.fetchall())
 
     def yield_per(self, result, dbapi_cursor, num):
index ba59ac297bca9ddb47e3f92d430b6270eb54fc00..4023019cfce4cbc2fd9bc8cc68280873902192fd 100644 (file)
@@ -80,9 +80,11 @@ if typing.TYPE_CHECKING:
     from .interfaces import _CoreSingleExecuteParams
     from .interfaces import _DBAPICursorDescription
     from .interfaces import _DBAPIMultiExecuteParams
+    from .interfaces import _DBAPISingleExecuteParams
     from .interfaces import _ExecuteOptions
     from .interfaces import _MutableCoreSingleExecuteParams
     from .interfaces import _ParamStyle
+    from .interfaces import ConnectArgsType
     from .interfaces import DBAPIConnection
     from .interfaces import IsolationLevel
     from .row import Row
@@ -102,6 +104,7 @@ if typing.TYPE_CHECKING:
     from ..sql.type_api import _ResultProcessorType
     from ..sql.type_api import TypeEngine
 
+
 # When we're handed literal SQL, ensure it's a SELECT query
 SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE)
 
@@ -440,7 +443,7 @@ class DefaultDialect(Dialect):
     def _bind_typing_render_casts(self):
         return self.bind_typing is interfaces.BindTyping.RENDER_CASTS
 
-    def _ensure_has_table_connection(self, arg):
+    def _ensure_has_table_connection(self, arg: Connection) -> None:
         if not isinstance(arg, Connection):
             raise exc.ArgumentError(
                 "The argument passed to Dialect.has_table() should be a "
@@ -524,7 +527,7 @@ class DefaultDialect(Dialect):
         else:
             return None
 
-    def initialize(self, connection):
+    def initialize(self, connection: Connection) -> None:
         try:
             self.server_version_info = self._get_server_version_info(
                 connection
@@ -560,7 +563,7 @@ class DefaultDialect(Dialect):
                 % (self.label_length, self.max_identifier_length)
             )
 
-    def on_connect(self):
+    def on_connect(self) -> Optional[Callable[[Any], Any]]:
         # inherits the docstring from interfaces.Dialect.on_connect
         return None
 
@@ -619,18 +622,18 @@ class DefaultDialect(Dialect):
     ) -> bool:
         return schema_name in self.get_schema_names(connection, **kw)
 
-    def validate_identifier(self, ident):
+    def validate_identifier(self, ident: str) -> None:
         if len(ident) > self.max_identifier_length:
             raise exc.IdentifierError(
                 "Identifier '%s' exceeds maximum length of %d characters"
                 % (ident, self.max_identifier_length)
             )
 
-    def connect(self, *cargs, **cparams):
+    def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection:
         # inherits the docstring from interfaces.Dialect.connect
-        return self.loaded_dbapi.connect(*cargs, **cparams)
+        return self.loaded_dbapi.connect(*cargs, **cparams)  # type: ignore[no-any-return]  # NOQA: E501
 
-    def create_connect_args(self, url):
+    def create_connect_args(self, url: URL) -> ConnectArgsType:
         # inherits the docstring from interfaces.Dialect.create_connect_args
         opts = url.translate_connect_args()
         opts.update(url.query)
@@ -953,7 +956,14 @@ class DefaultDialect(Dialect):
     def do_execute_no_params(self, cursor, statement, context=None):
         cursor.execute(statement)
 
-    def is_disconnect(self, e, connection, cursor):
+    def is_disconnect(
+        self,
+        e: Exception,
+        connection: Union[
+            pool.PoolProxiedConnection, interfaces.DBAPIConnection, None
+        ],
+        cursor: Optional[interfaces.DBAPICursor],
+    ) -> bool:
         return False
 
     @util.memoized_instancemethod
@@ -1669,7 +1679,12 @@ class DefaultExecutionContext(ExecutionContext):
     def no_parameters(self):
         return self.execution_options.get("no_parameters", False)
 
-    def _execute_scalar(self, stmt, type_, parameters=None):
+    def _execute_scalar(
+        self,
+        stmt: str,
+        type_: Optional[TypeEngine[Any]],
+        parameters: Optional[_DBAPISingleExecuteParams] = None,
+    ) -> Any:
         """Execute a string statement on the current cursor, returning a
         scalar result.
 
@@ -1743,7 +1758,7 @@ class DefaultExecutionContext(ExecutionContext):
 
         return use_server_side
 
-    def create_cursor(self):
+    def create_cursor(self) -> DBAPICursor:
         if (
             # inlining initial preference checks for SS cursors
             self.dialect.supports_server_side_cursors
@@ -1764,10 +1779,10 @@ class DefaultExecutionContext(ExecutionContext):
     def fetchall_for_returning(self, cursor):
         return cursor.fetchall()
 
-    def create_default_cursor(self):
+    def create_default_cursor(self) -> DBAPICursor:
         return self._dbapi_connection.cursor()
 
-    def create_server_side_cursor(self):
+    def create_server_side_cursor(self) -> DBAPICursor:
         raise NotImplementedError()
 
     def pre_exec(self):
index 35c52ae3b942df1a030ee9eb551ca3d625a322dc..464c6677b89d9d7bb009e56fbbce567a6fdc18d3 100644 (file)
@@ -122,7 +122,7 @@ class DBAPIConnection(Protocol):
 
     def commit(self) -> None: ...
 
-    def cursor(self) -> DBAPICursor: ...
+    def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ...
 
     def rollback(self) -> None: ...
 
@@ -780,6 +780,12 @@ class Dialect(EventTarget):
 
     max_identifier_length: int
     """The maximum length of identifier names."""
+    max_index_name_length: Optional[int]
+    """The maximum length of index names if different from
+    ``max_identifier_length``."""
+    max_constraint_name_length: Optional[int]
+    """The maximum length of constraint names if different from
+    ``max_identifier_length``."""
 
     supports_server_side_cursors: bool
     """indicates if the dialect supports server side cursors"""
@@ -1283,8 +1289,6 @@ class Dialect(EventTarget):
 
         """
 
-        pass
-
     if TYPE_CHECKING:
 
         def _overrides_default(self, method_name: str) -> bool: ...
@@ -2483,7 +2487,7 @@ class Dialect(EventTarget):
 
     def get_isolation_level_values(
         self, dbapi_conn: DBAPIConnection
-    ) -> List[IsolationLevel]:
+    ) -> Sequence[IsolationLevel]:
         """return a sequence of string isolation level names that are accepted
         by this dialect.
 
@@ -2657,6 +2661,9 @@ class Dialect(EventTarget):
         """return a Pool class to use for a given URL"""
         raise NotImplementedError()
 
+    def validate_identifier(self, ident: str) -> None:
+        """Validates an identifier name, raising an exception if invalid"""
+
 
 class CreateEnginePlugin:
     """A set of hooks intended to augment the construction of an
index 511eca923467cec62d14a79a4499039c2e72e672..29c28e1bb6dda36a53b3603519e3c0700f0934c8 100644 (file)
@@ -1075,7 +1075,7 @@ class PoolProxiedConnection(ManagesConnection):
 
         def commit(self) -> None: ...
 
-        def cursor(self) -> DBAPICursor: ...
+        def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ...
 
         def rollback(self) -> None: ...
 
index fc3614c06ba4218421249946d61759b9fbc4dc94..f643960e73cc71b649cc6bbba15471d402edc7e5 100644 (file)
@@ -76,7 +76,7 @@ _StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole)
 _T = TypeVar("_T", bound=Any)
 
 
-def _is_literal(element):
+def _is_literal(element: Any) -> bool:
     """Return whether or not the element is a "literal" in the context
     of a SQL expression construct.
 
index 32043dd7bb4534c632e1cf7fa66bfc6dae708c89..5f27ce05b733c5055482fdd9f981650dbc24c2c6 100644 (file)
@@ -76,19 +76,15 @@ from .base import _de_clone
 from .base import _from_objects
 from .base import _NONE_NAME
 from .base import _SentinelDefaultCharacterization
-from .base import Executable
 from .base import NO_ARG
-from .elements import ClauseElement
 from .elements import quoted_name
-from .schema import Column
 from .sqltypes import TupleType
-from .type_api import TypeEngine
 from .visitors import prefix_anon_map
-from .visitors import Visitable
 from .. import exc
 from .. import util
 from ..util import FastIntFlag
 from ..util.typing import Literal
+from ..util.typing import Self
 from ..util.typing import TupleAny
 from ..util.typing import Unpack
 
@@ -96,18 +92,33 @@ if typing.TYPE_CHECKING:
     from .annotation import _AnnotationDict
     from .base import _AmbiguousTableNameMap
     from .base import CompileState
+    from .base import Executable
     from .cache_key import CacheKey
     from .ddl import ExecutableDDLElement
     from .dml import Insert
+    from .dml import Update
     from .dml import UpdateBase
+    from .dml import UpdateDMLState
     from .dml import ValuesBase
     from .elements import _truncated_label
+    from .elements import BinaryExpression
     from .elements import BindParameter
+    from .elements import ClauseElement
     from .elements import ColumnClause
     from .elements import ColumnElement
+    from .elements import False_
     from .elements import Label
+    from .elements import Null
+    from .elements import True_
     from .functions import Function
+    from .schema import Column
+    from .schema import Constraint
+    from .schema import ForeignKeyConstraint
+    from .schema import Index
+    from .schema import PrimaryKeyConstraint
     from .schema import Table
+    from .schema import UniqueConstraint
+    from .selectable import _ColumnsClauseElement
     from .selectable import AliasedReturnsRows
     from .selectable import CompoundSelectState
     from .selectable import CTE
@@ -117,6 +128,10 @@ if typing.TYPE_CHECKING:
     from .selectable import Select
     from .selectable import SelectState
     from .type_api import _BindProcessorType
+    from .type_api import TypeDecorator
+    from .type_api import TypeEngine
+    from .type_api import UserDefinedType
+    from .visitors import Visitable
     from ..engine.cursor import CursorResultMetaData
     from ..engine.interfaces import _CoreSingleExecuteParams
     from ..engine.interfaces import _DBAPIAnyExecuteParams
@@ -128,6 +143,7 @@ if typing.TYPE_CHECKING:
     from ..engine.interfaces import Dialect
     from ..engine.interfaces import SchemaTranslateMapType
 
+
 _FromHintsType = Dict["FromClause", str]
 
 RESERVED_WORDS = {
@@ -872,6 +888,7 @@ class Compiled:
             self.string = self.process(self.statement, **compile_kwargs)
 
             if render_schema_translate:
+                assert schema_translate_map is not None
                 self.string = self.preparer._render_schema_translates(
                     self.string, schema_translate_map
                 )
@@ -904,7 +921,7 @@ class Compiled:
         raise exc.UnsupportedCompilationError(self, type(element)) from err
 
     @property
-    def sql_compiler(self):
+    def sql_compiler(self) -> SQLCompiler:
         """Return a Compiled that is capable of processing SQL expressions.
 
         If this compiler is one, it would likely just return 'self'.
@@ -1793,7 +1810,7 @@ class SQLCompiler(Compiled):
         return len(self.stack) > 1
 
     @property
-    def sql_compiler(self):
+    def sql_compiler(self) -> Self:
         return self
 
     def construct_expanded_state(
@@ -2344,7 +2361,7 @@ class SQLCompiler(Compiled):
 
         return get
 
-    def default_from(self):
+    def default_from(self) -> str:
         """Called when a SELECT statement has no froms, and no FROM clause is
         to be appended.
 
@@ -2736,16 +2753,16 @@ class SQLCompiler(Compiled):
 
         return text
 
-    def visit_null(self, expr, **kw):
+    def visit_null(self, expr: Null, **kw: Any) -> str:
         return "NULL"
 
-    def visit_true(self, expr, **kw):
+    def visit_true(self, expr: True_, **kw: Any) -> str:
         if self.dialect.supports_native_boolean:
             return "true"
         else:
             return "1"
 
-    def visit_false(self, expr, **kw):
+    def visit_false(self, expr: False_, **kw: Any) -> str:
         if self.dialect.supports_native_boolean:
             return "false"
         else:
@@ -2976,7 +2993,7 @@ class SQLCompiler(Compiled):
             % self.dialect.name
         )
 
-    def function_argspec(self, func, **kwargs):
+    def function_argspec(self, func: Function[Any], **kwargs: Any) -> str:
         return func.clause_expr._compiler_dispatch(self, **kwargs)
 
     def visit_compound_select(
@@ -3440,8 +3457,12 @@ class SQLCompiler(Compiled):
         )
 
     def _generate_generic_binary(
-        self, binary, opstring, eager_grouping=False, **kw
-    ):
+        self,
+        binary: BinaryExpression[Any],
+        opstring: str,
+        eager_grouping: bool = False,
+        **kw: Any,
+    ) -> str:
         _in_operator_expression = kw.get("_in_operator_expression", False)
 
         kw["_in_operator_expression"] = True
@@ -3610,19 +3631,25 @@ class SQLCompiler(Compiled):
             **kw,
         )
 
-    def visit_regexp_match_op_binary(self, binary, operator, **kw):
+    def visit_regexp_match_op_binary(
+        self, binary: BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         raise exc.CompileError(
             "%s dialect does not support regular expressions"
             % self.dialect.name
         )
 
-    def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+    def visit_not_regexp_match_op_binary(
+        self, binary: BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         raise exc.CompileError(
             "%s dialect does not support regular expressions"
             % self.dialect.name
         )
 
-    def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+    def visit_regexp_replace_op_binary(
+        self, binary: BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
         raise exc.CompileError(
             "%s dialect does not support regular expression replacements"
             % self.dialect.name
@@ -3829,7 +3856,9 @@ class SQLCompiler(Compiled):
         else:
             return self.render_literal_value(value, bindparam.type)
 
-    def render_literal_value(self, value, type_):
+    def render_literal_value(
+        self, value: Any, type_: sqltypes.TypeEngine[Any]
+    ) -> str:
         """Render the value of a bind parameter as a quoted literal.
 
         This is used for statement sections that do not accept bind parameters
@@ -4603,7 +4632,9 @@ class SQLCompiler(Compiled):
     def get_select_hint_text(self, byfroms):
         return None
 
-    def get_from_hint_text(self, table, text):
+    def get_from_hint_text(
+        self, table: FromClause, text: Optional[str]
+    ) -> Optional[str]:
         return None
 
     def get_crud_hint_text(self, table, text):
@@ -5109,7 +5140,7 @@ class SQLCompiler(Compiled):
         else:
             return "WITH"
 
-    def get_select_precolumns(self, select, **kw):
+    def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str:
         """Called when building a ``SELECT`` statement, position is just
         before column list.
 
@@ -5154,7 +5185,7 @@ class SQLCompiler(Compiled):
     def returning_clause(
         self,
         stmt: UpdateBase,
-        returning_cols: Sequence[ColumnElement[Any]],
+        returning_cols: Sequence[_ColumnsClauseElement],
         *,
         populate_result_map: bool,
         **kw: Any,
@@ -6187,11 +6218,18 @@ class SQLCompiler(Compiled):
         else:
             return None
 
-    def visit_update(self, update_stmt, visiting_cte=None, **kw):
-        compile_state = update_stmt._compile_state_factory(
-            update_stmt, self, **kw
+    def visit_update(
+        self,
+        update_stmt: Update,
+        visiting_cte: Optional[CTE] = None,
+        **kw: Any,
+    ) -> str:
+        compile_state = update_stmt._compile_state_factory(  # type: ignore[call-arg] # noqa: E501
+            update_stmt, self, **kw  # type: ignore[arg-type]
         )
-        update_stmt = compile_state.statement
+        if TYPE_CHECKING:
+            assert isinstance(compile_state, UpdateDMLState)
+        update_stmt = compile_state.statement  # type: ignore[assignment]
 
         if visiting_cte is not None:
             kw["visiting_cte"] = visiting_cte
@@ -6331,7 +6369,7 @@ class SQLCompiler(Compiled):
         return text
 
     def delete_extra_from_clause(
-        self, update_stmt, from_table, extra_froms, from_hints, **kw
+        self, delete_stmt, from_table, extra_froms, from_hints, **kw
     ):
         """Provide a hook to override the generation of an
         DELETE..FROM clause.
@@ -6555,7 +6593,7 @@ class StrSQLCompiler(SQLCompiler):
     def returning_clause(
         self,
         stmt: UpdateBase,
-        returning_cols: Sequence[ColumnElement[Any]],
+        returning_cols: Sequence[_ColumnsClauseElement],
         *,
         populate_result_map: bool,
         **kw: Any,
@@ -6576,7 +6614,7 @@ class StrSQLCompiler(SQLCompiler):
         )
 
     def delete_extra_from_clause(
-        self, update_stmt, from_table, extra_froms, from_hints, **kw
+        self, delete_stmt, from_table, extra_froms, from_hints, **kw
     ):
         kw["asfrom"] = True
         return ", " + ", ".join(
@@ -6623,8 +6661,8 @@ class DDLCompiler(Compiled):
             compile_kwargs: Mapping[str, Any] = ...,
         ): ...
 
-    @util.memoized_property
-    def sql_compiler(self):
+    @util.ro_memoized_property
+    def sql_compiler(self) -> SQLCompiler:
         return self.dialect.statement_compiler(
             self.dialect, None, schema_translate_map=self.schema_translate_map
         )
@@ -6788,7 +6826,7 @@ class DDLCompiler(Compiled):
     def visit_drop_view(self, drop, **kw):
         return "\nDROP VIEW " + self.preparer.format_table(drop.element)
 
-    def _verify_index_table(self, index):
+    def _verify_index_table(self, index: Index) -> None:
         if index.table is None:
             raise exc.CompileError(
                 "Index '%s' is not associated with any table." % index.name
@@ -6839,7 +6877,9 @@ class DDLCompiler(Compiled):
 
         return text + self._prepared_index_name(index, include_schema=True)
 
-    def _prepared_index_name(self, index, include_schema=False):
+    def _prepared_index_name(
+        self, index: Index, include_schema: bool = False
+    ) -> str:
         if index.table is not None:
             effective_schema = self.preparer.schema_for_object(index.table)
         else:
@@ -6986,13 +7026,13 @@ class DDLCompiler(Compiled):
     def post_create_table(self, table):
         return ""
 
-    def get_column_default_string(self, column):
+    def get_column_default_string(self, column: Column[Any]) -> Optional[str]:
         if isinstance(column.server_default, schema.DefaultClause):
             return self.render_default_string(column.server_default.arg)
         else:
             return None
 
-    def render_default_string(self, default):
+    def render_default_string(self, default: Union[Visitable, str]) -> str:
         if isinstance(default, str):
             return self.sql_compiler.render_literal_value(
                 default, sqltypes.STRINGTYPE
@@ -7030,7 +7070,9 @@ class DDLCompiler(Compiled):
         text += self.define_constraint_deferrability(constraint)
         return text
 
-    def visit_primary_key_constraint(self, constraint, **kw):
+    def visit_primary_key_constraint(
+        self, constraint: PrimaryKeyConstraint, **kw: Any
+    ) -> str:
         if len(constraint) == 0:
             return ""
         text = ""
@@ -7079,7 +7121,9 @@ class DDLCompiler(Compiled):
 
         return preparer.format_table(table)
 
-    def visit_unique_constraint(self, constraint, **kw):
+    def visit_unique_constraint(
+        self, constraint: UniqueConstraint, **kw: Any
+    ) -> str:
         if len(constraint) == 0:
             return ""
         text = ""
@@ -7094,10 +7138,14 @@ class DDLCompiler(Compiled):
         text += self.define_constraint_deferrability(constraint)
         return text
 
-    def define_unique_constraint_distinct(self, constraint, **kw):
+    def define_unique_constraint_distinct(
+        self, constraint: UniqueConstraint, **kw: Any
+    ) -> str:
         return ""
 
-    def define_constraint_cascades(self, constraint):
+    def define_constraint_cascades(
+        self, constraint: ForeignKeyConstraint
+    ) -> str:
         text = ""
         if constraint.ondelete is not None:
             text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
@@ -7109,7 +7157,7 @@ class DDLCompiler(Compiled):
             )
         return text
 
-    def define_constraint_deferrability(self, constraint):
+    def define_constraint_deferrability(self, constraint: Constraint) -> str:
         text = ""
         if constraint.deferrable is not None:
             if constraint.deferrable:
@@ -7149,19 +7197,21 @@ class DDLCompiler(Compiled):
 
 
 class GenericTypeCompiler(TypeCompiler):
-    def visit_FLOAT(self, type_, **kw):
+    def visit_FLOAT(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
         return "FLOAT"
 
-    def visit_DOUBLE(self, type_, **kw):
+    def visit_DOUBLE(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
         return "DOUBLE"
 
-    def visit_DOUBLE_PRECISION(self, type_, **kw):
+    def visit_DOUBLE_PRECISION(
+        self, type_: sqltypes.DOUBLE_PRECISION[Any], **kw: Any
+    ) -> str:
         return "DOUBLE PRECISION"
 
-    def visit_REAL(self, type_, **kw):
+    def visit_REAL(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
         return "REAL"
 
-    def visit_NUMERIC(self, type_, **kw):
+    def visit_NUMERIC(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
         if type_.precision is None:
             return "NUMERIC"
         elif type_.scale is None:
@@ -7172,7 +7222,7 @@ class GenericTypeCompiler(TypeCompiler):
                 "scale": type_.scale,
             }
 
-    def visit_DECIMAL(self, type_, **kw):
+    def visit_DECIMAL(self, type_: sqltypes.DECIMAL[Any], **kw: Any) -> str:
         if type_.precision is None:
             return "DECIMAL"
         elif type_.scale is None:
@@ -7183,128 +7233,138 @@ class GenericTypeCompiler(TypeCompiler):
                 "scale": type_.scale,
             }
 
-    def visit_INTEGER(self, type_, **kw):
+    def visit_INTEGER(self, type_: sqltypes.Integer, **kw: Any) -> str:
         return "INTEGER"
 
-    def visit_SMALLINT(self, type_, **kw):
+    def visit_SMALLINT(self, type_: sqltypes.SmallInteger, **kw: Any) -> str:
         return "SMALLINT"
 
-    def visit_BIGINT(self, type_, **kw):
+    def visit_BIGINT(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
         return "BIGINT"
 
-    def visit_TIMESTAMP(self, type_, **kw):
+    def visit_TIMESTAMP(self, type_: sqltypes.TIMESTAMP, **kw: Any) -> str:
         return "TIMESTAMP"
 
-    def visit_DATETIME(self, type_, **kw):
+    def visit_DATETIME(self, type_: sqltypes.DateTime, **kw: Any) -> str:
         return "DATETIME"
 
-    def visit_DATE(self, type_, **kw):
+    def visit_DATE(self, type_: sqltypes.Date, **kw: Any) -> str:
         return "DATE"
 
-    def visit_TIME(self, type_, **kw):
+    def visit_TIME(self, type_: sqltypes.Time, **kw: Any) -> str:
         return "TIME"
 
-    def visit_CLOB(self, type_, **kw):
+    def visit_CLOB(self, type_: sqltypes.CLOB, **kw: Any) -> str:
         return "CLOB"
 
-    def visit_NCLOB(self, type_, **kw):
+    def visit_NCLOB(self, type_: sqltypes.Text, **kw: Any) -> str:
         return "NCLOB"
 
-    def _render_string_type(self, type_, name, length_override=None):
+    def _render_string_type(
+        self, name: str, length: Optional[int], collation: Optional[str]
+    ) -> str:
         text = name
-        if length_override:
-            text += "(%d)" % length_override
-        elif type_.length:
-            text += "(%d)" % type_.length
-        if type_.collation:
-            text += ' COLLATE "%s"' % type_.collation
+        if length:
+            text += f"({length})"
+        if collation:
+            text += f' COLLATE "{collation}"'
         return text
 
-    def visit_CHAR(self, type_, **kw):
-        return self._render_string_type(type_, "CHAR")
+    def visit_CHAR(self, type_: sqltypes.CHAR, **kw: Any) -> str:
+        return self._render_string_type("CHAR", type_.length, type_.collation)
 
-    def visit_NCHAR(self, type_, **kw):
-        return self._render_string_type(type_, "NCHAR")
+    def visit_NCHAR(self, type_: sqltypes.NCHAR, **kw: Any) -> str:
+        return self._render_string_type("NCHAR", type_.length, type_.collation)
 
-    def visit_VARCHAR(self, type_, **kw):
-        return self._render_string_type(type_, "VARCHAR")
+    def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
+        return self._render_string_type(
+            "VARCHAR", type_.length, type_.collation
+        )
 
-    def visit_NVARCHAR(self, type_, **kw):
-        return self._render_string_type(type_, "NVARCHAR")
+    def visit_NVARCHAR(self, type_: sqltypes.NVARCHAR, **kw: Any) -> str:
+        return self._render_string_type(
+            "NVARCHAR", type_.length, type_.collation
+        )
 
-    def visit_TEXT(self, type_, **kw):
-        return self._render_string_type(type_, "TEXT")
+    def visit_TEXT(self, type_: sqltypes.Text, **kw: Any) -> str:
+        return self._render_string_type("TEXT", type_.length, type_.collation)
 
-    def visit_UUID(self, type_, **kw):
+    def visit_UUID(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
         return "UUID"
 
-    def visit_BLOB(self, type_, **kw):
+    def visit_BLOB(self, type_: sqltypes.LargeBinary, **kw: Any) -> str:
         return "BLOB"
 
-    def visit_BINARY(self, type_, **kw):
+    def visit_BINARY(self, type_: sqltypes.BINARY, **kw: Any) -> str:
         return "BINARY" + (type_.length and "(%d)" % type_.length or "")
 
-    def visit_VARBINARY(self, type_, **kw):
+    def visit_VARBINARY(self, type_: sqltypes.VARBINARY, **kw: Any) -> str:
         return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
 
-    def visit_BOOLEAN(self, type_, **kw):
+    def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str:
         return "BOOLEAN"
 
-    def visit_uuid(self, type_, **kw):
+    def visit_uuid(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
         if not type_.native_uuid or not self.dialect.supports_native_uuid:
-            return self._render_string_type(type_, "CHAR", length_override=32)
+            return self._render_string_type("CHAR", length=32, collation=None)
         else:
             return self.visit_UUID(type_, **kw)
 
-    def visit_large_binary(self, type_, **kw):
+    def visit_large_binary(
+        self, type_: sqltypes.LargeBinary, **kw: Any
+    ) -> str:
         return self.visit_BLOB(type_, **kw)
 
-    def visit_boolean(self, type_, **kw):
+    def visit_boolean(self, type_: sqltypes.Boolean, **kw: Any) -> str:
         return self.visit_BOOLEAN(type_, **kw)
 
-    def visit_time(self, type_, **kw):
+    def visit_time(self, type_: sqltypes.Time, **kw: Any) -> str:
         return self.visit_TIME(type_, **kw)
 
-    def visit_datetime(self, type_, **kw):
+    def visit_datetime(self, type_: sqltypes.DateTime, **kw: Any) -> str:
         return self.visit_DATETIME(type_, **kw)
 
-    def visit_date(self, type_, **kw):
+    def visit_date(self, type_: sqltypes.Date, **kw: Any) -> str:
         return self.visit_DATE(type_, **kw)
 
-    def visit_big_integer(self, type_, **kw):
+    def visit_big_integer(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
         return self.visit_BIGINT(type_, **kw)
 
-    def visit_small_integer(self, type_, **kw):
+    def visit_small_integer(
+        self, type_: sqltypes.SmallInteger, **kw: Any
+    ) -> str:
         return self.visit_SMALLINT(type_, **kw)
 
-    def visit_integer(self, type_, **kw):
+    def visit_integer(self, type_: sqltypes.Integer, **kw: Any) -> str:
         return self.visit_INTEGER(type_, **kw)
 
-    def visit_real(self, type_, **kw):
+    def visit_real(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
         return self.visit_REAL(type_, **kw)
 
-    def visit_float(self, type_, **kw):
+    def visit_float(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
         return self.visit_FLOAT(type_, **kw)
 
-    def visit_double(self, type_, **kw):
+    def visit_double(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
         return self.visit_DOUBLE(type_, **kw)
 
-    def visit_numeric(self, type_, **kw):
+    def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
         return self.visit_NUMERIC(type_, **kw)
 
-    def visit_string(self, type_, **kw):
+    def visit_string(self, type_: sqltypes.String, **kw: Any) -> str:
         return self.visit_VARCHAR(type_, **kw)
 
-    def visit_unicode(self, type_, **kw):
+    def visit_unicode(self, type_: sqltypes.Unicode, **kw: Any) -> str:
         return self.visit_VARCHAR(type_, **kw)
 
-    def visit_text(self, type_, **kw):
+    def visit_text(self, type_: sqltypes.Text, **kw: Any) -> str:
         return self.visit_TEXT(type_, **kw)
 
-    def visit_unicode_text(self, type_, **kw):
+    def visit_unicode_text(
+        self, type_: sqltypes.UnicodeText, **kw: Any
+    ) -> str:
         return self.visit_TEXT(type_, **kw)
 
-    def visit_enum(self, type_, **kw):
+    def visit_enum(self, type_: sqltypes.Enum, **kw: Any) -> str:
         return self.visit_VARCHAR(type_, **kw)
 
     def visit_null(self, type_, **kw):
@@ -7314,10 +7374,14 @@ class GenericTypeCompiler(TypeCompiler):
             "type on this Column?" % type_
         )
 
-    def visit_type_decorator(self, type_, **kw):
+    def visit_type_decorator(
+        self, type_: TypeDecorator[Any], **kw: Any
+    ) -> str:
         return self.process(type_.type_engine(self.dialect), **kw)
 
-    def visit_user_defined(self, type_, **kw):
+    def visit_user_defined(
+        self, type_: UserDefinedType[Any], **kw: Any
+    ) -> str:
         return type_.get_col_spec(**kw)
 
 
@@ -7392,12 +7456,12 @@ class IdentifierPreparer:
 
     def __init__(
         self,
-        dialect,
-        initial_quote='"',
-        final_quote=None,
-        escape_quote='"',
-        quote_case_sensitive_collations=True,
-        omit_schema=False,
+        dialect: Dialect,
+        initial_quote: str = '"',
+        final_quote: Optional[str] = None,
+        escape_quote: str = '"',
+        quote_case_sensitive_collations: bool = True,
+        omit_schema: bool = False,
     ):
         """Construct a new ``IdentifierPreparer`` object.
 
@@ -7450,7 +7514,9 @@ class IdentifierPreparer:
         prep._includes_none_schema_translate = includes_none
         return prep
 
-    def _render_schema_translates(self, statement, schema_translate_map):
+    def _render_schema_translates(
+        self, statement: str, schema_translate_map: SchemaTranslateMapType
+    ) -> str:
         d = schema_translate_map
         if None in d:
             if not self._includes_none_schema_translate:
@@ -7462,7 +7528,7 @@ class IdentifierPreparer:
                     "schema_translate_map dictionaries."
                 )
 
-            d["_none"] = d[None]
+            d["_none"] = d[None]  # type: ignore[index]
 
         def replace(m):
             name = m.group(2)
@@ -7655,7 +7721,9 @@ class IdentifierPreparer:
         else:
             return collation_name
 
-    def format_sequence(self, sequence, use_schema=True):
+    def format_sequence(
+        self, sequence: schema.Sequence, use_schema: bool = True
+    ) -> str:
         name = self.quote(sequence.name)
 
         effective_schema = self.schema_for_object(sequence)
@@ -7692,7 +7760,9 @@ class IdentifierPreparer:
         return ident
 
     @util.preload_module("sqlalchemy.sql.naming")
-    def format_constraint(self, constraint, _alembic_quote=True):
+    def format_constraint(
+        self, constraint: Union[Constraint, Index], _alembic_quote: bool = True
+    ) -> Optional[str]:
         naming = util.preloaded.sql_naming
 
         if constraint.name is _NONE_NAME:
@@ -7705,6 +7775,7 @@ class IdentifierPreparer:
         else:
             name = constraint.name
 
+        assert name is not None
         if constraint.__visit_name__ == "index":
             return self.truncate_and_render_index_name(
                 name, _alembic_quote=_alembic_quote
@@ -7714,7 +7785,9 @@ class IdentifierPreparer:
                 name, _alembic_quote=_alembic_quote
             )
 
-    def truncate_and_render_index_name(self, name, _alembic_quote=True):
+    def truncate_and_render_index_name(
+        self, name: str, _alembic_quote: bool = True
+    ) -> str:
         # calculate these at format time so that ad-hoc changes
         # to dialect.max_identifier_length etc. can be reflected
         # as IdentifierPreparer is long lived
@@ -7726,7 +7799,9 @@ class IdentifierPreparer:
             name, max_, _alembic_quote
         )
 
-    def truncate_and_render_constraint_name(self, name, _alembic_quote=True):
+    def truncate_and_render_constraint_name(
+        self, name: str, _alembic_quote: bool = True
+    ) -> str:
         # calculate these at format time so that ad-hoc changes
         # to dialect.max_identifier_length etc. can be reflected
         # as IdentifierPreparer is long lived
@@ -7738,7 +7813,9 @@ class IdentifierPreparer:
             name, max_, _alembic_quote
         )
 
-    def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote):
+    def _truncate_and_render_maxlen_name(
+        self, name: str, max_: int, _alembic_quote: bool
+    ) -> str:
         if isinstance(name, elements._truncated_label):
             if len(name) > max_:
                 name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
@@ -7750,13 +7827,21 @@ class IdentifierPreparer:
         else:
             return self.quote(name)
 
-    def format_index(self, index):
-        return self.format_constraint(index)
+    def format_index(self, index: Index) -> str:
+        name = self.format_constraint(index)
+        assert name is not None
+        return name
 
-    def format_table(self, table, use_schema=True, name=None):
+    def format_table(
+        self,
+        table: FromClause,
+        use_schema: bool = True,
+        name: Optional[str] = None,
+    ) -> str:
         """Prepare a quoted table and schema name."""
-
         if name is None:
+            if TYPE_CHECKING:
+                assert isinstance(table, NamedFromClause)
             name = table.name
 
         result = self.quote(name)
@@ -7788,17 +7873,18 @@ class IdentifierPreparer:
 
     def format_column(
         self,
-        column,
-        use_table=False,
-        name=None,
-        table_name=None,
-        use_schema=False,
-        anon_map=None,
-    ):
+        column: ColumnElement[Any],
+        use_table: bool = False,
+        name: Optional[str] = None,
+        table_name: Optional[str] = None,
+        use_schema: bool = False,
+        anon_map: Optional[Mapping[str, Any]] = None,
+    ) -> str:
         """Prepare a quoted column name."""
 
         if name is None:
             name = column.name
+            assert name is not None
 
         if anon_map is not None and isinstance(
             name, elements._truncated_label
@@ -7866,7 +7952,7 @@ class IdentifierPreparer:
         )
         return r
 
-    def unformat_identifiers(self, identifiers):
+    def unformat_identifiers(self, identifiers: str) -> Sequence[str]:
         """Unpack 'schema.table.column'-like strings into components."""
 
         r = self._r_identifiers
index 4e1973ea02455a5a9b8ffb6245f0570c1ceb9ae6..b1a115f49dfac027fa83e82c1ef6b06d2a681852 100644 (file)
@@ -17,12 +17,15 @@ import contextlib
 import typing
 from typing import Any
 from typing import Callable
+from typing import Generic
 from typing import Iterable
 from typing import List
 from typing import Optional
 from typing import Protocol
 from typing import Sequence as typing_Sequence
 from typing import Tuple
+from typing import TypeVar
+from typing import Union
 
 from . import roles
 from .base import _generative
@@ -38,10 +41,12 @@ if typing.TYPE_CHECKING:
     from .compiler import Compiled
     from .compiler import DDLCompiler
     from .elements import BindParameter
+    from .schema import Column
     from .schema import Constraint
     from .schema import ForeignKeyConstraint
+    from .schema import Index
     from .schema import SchemaItem
-    from .schema import Sequence
+    from .schema import Sequence as Sequence  # noqa: F401
     from .schema import Table
     from .selectable import TableClause
     from ..engine.base import Connection
@@ -50,6 +55,8 @@ if typing.TYPE_CHECKING:
     from ..engine.interfaces import Dialect
     from ..engine.interfaces import SchemaTranslateMapType
 
+_SI = TypeVar("_SI", bound=Union["SchemaItem", str])
+
 
 class BaseDDLElement(ClauseElement):
     """The root of DDL constructs, including those that are sub-elements
@@ -87,7 +94,7 @@ class DDLIfCallable(Protocol):
     def __call__(
         self,
         ddl: BaseDDLElement,
-        target: SchemaItem,
+        target: Union[SchemaItem, str],
         bind: Optional[Connection],
         tables: Optional[List[Table]] = None,
         state: Optional[Any] = None,
@@ -106,7 +113,7 @@ class DDLIf(typing.NamedTuple):
     def _should_execute(
         self,
         ddl: BaseDDLElement,
-        target: SchemaItem,
+        target: Union[SchemaItem, str],
         bind: Optional[Connection],
         compiler: Optional[DDLCompiler] = None,
         **kw: Any,
@@ -172,7 +179,7 @@ class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement):
     """
 
     _ddl_if: Optional[DDLIf] = None
-    target: Optional[SchemaItem] = None
+    target: Union[SchemaItem, str, None] = None
 
     def _execute_on_connection(
         self, connection, distilled_params, execution_options
@@ -415,7 +422,7 @@ class DDL(ExecutableDDLElement):
         )
 
 
-class _CreateDropBase(ExecutableDDLElement):
+class _CreateDropBase(ExecutableDDLElement, Generic[_SI]):
     """Base class for DDL constructs that represent CREATE and DROP or
     equivalents.
 
@@ -425,15 +432,13 @@ class _CreateDropBase(ExecutableDDLElement):
 
     """
 
-    def __init__(
-        self,
-        element,
-    ):
+    def __init__(self, element: _SI) -> None:
         self.element = self.target = element
         self._ddl_if = getattr(element, "_ddl_if", None)
 
     @property
     def stringify_dialect(self):
+        assert not isinstance(self.element, str)
         return self.element.create_drop_stringify_dialect
 
     def _create_rule_disable(self, compiler):
@@ -447,19 +452,19 @@ class _CreateDropBase(ExecutableDDLElement):
         return False
 
 
-class _CreateBase(_CreateDropBase):
-    def __init__(self, element, if_not_exists=False):
+class _CreateBase(_CreateDropBase[_SI]):
+    def __init__(self, element: _SI, if_not_exists: bool = False) -> None:
         super().__init__(element)
         self.if_not_exists = if_not_exists
 
 
-class _DropBase(_CreateDropBase):
-    def __init__(self, element, if_exists=False):
+class _DropBase(_CreateDropBase[_SI]):
+    def __init__(self, element: _SI, if_exists: bool = False) -> None:
         super().__init__(element)
         self.if_exists = if_exists
 
 
-class CreateSchema(_CreateBase):
+class CreateSchema(_CreateBase[str]):
     """Represent a CREATE SCHEMA statement.
 
     The argument here is the string name of the schema.
@@ -474,13 +479,13 @@ class CreateSchema(_CreateBase):
         self,
         name: str,
         if_not_exists: bool = False,
-    ):
+    ) -> None:
         """Create a new :class:`.CreateSchema` construct."""
 
         super().__init__(element=name, if_not_exists=if_not_exists)
 
 
-class DropSchema(_DropBase):
+class DropSchema(_DropBase[str]):
     """Represent a DROP SCHEMA statement.
 
     The argument here is the string name of the schema.
@@ -496,14 +501,14 @@ class DropSchema(_DropBase):
         name: str,
         cascade: bool = False,
         if_exists: bool = False,
-    ):
+    ) -> None:
         """Create a new :class:`.DropSchema` construct."""
 
         super().__init__(element=name, if_exists=if_exists)
         self.cascade = cascade
 
 
-class CreateTable(_CreateBase):
+class CreateTable(_CreateBase["Table"]):
     """Represent a CREATE TABLE statement."""
 
     __visit_name__ = "create_table"
@@ -515,7 +520,7 @@ class CreateTable(_CreateBase):
             typing_Sequence[ForeignKeyConstraint]
         ] = None,
         if_not_exists: bool = False,
-    ):
+    ) -> None:
         """Create a :class:`.CreateTable` construct.
 
         :param element: a :class:`_schema.Table` that's the subject
@@ -537,7 +542,7 @@ class CreateTable(_CreateBase):
         self.include_foreign_key_constraints = include_foreign_key_constraints
 
 
-class _DropView(_DropBase):
+class _DropView(_DropBase["Table"]):
     """Semi-public 'DROP VIEW' construct.
 
     Used by the test suite for dialect-agnostic drops of views.
@@ -549,7 +554,9 @@ class _DropView(_DropBase):
 
 
 class CreateConstraint(BaseDDLElement):
-    def __init__(self, element: Constraint):
+    element: Constraint
+
+    def __init__(self, element: Constraint) -> None:
         self.element = element
 
 
@@ -666,16 +673,18 @@ class CreateColumn(BaseDDLElement):
 
     __visit_name__ = "create_column"
 
-    def __init__(self, element):
+    element: Column[Any]
+
+    def __init__(self, element: Column[Any]) -> None:
         self.element = element
 
 
-class DropTable(_DropBase):
+class DropTable(_DropBase["Table"]):
     """Represent a DROP TABLE statement."""
 
     __visit_name__ = "drop_table"
 
-    def __init__(self, element: Table, if_exists: bool = False):
+    def __init__(self, element: Table, if_exists: bool = False) -> None:
         """Create a :class:`.DropTable` construct.
 
         :param element: a :class:`_schema.Table` that's the subject
@@ -690,30 +699,24 @@ class DropTable(_DropBase):
         super().__init__(element, if_exists=if_exists)
 
 
-class CreateSequence(_CreateBase):
+class CreateSequence(_CreateBase["Sequence"]):
     """Represent a CREATE SEQUENCE statement."""
 
     __visit_name__ = "create_sequence"
 
-    def __init__(self, element: Sequence, if_not_exists: bool = False):
-        super().__init__(element, if_not_exists=if_not_exists)
-
 
-class DropSequence(_DropBase):
+class DropSequence(_DropBase["Sequence"]):
     """Represent a DROP SEQUENCE statement."""
 
     __visit_name__ = "drop_sequence"
 
-    def __init__(self, element: Sequence, if_exists: bool = False):
-        super().__init__(element, if_exists=if_exists)
 
-
-class CreateIndex(_CreateBase):
+class CreateIndex(_CreateBase["Index"]):
     """Represent a CREATE INDEX statement."""
 
     __visit_name__ = "create_index"
 
-    def __init__(self, element, if_not_exists=False):
+    def __init__(self, element: Index, if_not_exists: bool = False) -> None:
         """Create a :class:`.Createindex` construct.
 
         :param element: a :class:`_schema.Index` that's the subject
@@ -727,12 +730,12 @@ class CreateIndex(_CreateBase):
         super().__init__(element, if_not_exists=if_not_exists)
 
 
-class DropIndex(_DropBase):
+class DropIndex(_DropBase["Index"]):
     """Represent a DROP INDEX statement."""
 
     __visit_name__ = "drop_index"
 
-    def __init__(self, element, if_exists=False):
+    def __init__(self, element: Index, if_exists: bool = False) -> None:
         """Create a :class:`.DropIndex` construct.
 
         :param element: a :class:`_schema.Index` that's the subject
@@ -746,7 +749,7 @@ class DropIndex(_DropBase):
         super().__init__(element, if_exists=if_exists)
 
 
-class AddConstraint(_CreateBase):
+class AddConstraint(_CreateBase["Constraint"]):
     """Represent an ALTER TABLE ADD CONSTRAINT statement."""
 
     __visit_name__ = "add_constraint"
@@ -756,7 +759,7 @@ class AddConstraint(_CreateBase):
         element: Constraint,
         *,
         isolate_from_table: bool = True,
-    ):
+    ) -> None:
         """Construct a new :class:`.AddConstraint` construct.
 
         :param element: a :class:`.Constraint` object
@@ -780,7 +783,7 @@ class AddConstraint(_CreateBase):
             )
 
 
-class DropConstraint(_DropBase):
+class DropConstraint(_DropBase["Constraint"]):
     """Represent an ALTER TABLE DROP CONSTRAINT statement."""
 
     __visit_name__ = "drop_constraint"
@@ -793,7 +796,7 @@ class DropConstraint(_DropBase):
         if_exists: bool = False,
         isolate_from_table: bool = True,
         **kw: Any,
-    ):
+    ) -> None:
         """Construct a new :class:`.DropConstraint` construct.
 
         :param element: a :class:`.Constraint` object
@@ -821,13 +824,13 @@ class DropConstraint(_DropBase):
             )
 
 
-class SetTableComment(_CreateDropBase):
+class SetTableComment(_CreateDropBase["Table"]):
     """Represent a COMMENT ON TABLE IS statement."""
 
     __visit_name__ = "set_table_comment"
 
 
-class DropTableComment(_CreateDropBase):
+class DropTableComment(_CreateDropBase["Table"]):
     """Represent a COMMENT ON TABLE '' statement.
 
     Note this varies a lot across database backends.
@@ -837,25 +840,25 @@ class DropTableComment(_CreateDropBase):
     __visit_name__ = "drop_table_comment"
 
 
-class SetColumnComment(_CreateDropBase):
+class SetColumnComment(_CreateDropBase["Column[Any]"]):
     """Represent a COMMENT ON COLUMN IS statement."""
 
     __visit_name__ = "set_column_comment"
 
 
-class DropColumnComment(_CreateDropBase):
+class DropColumnComment(_CreateDropBase["Column[Any]"]):
     """Represent a COMMENT ON COLUMN IS NULL statement."""
 
     __visit_name__ = "drop_column_comment"
 
 
-class SetConstraintComment(_CreateDropBase):
+class SetConstraintComment(_CreateDropBase["Constraint"]):
     """Represent a COMMENT ON CONSTRAINT IS statement."""
 
     __visit_name__ = "set_constraint_comment"
 
 
-class DropConstraintComment(_CreateDropBase):
+class DropConstraintComment(_CreateDropBase["Constraint"]):
     """Represent a COMMENT ON CONSTRAINT IS NULL statement."""
 
     __visit_name__ = "drop_constraint_comment"
index 8d256ea3772f077a0636306b9637af77e4d601a3..e394f73f4fde9ae3a53124d6d2e88cee9ba1855a 100644 (file)
@@ -2225,8 +2225,9 @@ class TypeClause(DQLDMLClauseElement):
     _traverse_internals: _TraverseInternalsType = [
         ("type", InternalTraversal.dp_type)
     ]
+    type: TypeEngine[Any]
 
-    def __init__(self, type_):
+    def __init__(self, type_: TypeEngine[Any]):
         self.type = type_
 
 
@@ -3913,10 +3914,9 @@ class BinaryExpression(OperatorExpression[_T]):
 
     """
 
-    modifiers: Optional[Mapping[str, Any]]
-
     left: ColumnElement[Any]
     right: ColumnElement[Any]
+    modifiers: Mapping[str, Any]
 
     def __init__(
         self,
index 3fcf22ee6865ecdf5cd21b26d43a913635582346..131a0f2e28153448b6a8a32631d086c4aa3042c0 100644 (file)
@@ -23,6 +23,7 @@ from typing import cast
 from typing import Dict
 from typing import Generic
 from typing import List
+from typing import Mapping
 from typing import Optional
 from typing import overload
 from typing import Sequence
@@ -246,10 +247,14 @@ class String(Concatenable, TypeEngine[str]):
 
         return process
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> Optional[_BindProcessorType[str]]:
         return None
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> Optional[_ResultProcessorType[str]]:
         return None
 
     @property
@@ -426,7 +431,7 @@ class NumericCommon(HasExpressionLookup, TypeEngineMixin, Generic[_N]):
     if TYPE_CHECKING:
 
         @util.ro_memoized_property
-        def _type_affinity(self) -> Type[NumericCommon[_N]]: ...
+        def _type_affinity(self) -> Type[Union[Numeric[_N], Float[_N]]]: ...
 
     def __init__(
         self,
@@ -653,8 +658,6 @@ class Float(NumericCommon[_N], TypeEngine[_N]):
 
     __visit_name__ = "float"
 
-    scale = None
-
     @overload
     def __init__(
         self: Float[float],
@@ -925,6 +928,8 @@ class Time(_RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.time]):
 class _Binary(TypeEngine[bytes]):
     """Define base behavior for binary types."""
 
+    length: Optional[int]
+
     def __init__(self, length: Optional[int] = None):
         self.length = length
 
@@ -1249,6 +1254,9 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
             return _we_are_the_impl(variant_mapping["_default"])
 
 
+_EnumTupleArg = Union[Sequence[enum.Enum], Sequence[str]]
+
+
 class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
     """Generic Enum Type.
 
@@ -1325,7 +1333,18 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
 
     __visit_name__ = "enum"
 
-    def __init__(self, *enums: object, **kw: Any):
+    values_callable: Optional[Callable[[Type[enum.Enum]], Sequence[str]]]
+    enum_class: Optional[Type[enum.Enum]]
+    _valid_lookup: Dict[Union[enum.Enum, str, None], Optional[str]]
+    _object_lookup: Dict[Optional[str], Union[enum.Enum, str, None]]
+
+    @overload
+    def __init__(self, enums: Type[enum.Enum], **kw: Any) -> None: ...
+
+    @overload
+    def __init__(self, *enums: str, **kw: Any) -> None: ...
+
+    def __init__(self, *enums: Union[str, Type[enum.Enum]], **kw: Any) -> None:
         r"""Construct an enum.
 
         Keyword arguments which don't apply to a specific backend are ignored
@@ -1457,7 +1476,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
            .. versionchanged:: 2.0 This parameter now defaults to True.
 
         """
-        self._enum_init(enums, kw)
+        self._enum_init(enums, kw)  # type: ignore[arg-type]
 
     @property
     def _enums_argument(self):
@@ -1466,7 +1485,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
         else:
             return self.enums
 
-    def _enum_init(self, enums, kw):
+    def _enum_init(self, enums: _EnumTupleArg, kw: Dict[str, Any]) -> None:
         """internal init for :class:`.Enum` and subclasses.
 
         friendly init helper used by subclasses to remove
@@ -1525,15 +1544,19 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
             _adapted_from=kw.pop("_adapted_from", None),
         )
 
-    def _parse_into_values(self, enums, kw):
+    def _parse_into_values(
+        self, enums: _EnumTupleArg, kw: Any
+    ) -> Tuple[Sequence[str], _EnumTupleArg]:
         if not enums and "_enums" in kw:
             enums = kw.pop("_enums")
 
         if len(enums) == 1 and hasattr(enums[0], "__members__"):
-            self.enum_class = enums[0]
+            self.enum_class = enums[0]  # type: ignore[assignment]
+            assert self.enum_class is not None
 
             _members = self.enum_class.__members__
 
+            members: Mapping[str, enum.Enum]
             if self._omit_aliases is True:
                 # remove aliases
                 members = OrderedDict(
@@ -1549,7 +1572,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
             return values, objects
         else:
             self.enum_class = None
-            return enums, enums
+            return enums, enums  # type: ignore[return-value]
 
     def _resolve_for_literal(self, value: Any) -> Enum:
         tv = type(value)
@@ -1625,7 +1648,12 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
             self._generic_type_affinity(_enums=enum_args, **kw),  # type: ignore  # noqa: E501
         )
 
-    def _setup_for_values(self, values, objects, kw):
+    def _setup_for_values(
+        self,
+        values: Sequence[str],
+        objects: _EnumTupleArg,
+        kw: Any,
+    ) -> None:
         self.enums = list(values)
 
         self._valid_lookup = dict(zip(reversed(objects), reversed(values)))
@@ -1692,9 +1720,10 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
 
     comparator_factory = Comparator
 
-    def _object_value_for_elem(self, elem):
+    def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]:
         try:
-            return self._object_lookup[elem]
+            # Value will not be None beacuse key is not None
+            return self._object_lookup[elem]  # type: ignore[return-value]
         except KeyError as err:
             raise LookupError(
                 "'%s' is not among the defined enum values. "
@@ -3625,6 +3654,7 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]):
 
     __visit_name__ = "uuid"
 
+    length: Optional[int] = None
     collation: Optional[str] = None
 
     @overload
@@ -3676,7 +3706,9 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]):
         else:
             return super().coerce_compared_value(op, value)
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> Optional[_BindProcessorType[_UUID_RETURN]]:
         character_based_uuid = (
             not dialect.supports_native_uuid or not self.native_uuid
         )
index bdc56b46ac479e1f3a2a1df5e875a26077370632..911071cc99b608b28a8ea6e2c84d0d38622fcef2 100644 (file)
@@ -1392,6 +1392,10 @@ class UserDefinedType(
 
         return self
 
+    if TYPE_CHECKING:
+
+        def get_col_spec(self, **kw: Any) -> str: ...
+
 
 class Emulated(TypeEngineMixin):
     """Mixin for base types that emulate the behavior of a DB-native type.
index 98990041784297912fc15402bcb72d2e16b4542e..a98b51c1dee5d8ae7312356d69a33ef1084cfbc2 100644 (file)
@@ -481,7 +481,7 @@ def surface_selectables(clause):
             stack.append(elem.element)
 
 
-def surface_selectables_only(clause):
+def surface_selectables_only(clause: ClauseElement) -> Iterator[ClauseElement]:
     stack = [clause]
     while stack:
         elem = stack.pop()
index 9ca5e60a202478058f9fcb9e88b884172820c37d..36ca6a56a92844fe19c69453c87e863f5ea1ff93 100644 (file)
@@ -430,7 +430,9 @@ def to_column_set(x: Any) -> Set[Any]:
         return x
 
 
-def update_copy(d, _new=None, **kw):
+def update_copy(
+    d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any
+) -> Dict[Any, Any]:
     """Copy the given dict and update with the given values."""
 
     d = d.copy()
index 01569cebdaf3adbfa8e087c53c49829e11ff69af..8980a8506296e0cf9034ed320c5249df951998ba 100644 (file)
@@ -56,6 +56,7 @@ if True:  # zimports removes the tailing comments
     from typing_extensions import TypeAliasType as TypeAliasType  # 3.12
     from typing_extensions import Unpack as Unpack  # 3.11
     from typing_extensions import Never as Never  # 3.11
+    from typing_extensions import LiteralString as LiteralString  # 3.11
 
 
 _T = TypeVar("_T", bound=Any)
index ade402dd6bec69892d06c42caa956fc8336627af..9a9b5658c87343870b240627af32f24fd163b11e 100644 (file)
@@ -176,6 +176,8 @@ reportTypedDictNotRequiredAccess = "warning"
 mypy_path = "./lib/"
 show_error_codes = true
 incremental = true
+# would be nice to enable this but too many error are surfaceds
+# enable_error_code = "ignore-without-code"
 
 [[tool.mypy.overrides]]
 
index 8ea523fb7e56b7307d1dd377d0e576f941ceb996..1f8a23f70dc945793ac2bb3cff9dba59ba9c2e73 100644 (file)
@@ -681,7 +681,6 @@ class CompatFlagsTest(fixtures.TestBase, AssertsCompiledSQL):
 
         dialect._get_server_version_info = server_version_info
         dialect.get_isolation_level = Mock()
-        dialect._check_unicode_returns = Mock()
         dialect._check_unicode_description = Mock()
         dialect._get_default_schema_name = Mock()
         dialect._detect_decimal_char = Mock()