]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pr fixes
authorPablo Estevez <pablo22estevez@gmail.com>
Sat, 11 Jan 2025 15:57:37 +0000 (12:57 -0300)
committerPablo Estevez <pablo22estevez@gmail.com>
Sat, 11 Jan 2025 15:57:37 +0000 (12:57 -0300)
12 files changed:
lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/ddl.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/util/_collections.py

index b26d32fdf7e66b76f9a2a2e990289e7d82752124..e9a93e13d92668f1f96e1e3bc6fc39d9cbe27412 100644 (file)
@@ -14,7 +14,6 @@ from typing import Any
 from typing import Dict
 from typing import List
 from typing import Optional
-from typing import Sequence
 from typing import Tuple
 from typing import Union
 
@@ -229,7 +228,7 @@ class PyODBCConnector(Connector):
 
     def get_isolation_level_values(
         self, dbapi_connection: interfaces.DBAPIConnection
-    ) -> Sequence[IsolationLevel]:
+    ) -> list[IsolationLevel]:
         return super().get_isolation_level_values(dbapi_connection) + [  # type: ignore  # NOQA: E501
             "AUTOCOMMIT"
         ]
index dc4f16f5143ace54f7b02a511664a104850caa7c..11e0bc147cc7945b1f3409f70bc12672b561afa7 100644 (file)
@@ -20,6 +20,7 @@ from typing import Any
 from typing import cast
 from typing import ClassVar
 from typing import Dict
+from typing import Iterable
 from typing import Iterator
 from typing import List
 from typing import Mapping
@@ -1380,14 +1381,17 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy):
 
     def __init__(
         self,
-        dbapi_cursor: DBAPICursor,
+        dbapi_cursor: Optional[DBAPICursor],
         alternate_description: _DBAPICursorDescription | None = None,
-        initial_buffer: Any = None,
+        initial_buffer: Optional[Iterable[Any]] = None,
     ):
         self.alternate_cursor_description = alternate_description
         if initial_buffer is not None:
             self._rowbuffer = collections.deque(initial_buffer)
         else:
+            dbapi_cursor = cast(
+                DBAPICursor, dbapi_cursor
+            )  # Can't be both None
             self._rowbuffer = collections.deque(dbapi_cursor.fetchall())
 
     def yield_per(self, result, dbapi_cursor, num):
@@ -1903,7 +1907,7 @@ class CursorResult(Result[Unpack[_Ts]]):
         clone._metadata = clone._metadata._splice_horizontally(other._metadata)
 
         clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
-            None,  # type: ignore[arg-type]
+            None,
             initial_buffer=total_rows,
         )
         clone._reset_memoizations()
@@ -1935,7 +1939,7 @@ class CursorResult(Result[Unpack[_Ts]]):
         )
 
         clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
-            None,  # type: ignore[arg-type]
+            None,
             initial_buffer=total_rows,
         )
         clone._reset_memoizations()
@@ -1964,7 +1968,7 @@ class CursorResult(Result[Unpack[_Ts]]):
         )._remove_processors()
 
         self.cursor_strategy = FullyBufferedCursorFetchStrategy(
-            None,  # type: ignore[arg-type]
+            None,
             # TODO: if these are Row objects, can we save on not having to
             # re-make new Row objects out of them a second time?  is that
             # what's actually happening right now?  maybe look into this
index 8dce692d7357b1de0e066ab129d27fb47ed3c619..3582bbc57fe2f53de3869e284018db4a20732356 100644 (file)
@@ -249,7 +249,9 @@ class DefaultDialect(Dialect):
 
     supports_is_distinct_from = True
 
-    supports_server_side_cursors: "generic_fn_descriptor[bool] | bool" = False
+    supports_server_side_cursors: Union[generic_fn_descriptor[bool], bool] = (
+        False
+    )
 
     server_side_cursors = False
 
@@ -630,11 +632,11 @@ class DefaultDialect(Dialect):
                 % (ident, self.max_identifier_length)
             )
 
-    def connect(self, *cargs: Any, **cparams: Any):  # type: ignore[no-untyped-def]  # NOQA: E501
+    def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection:
         # inherits the docstring from interfaces.Dialect.connect
-        return self.loaded_dbapi.connect(*cargs, **cparams)
+        return self.loaded_dbapi.connect(*cargs, **cparams)  # type: ignore[no-any-return]  # NOQA: E501
 
-    def create_connect_args(self, url: URL) -> "ConnectArgsType":
+    def create_connect_args(self, url: URL) -> ConnectArgsType:
         # inherits the docstring from interfaces.Dialect.create_connect_args
         opts = url.translate_connect_args()
         opts.update(url.query)
index 74ec272514950e75b078b47dc4aa36507a6e2d34..310b8fda0aabbbd07d52a81f197c3ec8aac5c30f 100644 (file)
@@ -242,7 +242,7 @@ _AnyMultiExecuteParams = _DBAPIMultiExecuteParams
 _AnyExecuteParams = _DBAPIAnyExecuteParams
 
 CompiledCacheType = MutableMapping[Any, "Compiled"]
-SchemaTranslateMapType = MutableMapping[Optional[str], Optional[str]]
+SchemaTranslateMapType = Mapping[Optional[str], Optional[str]]
 
 _ImmutableExecuteOptions = immutabledict[str, Any]
 
@@ -548,8 +548,6 @@ class ReflectedIndex(TypedDict):
     dialect_options: NotRequired[Dict[str, Any]]
     """Additional dialect-specific options detected for this index"""
 
-    type: NotRequired[str]
-
 
 class ReflectedTableComment(TypedDict):
     """Dictionary representing the reflected comment corresponding to
@@ -784,7 +782,7 @@ class Dialect(EventTarget):
     max_identifier_length: int
     """The maximum length of identifier names."""
 
-    supports_server_side_cursors: "generic_fn_descriptor[bool] | bool"
+    supports_server_side_cursors: generic_fn_descriptor[bool] | bool
     """indicates if the dialect supports server side cursors"""
 
     server_side_cursors: bool
index aa86b5e880ceec6a38f75bd0c02351101a0c091c..ef42d1c932174b5b7f7bf3a24911ad1d2f0f6894 100644 (file)
@@ -658,7 +658,7 @@ class CompileState:
     @classmethod
     def create_for_statement(
         cls, statement: Executable, compiler: Compiled, **kw: Any
-    ) -> "CompileState":
+    ) -> CompileState:
         # factory construction.
 
         if statement._propagate_attrs:
index a343efbcebe4ad3cb59f273812b20087eec24372..b6069bc2402773f2bb766390eb7e4193084d531c 100644 (file)
@@ -133,7 +133,6 @@ if typing.TYPE_CHECKING:
     from .selectable import SelectState
     from .type_api import _BindProcessorType
     from ..engine.cursor import CursorResultMetaData
-    from ..engine.default import DefaultDialect
     from ..engine.interfaces import _CoreSingleExecuteParams
     from ..engine.interfaces import _DBAPIAnyExecuteParams
     from ..engine.interfaces import _DBAPIMultiExecuteParams
@@ -889,8 +888,9 @@ class Compiled:
             self.string = self.process(self.statement, **compile_kwargs)
 
             if render_schema_translate:
+                assert schema_translate_map is not None
                 self.string = self.preparer._render_schema_translates(
-                    self.string, schema_translate_map  # type: ignore[arg-type]
+                    self.string, schema_translate_map
                 )
 
             self.state = CompilerState.STRING_APPLIED
@@ -4618,7 +4618,9 @@ class SQLCompiler(Compiled):
     def get_select_hint_text(self, byfroms):
         return None
 
-    def get_from_hint_text(self, table: Any, text: str | None) -> str | None:
+    def get_from_hint_text(
+        self, table: FromClause, text: str | None
+    ) -> str | None:
         return None
 
     def get_crud_hint_text(self, table, text):
@@ -6164,8 +6166,8 @@ class SQLCompiler(Compiled):
     def visit_update(
         self, update_stmt: "Update", visiting_cte: CTE | None = None, **kw: Any
     ) -> str:
-        compile_state = update_stmt._compile_state_factory(  # type: ignore
-            update_stmt, self, **kw  # type: ignore
+        compile_state = update_stmt._compile_state_factory(  # type: ignore[call-arg] # NOQA: E501
+            update_stmt, self, **kw  # type: ignore[arg-type]
         )
         compile_state = cast("UpdateDMLState", compile_state)
         update_stmt = compile_state.statement  # type: ignore[assignment]
@@ -6597,7 +6599,7 @@ class DDLCompiler(Compiled):
         ): ...
 
     @util.memoized_property
-    def sql_compiler(self):
+    def sql_compiler(self) -> SQLCompiler:
         return self.dialect.statement_compiler(
             self.dialect, None, schema_translate_map=self.schema_translate_map
         )
@@ -6969,11 +6971,11 @@ class DDLCompiler(Compiled):
 
     def render_default_string(self, default: Visitable | str) -> str:
         if isinstance(default, str):
-            return self.sql_compiler.render_literal_value(  # type: ignore[no-any-return]  # NOQA: E501
+            return self.sql_compiler.render_literal_value(
                 default, sqltypes.STRINGTYPE
             )
         else:
-            return self.sql_compiler.process(default, literal_binds=True)  # type: ignore[no-any-return]  # NOQA: E501
+            return self.sql_compiler.process(default, literal_binds=True)
 
     def visit_table_or_column_check_constraint(self, constraint, **kw):
         if constraint.is_column_level:
@@ -7132,31 +7134,23 @@ class DDLCompiler(Compiled):
 
 
 class GenericTypeCompiler(TypeCompiler):
-    def visit_FLOAT(
-        self, type_: "sqltypes.Float[decimal.Decimal| float]", **kw: Any
-    ) -> str:
+    def visit_FLOAT(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "FLOAT"
 
-    def visit_DOUBLE(
-        self, type_: "sqltypes.Double[decimal.Decimal | float]", **kw: Any
-    ) -> str:
+    def visit_DOUBLE(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "DOUBLE"
 
     def visit_DOUBLE_PRECISION(
         self,
-        type_: "sqltypes.DOUBLE_PRECISION[decimal.Decimal| float]",
+        type_: TypeEngine[Any],
         **kw: Any,
     ) -> str:
         return "DOUBLE PRECISION"
 
-    def visit_REAL(
-        self, type_: "sqltypes.REAL[decimal.Decimal| float]", **kw: Any
-    ) -> str:
+    def visit_REAL(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "REAL"
 
-    def visit_NUMERIC(
-        self, type_: "sqltypes.Numeric[decimal.Decimal| float]", **kw: Any
-    ) -> str:
+    def visit_NUMERIC(self, type_: "sqltypes.Numeric[Any]", **kw: Any) -> str:
         if type_.precision is None:
             return "NUMERIC"
         elif type_.scale is None:
@@ -7180,31 +7174,31 @@ class GenericTypeCompiler(TypeCompiler):
                 "scale": type_.scale,
             }
 
-    def visit_INTEGER(self, type_: "sqltypes.Integer", **kw: Any) -> str:
+    def visit_INTEGER(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "INTEGER"
 
-    def visit_SMALLINT(self, type_: "sqltypes.SmallInteger", **kw: Any) -> str:
+    def visit_SMALLINT(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "SMALLINT"
 
-    def visit_BIGINT(self, type_: "sqltypes.BigInteger", **kw: Any) -> str:
+    def visit_BIGINT(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "BIGINT"
 
-    def visit_TIMESTAMP(self, type_: "sqltypes.TIMESTAMP", **kw: Any) -> str:
+    def visit_TIMESTAMP(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "TIMESTAMP"
 
-    def visit_DATETIME(self, type_: "sqltypes.DateTime", **kw: Any) -> str:
+    def visit_DATETIME(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "DATETIME"
 
-    def visit_DATE(self, type_: "sqltypes.Date", **kw: Any) -> str:
+    def visit_DATE(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "DATE"
 
-    def visit_TIME(self, type_: "sqltypes.Time", **kw: Any) -> str:
+    def visit_TIME(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "TIME"
 
-    def visit_CLOB(self, type_: "sqltypes.Text", **kw: Any) -> str:
+    def visit_CLOB(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "CLOB"
 
-    def visit_NCLOB(self, type_: "sqltypes.Text", **kw: Any) -> str:
+    def visit_NCLOB(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "NCLOB"
 
     def _render_string_type(
@@ -7222,36 +7216,34 @@ class GenericTypeCompiler(TypeCompiler):
             text += ' COLLATE "%s"' % type_.collation
         return text
 
-    def visit_CHAR(self, type_: "sqltypes.CHAR", **kw: Any) -> str:
+    def visit_CHAR(self, type_: sqltypes.String, **kw: Any) -> str:
         return self._render_string_type(type_, "CHAR")
 
-    def visit_NCHAR(self, type_: "sqltypes.NCHAR", **kw: Any) -> str:
+    def visit_NCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
         return self._render_string_type(type_, "NCHAR")
 
-    def visit_VARCHAR(self, type_: "sqltypes.String", **kw: Any) -> str:
+    def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
         return self._render_string_type(type_, "VARCHAR")
 
-    def visit_NVARCHAR(self, type_: "sqltypes.NVARCHAR", **kw: Any) -> str:
+    def visit_NVARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
         return self._render_string_type(type_, "NVARCHAR")
 
-    def visit_TEXT(self, type_: "sqltypes.Text", **kw: Any) -> str:
+    def visit_TEXT(self, type_: sqltypes.String, **kw: Any) -> str:
         return self._render_string_type(type_, "TEXT")
 
-    def visit_UUID(
-        self, type_: "sqltypes.Uuid[_UUID_RETURN]", **kw: Any
-    ) -> str:
+    def visit_UUID(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "UUID"
 
-    def visit_BLOB(self, type_: "sqltypes.LargeBinary", **kw: Any) -> str:
+    def visit_BLOB(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "BLOB"
 
-    def visit_BINARY(self, type_: "sqltypes.BINARY", **kw: Any) -> str:
+    def visit_BINARY(self, type_: "sqltypes._Binary", **kw: Any) -> str:
         return "BINARY" + (type_.length and "(%d)" % type_.length or "")
 
-    def visit_VARBINARY(self, type_: "sqltypes.VARBINARY", **kw: Any) -> str:
+    def visit_VARBINARY(self, type_: "sqltypes._Binary", **kw: Any) -> str:
         return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
 
-    def visit_BOOLEAN(self, type_: "sqltypes.Boolean", **kw: Any) -> str:
+    def visit_BOOLEAN(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return "BOOLEAN"
 
     def visit_uuid(
@@ -7262,71 +7254,55 @@ class GenericTypeCompiler(TypeCompiler):
         else:
             return self.visit_UUID(type_, **kw)
 
-    def visit_large_binary(
-        self, type_: "sqltypes.LargeBinary", **kw: Any
-    ) -> str:
+    def visit_large_binary(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return self.visit_BLOB(type_, **kw)
 
-    def visit_boolean(self, type_: "sqltypes.Boolean", **kw: Any) -> str:
+    def visit_boolean(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return self.visit_BOOLEAN(type_, **kw)
 
-    def visit_time(self, type_: "sqltypes.Time", **kw: Any) -> str:
+    def visit_time(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return self.visit_TIME(type_, **kw)
 
-    def visit_datetime(self, type_: "sqltypes.DateTime", **kw: Any) -> str:
+    def visit_datetime(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return self.visit_DATETIME(type_, **kw)
 
-    def visit_date(self, type_: "sqltypes.Date", **kw: Any) -> str:
+    def visit_date(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return self.visit_DATE(type_, **kw)
 
-    def visit_big_integer(
-        self, type_: "sqltypes.BigInteger", **kw: Any
-    ) -> str:
+    def visit_big_integer(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return self.visit_BIGINT(type_, **kw)
 
-    def visit_small_integer(
-        self, type_: "sqltypes.SmallInteger", **kw: Any
-    ) -> str:
+    def visit_small_integer(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return self.visit_SMALLINT(type_, **kw)
 
-    def visit_integer(self, type_: "sqltypes.Integer", **kw: Any) -> str:
+    def visit_integer(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return self.visit_INTEGER(type_, **kw)
 
-    def visit_real(
-        self, type_: "sqltypes.REAL[decimal.Decimal| float]", **kw: Any
-    ) -> str:
+    def visit_real(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return self.visit_REAL(type_, **kw)
 
-    def visit_float(
-        self, type_: "sqltypes.Float[decimal.Decimal| float]", **kw: Any
-    ) -> str:
+    def visit_float(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return self.visit_FLOAT(type_, **kw)
 
-    def visit_double(
-        self, type_: "sqltypes.Double[decimal.Decimal | float]", **kw: Any
-    ) -> str:
+    def visit_double(self, type_: TypeEngine[Any], **kw: Any) -> str:
         return self.visit_DOUBLE(type_, **kw)
 
-    def visit_numeric(
-        self, type_: "sqltypes.Numeric[decimal.Decimal | float]", **kw: Any
-    ) -> str:
+    def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
         return self.visit_NUMERIC(type_, **kw)
 
-    def visit_string(self, type_: "sqltypes.String", **kw: Any) -> str:
+    def visit_string(self, type_: sqltypes.String, **kw: Any) -> str:
         return self.visit_VARCHAR(type_, **kw)
 
-    def visit_unicode(self, type_: "sqltypes.Unicode", **kw: Any) -> str:
+    def visit_unicode(self, type_: sqltypes.String, **kw: Any) -> str:
         return self.visit_VARCHAR(type_, **kw)
 
-    def visit_text(self, type_: "sqltypes.Text", **kw: Any) -> str:
+    def visit_text(self, type_: sqltypes.String, **kw: Any) -> str:
         return self.visit_TEXT(type_, **kw)
 
-    def visit_unicode_text(
-        self, type_: "sqltypes.UnicodeText", **kw: Any
-    ) -> str:
+    def visit_unicode_text(self, type_: sqltypes.String, **kw: Any) -> str:
         return self.visit_TEXT(type_, **kw)
 
-    def visit_enum(self, type_: "sqltypes.Enum", **kw: Any) -> str:
+    def visit_enum(self, type_: sqltypes.String, **kw: Any) -> str:
         return self.visit_VARCHAR(type_, **kw)
 
     def visit_null(self, type_, **kw):
@@ -7418,7 +7394,7 @@ class IdentifierPreparer:
 
     def __init__(
         self,
-        dialect: DefaultDialect,
+        dialect: Dialect,
         initial_quote: str = '"',
         final_quote: str | None = None,
         escape_quote: str = '"',
@@ -7490,7 +7466,7 @@ class IdentifierPreparer:
                     "schema_translate_map dictionaries."
                 )
 
-            d["_none"] = d[None]
+            d["_none"] = d[None]  # type: ignore[index]
 
         def replace(m):
             name = m.group(2)
@@ -7753,7 +7729,7 @@ class IdentifierPreparer:
         # to dialect.max_identifier_length etc. can be reflected
         # as IdentifierPreparer is long lived
         max_ = (
-            self.dialect.max_index_name_length
+            self.dialect.max_index_name_length  # type: ignore[attr-defined]
             or self.dialect.max_identifier_length
         )
         return self._truncate_and_render_maxlen_name(
@@ -7767,7 +7743,7 @@ class IdentifierPreparer:
         # to dialect.max_identifier_length etc. can be reflected
         # as IdentifierPreparer is long lived
         max_ = (
-            self.dialect.max_constraint_name_length
+            self.dialect.max_constraint_name_length  # type: ignore[attr-defined]  # NOQA: E501
             or self.dialect.max_identifier_length
         )
         return self._truncate_and_render_maxlen_name(
@@ -7781,7 +7757,7 @@ class IdentifierPreparer:
             if len(name) > max_:
                 name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
         else:
-            self.dialect.validate_identifier(name)
+            self.dialect.validate_identifier(name)  # type: ignore[attr-defined]  # NOQA: E501
 
         if not _alembic_quote:
             return name
@@ -7794,7 +7770,7 @@ class IdentifierPreparer:
     @overload
     def format_table(
         self,
-        table: "Table | None",
+        table: "FromClause | None",
         use_schema: bool,
         name: str,
     ) -> str: ...
@@ -7802,20 +7778,21 @@ class IdentifierPreparer:
     @overload
     def format_table(
         self,
-        table: "Table",
+        table: "NamedFromClause",
         use_schema: bool = True,
         name: None = None,
     ) -> str: ...
 
     def format_table(
         self,
-        table: "Table | None",
+        table: "FromClause | None",
         use_schema: bool = True,
         name: str | None = None,
     ) -> str:
         """Prepare a quoted table and schema name."""
         if name is None:
             assert table is not None
+            table = cast("NamedFromClause", table)
             name = table.name
 
         result = self.quote(name)
@@ -7847,7 +7824,7 @@ class IdentifierPreparer:
 
     def format_column(
         self,
-        column: "Column[Any]",
+        column: ColumnElement[Any],
         use_table: bool = False,
         name: str | None = None,
         table_name: str | None = None,
@@ -7858,6 +7835,7 @@ class IdentifierPreparer:
 
         if name is None:
             name = column.name
+            name = cast(str, name)
 
         if anon_map is not None and isinstance(
             name, elements._truncated_label
index c3a5b015bd2563f1950968d9c2a24ac0b6ed71a5..19af40ff08038e375d51e87f4d95378287713963 100644 (file)
@@ -207,7 +207,7 @@ def _get_crud_params(
             [
                 (
                     c,
-                    compiler.preparer.format_column(c),  # type: ignore[arg-type] # noqa: E501
+                    compiler.preparer.format_column(c),
                     _create_bind_param(compiler, c, None, required=True),
                     (c.key,),
                 )
@@ -369,7 +369,7 @@ def _get_crud_params(
         values = [
             (
                 _as_dml_column(stmt.table.columns[0]),
-                compiler.preparer.format_column(stmt.table.columns[0]),  # type: ignore[arg-type] # noqa: E501
+                compiler.preparer.format_column(stmt.table.columns[0]),
                 compiler.dialect.default_metavalue_token,
                 (),
             )
@@ -1136,7 +1136,7 @@ def _append_param_insert_select_hasdefault(
             values.append(
                 (
                     c,
-                    compiler.preparer.format_column(c),  # type: ignore[arg-type] # noqa: E501
+                    compiler.preparer.format_column(c),
                     c.default.next_value(),
                     (),
                 )
@@ -1145,7 +1145,7 @@ def _append_param_insert_select_hasdefault(
         values.append(
             (
                 c,
-                compiler.preparer.format_column(c),  # type: ignore[arg-type]
+                compiler.preparer.format_column(c),
                 c.default.arg.self_group(),
                 (),
             )
@@ -1154,7 +1154,7 @@ def _append_param_insert_select_hasdefault(
         values.append(
             (
                 c,
-                compiler.preparer.format_column(c),  # type: ignore[arg-type]
+                compiler.preparer.format_column(c),
                 _create_insert_prefetch_bind_param(
                     compiler, c, process=False, **kw
                 ),
index d397331c14e4257b5b7d892724e71fd9fb4744dd..5f195ec1fa9eac228d620debfd7405121553a283 100644 (file)
@@ -17,18 +17,21 @@ import contextlib
 import typing
 from typing import Any
 from typing import Callable
+from typing import Generic
 from typing import Iterable
 from typing import List
 from typing import Optional
 from typing import Protocol
 from typing import Sequence as typing_Sequence
 from typing import Tuple
+from typing import TypeVar
 
 from . import roles
 from .base import _generative
 from .base import Executable
 from .base import SchemaVisitor
 from .elements import ClauseElement
+from .schema import Column
 from .. import exc
 from .. import util
 from ..util import topological
@@ -38,7 +41,6 @@ if typing.TYPE_CHECKING:
     from .compiler import Compiled
     from .compiler import DDLCompiler
     from .elements import BindParameter
-    from .schema import Column
     from .schema import Constraint
     from .schema import ForeignKeyConstraint
     from .schema import Index
@@ -52,6 +54,8 @@ if typing.TYPE_CHECKING:
     from ..engine.interfaces import Dialect
     from ..engine.interfaces import SchemaTranslateMapType
 
+T = TypeVar("T", bound="SchemaItem")
+
 
 class BaseDDLElement(ClauseElement):
     """The root of DDL constructs, including those that are sub-elements
@@ -417,7 +421,7 @@ class DDL(ExecutableDDLElement):
         )
 
 
-class _CreateDropBase(ExecutableDDLElement):
+class _CreateDropBase(ExecutableDDLElement, Generic[T]):
     """Base class for DDL constructs that represent CREATE and DROP or
     equivalents.
 
@@ -429,7 +433,7 @@ class _CreateDropBase(ExecutableDDLElement):
 
     def __init__(
         self,
-        element,
+        element: T,
     ):
         self.element = self.target = element
         self._ddl_if = getattr(element, "_ddl_if", None)
@@ -449,13 +453,13 @@ class _CreateDropBase(ExecutableDDLElement):
         return False
 
 
-class _CreateBase(_CreateDropBase):
+class _CreateBase(_CreateDropBase[Any]):
     def __init__(self, element, if_not_exists=False):
         super().__init__(element)
         self.if_not_exists = if_not_exists
 
 
-class _DropBase(_CreateDropBase):
+class _DropBase(_CreateDropBase[Any]):
     def __init__(self, element, if_exists=False):
         super().__init__(element)
         self.if_exists = if_exists
@@ -714,9 +718,9 @@ class CreateIndex(_CreateBase):
     """Represent a CREATE INDEX statement."""
 
     __visit_name__ = "create_index"
-    element: "Index"
+    element: Index
 
-    def __init__(self, element: "Index", if_not_exists: bool = False):
+    def __init__(self, element: Index, if_not_exists: bool = False):
         """Create a :class:`.Createindex` construct.
 
         :param element: a :class:`_schema.Index` that's the subject
@@ -735,9 +739,9 @@ class DropIndex(_DropBase):
 
     __visit_name__ = "drop_index"
 
-    element: "Index"
+    element: Index
 
-    def __init__(self, element: "Index", if_exists: bool = False):
+    def __init__(self, element: Index, if_exists: bool = False):
         """Create a :class:`.DropIndex` construct.
 
         :param element: a :class:`_schema.Index` that's the subject
@@ -776,13 +780,13 @@ class DropConstraint(_DropBase):
         )
 
 
-class SetTableComment(_CreateDropBase):
+class SetTableComment(_CreateDropBase["Table"]):
     """Represent a COMMENT ON TABLE IS statement."""
 
     __visit_name__ = "set_table_comment"
 
 
-class DropTableComment(_CreateDropBase):
+class DropTableComment(_CreateDropBase["Table"]):
     """Represent a COMMENT ON TABLE '' statement.
 
     Note this varies a lot across database backends.
@@ -792,26 +796,25 @@ class DropTableComment(_CreateDropBase):
     __visit_name__ = "drop_table_comment"
 
 
-class SetColumnComment(_CreateDropBase):
+class SetColumnComment(_CreateDropBase[Column[Any]]):
     """Represent a COMMENT ON COLUMN IS statement."""
 
     __visit_name__ = "set_column_comment"
-    element: "Column[Any]"
 
 
-class DropColumnComment(_CreateDropBase):
+class DropColumnComment(_CreateDropBase[Column[Any]]):
     """Represent a COMMENT ON COLUMN IS NULL statement."""
 
     __visit_name__ = "drop_column_comment"
 
 
-class SetConstraintComment(_CreateDropBase):
+class SetConstraintComment(_CreateDropBase["Constraint"]):
     """Represent a COMMENT ON CONSTRAINT IS statement."""
 
     __visit_name__ = "set_constraint_comment"
 
 
-class DropConstraintComment(_CreateDropBase):
+class DropConstraintComment(_CreateDropBase["Constraint"]):
     """Represent a COMMENT ON CONSTRAINT IS NULL statement."""
 
     __visit_name__ = "drop_constraint_comment"
index 6e660c0be7f85d2255b3bd35f94486f23d0da6a1..9a1db7d69eeb36553bbf3bdfe5084d4c6c82e193 100644 (file)
@@ -76,7 +76,6 @@ from .. import inspection
 from .. import util
 from ..util import HasMemoized_ro_memoized_attribute
 from ..util import TypingOnly
-from ..util._immutabledict_cy import immutabledict
 from ..util.typing import Literal
 from ..util.typing import ParamSpec
 from ..util.typing import Self
@@ -3879,6 +3878,7 @@ class BinaryExpression(OperatorExpression[_T]):
 
     left: ColumnElement[Any]
     right: ColumnElement[Any]
+    modifiers: Mapping[str, Any]
 
     def __init__(
         self,
@@ -3907,9 +3907,9 @@ class BinaryExpression(OperatorExpression[_T]):
         self._is_implicitly_boolean = operators.is_boolean(operator)
 
         if modifiers is None:
-            self.modifiers: immutabledict[str, str] = immutabledict({})
+            self.modifiers = {}
         else:
-            self.modifiers = immutabledict(modifiers)
+            self.modifiers = modifiers
 
     @property
     def _flattened_operator_clauses(
index 9f7fb1330f6b6095ade39f7a4b5dd468c1e640d9..ef0396bbe5cea66658c4ace458866bc19296934c 100644 (file)
@@ -247,7 +247,7 @@ class String(Concatenable, TypeEngine[str]):
         return process
 
     def bind_processor(
-        self, dialect: "Dialect"
+        self, dialect: Dialect
     ) -> _BindProcessorType[str] | None:
         return None
 
@@ -1333,9 +1333,9 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
 
     __visit_name__ = "enum"
 
-    enum_class: None | str | type[enum.StrEnum]
+    enum_class: None | str | type[enum.Enum]
 
-    def __init__(self, *enums: object, **kw: Any):
+    def __init__(self, *enums: Union[str, type[enum.Enum]], **kw: Any):
         r"""Construct an enum.
 
         Keyword arguments which don't apply to a specific backend are ignored
@@ -1467,7 +1467,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
            .. versionchanged:: 2.0 This parameter now defaults to True.
 
         """
-        self._enum_init(enums, kw)  # type: ignore[arg-type]
+        self._enum_init(enums, kw)
 
     @property
     def _enums_argument(self):
@@ -1477,7 +1477,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
             return self.enums
 
     def _enum_init(
-        self, enums: Sequence[str | type[enum.StrEnum]], kw: dict[str, Any]
+        self, enums: Sequence[Union[str, type[enum.Enum]]], kw: dict[str, Any]
     ) -> None:
         """internal init for :class:`.Enum` and subclasses.
 
@@ -1489,7 +1489,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
         self.native_enum = kw.pop("native_enum", True)
         self.create_constraint = kw.pop("create_constraint", False)
         self.values_callable: (
-            Callable[[type[enum.StrEnum]], Sequence[str]] | None
+            Callable[[type[enum.Enum]], Sequence[str]] | None
         ) = kw.pop("values_callable", None)
         self._sort_key_function = kw.pop("sort_key_function", NO_ARG)
         length_arg = kw.pop("length", NO_ARG)
@@ -1518,7 +1518,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
                 )
             length = length_arg
 
-        self._valid_lookup[None] = self._object_lookup[None] = None  # type: ignore # noqa: E501
+        self._valid_lookup[None] = self._object_lookup[None] = None
 
         super().__init__(length=length)
 
@@ -1540,8 +1540,8 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
         )
 
     def _parse_into_values(
-        self, enums: Sequence[str | type[enum.StrEnum]], kw: Any
-    ) -> tuple[Sequence[str], Sequence[enum.StrEnum] | Sequence[str]]:
+        self, enums: Sequence[str | type[enum.Enum]], kw: Any
+    ) -> tuple[Sequence[str], Sequence[enum.Enum] | Sequence[str]]:
         if not enums and "_enums" in kw:
             enums = kw.pop("_enums")
 
@@ -1658,16 +1658,18 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
     def _setup_for_values(
         self,
         values: Sequence[str],
-        objects: Sequence[enum.StrEnum] | Sequence[str],
+        objects: Sequence[enum.Enum] | Sequence[str],
         kw: Any,
     ) -> None:
         self.enums = list(values)
 
-        self._valid_lookup: dict[str, str] = dict(
+        self._valid_lookup: dict[enum.Enum | str | None, str | None] = dict(
             zip(reversed(objects), reversed(values))
         )
 
-        self._object_lookup: dict[str, str] = dict(zip(values, objects))
+        self._object_lookup: dict[str | None, enum.Enum | str | None] = dict(
+            zip(values, objects)
+        )
 
         self._valid_lookup.update(
             [
@@ -1729,9 +1731,10 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
 
     comparator_factory = Comparator
 
-    def _object_value_for_elem(self, elem: str) -> str:
+    def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]:
         try:
-            return self._object_lookup[elem]
+            # Value will not be None beacuse key is not None
+            return self._object_lookup[elem]  # type: ignore[return-value]
         except KeyError as err:
             raise LookupError(
                 "'%s' is not among the defined enum values. "
@@ -3513,7 +3516,7 @@ class BINARY(_Binary):
 class VARBINARY(_Binary):
     """The SQL VARBINARY type."""
 
-    length: int
+    length: Optional[int]
     __visit_name__ = "VARBINARY"
 
 
@@ -3720,18 +3723,18 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]):
         if character_based_uuid:
             if self.as_uuid:
 
-                def process(value: Any) -> str:
+                def process(value):
                     if value is not None:
                         value = value.hex
-                    return value  # type: ignore[no-any-return]
+                    return value
 
                 return process
             else:
 
-                def process(value: Any) -> str:
+                def process(value):
                     if value is not None:
                         value = value.replace("-", "")
-                    return value  # type: ignore[no-any-return]
+                    return value
 
                 return process
         else:
index 33f7bc41a159c4d3facfbd4c707cb82f55c7637a..55e8807b9dfd1117173e964a98400e28b60e097b 100644 (file)
@@ -1376,8 +1376,10 @@ class UserDefinedType(
 
         return self
 
-    def get_col_spec(self, **kw: Any) -> str:
-        raise NotImplementedError()
+    if TYPE_CHECKING:
+
+        def get_col_spec(self, **kw: Any) -> str:
+            raise NotImplementedError()
 
 
 class Emulated(TypeEngineMixin):
index 34b6b8a29e35f75c02ce07ee931c7cbc29d37392..97a38393e98a39839661ac9ee013a7e18617d293 100644 (file)
@@ -431,7 +431,7 @@ def to_column_set(x: Any) -> Set[Any]:
 
 
 def update_copy(
-    d: dict[Any, Any], _new: dict[Any, Any] | None = None, **kw: dict[Any, Any]
+    d: dict[Any, Any], _new: dict[Any, Any] | None = None, **kw: Any
 ) -> dict[Any, Any]:
     """Copy the given dict and update with the given values."""