From: Pablo Estevez Date: Sat, 11 Jan 2025 15:57:37 +0000 (-0300) Subject: pr fixes X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b41b1ad94a7c6e84ca4b43949115a77ec1042d94;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pr fixes --- diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index b26d32fdf7..e9a93e13d9 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -14,7 +14,6 @@ from typing import Any from typing import Dict from typing import List from typing import Optional -from typing import Sequence from typing import Tuple from typing import Union @@ -229,7 +228,7 @@ class PyODBCConnector(Connector): def get_isolation_level_values( self, dbapi_connection: interfaces.DBAPIConnection - ) -> Sequence[IsolationLevel]: + ) -> list[IsolationLevel]: return super().get_isolation_level_values(dbapi_connection) + [ # type: ignore # NOQA: E501 "AUTOCOMMIT" ] diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index dc4f16f514..11e0bc147c 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -20,6 +20,7 @@ from typing import Any from typing import cast from typing import ClassVar from typing import Dict +from typing import Iterable from typing import Iterator from typing import List from typing import Mapping @@ -1380,14 +1381,17 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): def __init__( self, - dbapi_cursor: DBAPICursor, + dbapi_cursor: Optional[DBAPICursor], alternate_description: _DBAPICursorDescription | None = None, - initial_buffer: Any = None, + initial_buffer: Optional[Iterable[Any]] = None, ): self.alternate_cursor_description = alternate_description if initial_buffer is not None: self._rowbuffer = collections.deque(initial_buffer) else: + dbapi_cursor = cast( + DBAPICursor, dbapi_cursor + ) # Can't be both None self._rowbuffer = collections.deque(dbapi_cursor.fetchall()) def yield_per(self, result, dbapi_cursor, num): @@ -1903,7 +1907,7 @@ class CursorResult(Result[Unpack[_Ts]]): clone._metadata = clone._metadata._splice_horizontally(other._metadata) clone.cursor_strategy = FullyBufferedCursorFetchStrategy( - None, # type: ignore[arg-type] + None, initial_buffer=total_rows, ) clone._reset_memoizations() @@ -1935,7 +1939,7 @@ class CursorResult(Result[Unpack[_Ts]]): ) clone.cursor_strategy = FullyBufferedCursorFetchStrategy( - None, # type: ignore[arg-type] + None, initial_buffer=total_rows, ) clone._reset_memoizations() @@ -1964,7 +1968,7 @@ class CursorResult(Result[Unpack[_Ts]]): )._remove_processors() self.cursor_strategy = FullyBufferedCursorFetchStrategy( - None, # type: ignore[arg-type] + None, # TODO: if these are Row objects, can we save on not having to # re-make new Row objects out of them a second time? is that # what's actually happening right now? maybe look into this diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 8dce692d73..3582bbc57f 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -249,7 +249,9 @@ class DefaultDialect(Dialect): supports_is_distinct_from = True - supports_server_side_cursors: "generic_fn_descriptor[bool] | bool" = False + supports_server_side_cursors: Union[generic_fn_descriptor[bool], bool] = ( + False + ) server_side_cursors = False @@ -630,11 +632,11 @@ class DefaultDialect(Dialect): % (ident, self.max_identifier_length) ) - def connect(self, *cargs: Any, **cparams: Any): # type: ignore[no-untyped-def] # NOQA: E501 + def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection: # inherits the docstring from interfaces.Dialect.connect - return self.loaded_dbapi.connect(*cargs, **cparams) + return self.loaded_dbapi.connect(*cargs, **cparams) # type: ignore[no-any-return] # NOQA: E501 - def create_connect_args(self, url: URL) -> "ConnectArgsType": + def create_connect_args(self, url: URL) -> ConnectArgsType: # inherits the docstring from interfaces.Dialect.create_connect_args opts = url.translate_connect_args() opts.update(url.query) diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 74ec272514..310b8fda0a 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -242,7 +242,7 @@ _AnyMultiExecuteParams = _DBAPIMultiExecuteParams _AnyExecuteParams = _DBAPIAnyExecuteParams CompiledCacheType = MutableMapping[Any, "Compiled"] -SchemaTranslateMapType = MutableMapping[Optional[str], Optional[str]] +SchemaTranslateMapType = Mapping[Optional[str], Optional[str]] _ImmutableExecuteOptions = immutabledict[str, Any] @@ -548,8 +548,6 @@ class ReflectedIndex(TypedDict): dialect_options: NotRequired[Dict[str, Any]] """Additional dialect-specific options detected for this index""" - type: NotRequired[str] - class ReflectedTableComment(TypedDict): """Dictionary representing the reflected comment corresponding to @@ -784,7 +782,7 @@ class Dialect(EventTarget): max_identifier_length: int """The maximum length of identifier names.""" - supports_server_side_cursors: "generic_fn_descriptor[bool] | bool" + supports_server_side_cursors: generic_fn_descriptor[bool] | bool """indicates if the dialect supports server side cursors""" server_side_cursors: bool diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index aa86b5e880..ef42d1c932 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -658,7 +658,7 @@ class CompileState: @classmethod def create_for_statement( cls, statement: Executable, compiler: Compiled, **kw: Any - ) -> "CompileState": + ) -> CompileState: # factory construction. if statement._propagate_attrs: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a343efbceb..b6069bc240 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -133,7 +133,6 @@ if typing.TYPE_CHECKING: from .selectable import SelectState from .type_api import _BindProcessorType from ..engine.cursor import CursorResultMetaData - from ..engine.default import DefaultDialect from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _DBAPIAnyExecuteParams from ..engine.interfaces import _DBAPIMultiExecuteParams @@ -889,8 +888,9 @@ class Compiled: self.string = self.process(self.statement, **compile_kwargs) if render_schema_translate: + assert schema_translate_map is not None self.string = self.preparer._render_schema_translates( - self.string, schema_translate_map # type: ignore[arg-type] + self.string, schema_translate_map ) self.state = CompilerState.STRING_APPLIED @@ -4618,7 +4618,9 @@ class SQLCompiler(Compiled): def get_select_hint_text(self, byfroms): return None - def get_from_hint_text(self, table: Any, text: str | None) -> str | None: + def get_from_hint_text( + self, table: FromClause, text: str | None + ) -> str | None: return None def get_crud_hint_text(self, table, text): @@ -6164,8 +6166,8 @@ class SQLCompiler(Compiled): def visit_update( self, update_stmt: "Update", visiting_cte: CTE | None = None, **kw: Any ) -> str: - compile_state = update_stmt._compile_state_factory( # type: ignore - update_stmt, self, **kw # type: ignore + compile_state = update_stmt._compile_state_factory( # type: ignore[call-arg] # NOQA: E501 + update_stmt, self, **kw # type: ignore[arg-type] ) compile_state = cast("UpdateDMLState", compile_state) update_stmt = compile_state.statement # type: ignore[assignment] @@ -6597,7 +6599,7 @@ class DDLCompiler(Compiled): ): ... @util.memoized_property - def sql_compiler(self): + def sql_compiler(self) -> SQLCompiler: return self.dialect.statement_compiler( self.dialect, None, schema_translate_map=self.schema_translate_map ) @@ -6969,11 +6971,11 @@ class DDLCompiler(Compiled): def render_default_string(self, default: Visitable | str) -> str: if isinstance(default, str): - return self.sql_compiler.render_literal_value( # type: ignore[no-any-return] # NOQA: E501 + return self.sql_compiler.render_literal_value( default, sqltypes.STRINGTYPE ) else: - return self.sql_compiler.process(default, literal_binds=True) # type: ignore[no-any-return] # NOQA: E501 + return self.sql_compiler.process(default, literal_binds=True) def visit_table_or_column_check_constraint(self, constraint, **kw): if constraint.is_column_level: @@ -7132,31 +7134,23 @@ class DDLCompiler(Compiled): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT( - self, type_: "sqltypes.Float[decimal.Decimal| float]", **kw: Any - ) -> str: + def visit_FLOAT(self, type_: TypeEngine[Any], **kw: Any) -> str: return "FLOAT" - def visit_DOUBLE( - self, type_: "sqltypes.Double[decimal.Decimal | float]", **kw: Any - ) -> str: + def visit_DOUBLE(self, type_: TypeEngine[Any], **kw: Any) -> str: return "DOUBLE" def visit_DOUBLE_PRECISION( self, - type_: "sqltypes.DOUBLE_PRECISION[decimal.Decimal| float]", + type_: TypeEngine[Any], **kw: Any, ) -> str: return "DOUBLE PRECISION" - def visit_REAL( - self, type_: "sqltypes.REAL[decimal.Decimal| float]", **kw: Any - ) -> str: + def visit_REAL(self, type_: TypeEngine[Any], **kw: Any) -> str: return "REAL" - def visit_NUMERIC( - self, type_: "sqltypes.Numeric[decimal.Decimal| float]", **kw: Any - ) -> str: + def visit_NUMERIC(self, type_: "sqltypes.Numeric[Any]", **kw: Any) -> str: if type_.precision is None: return "NUMERIC" elif type_.scale is None: @@ -7180,31 +7174,31 @@ class GenericTypeCompiler(TypeCompiler): "scale": type_.scale, } - def visit_INTEGER(self, type_: "sqltypes.Integer", **kw: Any) -> str: + def visit_INTEGER(self, type_: TypeEngine[Any], **kw: Any) -> str: return "INTEGER" - def visit_SMALLINT(self, type_: "sqltypes.SmallInteger", **kw: Any) -> str: + def visit_SMALLINT(self, type_: TypeEngine[Any], **kw: Any) -> str: return "SMALLINT" - def visit_BIGINT(self, type_: "sqltypes.BigInteger", **kw: Any) -> str: + def visit_BIGINT(self, type_: TypeEngine[Any], **kw: Any) -> str: return "BIGINT" - def visit_TIMESTAMP(self, type_: "sqltypes.TIMESTAMP", **kw: Any) -> str: + def visit_TIMESTAMP(self, type_: TypeEngine[Any], **kw: Any) -> str: return "TIMESTAMP" - def visit_DATETIME(self, type_: "sqltypes.DateTime", **kw: Any) -> str: + def visit_DATETIME(self, type_: TypeEngine[Any], **kw: Any) -> str: return "DATETIME" - def visit_DATE(self, type_: "sqltypes.Date", **kw: Any) -> str: + def visit_DATE(self, type_: TypeEngine[Any], **kw: Any) -> str: return "DATE" - def visit_TIME(self, type_: "sqltypes.Time", **kw: Any) -> str: + def visit_TIME(self, type_: TypeEngine[Any], **kw: Any) -> str: return "TIME" - def visit_CLOB(self, type_: "sqltypes.Text", **kw: Any) -> str: + def visit_CLOB(self, type_: TypeEngine[Any], **kw: Any) -> str: return "CLOB" - def visit_NCLOB(self, type_: "sqltypes.Text", **kw: Any) -> str: + def visit_NCLOB(self, type_: TypeEngine[Any], **kw: Any) -> str: return "NCLOB" def _render_string_type( @@ -7222,36 +7216,34 @@ class GenericTypeCompiler(TypeCompiler): text += ' COLLATE "%s"' % type_.collation return text - def visit_CHAR(self, type_: "sqltypes.CHAR", **kw: Any) -> str: + def visit_CHAR(self, type_: sqltypes.String, **kw: Any) -> str: return self._render_string_type(type_, "CHAR") - def visit_NCHAR(self, type_: "sqltypes.NCHAR", **kw: Any) -> str: + def visit_NCHAR(self, type_: sqltypes.String, **kw: Any) -> str: return self._render_string_type(type_, "NCHAR") - def visit_VARCHAR(self, type_: "sqltypes.String", **kw: Any) -> str: + def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str: return self._render_string_type(type_, "VARCHAR") - def visit_NVARCHAR(self, type_: "sqltypes.NVARCHAR", **kw: Any) -> str: + def visit_NVARCHAR(self, type_: sqltypes.String, **kw: Any) -> str: return self._render_string_type(type_, "NVARCHAR") - def visit_TEXT(self, type_: "sqltypes.Text", **kw: Any) -> str: + def visit_TEXT(self, type_: sqltypes.String, **kw: Any) -> str: return self._render_string_type(type_, "TEXT") - def visit_UUID( - self, type_: "sqltypes.Uuid[_UUID_RETURN]", **kw: Any - ) -> str: + def visit_UUID(self, type_: TypeEngine[Any], **kw: Any) -> str: return "UUID" - def visit_BLOB(self, type_: "sqltypes.LargeBinary", **kw: Any) -> str: + def visit_BLOB(self, type_: TypeEngine[Any], **kw: Any) -> str: return "BLOB" - def visit_BINARY(self, type_: "sqltypes.BINARY", **kw: Any) -> str: + def visit_BINARY(self, type_: "sqltypes._Binary", **kw: Any) -> str: return "BINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_VARBINARY(self, type_: "sqltypes.VARBINARY", **kw: Any) -> str: + def visit_VARBINARY(self, type_: "sqltypes._Binary", **kw: Any) -> str: return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_BOOLEAN(self, type_: "sqltypes.Boolean", **kw: Any) -> str: + def visit_BOOLEAN(self, type_: TypeEngine[Any], **kw: Any) -> str: return "BOOLEAN" def visit_uuid( @@ -7262,71 +7254,55 @@ class GenericTypeCompiler(TypeCompiler): else: return self.visit_UUID(type_, **kw) - def visit_large_binary( - self, type_: "sqltypes.LargeBinary", **kw: Any - ) -> str: + def visit_large_binary(self, type_: TypeEngine[Any], **kw: Any) -> str: return self.visit_BLOB(type_, **kw) - def visit_boolean(self, type_: "sqltypes.Boolean", **kw: Any) -> str: + def visit_boolean(self, type_: TypeEngine[Any], **kw: Any) -> str: return self.visit_BOOLEAN(type_, **kw) - def visit_time(self, type_: "sqltypes.Time", **kw: Any) -> str: + def visit_time(self, type_: TypeEngine[Any], **kw: Any) -> str: return self.visit_TIME(type_, **kw) - def visit_datetime(self, type_: "sqltypes.DateTime", **kw: Any) -> str: + def visit_datetime(self, type_: TypeEngine[Any], **kw: Any) -> str: return self.visit_DATETIME(type_, **kw) - def visit_date(self, type_: "sqltypes.Date", **kw: Any) -> str: + def visit_date(self, type_: TypeEngine[Any], **kw: Any) -> str: return self.visit_DATE(type_, **kw) - def visit_big_integer( - self, type_: "sqltypes.BigInteger", **kw: Any - ) -> str: + def visit_big_integer(self, type_: TypeEngine[Any], **kw: Any) -> str: return self.visit_BIGINT(type_, **kw) - def visit_small_integer( - self, type_: "sqltypes.SmallInteger", **kw: Any - ) -> str: + def visit_small_integer(self, type_: TypeEngine[Any], **kw: Any) -> str: return self.visit_SMALLINT(type_, **kw) - def visit_integer(self, type_: "sqltypes.Integer", **kw: Any) -> str: + def visit_integer(self, type_: TypeEngine[Any], **kw: Any) -> str: return self.visit_INTEGER(type_, **kw) - def visit_real( - self, type_: "sqltypes.REAL[decimal.Decimal| float]", **kw: Any - ) -> str: + def visit_real(self, type_: TypeEngine[Any], **kw: Any) -> str: return self.visit_REAL(type_, **kw) - def visit_float( - self, type_: "sqltypes.Float[decimal.Decimal| float]", **kw: Any - ) -> str: + def visit_float(self, type_: TypeEngine[Any], **kw: Any) -> str: return self.visit_FLOAT(type_, **kw) - def visit_double( - self, type_: "sqltypes.Double[decimal.Decimal | float]", **kw: Any - ) -> str: + def visit_double(self, type_: TypeEngine[Any], **kw: Any) -> str: return self.visit_DOUBLE(type_, **kw) - def visit_numeric( - self, type_: "sqltypes.Numeric[decimal.Decimal | float]", **kw: Any - ) -> str: + def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str: return self.visit_NUMERIC(type_, **kw) - def visit_string(self, type_: "sqltypes.String", **kw: Any) -> str: + def visit_string(self, type_: sqltypes.String, **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) - def visit_unicode(self, type_: "sqltypes.Unicode", **kw: Any) -> str: + def visit_unicode(self, type_: sqltypes.String, **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) - def visit_text(self, type_: "sqltypes.Text", **kw: Any) -> str: + def visit_text(self, type_: sqltypes.String, **kw: Any) -> str: return self.visit_TEXT(type_, **kw) - def visit_unicode_text( - self, type_: "sqltypes.UnicodeText", **kw: Any - ) -> str: + def visit_unicode_text(self, type_: sqltypes.String, **kw: Any) -> str: return self.visit_TEXT(type_, **kw) - def visit_enum(self, type_: "sqltypes.Enum", **kw: Any) -> str: + def visit_enum(self, type_: sqltypes.String, **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) def visit_null(self, type_, **kw): @@ -7418,7 +7394,7 @@ class IdentifierPreparer: def __init__( self, - dialect: DefaultDialect, + dialect: Dialect, initial_quote: str = '"', final_quote: str | None = None, escape_quote: str = '"', @@ -7490,7 +7466,7 @@ class IdentifierPreparer: "schema_translate_map dictionaries." ) - d["_none"] = d[None] + d["_none"] = d[None] # type: ignore[index] def replace(m): name = m.group(2) @@ -7753,7 +7729,7 @@ class IdentifierPreparer: # to dialect.max_identifier_length etc. can be reflected # as IdentifierPreparer is long lived max_ = ( - self.dialect.max_index_name_length + self.dialect.max_index_name_length # type: ignore[attr-defined] or self.dialect.max_identifier_length ) return self._truncate_and_render_maxlen_name( @@ -7767,7 +7743,7 @@ class IdentifierPreparer: # to dialect.max_identifier_length etc. can be reflected # as IdentifierPreparer is long lived max_ = ( - self.dialect.max_constraint_name_length + self.dialect.max_constraint_name_length # type: ignore[attr-defined] # NOQA: E501 or self.dialect.max_identifier_length ) return self._truncate_and_render_maxlen_name( @@ -7781,7 +7757,7 @@ class IdentifierPreparer: if len(name) > max_: name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:] else: - self.dialect.validate_identifier(name) + self.dialect.validate_identifier(name) # type: ignore[attr-defined] # NOQA: E501 if not _alembic_quote: return name @@ -7794,7 +7770,7 @@ class IdentifierPreparer: @overload def format_table( self, - table: "Table | None", + table: "FromClause | None", use_schema: bool, name: str, ) -> str: ... @@ -7802,20 +7778,21 @@ class IdentifierPreparer: @overload def format_table( self, - table: "Table", + table: "NamedFromClause", use_schema: bool = True, name: None = None, ) -> str: ... def format_table( self, - table: "Table | None", + table: "FromClause | None", use_schema: bool = True, name: str | None = None, ) -> str: """Prepare a quoted table and schema name.""" if name is None: assert table is not None + table = cast("NamedFromClause", table) name = table.name result = self.quote(name) @@ -7847,7 +7824,7 @@ class IdentifierPreparer: def format_column( self, - column: "Column[Any]", + column: ColumnElement[Any], use_table: bool = False, name: str | None = None, table_name: str | None = None, @@ -7858,6 +7835,7 @@ class IdentifierPreparer: if name is None: name = column.name + name = cast(str, name) if anon_map is not None and isinstance( name, elements._truncated_label diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index c3a5b015bd..19af40ff08 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -207,7 +207,7 @@ def _get_crud_params( [ ( c, - compiler.preparer.format_column(c), # type: ignore[arg-type] # noqa: E501 + compiler.preparer.format_column(c), _create_bind_param(compiler, c, None, required=True), (c.key,), ) @@ -369,7 +369,7 @@ def _get_crud_params( values = [ ( _as_dml_column(stmt.table.columns[0]), - compiler.preparer.format_column(stmt.table.columns[0]), # type: ignore[arg-type] # noqa: E501 + compiler.preparer.format_column(stmt.table.columns[0]), compiler.dialect.default_metavalue_token, (), ) @@ -1136,7 +1136,7 @@ def _append_param_insert_select_hasdefault( values.append( ( c, - compiler.preparer.format_column(c), # type: ignore[arg-type] # noqa: E501 + compiler.preparer.format_column(c), c.default.next_value(), (), ) @@ -1145,7 +1145,7 @@ def _append_param_insert_select_hasdefault( values.append( ( c, - compiler.preparer.format_column(c), # type: ignore[arg-type] + compiler.preparer.format_column(c), c.default.arg.self_group(), (), ) @@ -1154,7 +1154,7 @@ def _append_param_insert_select_hasdefault( values.append( ( c, - compiler.preparer.format_column(c), # type: ignore[arg-type] + compiler.preparer.format_column(c), _create_insert_prefetch_bind_param( compiler, c, process=False, **kw ), diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index d397331c14..5f195ec1fa 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -17,18 +17,21 @@ import contextlib import typing from typing import Any from typing import Callable +from typing import Generic from typing import Iterable from typing import List from typing import Optional from typing import Protocol from typing import Sequence as typing_Sequence from typing import Tuple +from typing import TypeVar from . import roles from .base import _generative from .base import Executable from .base import SchemaVisitor from .elements import ClauseElement +from .schema import Column from .. import exc from .. import util from ..util import topological @@ -38,7 +41,6 @@ if typing.TYPE_CHECKING: from .compiler import Compiled from .compiler import DDLCompiler from .elements import BindParameter - from .schema import Column from .schema import Constraint from .schema import ForeignKeyConstraint from .schema import Index @@ -52,6 +54,8 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import Dialect from ..engine.interfaces import SchemaTranslateMapType +T = TypeVar("T", bound="SchemaItem") + class BaseDDLElement(ClauseElement): """The root of DDL constructs, including those that are sub-elements @@ -417,7 +421,7 @@ class DDL(ExecutableDDLElement): ) -class _CreateDropBase(ExecutableDDLElement): +class _CreateDropBase(ExecutableDDLElement, Generic[T]): """Base class for DDL constructs that represent CREATE and DROP or equivalents. @@ -429,7 +433,7 @@ class _CreateDropBase(ExecutableDDLElement): def __init__( self, - element, + element: T, ): self.element = self.target = element self._ddl_if = getattr(element, "_ddl_if", None) @@ -449,13 +453,13 @@ class _CreateDropBase(ExecutableDDLElement): return False -class _CreateBase(_CreateDropBase): +class _CreateBase(_CreateDropBase[Any]): def __init__(self, element, if_not_exists=False): super().__init__(element) self.if_not_exists = if_not_exists -class _DropBase(_CreateDropBase): +class _DropBase(_CreateDropBase[Any]): def __init__(self, element, if_exists=False): super().__init__(element) self.if_exists = if_exists @@ -714,9 +718,9 @@ class CreateIndex(_CreateBase): """Represent a CREATE INDEX statement.""" __visit_name__ = "create_index" - element: "Index" + element: Index - def __init__(self, element: "Index", if_not_exists: bool = False): + def __init__(self, element: Index, if_not_exists: bool = False): """Create a :class:`.Createindex` construct. :param element: a :class:`_schema.Index` that's the subject @@ -735,9 +739,9 @@ class DropIndex(_DropBase): __visit_name__ = "drop_index" - element: "Index" + element: Index - def __init__(self, element: "Index", if_exists: bool = False): + def __init__(self, element: Index, if_exists: bool = False): """Create a :class:`.DropIndex` construct. :param element: a :class:`_schema.Index` that's the subject @@ -776,13 +780,13 @@ class DropConstraint(_DropBase): ) -class SetTableComment(_CreateDropBase): +class SetTableComment(_CreateDropBase["Table"]): """Represent a COMMENT ON TABLE IS statement.""" __visit_name__ = "set_table_comment" -class DropTableComment(_CreateDropBase): +class DropTableComment(_CreateDropBase["Table"]): """Represent a COMMENT ON TABLE '' statement. Note this varies a lot across database backends. @@ -792,26 +796,25 @@ class DropTableComment(_CreateDropBase): __visit_name__ = "drop_table_comment" -class SetColumnComment(_CreateDropBase): +class SetColumnComment(_CreateDropBase[Column[Any]]): """Represent a COMMENT ON COLUMN IS statement.""" __visit_name__ = "set_column_comment" - element: "Column[Any]" -class DropColumnComment(_CreateDropBase): +class DropColumnComment(_CreateDropBase[Column[Any]]): """Represent a COMMENT ON COLUMN IS NULL statement.""" __visit_name__ = "drop_column_comment" -class SetConstraintComment(_CreateDropBase): +class SetConstraintComment(_CreateDropBase["Constraint"]): """Represent a COMMENT ON CONSTRAINT IS statement.""" __visit_name__ = "set_constraint_comment" -class DropConstraintComment(_CreateDropBase): +class DropConstraintComment(_CreateDropBase["Constraint"]): """Represent a COMMENT ON CONSTRAINT IS NULL statement.""" __visit_name__ = "drop_constraint_comment" diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 6e660c0be7..9a1db7d69e 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -76,7 +76,6 @@ from .. import inspection from .. import util from ..util import HasMemoized_ro_memoized_attribute from ..util import TypingOnly -from ..util._immutabledict_cy import immutabledict from ..util.typing import Literal from ..util.typing import ParamSpec from ..util.typing import Self @@ -3879,6 +3878,7 @@ class BinaryExpression(OperatorExpression[_T]): left: ColumnElement[Any] right: ColumnElement[Any] + modifiers: Mapping[str, Any] def __init__( self, @@ -3907,9 +3907,9 @@ class BinaryExpression(OperatorExpression[_T]): self._is_implicitly_boolean = operators.is_boolean(operator) if modifiers is None: - self.modifiers: immutabledict[str, str] = immutabledict({}) + self.modifiers = {} else: - self.modifiers = immutabledict(modifiers) + self.modifiers = modifiers @property def _flattened_operator_clauses( diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 9f7fb1330f..ef0396bbe5 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -247,7 +247,7 @@ class String(Concatenable, TypeEngine[str]): return process def bind_processor( - self, dialect: "Dialect" + self, dialect: Dialect ) -> _BindProcessorType[str] | None: return None @@ -1333,9 +1333,9 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): __visit_name__ = "enum" - enum_class: None | str | type[enum.StrEnum] + enum_class: None | str | type[enum.Enum] - def __init__(self, *enums: object, **kw: Any): + def __init__(self, *enums: Union[str, type[enum.Enum]], **kw: Any): r"""Construct an enum. Keyword arguments which don't apply to a specific backend are ignored @@ -1467,7 +1467,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): .. versionchanged:: 2.0 This parameter now defaults to True. """ - self._enum_init(enums, kw) # type: ignore[arg-type] + self._enum_init(enums, kw) @property def _enums_argument(self): @@ -1477,7 +1477,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): return self.enums def _enum_init( - self, enums: Sequence[str | type[enum.StrEnum]], kw: dict[str, Any] + self, enums: Sequence[Union[str, type[enum.Enum]]], kw: dict[str, Any] ) -> None: """internal init for :class:`.Enum` and subclasses. @@ -1489,7 +1489,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): self.native_enum = kw.pop("native_enum", True) self.create_constraint = kw.pop("create_constraint", False) self.values_callable: ( - Callable[[type[enum.StrEnum]], Sequence[str]] | None + Callable[[type[enum.Enum]], Sequence[str]] | None ) = kw.pop("values_callable", None) self._sort_key_function = kw.pop("sort_key_function", NO_ARG) length_arg = kw.pop("length", NO_ARG) @@ -1518,7 +1518,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): ) length = length_arg - self._valid_lookup[None] = self._object_lookup[None] = None # type: ignore # noqa: E501 + self._valid_lookup[None] = self._object_lookup[None] = None super().__init__(length=length) @@ -1540,8 +1540,8 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): ) def _parse_into_values( - self, enums: Sequence[str | type[enum.StrEnum]], kw: Any - ) -> tuple[Sequence[str], Sequence[enum.StrEnum] | Sequence[str]]: + self, enums: Sequence[str | type[enum.Enum]], kw: Any + ) -> tuple[Sequence[str], Sequence[enum.Enum] | Sequence[str]]: if not enums and "_enums" in kw: enums = kw.pop("_enums") @@ -1658,16 +1658,18 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): def _setup_for_values( self, values: Sequence[str], - objects: Sequence[enum.StrEnum] | Sequence[str], + objects: Sequence[enum.Enum] | Sequence[str], kw: Any, ) -> None: self.enums = list(values) - self._valid_lookup: dict[str, str] = dict( + self._valid_lookup: dict[enum.Enum | str | None, str | None] = dict( zip(reversed(objects), reversed(values)) ) - self._object_lookup: dict[str, str] = dict(zip(values, objects)) + self._object_lookup: dict[str | None, enum.Enum | str | None] = dict( + zip(values, objects) + ) self._valid_lookup.update( [ @@ -1729,9 +1731,10 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): comparator_factory = Comparator - def _object_value_for_elem(self, elem: str) -> str: + def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]: try: - return self._object_lookup[elem] + # Value will not be None beacuse key is not None + return self._object_lookup[elem] # type: ignore[return-value] except KeyError as err: raise LookupError( "'%s' is not among the defined enum values. " @@ -3513,7 +3516,7 @@ class BINARY(_Binary): class VARBINARY(_Binary): """The SQL VARBINARY type.""" - length: int + length: Optional[int] __visit_name__ = "VARBINARY" @@ -3720,18 +3723,18 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]): if character_based_uuid: if self.as_uuid: - def process(value: Any) -> str: + def process(value): if value is not None: value = value.hex - return value # type: ignore[no-any-return] + return value return process else: - def process(value: Any) -> str: + def process(value): if value is not None: value = value.replace("-", "") - return value # type: ignore[no-any-return] + return value return process else: diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 33f7bc41a1..55e8807b9d 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1376,8 +1376,10 @@ class UserDefinedType( return self - def get_col_spec(self, **kw: Any) -> str: - raise NotImplementedError() + if TYPE_CHECKING: + + def get_col_spec(self, **kw: Any) -> str: + raise NotImplementedError() class Emulated(TypeEngineMixin): diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 34b6b8a29e..97a38393e9 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -431,7 +431,7 @@ def to_column_set(x: Any) -> Set[Any]: def update_copy( - d: dict[Any, Any], _new: dict[Any, Any] | None = None, **kw: dict[Any, Any] + d: dict[Any, Any], _new: dict[Any, Any] | None = None, **kw: Any ) -> dict[Any, Any]: """Copy the given dict and update with the given values."""