From: Pablo Estevez Date: Wed, 8 Jan 2025 20:31:05 +0000 (-0300) Subject: miscelaneous to type mysql dialect X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c2f88226312658626ca88caa16a1d644c9033baa;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git miscelaneous to type mysql dialect --- diff --git a/lib/sqlalchemy/connectors/asyncio.py b/lib/sqlalchemy/connectors/asyncio.py index e57f7bfdf2..bce08d9cc3 100644 --- a/lib/sqlalchemy/connectors/asyncio.py +++ b/lib/sqlalchemy/connectors/asyncio.py @@ -40,7 +40,7 @@ class AsyncIODBAPIConnection(Protocol): async def commit(self) -> None: ... - def cursor(self) -> AsyncIODBAPICursor: ... + def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ... async def rollback(self) -> None: ... diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 3a32d19c8b..b26d32fdf7 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -14,6 +14,7 @@ from typing import Any from typing import Dict from typing import List from typing import Optional +from typing import Sequence from typing import Tuple from typing import Union @@ -228,8 +229,8 @@ class PyODBCConnector(Connector): def get_isolation_level_values( self, dbapi_connection: interfaces.DBAPIConnection - ) -> List[IsolationLevel]: - return super().get_isolation_level_values(dbapi_connection) + [ + ) -> Sequence[IsolationLevel]: + return super().get_isolation_level_values(dbapi_connection) + [ # type: ignore # NOQA: E501 "AUTOCOMMIT" ] diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 56d7ee7588..dc4f16f514 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1379,7 +1379,10 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): __slots__ = ("_rowbuffer", "alternate_cursor_description") def __init__( - self, dbapi_cursor, alternate_description=None, initial_buffer=None + self, + dbapi_cursor: DBAPICursor, + alternate_description: _DBAPICursorDescription | None = None, + initial_buffer: Any = None, ): self.alternate_cursor_description = alternate_description if initial_buffer is not None: @@ -1900,7 +1903,7 @@ class CursorResult(Result[Unpack[_Ts]]): clone._metadata = clone._metadata._splice_horizontally(other._metadata) clone.cursor_strategy = FullyBufferedCursorFetchStrategy( - None, + None, # type: ignore[arg-type] initial_buffer=total_rows, ) clone._reset_memoizations() @@ -1932,7 +1935,7 @@ class CursorResult(Result[Unpack[_Ts]]): ) clone.cursor_strategy = FullyBufferedCursorFetchStrategy( - None, + None, # type: ignore[arg-type] initial_buffer=total_rows, ) clone._reset_memoizations() @@ -1961,7 +1964,7 @@ class CursorResult(Result[Unpack[_Ts]]): )._remove_processors() self.cursor_strategy = FullyBufferedCursorFetchStrategy( - None, + None, # type: ignore[arg-type] # TODO: if these are Row objects, can we save on not having to # re-make new Row objects out of them a second time? is that # what's actually happening right now? maybe look into this diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ba59ac297b..8dce692d73 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -80,9 +80,11 @@ if typing.TYPE_CHECKING: from .interfaces import _CoreSingleExecuteParams from .interfaces import _DBAPICursorDescription from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions from .interfaces import _MutableCoreSingleExecuteParams from .interfaces import _ParamStyle + from .interfaces import ConnectArgsType from .interfaces import DBAPIConnection from .interfaces import IsolationLevel from .row import Row @@ -101,6 +103,8 @@ if typing.TYPE_CHECKING: from ..sql.type_api import _BindProcessorType from ..sql.type_api import _ResultProcessorType from ..sql.type_api import TypeEngine + from ..util.langhelpers import generic_fn_descriptor + # When we're handed literal SQL, ensure it's a SELECT query SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE) @@ -245,7 +249,7 @@ class DefaultDialect(Dialect): supports_is_distinct_from = True - supports_server_side_cursors = False + supports_server_side_cursors: "generic_fn_descriptor[bool] | bool" = False server_side_cursors = False @@ -440,7 +444,7 @@ class DefaultDialect(Dialect): def _bind_typing_render_casts(self): return self.bind_typing is interfaces.BindTyping.RENDER_CASTS - def _ensure_has_table_connection(self, arg): + def _ensure_has_table_connection(self, arg: Connection) -> None: if not isinstance(arg, Connection): raise exc.ArgumentError( "The argument passed to Dialect.has_table() should be a " @@ -524,7 +528,7 @@ class DefaultDialect(Dialect): else: return None - def initialize(self, connection): + def initialize(self, connection: Connection) -> None: try: self.server_version_info = self._get_server_version_info( connection @@ -560,7 +564,7 @@ class DefaultDialect(Dialect): % (self.label_length, self.max_identifier_length) ) - def on_connect(self): + def on_connect(self) -> Callable[[Any], None] | None: # inherits the docstring from interfaces.Dialect.on_connect return None @@ -626,11 +630,11 @@ class DefaultDialect(Dialect): % (ident, self.max_identifier_length) ) - def connect(self, *cargs, **cparams): + def connect(self, *cargs: Any, **cparams: Any): # type: ignore[no-untyped-def] # NOQA: E501 # inherits the docstring from interfaces.Dialect.connect return self.loaded_dbapi.connect(*cargs, **cparams) - def create_connect_args(self, url): + def create_connect_args(self, url: URL) -> "ConnectArgsType": # inherits the docstring from interfaces.Dialect.create_connect_args opts = url.translate_connect_args() opts.update(url.query) @@ -953,7 +957,14 @@ class DefaultDialect(Dialect): def do_execute_no_params(self, cursor, statement, context=None): cursor.execute(statement) - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: Exception, + connection: ( + pool.PoolProxiedConnection | interfaces.DBAPIConnection | None + ), + cursor: interfaces.DBAPICursor | None, + ) -> bool: return False @util.memoized_instancemethod @@ -1669,7 +1680,12 @@ class DefaultExecutionContext(ExecutionContext): def no_parameters(self): return self.execution_options.get("no_parameters", False) - def _execute_scalar(self, stmt, type_, parameters=None): + def _execute_scalar( + self, + stmt: str, + type_: TypeEngine[Any] | None, + parameters: _DBAPISingleExecuteParams | None = None, + ) -> Any: """Execute a string statement on the current cursor, returning a scalar result. @@ -1743,7 +1759,7 @@ class DefaultExecutionContext(ExecutionContext): return use_server_side - def create_cursor(self): + def create_cursor(self) -> DBAPICursor: if ( # inlining initial preference checks for SS cursors self.dialect.supports_server_side_cursors @@ -1764,10 +1780,10 @@ class DefaultExecutionContext(ExecutionContext): def fetchall_for_returning(self, cursor): return cursor.fetchall() - def create_default_cursor(self): + def create_default_cursor(self) -> DBAPICursor: return self._dbapi_connection.cursor() - def create_server_side_cursor(self): + def create_server_side_cursor(self) -> DBAPICursor: raise NotImplementedError() def pre_exec(self): diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 35c52ae3b9..74ec272514 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -70,6 +70,7 @@ if TYPE_CHECKING: from ..sql.sqltypes import Integer from ..sql.type_api import _TypeMemoDict from ..sql.type_api import TypeEngine + from ..util.langhelpers import generic_fn_descriptor ConnectArgsType = Tuple[Sequence[str], MutableMapping[str, Any]] @@ -122,7 +123,7 @@ class DBAPIConnection(Protocol): def commit(self) -> None: ... - def cursor(self) -> DBAPICursor: ... + def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ... def rollback(self) -> None: ... @@ -241,7 +242,7 @@ _AnyMultiExecuteParams = _DBAPIMultiExecuteParams _AnyExecuteParams = _DBAPIAnyExecuteParams CompiledCacheType = MutableMapping[Any, "Compiled"] -SchemaTranslateMapType = Mapping[Optional[str], Optional[str]] +SchemaTranslateMapType = MutableMapping[Optional[str], Optional[str]] _ImmutableExecuteOptions = immutabledict[str, Any] @@ -547,6 +548,8 @@ class ReflectedIndex(TypedDict): dialect_options: NotRequired[Dict[str, Any]] """Additional dialect-specific options detected for this index""" + type: NotRequired[str] + class ReflectedTableComment(TypedDict): """Dictionary representing the reflected comment corresponding to @@ -781,7 +784,7 @@ class Dialect(EventTarget): max_identifier_length: int """The maximum length of identifier names.""" - supports_server_side_cursors: bool + supports_server_side_cursors: "generic_fn_descriptor[bool] | bool" """indicates if the dialect supports server side cursors""" server_side_cursors: bool @@ -2483,7 +2486,7 @@ class Dialect(EventTarget): def get_isolation_level_values( self, dbapi_conn: DBAPIConnection - ) -> List[IsolationLevel]: + ) -> Sequence[IsolationLevel]: """return a sequence of string isolation level names that are accepted by this dialect. diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index b91048e387..1eb8b4c217 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -1074,7 +1074,7 @@ class PoolProxiedConnection(ManagesConnection): def commit(self) -> None: ... - def cursor(self) -> DBAPICursor: ... + def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ... def rollback(self) -> None: ... diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a93ea4e42e..aa86b5e880 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -656,7 +656,9 @@ class CompileState: _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] @classmethod - def create_for_statement(cls, statement, compiler, **kw): + def create_for_statement( + cls, statement: Executable, compiler: Compiled, **kw: Any + ) -> "CompileState": # factory construction. if statement._propagate_attrs: diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 7119ae1c1f..620f53a8b6 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -75,7 +75,7 @@ _StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole) _T = TypeVar("_T", bound=Any) -def _is_literal(element): +def _is_literal(element: Any) -> bool: """Return whether or not the element is a "literal" in the context of a SQL expression construct. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6010b95862..a343efbceb 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -28,6 +28,7 @@ from __future__ import annotations import collections import collections.abc as collections_abc import contextlib +import decimal from enum import IntEnum import functools import itertools @@ -49,6 +50,7 @@ from typing import MutableMapping from typing import NamedTuple from typing import NoReturn from typing import Optional +from typing import overload from typing import Pattern from typing import Protocol from typing import Sequence @@ -79,10 +81,18 @@ from .base import _SentinelDefaultCharacterization from .base import Executable from .base import NO_ARG from .elements import ClauseElement +from .elements import False_ +from .elements import Null from .elements import quoted_name +from .elements import True_ from .schema import Column +from .schema import ForeignKeyConstraint +from .schema import UniqueConstraint +from .sqltypes import _UUID_RETURN from .sqltypes import TupleType +from .type_api import TypeDecorator from .type_api import TypeEngine +from .type_api import UserDefinedType from .visitors import prefix_anon_map from .visitors import Visitable from .. import exc @@ -99,7 +109,9 @@ if typing.TYPE_CHECKING: from .cache_key import CacheKey from .ddl import ExecutableDDLElement from .dml import Insert + from .dml import Update from .dml import UpdateBase + from .dml import UpdateDMLState from .dml import ValuesBase from .elements import _truncated_label from .elements import BindParameter @@ -107,6 +119,9 @@ if typing.TYPE_CHECKING: from .elements import ColumnElement from .elements import Label from .functions import Function + from .schema import Constraint + from .schema import Index + from .schema import PrimaryKeyConstraint from .schema import Table from .selectable import AliasedReturnsRows from .selectable import CompoundSelectState @@ -118,6 +133,7 @@ if typing.TYPE_CHECKING: from .selectable import SelectState from .type_api import _BindProcessorType from ..engine.cursor import CursorResultMetaData + from ..engine.default import DefaultDialect from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _DBAPIAnyExecuteParams from ..engine.interfaces import _DBAPIMultiExecuteParams @@ -128,6 +144,7 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import Dialect from ..engine.interfaces import SchemaTranslateMapType + _FromHintsType = Dict["FromClause", str] RESERVED_WORDS = { @@ -873,7 +890,7 @@ class Compiled: if render_schema_translate: self.string = self.preparer._render_schema_translates( - self.string, schema_translate_map + self.string, schema_translate_map # type: ignore[arg-type] ) self.state = CompilerState.STRING_APPLIED @@ -2344,7 +2361,7 @@ class SQLCompiler(Compiled): return get - def default_from(self): + def default_from(self) -> str: """Called when a SELECT statement has no froms, and no FROM clause is to be appended. @@ -2736,16 +2753,16 @@ class SQLCompiler(Compiled): return text - def visit_null(self, expr, **kw): + def visit_null(self, expr: Null, **kw: Any) -> str: return "NULL" - def visit_true(self, expr, **kw): + def visit_true(self, expr: True_, **kw: Any) -> str: if self.dialect.supports_native_boolean: return "true" else: return "1" - def visit_false(self, expr, **kw): + def visit_false(self, expr: False_, **kw: Any) -> str: if self.dialect.supports_native_boolean: return "false" else: @@ -2973,7 +2990,9 @@ class SQLCompiler(Compiled): % self.dialect.name ) - def function_argspec(self, func, **kwargs): + def function_argspec( + self, func: functions.Function[Any], **kwargs: Any + ) -> str: return func.clause_expr._compiler_dispatch(self, **kwargs) def visit_compound_select( @@ -3437,8 +3456,12 @@ class SQLCompiler(Compiled): ) def _generate_generic_binary( - self, binary, opstring, eager_grouping=False, **kw - ): + self, + binary: elements.BinaryExpression[Any], + opstring: str, + eager_grouping: bool = False, + **kw: Any, + ) -> str: _in_operator_expression = kw.get("_in_operator_expression", False) kw["_in_operator_expression"] = True @@ -3607,19 +3630,25 @@ class SQLCompiler(Compiled): **kw, ) - def visit_regexp_match_op_binary(self, binary, operator, **kw): + def visit_regexp_match_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: raise exc.CompileError( "%s dialect does not support regular expressions" % self.dialect.name ) - def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + def visit_not_regexp_match_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: raise exc.CompileError( "%s dialect does not support regular expressions" % self.dialect.name ) - def visit_regexp_replace_op_binary(self, binary, operator, **kw): + def visit_regexp_replace_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: raise exc.CompileError( "%s dialect does not support regular expression replacements" % self.dialect.name @@ -3826,7 +3855,9 @@ class SQLCompiler(Compiled): else: return self.render_literal_value(value, bindparam.type) - def render_literal_value(self, value, type_): + def render_literal_value( + self, value: str | None, type_: sqltypes.String + ) -> str: """Render the value of a bind parameter as a quoted literal. This is used for statement sections that do not accept bind parameters @@ -4587,7 +4618,7 @@ class SQLCompiler(Compiled): def get_select_hint_text(self, byfroms): return None - def get_from_hint_text(self, table, text): + def get_from_hint_text(self, table: Any, text: str | None) -> str | None: return None def get_crud_hint_text(self, table, text): @@ -5072,7 +5103,7 @@ class SQLCompiler(Compiled): else: return "WITH" - def get_select_precolumns(self, select, **kw): + def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str: """Called when building a ``SELECT`` statement, position is just before column list. @@ -6098,7 +6129,7 @@ class SQLCompiler(Compiled): return text - def update_limit_clause(self, update_stmt): + def update_limit_clause(self, update_stmt: "Update") -> str | None: """Provide a hook for MySQL to add LIMIT to the UPDATE""" return None @@ -6130,11 +6161,14 @@ class SQLCompiler(Compiled): "criteria within UPDATE" ) - def visit_update(self, update_stmt, visiting_cte=None, **kw): - compile_state = update_stmt._compile_state_factory( - update_stmt, self, **kw + def visit_update( + self, update_stmt: "Update", visiting_cte: CTE | None = None, **kw: Any + ) -> str: + compile_state = update_stmt._compile_state_factory( # type: ignore + update_stmt, self, **kw # type: ignore ) - update_stmt = compile_state.statement + compile_state = cast("UpdateDMLState", compile_state) + update_stmt = compile_state.statement # type: ignore[assignment] if visiting_cte is not None: kw["visiting_cte"] = visiting_cte @@ -6223,7 +6257,7 @@ class SQLCompiler(Compiled): if self.returning_precedes_values: text += " " + self.returning_clause( update_stmt, - self.implicit_returning or update_stmt._returning, + self.implicit_returning or update_stmt._returning, # type: ignore[arg-type] # NOQA: E501 populate_result_map=toplevel, ) @@ -6255,7 +6289,7 @@ class SQLCompiler(Compiled): ) and not self.returning_precedes_values: text += " " + self.returning_clause( update_stmt, - self.implicit_returning or update_stmt._returning, + self.implicit_returning or update_stmt._returning, # type: ignore[arg-type] # noqa: E501 populate_result_map=toplevel, ) @@ -6272,7 +6306,7 @@ class SQLCompiler(Compiled): return text def delete_extra_from_clause( - self, update_stmt, from_table, extra_froms, from_hints, **kw + self, delete_stmt, from_table, extra_froms, from_hints, **kw ): """Provide a hook to override the generation of an DELETE..FROM clause. @@ -6515,7 +6549,7 @@ class StrSQLCompiler(SQLCompiler): ) def delete_extra_from_clause( - self, update_stmt, from_table, extra_froms, from_hints, **kw + self, delete_stmt, from_table, extra_froms, from_hints, **kw ): kw["asfrom"] = True return ", " + ", ".join( @@ -6727,7 +6761,7 @@ class DDLCompiler(Compiled): def visit_drop_view(self, drop, **kw): return "\nDROP VIEW " + self.preparer.format_table(drop.element) - def _verify_index_table(self, index): + def _verify_index_table(self, index: "Index") -> None: if index.table is None: raise exc.CompileError( "Index '%s' is not associated with any table." % index.name @@ -6778,7 +6812,9 @@ class DDLCompiler(Compiled): return text + self._prepared_index_name(index, include_schema=True) - def _prepared_index_name(self, index, include_schema=False): + def _prepared_index_name( + self, index: "Index", include_schema: bool = False + ) -> str: if index.table is not None: effective_schema = self.preparer.schema_for_object(index.table) else: @@ -6788,7 +6824,7 @@ class DDLCompiler(Compiled): else: schema_name = None - index_name = self.preparer.format_index(index) + index_name: str = self.preparer.format_index(index) if schema_name: index_name = schema_name + "." + index_name @@ -6925,19 +6961,19 @@ class DDLCompiler(Compiled): def post_create_table(self, table): return "" - def get_column_default_string(self, column): + def get_column_default_string(self, column: Column[Any]) -> str | None: if isinstance(column.server_default, schema.DefaultClause): return self.render_default_string(column.server_default.arg) else: return None - def render_default_string(self, default): + def render_default_string(self, default: Visitable | str) -> str: if isinstance(default, str): - return self.sql_compiler.render_literal_value( + return self.sql_compiler.render_literal_value( # type: ignore[no-any-return] # NOQA: E501 default, sqltypes.STRINGTYPE ) else: - return self.sql_compiler.process(default, literal_binds=True) + return self.sql_compiler.process(default, literal_binds=True) # type: ignore[no-any-return] # NOQA: E501 def visit_table_or_column_check_constraint(self, constraint, **kw): if constraint.is_column_level: @@ -6969,7 +7005,9 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_primary_key_constraint(self, constraint, **kw): + def visit_primary_key_constraint( + self, constraint: "PrimaryKeyConstraint", **kw: Any + ) -> str: if len(constraint) == 0: return "" text = "" @@ -7018,7 +7056,9 @@ class DDLCompiler(Compiled): return preparer.format_table(table) - def visit_unique_constraint(self, constraint, **kw): + def visit_unique_constraint( + self, constraint: UniqueConstraint, **kw: Any + ) -> str: if len(constraint) == 0: return "" text = "" @@ -7033,10 +7073,14 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def define_unique_constraint_distinct(self, constraint, **kw): + def define_unique_constraint_distinct( + self, constraint: UniqueConstraint, **kw: Any + ) -> str: return "" - def define_constraint_cascades(self, constraint): + def define_constraint_cascades( + self, constraint: ForeignKeyConstraint + ) -> str: text = "" if constraint.ondelete is not None: text += " ON DELETE %s" % self.preparer.validate_sql_phrase( @@ -7048,7 +7092,7 @@ class DDLCompiler(Compiled): ) return text - def define_constraint_deferrability(self, constraint): + def define_constraint_deferrability(self, constraint: Constraint) -> str: text = "" if constraint.deferrable is not None: if constraint.deferrable: @@ -7088,19 +7132,31 @@ class DDLCompiler(Compiled): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT(self, type_, **kw): + def visit_FLOAT( + self, type_: "sqltypes.Float[decimal.Decimal| float]", **kw: Any + ) -> str: return "FLOAT" - def visit_DOUBLE(self, type_, **kw): + def visit_DOUBLE( + self, type_: "sqltypes.Double[decimal.Decimal | float]", **kw: Any + ) -> str: return "DOUBLE" - def visit_DOUBLE_PRECISION(self, type_, **kw): + def visit_DOUBLE_PRECISION( + self, + type_: "sqltypes.DOUBLE_PRECISION[decimal.Decimal| float]", + **kw: Any, + ) -> str: return "DOUBLE PRECISION" - def visit_REAL(self, type_, **kw): + def visit_REAL( + self, type_: "sqltypes.REAL[decimal.Decimal| float]", **kw: Any + ) -> str: return "REAL" - def visit_NUMERIC(self, type_, **kw): + def visit_NUMERIC( + self, type_: "sqltypes.Numeric[decimal.Decimal| float]", **kw: Any + ) -> str: if type_.precision is None: return "NUMERIC" elif type_.scale is None: @@ -7111,7 +7167,9 @@ class GenericTypeCompiler(TypeCompiler): "scale": type_.scale, } - def visit_DECIMAL(self, type_, **kw): + def visit_DECIMAL( + self, type_: "sqltypes.DECIMAL[decimal.Decimal| float]", **kw: Any + ) -> str: if type_.precision is None: return "DECIMAL" elif type_.scale is None: @@ -7122,128 +7180,153 @@ class GenericTypeCompiler(TypeCompiler): "scale": type_.scale, } - def visit_INTEGER(self, type_, **kw): + def visit_INTEGER(self, type_: "sqltypes.Integer", **kw: Any) -> str: return "INTEGER" - def visit_SMALLINT(self, type_, **kw): + def visit_SMALLINT(self, type_: "sqltypes.SmallInteger", **kw: Any) -> str: return "SMALLINT" - def visit_BIGINT(self, type_, **kw): + def visit_BIGINT(self, type_: "sqltypes.BigInteger", **kw: Any) -> str: return "BIGINT" - def visit_TIMESTAMP(self, type_, **kw): + def visit_TIMESTAMP(self, type_: "sqltypes.TIMESTAMP", **kw: Any) -> str: return "TIMESTAMP" - def visit_DATETIME(self, type_, **kw): + def visit_DATETIME(self, type_: "sqltypes.DateTime", **kw: Any) -> str: return "DATETIME" - def visit_DATE(self, type_, **kw): + def visit_DATE(self, type_: "sqltypes.Date", **kw: Any) -> str: return "DATE" - def visit_TIME(self, type_, **kw): + def visit_TIME(self, type_: "sqltypes.Time", **kw: Any) -> str: return "TIME" - def visit_CLOB(self, type_, **kw): + def visit_CLOB(self, type_: "sqltypes.Text", **kw: Any) -> str: return "CLOB" - def visit_NCLOB(self, type_, **kw): + def visit_NCLOB(self, type_: "sqltypes.Text", **kw: Any) -> str: return "NCLOB" - def _render_string_type(self, type_, name, length_override=None): + def _render_string_type( + self, + type_: "sqltypes.String | sqltypes.Uuid[_UUID_RETURN]", + name: str, + length_override: int | None = None, + ) -> str: text = name if length_override: text += "(%d)" % length_override - elif type_.length: - text += "(%d)" % type_.length + elif type_.length: # type: ignore[union-attr] + text += "(%d)" % type_.length # type: ignore[union-attr] if type_.collation: text += ' COLLATE "%s"' % type_.collation return text - def visit_CHAR(self, type_, **kw): + def visit_CHAR(self, type_: "sqltypes.CHAR", **kw: Any) -> str: return self._render_string_type(type_, "CHAR") - def visit_NCHAR(self, type_, **kw): + def visit_NCHAR(self, type_: "sqltypes.NCHAR", **kw: Any) -> str: return self._render_string_type(type_, "NCHAR") - def visit_VARCHAR(self, type_, **kw): + def visit_VARCHAR(self, type_: "sqltypes.String", **kw: Any) -> str: return self._render_string_type(type_, "VARCHAR") - def visit_NVARCHAR(self, type_, **kw): + def visit_NVARCHAR(self, type_: "sqltypes.NVARCHAR", **kw: Any) -> str: return self._render_string_type(type_, "NVARCHAR") - def visit_TEXT(self, type_, **kw): + def visit_TEXT(self, type_: "sqltypes.Text", **kw: Any) -> str: return self._render_string_type(type_, "TEXT") - def visit_UUID(self, type_, **kw): + def visit_UUID( + self, type_: "sqltypes.Uuid[_UUID_RETURN]", **kw: Any + ) -> str: return "UUID" - def visit_BLOB(self, type_, **kw): + def visit_BLOB(self, type_: "sqltypes.LargeBinary", **kw: Any) -> str: return "BLOB" - def visit_BINARY(self, type_, **kw): + def visit_BINARY(self, type_: "sqltypes.BINARY", **kw: Any) -> str: return "BINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_VARBINARY(self, type_, **kw): + def visit_VARBINARY(self, type_: "sqltypes.VARBINARY", **kw: Any) -> str: return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_BOOLEAN(self, type_, **kw): + def visit_BOOLEAN(self, type_: "sqltypes.Boolean", **kw: Any) -> str: return "BOOLEAN" - def visit_uuid(self, type_, **kw): + def visit_uuid( + self, type_: "sqltypes.Uuid[_UUID_RETURN]", **kw: Any + ) -> str: if not type_.native_uuid or not self.dialect.supports_native_uuid: return self._render_string_type(type_, "CHAR", length_override=32) else: return self.visit_UUID(type_, **kw) - def visit_large_binary(self, type_, **kw): + def visit_large_binary( + self, type_: "sqltypes.LargeBinary", **kw: Any + ) -> str: return self.visit_BLOB(type_, **kw) - def visit_boolean(self, type_, **kw): + def visit_boolean(self, type_: "sqltypes.Boolean", **kw: Any) -> str: return self.visit_BOOLEAN(type_, **kw) - def visit_time(self, type_, **kw): + def visit_time(self, type_: "sqltypes.Time", **kw: Any) -> str: return self.visit_TIME(type_, **kw) - def visit_datetime(self, type_, **kw): + def visit_datetime(self, type_: "sqltypes.DateTime", **kw: Any) -> str: return self.visit_DATETIME(type_, **kw) - def visit_date(self, type_, **kw): + def visit_date(self, type_: "sqltypes.Date", **kw: Any) -> str: return self.visit_DATE(type_, **kw) - def visit_big_integer(self, type_, **kw): + def visit_big_integer( + self, type_: "sqltypes.BigInteger", **kw: Any + ) -> str: return self.visit_BIGINT(type_, **kw) - def visit_small_integer(self, type_, **kw): + def visit_small_integer( + self, type_: "sqltypes.SmallInteger", **kw: Any + ) -> str: return self.visit_SMALLINT(type_, **kw) - def visit_integer(self, type_, **kw): + def visit_integer(self, type_: "sqltypes.Integer", **kw: Any) -> str: return self.visit_INTEGER(type_, **kw) - def visit_real(self, type_, **kw): + def visit_real( + self, type_: "sqltypes.REAL[decimal.Decimal| float]", **kw: Any + ) -> str: return self.visit_REAL(type_, **kw) - def visit_float(self, type_, **kw): + def visit_float( + self, type_: "sqltypes.Float[decimal.Decimal| float]", **kw: Any + ) -> str: return self.visit_FLOAT(type_, **kw) - def visit_double(self, type_, **kw): + def visit_double( + self, type_: "sqltypes.Double[decimal.Decimal | float]", **kw: Any + ) -> str: return self.visit_DOUBLE(type_, **kw) - def visit_numeric(self, type_, **kw): + def visit_numeric( + self, type_: "sqltypes.Numeric[decimal.Decimal | float]", **kw: Any + ) -> str: return self.visit_NUMERIC(type_, **kw) - def visit_string(self, type_, **kw): + def visit_string(self, type_: "sqltypes.String", **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) - def visit_unicode(self, type_, **kw): + def visit_unicode(self, type_: "sqltypes.Unicode", **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) - def visit_text(self, type_, **kw): + def visit_text(self, type_: "sqltypes.Text", **kw: Any) -> str: return self.visit_TEXT(type_, **kw) - def visit_unicode_text(self, type_, **kw): + def visit_unicode_text( + self, type_: "sqltypes.UnicodeText", **kw: Any + ) -> str: return self.visit_TEXT(type_, **kw) - def visit_enum(self, type_, **kw): + def visit_enum(self, type_: "sqltypes.Enum", **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) def visit_null(self, type_, **kw): @@ -7253,10 +7336,14 @@ class GenericTypeCompiler(TypeCompiler): "type on this Column?" % type_ ) - def visit_type_decorator(self, type_, **kw): + def visit_type_decorator( + self, type_: TypeDecorator[Any], **kw: Any + ) -> str: return self.process(type_.type_engine(self.dialect), **kw) - def visit_user_defined(self, type_, **kw): + def visit_user_defined( + self, type_: UserDefinedType[Any], **kw: Any + ) -> str: return type_.get_col_spec(**kw) @@ -7331,12 +7418,12 @@ class IdentifierPreparer: def __init__( self, - dialect, - initial_quote='"', - final_quote=None, - escape_quote='"', - quote_case_sensitive_collations=True, - omit_schema=False, + dialect: DefaultDialect, + initial_quote: str = '"', + final_quote: str | None = None, + escape_quote: str = '"', + quote_case_sensitive_collations: bool = True, + omit_schema: bool = False, ): """Construct a new ``IdentifierPreparer`` object. @@ -7389,7 +7476,9 @@ class IdentifierPreparer: prep._includes_none_schema_translate = includes_none return prep - def _render_schema_translates(self, statement, schema_translate_map): + def _render_schema_translates( + self, statement: str, schema_translate_map: SchemaTranslateMapType + ) -> str: d = schema_translate_map if None in d: if not self._includes_none_schema_translate: @@ -7594,7 +7683,9 @@ class IdentifierPreparer: else: return collation_name - def format_sequence(self, sequence, use_schema=True): + def format_sequence( + self, sequence: schema.Sequence, use_schema: bool = True + ) -> str: name = self.quote(sequence.name) effective_schema = self.schema_for_object(sequence) @@ -7631,7 +7722,9 @@ class IdentifierPreparer: return ident @util.preload_module("sqlalchemy.sql.naming") - def format_constraint(self, constraint, _alembic_quote=True): + def format_constraint( + self, constraint: Constraint, _alembic_quote: bool = True + ) -> str | None: naming = util.preloaded.sql_naming if constraint.name is _NONE_NAME: @@ -7653,7 +7746,9 @@ class IdentifierPreparer: name, _alembic_quote=_alembic_quote ) - def truncate_and_render_index_name(self, name, _alembic_quote=True): + def truncate_and_render_index_name( + self, name: str, _alembic_quote: bool = True + ) -> str | None: # calculate these at format time so that ad-hoc changes # to dialect.max_identifier_length etc. can be reflected # as IdentifierPreparer is long lived @@ -7665,7 +7760,9 @@ class IdentifierPreparer: name, max_, _alembic_quote ) - def truncate_and_render_constraint_name(self, name, _alembic_quote=True): + def truncate_and_render_constraint_name( + self, name: str, _alembic_quote: bool = True + ) -> str | None: # calculate these at format time so that ad-hoc changes # to dialect.max_identifier_length etc. can be reflected # as IdentifierPreparer is long lived @@ -7677,7 +7774,9 @@ class IdentifierPreparer: name, max_, _alembic_quote ) - def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote): + def _truncate_and_render_maxlen_name( + self, name: str, max_: int, _alembic_quote: bool + ) -> str | None: if isinstance(name, elements._truncated_label): if len(name) > max_: name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:] @@ -7692,10 +7791,31 @@ class IdentifierPreparer: def format_index(self, index): return self.format_constraint(index) - def format_table(self, table, use_schema=True, name=None): - """Prepare a quoted table and schema name.""" + @overload + def format_table( + self, + table: "Table | None", + use_schema: bool, + name: str, + ) -> str: ... + + @overload + def format_table( + self, + table: "Table", + use_schema: bool = True, + name: None = None, + ) -> str: ... + def format_table( + self, + table: "Table | None", + use_schema: bool = True, + name: str | None = None, + ) -> str: + """Prepare a quoted table and schema name.""" if name is None: + assert table is not None name = table.name result = self.quote(name) @@ -7727,13 +7847,13 @@ class IdentifierPreparer: def format_column( self, - column, - use_table=False, - name=None, - table_name=None, - use_schema=False, - anon_map=None, - ): + column: "Column[Any]", + use_table: bool = False, + name: str | None = None, + table_name: str | None = None, + use_schema: bool = False, + anon_map: Mapping[str, Any] | None = None, + ) -> str: """Prepare a quoted column name.""" if name is None: @@ -7805,7 +7925,7 @@ class IdentifierPreparer: ) return r - def unformat_identifiers(self, identifiers): + def unformat_identifiers(self, identifiers: str) -> list[str]: """Unpack 'schema.table.column'-like strings into components.""" r = self._r_identifiers diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 19af40ff08..c3a5b015bd 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -207,7 +207,7 @@ def _get_crud_params( [ ( c, - compiler.preparer.format_column(c), + compiler.preparer.format_column(c), # type: ignore[arg-type] # noqa: E501 _create_bind_param(compiler, c, None, required=True), (c.key,), ) @@ -369,7 +369,7 @@ def _get_crud_params( values = [ ( _as_dml_column(stmt.table.columns[0]), - compiler.preparer.format_column(stmt.table.columns[0]), + compiler.preparer.format_column(stmt.table.columns[0]), # type: ignore[arg-type] # noqa: E501 compiler.dialect.default_metavalue_token, (), ) @@ -1136,7 +1136,7 @@ def _append_param_insert_select_hasdefault( values.append( ( c, - compiler.preparer.format_column(c), + compiler.preparer.format_column(c), # type: ignore[arg-type] # noqa: E501 c.default.next_value(), (), ) @@ -1145,7 +1145,7 @@ def _append_param_insert_select_hasdefault( values.append( ( c, - compiler.preparer.format_column(c), + compiler.preparer.format_column(c), # type: ignore[arg-type] c.default.arg.self_group(), (), ) @@ -1154,7 +1154,7 @@ def _append_param_insert_select_hasdefault( values.append( ( c, - compiler.preparer.format_column(c), + compiler.preparer.format_column(c), # type: ignore[arg-type] _create_insert_prefetch_bind_param( compiler, c, process=False, **kw ), diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 7210d930a1..d397331c14 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -38,8 +38,10 @@ if typing.TYPE_CHECKING: from .compiler import Compiled from .compiler import DDLCompiler from .elements import BindParameter + from .schema import Column from .schema import Constraint from .schema import ForeignKeyConstraint + from .schema import Index from .schema import SchemaItem from .schema import Sequence from .schema import Table @@ -712,8 +714,9 @@ class CreateIndex(_CreateBase): """Represent a CREATE INDEX statement.""" __visit_name__ = "create_index" + element: "Index" - def __init__(self, element, if_not_exists=False): + def __init__(self, element: "Index", if_not_exists: bool = False): """Create a :class:`.Createindex` construct. :param element: a :class:`_schema.Index` that's the subject @@ -732,7 +735,9 @@ class DropIndex(_DropBase): __visit_name__ = "drop_index" - def __init__(self, element, if_exists=False): + element: "Index" + + def __init__(self, element: "Index", if_exists: bool = False): """Create a :class:`.DropIndex` construct. :param element: a :class:`_schema.Index` that's the subject @@ -791,6 +796,7 @@ class SetColumnComment(_CreateDropBase): """Represent a COMMENT ON COLUMN IS statement.""" __visit_name__ = "set_column_comment" + element: "Column[Any]" class DropColumnComment(_CreateDropBase): diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 41630261ed..6e660c0be7 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -76,6 +76,7 @@ from .. import inspection from .. import util from ..util import HasMemoized_ro_memoized_attribute from ..util import TypingOnly +from ..util._immutabledict_cy import immutabledict from ..util.typing import Literal from ..util.typing import ParamSpec from ..util.typing import Self @@ -118,6 +119,7 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import SchemaTranslateMapType from ..engine.result import Result + _NUMERIC = Union[float, Decimal] _NUMBER = Union[float, int, Decimal] @@ -2216,8 +2218,8 @@ class TypeClause(DQLDMLClauseElement): ("type", InternalTraversal.dp_type) ] - def __init__(self, type_): - self.type = type_ + def __init__(self, type_: TypeEngine[Any]): + self.type: TypeEngine[Any] = type_ class TextClause( @@ -3875,8 +3877,6 @@ class BinaryExpression(OperatorExpression[_T]): """ - modifiers: Optional[Mapping[str, Any]] - left: ColumnElement[Any] right: ColumnElement[Any] @@ -3907,9 +3907,9 @@ class BinaryExpression(OperatorExpression[_T]): self._is_implicitly_boolean = operators.is_boolean(operator) if modifiers is None: - self.modifiers = {} + self.modifiers: immutabledict[str, str] = immutabledict({}) else: - self.modifiers = modifiers + self.modifiers = immutabledict(modifiers) @property def _flattened_operator_clauses( diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index b905913d37..c66482179e 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -62,6 +62,7 @@ from .sqltypes import TableValueType from .type_api import TypeEngine from .visitors import InternalTraversal from .. import util +from ..util._immutabledict_cy import immutabledict if TYPE_CHECKING: @@ -788,7 +789,7 @@ class FunctionAsBinary(BinaryExpression[Any]): self.type = sqltypes.BOOLEANTYPE self.negate = None self._is_implicitly_boolean = True - self.modifiers = {} + self.modifiers = immutabledict({}) @property def left_expr(self) -> ColumnElement[Any]: diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 212b86ca8a..9f7fb1330f 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -246,10 +246,14 @@ class String(Concatenable, TypeEngine[str]): return process - def bind_processor(self, dialect): + def bind_processor( + self, dialect: "Dialect" + ) -> _BindProcessorType[str] | None: return None - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> _ResultProcessorType[str] | None: return None @property @@ -426,7 +430,11 @@ class NumericCommon(HasExpressionLookup, TypeEngineMixin, Generic[_N]): if TYPE_CHECKING: @util.ro_memoized_property - def _type_affinity(self) -> Type[NumericCommon[_N]]: ... + def _type_affinity( + self, + ) -> Type[ + Numeric[decimal.Decimal | float] | Float[decimal.Decimal | float] + ]: ... def __init__( self, @@ -653,7 +661,7 @@ class Float(NumericCommon[_N], TypeEngine[_N]): __visit_name__ = "float" - scale = None + scale: int | None = None @overload def __init__( @@ -1325,6 +1333,8 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): __visit_name__ = "enum" + enum_class: None | str | type[enum.StrEnum] + def __init__(self, *enums: object, **kw: Any): r"""Construct an enum. @@ -1457,7 +1467,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): .. versionchanged:: 2.0 This parameter now defaults to True. """ - self._enum_init(enums, kw) + self._enum_init(enums, kw) # type: ignore[arg-type] @property def _enums_argument(self): @@ -1466,7 +1476,9 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): else: return self.enums - def _enum_init(self, enums, kw): + def _enum_init( + self, enums: Sequence[str | type[enum.StrEnum]], kw: dict[str, Any] + ) -> None: """internal init for :class:`.Enum` and subclasses. friendly init helper used by subclasses to remove @@ -1476,7 +1488,9 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): """ self.native_enum = kw.pop("native_enum", True) self.create_constraint = kw.pop("create_constraint", False) - self.values_callable = kw.pop("values_callable", None) + self.values_callable: ( + Callable[[type[enum.StrEnum]], Sequence[str]] | None + ) = kw.pop("values_callable", None) self._sort_key_function = kw.pop("sort_key_function", NO_ARG) length_arg = kw.pop("length", NO_ARG) self._omit_aliases = kw.pop("omit_aliases", True) @@ -1504,7 +1518,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): ) length = length_arg - self._valid_lookup[None] = self._object_lookup[None] = None + self._valid_lookup[None] = self._object_lookup[None] = None # type: ignore # noqa: E501 super().__init__(length=length) @@ -1513,7 +1527,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): # this is a template enum that will be used to generate # new Enum classes. if self.enum_class and values: - kw.setdefault("name", self.enum_class.__name__.lower()) + kw.setdefault("name", self.enum_class.__name__.lower()) # type: ignore[union-attr] # noqa: E501 SchemaType.__init__( self, name=kw.pop("name", None), @@ -1525,7 +1539,9 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): _adapted_from=kw.pop("_adapted_from", None), ) - def _parse_into_values(self, enums, kw): + def _parse_into_values( + self, enums: Sequence[str | type[enum.StrEnum]], kw: Any + ) -> tuple[Sequence[str], Sequence[enum.StrEnum] | Sequence[str]]: if not enums and "_enums" in kw: enums = kw.pop("_enums") @@ -1540,16 +1556,16 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): (n, v) for n, v in _members.items() if v.name == n ) else: - members = _members + members = _members # type: ignore[assignment] if self.values_callable: - values = self.values_callable(self.enum_class) + values = self.values_callable(self.enum_class) # type: ignore[arg-type] # noqa: E501 else: values = list(members) objects = [members[k] for k in members] return values, objects else: self.enum_class = None - return enums, enums + return enums, enums # type: ignore[return-value] def _resolve_for_literal(self, value: Any) -> Enum: tv = type(value) @@ -1639,12 +1655,19 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): self._generic_type_affinity(_enums=enum_args, **kw), # type: ignore # noqa: E501 ) - def _setup_for_values(self, values, objects, kw): + def _setup_for_values( + self, + values: Sequence[str], + objects: Sequence[enum.StrEnum] | Sequence[str], + kw: Any, + ) -> None: self.enums = list(values) - self._valid_lookup = dict(zip(reversed(objects), reversed(values))) + self._valid_lookup: dict[str, str] = dict( + zip(reversed(objects), reversed(values)) + ) - self._object_lookup = dict(zip(values, objects)) + self._object_lookup: dict[str, str] = dict(zip(values, objects)) self._valid_lookup.update( [ @@ -1706,7 +1729,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): comparator_factory = Comparator - def _object_value_for_elem(self, elem): + def _object_value_for_elem(self, elem: str) -> str: try: return self._object_lookup[elem] except KeyError as err: @@ -2183,7 +2206,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]): def bind_processor( self, dialect: Dialect - ) -> _BindProcessorType[dt.timedelta]: + ) -> "_BindProcessorType[dt.timedelta]": if TYPE_CHECKING: assert isinstance(self.impl_instance, DateTime) impl_processor = self.impl_instance.bind_processor(dialect) @@ -3490,6 +3513,7 @@ class BINARY(_Binary): class VARBINARY(_Binary): """The SQL VARBINARY type.""" + length: int __visit_name__ = "VARBINARY" @@ -3686,7 +3710,9 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]): else: return super().coerce_compared_value(op, value) - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> "_BindProcessorType[_UUID_RETURN] | None": character_based_uuid = ( not dialect.supports_native_uuid or not self.native_uuid ) @@ -3694,18 +3720,18 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]): if character_based_uuid: if self.as_uuid: - def process(value): + def process(value: Any) -> str: if value is not None: value = value.hex - return value + return value # type: ignore[no-any-return] return process else: - def process(value): + def process(value: Any) -> str: if value is not None: value = value.replace("-", "") - return value + return value # type: ignore[no-any-return] return process else: diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index fb72c825e5..33f7bc41a1 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1376,6 +1376,9 @@ class UserDefinedType( return self + def get_col_spec(self, **kw: Any) -> str: + raise NotImplementedError() + class Emulated(TypeEngineMixin): """Mixin for base types that emulate the behavior of a DB-native type. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 9899004178..b3b5d0e8c1 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -21,6 +21,7 @@ from typing import Callable from typing import cast from typing import Collection from typing import Dict +from typing import Generator from typing import Iterable from typing import Iterator from typing import List @@ -481,7 +482,9 @@ def surface_selectables(clause): stack.append(elem.element) -def surface_selectables_only(clause): +def surface_selectables_only( + clause: ClauseElement, +) -> Generator[ClauseElement]: stack = [clause] while stack: elem = stack.pop() diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 9ca5e60a20..34b6b8a29e 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -430,7 +430,9 @@ def to_column_set(x: Any) -> Set[Any]: return x -def update_copy(d, _new=None, **kw): +def update_copy( + d: dict[Any, Any], _new: dict[Any, Any] | None = None, **kw: dict[Any, Any] +) -> dict[Any, Any]: """Copy the given dict and update with the given values.""" d = d.copy() diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 7809c9fcad..d2a78b2287 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -56,6 +56,7 @@ if True: # zimports removes the tailing comments from typing_extensions import TypeAliasType as TypeAliasType # 3.12 from typing_extensions import Unpack as Unpack # 3.11 from typing_extensions import Never as Never # 3.11 + from typing_extensions import LiteralString as LiteralString # 3.11 _T = TypeVar("_T", bound=Any) diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index 8ea523fb7e..1f8a23f70d 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -681,7 +681,6 @@ class CompatFlagsTest(fixtures.TestBase, AssertsCompiledSQL): dialect._get_server_version_info = server_version_info dialect.get_isolation_level = Mock() - dialect._check_unicode_returns = Mock() dialect._check_unicode_description = Mock() dialect._get_default_schema_name = Mock() dialect._detect_decimal_char = Mock()