From: Pablo Estevez Date: Sat, 8 Feb 2025 15:46:24 +0000 (-0500) Subject: miscellaneous to type dialects X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0ee4b08b111f65602f260c672ef88617f82f0009;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git miscellaneous to type dialects Type of certain methods that are called by dialect, so typing dialects is easier. Related to https://github.com/sqlalchemy/sqlalchemy/pull/12164 breaking changes: - Change modifiers from TextClause to InmutableDict, from Mapping, as is in the other classes Closes: #12231 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12231 Pull-request-sha: 514fe4751c7b1ceefffed2a4ef9c8df339bd9c25 Change-Id: I29314045b2c7eb5428f8d6fec8911c4b6d5ae73e --- diff --git a/lib/sqlalchemy/connectors/asyncio.py b/lib/sqlalchemy/connectors/asyncio.py index e57f7bfdf2..bce08d9cc3 100644 --- a/lib/sqlalchemy/connectors/asyncio.py +++ b/lib/sqlalchemy/connectors/asyncio.py @@ -40,7 +40,7 @@ class AsyncIODBAPIConnection(Protocol): async def commit(self) -> None: ... - def cursor(self) -> AsyncIODBAPICursor: ... + def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ... async def rollback(self) -> None: ... diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 3a32d19c8b..8aaf223d4d 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -227,11 +227,9 @@ class PyODBCConnector(Connector): ) def get_isolation_level_values( - self, dbapi_connection: interfaces.DBAPIConnection + self, dbapi_conn: interfaces.DBAPIConnection ) -> List[IsolationLevel]: - return super().get_isolation_level_values(dbapi_connection) + [ - "AUTOCOMMIT" - ] + return [*super().get_isolation_level_values(dbapi_conn), "AUTOCOMMIT"] def set_isolation_level( self, diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 1f00127bfa..d25ad83552 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1482,6 +1482,7 @@ from functools import lru_cache import re from typing import Any from typing import cast +from typing import Dict from typing import List from typing import Optional from typing import Tuple @@ -3738,8 +3739,8 @@ class PGDialect(default.DefaultDialect): def _reflect_type( self, format_type: Optional[str], - domains: dict[str, ReflectedDomain], - enums: dict[str, ReflectedEnum], + domains: Dict[str, ReflectedDomain], + enums: Dict[str, ReflectedEnum], type_description: str, ) -> sqltypes.TypeEngine[Any]: """ diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 56d7ee7588..bff473ac5a 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -20,6 +20,7 @@ from typing import Any from typing import cast from typing import ClassVar from typing import Dict +from typing import Iterable from typing import Iterator from typing import List from typing import Mapping @@ -1379,12 +1380,16 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): __slots__ = ("_rowbuffer", "alternate_cursor_description") def __init__( - self, dbapi_cursor, alternate_description=None, initial_buffer=None + self, + dbapi_cursor: Optional[DBAPICursor], + alternate_description: Optional[_DBAPICursorDescription] = None, + initial_buffer: Optional[Iterable[Any]] = None, ): self.alternate_cursor_description = alternate_description if initial_buffer is not None: self._rowbuffer = collections.deque(initial_buffer) else: + assert dbapi_cursor is not None self._rowbuffer = collections.deque(dbapi_cursor.fetchall()) def yield_per(self, result, dbapi_cursor, num): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ba59ac297b..4023019cfc 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -80,9 +80,11 @@ if typing.TYPE_CHECKING: from .interfaces import _CoreSingleExecuteParams from .interfaces import _DBAPICursorDescription from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions from .interfaces import _MutableCoreSingleExecuteParams from .interfaces import _ParamStyle + from .interfaces import ConnectArgsType from .interfaces import DBAPIConnection from .interfaces import IsolationLevel from .row import Row @@ -102,6 +104,7 @@ if typing.TYPE_CHECKING: from ..sql.type_api import _ResultProcessorType from ..sql.type_api import TypeEngine + # When we're handed literal SQL, ensure it's a SELECT query SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE) @@ -440,7 +443,7 @@ class DefaultDialect(Dialect): def _bind_typing_render_casts(self): return self.bind_typing is interfaces.BindTyping.RENDER_CASTS - def _ensure_has_table_connection(self, arg): + def _ensure_has_table_connection(self, arg: Connection) -> None: if not isinstance(arg, Connection): raise exc.ArgumentError( "The argument passed to Dialect.has_table() should be a " @@ -524,7 +527,7 @@ class DefaultDialect(Dialect): else: return None - def initialize(self, connection): + def initialize(self, connection: Connection) -> None: try: self.server_version_info = self._get_server_version_info( connection @@ -560,7 +563,7 @@ class DefaultDialect(Dialect): % (self.label_length, self.max_identifier_length) ) - def on_connect(self): + def on_connect(self) -> Optional[Callable[[Any], Any]]: # inherits the docstring from interfaces.Dialect.on_connect return None @@ -619,18 +622,18 @@ class DefaultDialect(Dialect): ) -> bool: return schema_name in self.get_schema_names(connection, **kw) - def validate_identifier(self, ident): + def validate_identifier(self, ident: str) -> None: if len(ident) > self.max_identifier_length: raise exc.IdentifierError( "Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length) ) - def connect(self, *cargs, **cparams): + def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection: # inherits the docstring from interfaces.Dialect.connect - return self.loaded_dbapi.connect(*cargs, **cparams) + return self.loaded_dbapi.connect(*cargs, **cparams) # type: ignore[no-any-return] # NOQA: E501 - def create_connect_args(self, url): + def create_connect_args(self, url: URL) -> ConnectArgsType: # inherits the docstring from interfaces.Dialect.create_connect_args opts = url.translate_connect_args() opts.update(url.query) @@ -953,7 +956,14 @@ class DefaultDialect(Dialect): def do_execute_no_params(self, cursor, statement, context=None): cursor.execute(statement) - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: Exception, + connection: Union[ + pool.PoolProxiedConnection, interfaces.DBAPIConnection, None + ], + cursor: Optional[interfaces.DBAPICursor], + ) -> bool: return False @util.memoized_instancemethod @@ -1669,7 +1679,12 @@ class DefaultExecutionContext(ExecutionContext): def no_parameters(self): return self.execution_options.get("no_parameters", False) - def _execute_scalar(self, stmt, type_, parameters=None): + def _execute_scalar( + self, + stmt: str, + type_: Optional[TypeEngine[Any]], + parameters: Optional[_DBAPISingleExecuteParams] = None, + ) -> Any: """Execute a string statement on the current cursor, returning a scalar result. @@ -1743,7 +1758,7 @@ class DefaultExecutionContext(ExecutionContext): return use_server_side - def create_cursor(self): + def create_cursor(self) -> DBAPICursor: if ( # inlining initial preference checks for SS cursors self.dialect.supports_server_side_cursors @@ -1764,10 +1779,10 @@ class DefaultExecutionContext(ExecutionContext): def fetchall_for_returning(self, cursor): return cursor.fetchall() - def create_default_cursor(self): + def create_default_cursor(self) -> DBAPICursor: return self._dbapi_connection.cursor() - def create_server_side_cursor(self): + def create_server_side_cursor(self) -> DBAPICursor: raise NotImplementedError() def pre_exec(self): diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 35c52ae3b9..464c6677b8 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -122,7 +122,7 @@ class DBAPIConnection(Protocol): def commit(self) -> None: ... - def cursor(self) -> DBAPICursor: ... + def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ... def rollback(self) -> None: ... @@ -780,6 +780,12 @@ class Dialect(EventTarget): max_identifier_length: int """The maximum length of identifier names.""" + max_index_name_length: Optional[int] + """The maximum length of index names if different from + ``max_identifier_length``.""" + max_constraint_name_length: Optional[int] + """The maximum length of constraint names if different from + ``max_identifier_length``.""" supports_server_side_cursors: bool """indicates if the dialect supports server side cursors""" @@ -1283,8 +1289,6 @@ class Dialect(EventTarget): """ - pass - if TYPE_CHECKING: def _overrides_default(self, method_name: str) -> bool: ... @@ -2483,7 +2487,7 @@ class Dialect(EventTarget): def get_isolation_level_values( self, dbapi_conn: DBAPIConnection - ) -> List[IsolationLevel]: + ) -> Sequence[IsolationLevel]: """return a sequence of string isolation level names that are accepted by this dialect. @@ -2657,6 +2661,9 @@ class Dialect(EventTarget): """return a Pool class to use for a given URL""" raise NotImplementedError() + def validate_identifier(self, ident: str) -> None: + """Validates an identifier name, raising an exception if invalid""" + class CreateEnginePlugin: """A set of hooks intended to augment the construction of an diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 511eca9234..29c28e1bb6 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -1075,7 +1075,7 @@ class PoolProxiedConnection(ManagesConnection): def commit(self) -> None: ... - def cursor(self) -> DBAPICursor: ... + def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ... def rollback(self) -> None: ... diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index fc3614c06b..f643960e73 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -76,7 +76,7 @@ _StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole) _T = TypeVar("_T", bound=Any) -def _is_literal(element): +def _is_literal(element: Any) -> bool: """Return whether or not the element is a "literal" in the context of a SQL expression construct. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 32043dd7bb..5f27ce05b7 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -76,19 +76,15 @@ from .base import _de_clone from .base import _from_objects from .base import _NONE_NAME from .base import _SentinelDefaultCharacterization -from .base import Executable from .base import NO_ARG -from .elements import ClauseElement from .elements import quoted_name -from .schema import Column from .sqltypes import TupleType -from .type_api import TypeEngine from .visitors import prefix_anon_map -from .visitors import Visitable from .. import exc from .. import util from ..util import FastIntFlag from ..util.typing import Literal +from ..util.typing import Self from ..util.typing import TupleAny from ..util.typing import Unpack @@ -96,18 +92,33 @@ if typing.TYPE_CHECKING: from .annotation import _AnnotationDict from .base import _AmbiguousTableNameMap from .base import CompileState + from .base import Executable from .cache_key import CacheKey from .ddl import ExecutableDDLElement from .dml import Insert + from .dml import Update from .dml import UpdateBase + from .dml import UpdateDMLState from .dml import ValuesBase from .elements import _truncated_label + from .elements import BinaryExpression from .elements import BindParameter + from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement + from .elements import False_ from .elements import Label + from .elements import Null + from .elements import True_ from .functions import Function + from .schema import Column + from .schema import Constraint + from .schema import ForeignKeyConstraint + from .schema import Index + from .schema import PrimaryKeyConstraint from .schema import Table + from .schema import UniqueConstraint + from .selectable import _ColumnsClauseElement from .selectable import AliasedReturnsRows from .selectable import CompoundSelectState from .selectable import CTE @@ -117,6 +128,10 @@ if typing.TYPE_CHECKING: from .selectable import Select from .selectable import SelectState from .type_api import _BindProcessorType + from .type_api import TypeDecorator + from .type_api import TypeEngine + from .type_api import UserDefinedType + from .visitors import Visitable from ..engine.cursor import CursorResultMetaData from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _DBAPIAnyExecuteParams @@ -128,6 +143,7 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import Dialect from ..engine.interfaces import SchemaTranslateMapType + _FromHintsType = Dict["FromClause", str] RESERVED_WORDS = { @@ -872,6 +888,7 @@ class Compiled: self.string = self.process(self.statement, **compile_kwargs) if render_schema_translate: + assert schema_translate_map is not None self.string = self.preparer._render_schema_translates( self.string, schema_translate_map ) @@ -904,7 +921,7 @@ class Compiled: raise exc.UnsupportedCompilationError(self, type(element)) from err @property - def sql_compiler(self): + def sql_compiler(self) -> SQLCompiler: """Return a Compiled that is capable of processing SQL expressions. If this compiler is one, it would likely just return 'self'. @@ -1793,7 +1810,7 @@ class SQLCompiler(Compiled): return len(self.stack) > 1 @property - def sql_compiler(self): + def sql_compiler(self) -> Self: return self def construct_expanded_state( @@ -2344,7 +2361,7 @@ class SQLCompiler(Compiled): return get - def default_from(self): + def default_from(self) -> str: """Called when a SELECT statement has no froms, and no FROM clause is to be appended. @@ -2736,16 +2753,16 @@ class SQLCompiler(Compiled): return text - def visit_null(self, expr, **kw): + def visit_null(self, expr: Null, **kw: Any) -> str: return "NULL" - def visit_true(self, expr, **kw): + def visit_true(self, expr: True_, **kw: Any) -> str: if self.dialect.supports_native_boolean: return "true" else: return "1" - def visit_false(self, expr, **kw): + def visit_false(self, expr: False_, **kw: Any) -> str: if self.dialect.supports_native_boolean: return "false" else: @@ -2976,7 +2993,7 @@ class SQLCompiler(Compiled): % self.dialect.name ) - def function_argspec(self, func, **kwargs): + def function_argspec(self, func: Function[Any], **kwargs: Any) -> str: return func.clause_expr._compiler_dispatch(self, **kwargs) def visit_compound_select( @@ -3440,8 +3457,12 @@ class SQLCompiler(Compiled): ) def _generate_generic_binary( - self, binary, opstring, eager_grouping=False, **kw - ): + self, + binary: BinaryExpression[Any], + opstring: str, + eager_grouping: bool = False, + **kw: Any, + ) -> str: _in_operator_expression = kw.get("_in_operator_expression", False) kw["_in_operator_expression"] = True @@ -3610,19 +3631,25 @@ class SQLCompiler(Compiled): **kw, ) - def visit_regexp_match_op_binary(self, binary, operator, **kw): + def visit_regexp_match_op_binary( + self, binary: BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: raise exc.CompileError( "%s dialect does not support regular expressions" % self.dialect.name ) - def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + def visit_not_regexp_match_op_binary( + self, binary: BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: raise exc.CompileError( "%s dialect does not support regular expressions" % self.dialect.name ) - def visit_regexp_replace_op_binary(self, binary, operator, **kw): + def visit_regexp_replace_op_binary( + self, binary: BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: raise exc.CompileError( "%s dialect does not support regular expression replacements" % self.dialect.name @@ -3829,7 +3856,9 @@ class SQLCompiler(Compiled): else: return self.render_literal_value(value, bindparam.type) - def render_literal_value(self, value, type_): + def render_literal_value( + self, value: Any, type_: sqltypes.TypeEngine[Any] + ) -> str: """Render the value of a bind parameter as a quoted literal. This is used for statement sections that do not accept bind parameters @@ -4603,7 +4632,9 @@ class SQLCompiler(Compiled): def get_select_hint_text(self, byfroms): return None - def get_from_hint_text(self, table, text): + def get_from_hint_text( + self, table: FromClause, text: Optional[str] + ) -> Optional[str]: return None def get_crud_hint_text(self, table, text): @@ -5109,7 +5140,7 @@ class SQLCompiler(Compiled): else: return "WITH" - def get_select_precolumns(self, select, **kw): + def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str: """Called when building a ``SELECT`` statement, position is just before column list. @@ -5154,7 +5185,7 @@ class SQLCompiler(Compiled): def returning_clause( self, stmt: UpdateBase, - returning_cols: Sequence[ColumnElement[Any]], + returning_cols: Sequence[_ColumnsClauseElement], *, populate_result_map: bool, **kw: Any, @@ -6187,11 +6218,18 @@ class SQLCompiler(Compiled): else: return None - def visit_update(self, update_stmt, visiting_cte=None, **kw): - compile_state = update_stmt._compile_state_factory( - update_stmt, self, **kw + def visit_update( + self, + update_stmt: Update, + visiting_cte: Optional[CTE] = None, + **kw: Any, + ) -> str: + compile_state = update_stmt._compile_state_factory( # type: ignore[call-arg] # noqa: E501 + update_stmt, self, **kw # type: ignore[arg-type] ) - update_stmt = compile_state.statement + if TYPE_CHECKING: + assert isinstance(compile_state, UpdateDMLState) + update_stmt = compile_state.statement # type: ignore[assignment] if visiting_cte is not None: kw["visiting_cte"] = visiting_cte @@ -6331,7 +6369,7 @@ class SQLCompiler(Compiled): return text def delete_extra_from_clause( - self, update_stmt, from_table, extra_froms, from_hints, **kw + self, delete_stmt, from_table, extra_froms, from_hints, **kw ): """Provide a hook to override the generation of an DELETE..FROM clause. @@ -6555,7 +6593,7 @@ class StrSQLCompiler(SQLCompiler): def returning_clause( self, stmt: UpdateBase, - returning_cols: Sequence[ColumnElement[Any]], + returning_cols: Sequence[_ColumnsClauseElement], *, populate_result_map: bool, **kw: Any, @@ -6576,7 +6614,7 @@ class StrSQLCompiler(SQLCompiler): ) def delete_extra_from_clause( - self, update_stmt, from_table, extra_froms, from_hints, **kw + self, delete_stmt, from_table, extra_froms, from_hints, **kw ): kw["asfrom"] = True return ", " + ", ".join( @@ -6623,8 +6661,8 @@ class DDLCompiler(Compiled): compile_kwargs: Mapping[str, Any] = ..., ): ... - @util.memoized_property - def sql_compiler(self): + @util.ro_memoized_property + def sql_compiler(self) -> SQLCompiler: return self.dialect.statement_compiler( self.dialect, None, schema_translate_map=self.schema_translate_map ) @@ -6788,7 +6826,7 @@ class DDLCompiler(Compiled): def visit_drop_view(self, drop, **kw): return "\nDROP VIEW " + self.preparer.format_table(drop.element) - def _verify_index_table(self, index): + def _verify_index_table(self, index: Index) -> None: if index.table is None: raise exc.CompileError( "Index '%s' is not associated with any table." % index.name @@ -6839,7 +6877,9 @@ class DDLCompiler(Compiled): return text + self._prepared_index_name(index, include_schema=True) - def _prepared_index_name(self, index, include_schema=False): + def _prepared_index_name( + self, index: Index, include_schema: bool = False + ) -> str: if index.table is not None: effective_schema = self.preparer.schema_for_object(index.table) else: @@ -6986,13 +7026,13 @@ class DDLCompiler(Compiled): def post_create_table(self, table): return "" - def get_column_default_string(self, column): + def get_column_default_string(self, column: Column[Any]) -> Optional[str]: if isinstance(column.server_default, schema.DefaultClause): return self.render_default_string(column.server_default.arg) else: return None - def render_default_string(self, default): + def render_default_string(self, default: Union[Visitable, str]) -> str: if isinstance(default, str): return self.sql_compiler.render_literal_value( default, sqltypes.STRINGTYPE @@ -7030,7 +7070,9 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def visit_primary_key_constraint(self, constraint, **kw): + def visit_primary_key_constraint( + self, constraint: PrimaryKeyConstraint, **kw: Any + ) -> str: if len(constraint) == 0: return "" text = "" @@ -7079,7 +7121,9 @@ class DDLCompiler(Compiled): return preparer.format_table(table) - def visit_unique_constraint(self, constraint, **kw): + def visit_unique_constraint( + self, constraint: UniqueConstraint, **kw: Any + ) -> str: if len(constraint) == 0: return "" text = "" @@ -7094,10 +7138,14 @@ class DDLCompiler(Compiled): text += self.define_constraint_deferrability(constraint) return text - def define_unique_constraint_distinct(self, constraint, **kw): + def define_unique_constraint_distinct( + self, constraint: UniqueConstraint, **kw: Any + ) -> str: return "" - def define_constraint_cascades(self, constraint): + def define_constraint_cascades( + self, constraint: ForeignKeyConstraint + ) -> str: text = "" if constraint.ondelete is not None: text += " ON DELETE %s" % self.preparer.validate_sql_phrase( @@ -7109,7 +7157,7 @@ class DDLCompiler(Compiled): ) return text - def define_constraint_deferrability(self, constraint): + def define_constraint_deferrability(self, constraint: Constraint) -> str: text = "" if constraint.deferrable is not None: if constraint.deferrable: @@ -7149,19 +7197,21 @@ class DDLCompiler(Compiled): class GenericTypeCompiler(TypeCompiler): - def visit_FLOAT(self, type_, **kw): + def visit_FLOAT(self, type_: sqltypes.Float[Any], **kw: Any) -> str: return "FLOAT" - def visit_DOUBLE(self, type_, **kw): + def visit_DOUBLE(self, type_: sqltypes.Double[Any], **kw: Any) -> str: return "DOUBLE" - def visit_DOUBLE_PRECISION(self, type_, **kw): + def visit_DOUBLE_PRECISION( + self, type_: sqltypes.DOUBLE_PRECISION[Any], **kw: Any + ) -> str: return "DOUBLE PRECISION" - def visit_REAL(self, type_, **kw): + def visit_REAL(self, type_: sqltypes.REAL[Any], **kw: Any) -> str: return "REAL" - def visit_NUMERIC(self, type_, **kw): + def visit_NUMERIC(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str: if type_.precision is None: return "NUMERIC" elif type_.scale is None: @@ -7172,7 +7222,7 @@ class GenericTypeCompiler(TypeCompiler): "scale": type_.scale, } - def visit_DECIMAL(self, type_, **kw): + def visit_DECIMAL(self, type_: sqltypes.DECIMAL[Any], **kw: Any) -> str: if type_.precision is None: return "DECIMAL" elif type_.scale is None: @@ -7183,128 +7233,138 @@ class GenericTypeCompiler(TypeCompiler): "scale": type_.scale, } - def visit_INTEGER(self, type_, **kw): + def visit_INTEGER(self, type_: sqltypes.Integer, **kw: Any) -> str: return "INTEGER" - def visit_SMALLINT(self, type_, **kw): + def visit_SMALLINT(self, type_: sqltypes.SmallInteger, **kw: Any) -> str: return "SMALLINT" - def visit_BIGINT(self, type_, **kw): + def visit_BIGINT(self, type_: sqltypes.BigInteger, **kw: Any) -> str: return "BIGINT" - def visit_TIMESTAMP(self, type_, **kw): + def visit_TIMESTAMP(self, type_: sqltypes.TIMESTAMP, **kw: Any) -> str: return "TIMESTAMP" - def visit_DATETIME(self, type_, **kw): + def visit_DATETIME(self, type_: sqltypes.DateTime, **kw: Any) -> str: return "DATETIME" - def visit_DATE(self, type_, **kw): + def visit_DATE(self, type_: sqltypes.Date, **kw: Any) -> str: return "DATE" - def visit_TIME(self, type_, **kw): + def visit_TIME(self, type_: sqltypes.Time, **kw: Any) -> str: return "TIME" - def visit_CLOB(self, type_, **kw): + def visit_CLOB(self, type_: sqltypes.CLOB, **kw: Any) -> str: return "CLOB" - def visit_NCLOB(self, type_, **kw): + def visit_NCLOB(self, type_: sqltypes.Text, **kw: Any) -> str: return "NCLOB" - def _render_string_type(self, type_, name, length_override=None): + def _render_string_type( + self, name: str, length: Optional[int], collation: Optional[str] + ) -> str: text = name - if length_override: - text += "(%d)" % length_override - elif type_.length: - text += "(%d)" % type_.length - if type_.collation: - text += ' COLLATE "%s"' % type_.collation + if length: + text += f"({length})" + if collation: + text += f' COLLATE "{collation}"' return text - def visit_CHAR(self, type_, **kw): - return self._render_string_type(type_, "CHAR") + def visit_CHAR(self, type_: sqltypes.CHAR, **kw: Any) -> str: + return self._render_string_type("CHAR", type_.length, type_.collation) - def visit_NCHAR(self, type_, **kw): - return self._render_string_type(type_, "NCHAR") + def visit_NCHAR(self, type_: sqltypes.NCHAR, **kw: Any) -> str: + return self._render_string_type("NCHAR", type_.length, type_.collation) - def visit_VARCHAR(self, type_, **kw): - return self._render_string_type(type_, "VARCHAR") + def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str: + return self._render_string_type( + "VARCHAR", type_.length, type_.collation + ) - def visit_NVARCHAR(self, type_, **kw): - return self._render_string_type(type_, "NVARCHAR") + def visit_NVARCHAR(self, type_: sqltypes.NVARCHAR, **kw: Any) -> str: + return self._render_string_type( + "NVARCHAR", type_.length, type_.collation + ) - def visit_TEXT(self, type_, **kw): - return self._render_string_type(type_, "TEXT") + def visit_TEXT(self, type_: sqltypes.Text, **kw: Any) -> str: + return self._render_string_type("TEXT", type_.length, type_.collation) - def visit_UUID(self, type_, **kw): + def visit_UUID(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str: return "UUID" - def visit_BLOB(self, type_, **kw): + def visit_BLOB(self, type_: sqltypes.LargeBinary, **kw: Any) -> str: return "BLOB" - def visit_BINARY(self, type_, **kw): + def visit_BINARY(self, type_: sqltypes.BINARY, **kw: Any) -> str: return "BINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_VARBINARY(self, type_, **kw): + def visit_VARBINARY(self, type_: sqltypes.VARBINARY, **kw: Any) -> str: return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") - def visit_BOOLEAN(self, type_, **kw): + def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str: return "BOOLEAN" - def visit_uuid(self, type_, **kw): + def visit_uuid(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str: if not type_.native_uuid or not self.dialect.supports_native_uuid: - return self._render_string_type(type_, "CHAR", length_override=32) + return self._render_string_type("CHAR", length=32, collation=None) else: return self.visit_UUID(type_, **kw) - def visit_large_binary(self, type_, **kw): + def visit_large_binary( + self, type_: sqltypes.LargeBinary, **kw: Any + ) -> str: return self.visit_BLOB(type_, **kw) - def visit_boolean(self, type_, **kw): + def visit_boolean(self, type_: sqltypes.Boolean, **kw: Any) -> str: return self.visit_BOOLEAN(type_, **kw) - def visit_time(self, type_, **kw): + def visit_time(self, type_: sqltypes.Time, **kw: Any) -> str: return self.visit_TIME(type_, **kw) - def visit_datetime(self, type_, **kw): + def visit_datetime(self, type_: sqltypes.DateTime, **kw: Any) -> str: return self.visit_DATETIME(type_, **kw) - def visit_date(self, type_, **kw): + def visit_date(self, type_: sqltypes.Date, **kw: Any) -> str: return self.visit_DATE(type_, **kw) - def visit_big_integer(self, type_, **kw): + def visit_big_integer(self, type_: sqltypes.BigInteger, **kw: Any) -> str: return self.visit_BIGINT(type_, **kw) - def visit_small_integer(self, type_, **kw): + def visit_small_integer( + self, type_: sqltypes.SmallInteger, **kw: Any + ) -> str: return self.visit_SMALLINT(type_, **kw) - def visit_integer(self, type_, **kw): + def visit_integer(self, type_: sqltypes.Integer, **kw: Any) -> str: return self.visit_INTEGER(type_, **kw) - def visit_real(self, type_, **kw): + def visit_real(self, type_: sqltypes.REAL[Any], **kw: Any) -> str: return self.visit_REAL(type_, **kw) - def visit_float(self, type_, **kw): + def visit_float(self, type_: sqltypes.Float[Any], **kw: Any) -> str: return self.visit_FLOAT(type_, **kw) - def visit_double(self, type_, **kw): + def visit_double(self, type_: sqltypes.Double[Any], **kw: Any) -> str: return self.visit_DOUBLE(type_, **kw) - def visit_numeric(self, type_, **kw): + def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str: return self.visit_NUMERIC(type_, **kw) - def visit_string(self, type_, **kw): + def visit_string(self, type_: sqltypes.String, **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) - def visit_unicode(self, type_, **kw): + def visit_unicode(self, type_: sqltypes.Unicode, **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) - def visit_text(self, type_, **kw): + def visit_text(self, type_: sqltypes.Text, **kw: Any) -> str: return self.visit_TEXT(type_, **kw) - def visit_unicode_text(self, type_, **kw): + def visit_unicode_text( + self, type_: sqltypes.UnicodeText, **kw: Any + ) -> str: return self.visit_TEXT(type_, **kw) - def visit_enum(self, type_, **kw): + def visit_enum(self, type_: sqltypes.Enum, **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) def visit_null(self, type_, **kw): @@ -7314,10 +7374,14 @@ class GenericTypeCompiler(TypeCompiler): "type on this Column?" % type_ ) - def visit_type_decorator(self, type_, **kw): + def visit_type_decorator( + self, type_: TypeDecorator[Any], **kw: Any + ) -> str: return self.process(type_.type_engine(self.dialect), **kw) - def visit_user_defined(self, type_, **kw): + def visit_user_defined( + self, type_: UserDefinedType[Any], **kw: Any + ) -> str: return type_.get_col_spec(**kw) @@ -7392,12 +7456,12 @@ class IdentifierPreparer: def __init__( self, - dialect, - initial_quote='"', - final_quote=None, - escape_quote='"', - quote_case_sensitive_collations=True, - omit_schema=False, + dialect: Dialect, + initial_quote: str = '"', + final_quote: Optional[str] = None, + escape_quote: str = '"', + quote_case_sensitive_collations: bool = True, + omit_schema: bool = False, ): """Construct a new ``IdentifierPreparer`` object. @@ -7450,7 +7514,9 @@ class IdentifierPreparer: prep._includes_none_schema_translate = includes_none return prep - def _render_schema_translates(self, statement, schema_translate_map): + def _render_schema_translates( + self, statement: str, schema_translate_map: SchemaTranslateMapType + ) -> str: d = schema_translate_map if None in d: if not self._includes_none_schema_translate: @@ -7462,7 +7528,7 @@ class IdentifierPreparer: "schema_translate_map dictionaries." ) - d["_none"] = d[None] + d["_none"] = d[None] # type: ignore[index] def replace(m): name = m.group(2) @@ -7655,7 +7721,9 @@ class IdentifierPreparer: else: return collation_name - def format_sequence(self, sequence, use_schema=True): + def format_sequence( + self, sequence: schema.Sequence, use_schema: bool = True + ) -> str: name = self.quote(sequence.name) effective_schema = self.schema_for_object(sequence) @@ -7692,7 +7760,9 @@ class IdentifierPreparer: return ident @util.preload_module("sqlalchemy.sql.naming") - def format_constraint(self, constraint, _alembic_quote=True): + def format_constraint( + self, constraint: Union[Constraint, Index], _alembic_quote: bool = True + ) -> Optional[str]: naming = util.preloaded.sql_naming if constraint.name is _NONE_NAME: @@ -7705,6 +7775,7 @@ class IdentifierPreparer: else: name = constraint.name + assert name is not None if constraint.__visit_name__ == "index": return self.truncate_and_render_index_name( name, _alembic_quote=_alembic_quote @@ -7714,7 +7785,9 @@ class IdentifierPreparer: name, _alembic_quote=_alembic_quote ) - def truncate_and_render_index_name(self, name, _alembic_quote=True): + def truncate_and_render_index_name( + self, name: str, _alembic_quote: bool = True + ) -> str: # calculate these at format time so that ad-hoc changes # to dialect.max_identifier_length etc. can be reflected # as IdentifierPreparer is long lived @@ -7726,7 +7799,9 @@ class IdentifierPreparer: name, max_, _alembic_quote ) - def truncate_and_render_constraint_name(self, name, _alembic_quote=True): + def truncate_and_render_constraint_name( + self, name: str, _alembic_quote: bool = True + ) -> str: # calculate these at format time so that ad-hoc changes # to dialect.max_identifier_length etc. can be reflected # as IdentifierPreparer is long lived @@ -7738,7 +7813,9 @@ class IdentifierPreparer: name, max_, _alembic_quote ) - def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote): + def _truncate_and_render_maxlen_name( + self, name: str, max_: int, _alembic_quote: bool + ) -> str: if isinstance(name, elements._truncated_label): if len(name) > max_: name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:] @@ -7750,13 +7827,21 @@ class IdentifierPreparer: else: return self.quote(name) - def format_index(self, index): - return self.format_constraint(index) + def format_index(self, index: Index) -> str: + name = self.format_constraint(index) + assert name is not None + return name - def format_table(self, table, use_schema=True, name=None): + def format_table( + self, + table: FromClause, + use_schema: bool = True, + name: Optional[str] = None, + ) -> str: """Prepare a quoted table and schema name.""" - if name is None: + if TYPE_CHECKING: + assert isinstance(table, NamedFromClause) name = table.name result = self.quote(name) @@ -7788,17 +7873,18 @@ class IdentifierPreparer: def format_column( self, - column, - use_table=False, - name=None, - table_name=None, - use_schema=False, - anon_map=None, - ): + column: ColumnElement[Any], + use_table: bool = False, + name: Optional[str] = None, + table_name: Optional[str] = None, + use_schema: bool = False, + anon_map: Optional[Mapping[str, Any]] = None, + ) -> str: """Prepare a quoted column name.""" if name is None: name = column.name + assert name is not None if anon_map is not None and isinstance( name, elements._truncated_label @@ -7866,7 +7952,7 @@ class IdentifierPreparer: ) return r - def unformat_identifiers(self, identifiers): + def unformat_identifiers(self, identifiers: str) -> Sequence[str]: """Unpack 'schema.table.column'-like strings into components.""" r = self._r_identifiers diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 4e1973ea02..b1a115f49d 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -17,12 +17,15 @@ import contextlib import typing from typing import Any from typing import Callable +from typing import Generic from typing import Iterable from typing import List from typing import Optional from typing import Protocol from typing import Sequence as typing_Sequence from typing import Tuple +from typing import TypeVar +from typing import Union from . import roles from .base import _generative @@ -38,10 +41,12 @@ if typing.TYPE_CHECKING: from .compiler import Compiled from .compiler import DDLCompiler from .elements import BindParameter + from .schema import Column from .schema import Constraint from .schema import ForeignKeyConstraint + from .schema import Index from .schema import SchemaItem - from .schema import Sequence + from .schema import Sequence as Sequence # noqa: F401 from .schema import Table from .selectable import TableClause from ..engine.base import Connection @@ -50,6 +55,8 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import Dialect from ..engine.interfaces import SchemaTranslateMapType +_SI = TypeVar("_SI", bound=Union["SchemaItem", str]) + class BaseDDLElement(ClauseElement): """The root of DDL constructs, including those that are sub-elements @@ -87,7 +94,7 @@ class DDLIfCallable(Protocol): def __call__( self, ddl: BaseDDLElement, - target: SchemaItem, + target: Union[SchemaItem, str], bind: Optional[Connection], tables: Optional[List[Table]] = None, state: Optional[Any] = None, @@ -106,7 +113,7 @@ class DDLIf(typing.NamedTuple): def _should_execute( self, ddl: BaseDDLElement, - target: SchemaItem, + target: Union[SchemaItem, str], bind: Optional[Connection], compiler: Optional[DDLCompiler] = None, **kw: Any, @@ -172,7 +179,7 @@ class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement): """ _ddl_if: Optional[DDLIf] = None - target: Optional[SchemaItem] = None + target: Union[SchemaItem, str, None] = None def _execute_on_connection( self, connection, distilled_params, execution_options @@ -415,7 +422,7 @@ class DDL(ExecutableDDLElement): ) -class _CreateDropBase(ExecutableDDLElement): +class _CreateDropBase(ExecutableDDLElement, Generic[_SI]): """Base class for DDL constructs that represent CREATE and DROP or equivalents. @@ -425,15 +432,13 @@ class _CreateDropBase(ExecutableDDLElement): """ - def __init__( - self, - element, - ): + def __init__(self, element: _SI) -> None: self.element = self.target = element self._ddl_if = getattr(element, "_ddl_if", None) @property def stringify_dialect(self): + assert not isinstance(self.element, str) return self.element.create_drop_stringify_dialect def _create_rule_disable(self, compiler): @@ -447,19 +452,19 @@ class _CreateDropBase(ExecutableDDLElement): return False -class _CreateBase(_CreateDropBase): - def __init__(self, element, if_not_exists=False): +class _CreateBase(_CreateDropBase[_SI]): + def __init__(self, element: _SI, if_not_exists: bool = False) -> None: super().__init__(element) self.if_not_exists = if_not_exists -class _DropBase(_CreateDropBase): - def __init__(self, element, if_exists=False): +class _DropBase(_CreateDropBase[_SI]): + def __init__(self, element: _SI, if_exists: bool = False) -> None: super().__init__(element) self.if_exists = if_exists -class CreateSchema(_CreateBase): +class CreateSchema(_CreateBase[str]): """Represent a CREATE SCHEMA statement. The argument here is the string name of the schema. @@ -474,13 +479,13 @@ class CreateSchema(_CreateBase): self, name: str, if_not_exists: bool = False, - ): + ) -> None: """Create a new :class:`.CreateSchema` construct.""" super().__init__(element=name, if_not_exists=if_not_exists) -class DropSchema(_DropBase): +class DropSchema(_DropBase[str]): """Represent a DROP SCHEMA statement. The argument here is the string name of the schema. @@ -496,14 +501,14 @@ class DropSchema(_DropBase): name: str, cascade: bool = False, if_exists: bool = False, - ): + ) -> None: """Create a new :class:`.DropSchema` construct.""" super().__init__(element=name, if_exists=if_exists) self.cascade = cascade -class CreateTable(_CreateBase): +class CreateTable(_CreateBase["Table"]): """Represent a CREATE TABLE statement.""" __visit_name__ = "create_table" @@ -515,7 +520,7 @@ class CreateTable(_CreateBase): typing_Sequence[ForeignKeyConstraint] ] = None, if_not_exists: bool = False, - ): + ) -> None: """Create a :class:`.CreateTable` construct. :param element: a :class:`_schema.Table` that's the subject @@ -537,7 +542,7 @@ class CreateTable(_CreateBase): self.include_foreign_key_constraints = include_foreign_key_constraints -class _DropView(_DropBase): +class _DropView(_DropBase["Table"]): """Semi-public 'DROP VIEW' construct. Used by the test suite for dialect-agnostic drops of views. @@ -549,7 +554,9 @@ class _DropView(_DropBase): class CreateConstraint(BaseDDLElement): - def __init__(self, element: Constraint): + element: Constraint + + def __init__(self, element: Constraint) -> None: self.element = element @@ -666,16 +673,18 @@ class CreateColumn(BaseDDLElement): __visit_name__ = "create_column" - def __init__(self, element): + element: Column[Any] + + def __init__(self, element: Column[Any]) -> None: self.element = element -class DropTable(_DropBase): +class DropTable(_DropBase["Table"]): """Represent a DROP TABLE statement.""" __visit_name__ = "drop_table" - def __init__(self, element: Table, if_exists: bool = False): + def __init__(self, element: Table, if_exists: bool = False) -> None: """Create a :class:`.DropTable` construct. :param element: a :class:`_schema.Table` that's the subject @@ -690,30 +699,24 @@ class DropTable(_DropBase): super().__init__(element, if_exists=if_exists) -class CreateSequence(_CreateBase): +class CreateSequence(_CreateBase["Sequence"]): """Represent a CREATE SEQUENCE statement.""" __visit_name__ = "create_sequence" - def __init__(self, element: Sequence, if_not_exists: bool = False): - super().__init__(element, if_not_exists=if_not_exists) - -class DropSequence(_DropBase): +class DropSequence(_DropBase["Sequence"]): """Represent a DROP SEQUENCE statement.""" __visit_name__ = "drop_sequence" - def __init__(self, element: Sequence, if_exists: bool = False): - super().__init__(element, if_exists=if_exists) - -class CreateIndex(_CreateBase): +class CreateIndex(_CreateBase["Index"]): """Represent a CREATE INDEX statement.""" __visit_name__ = "create_index" - def __init__(self, element, if_not_exists=False): + def __init__(self, element: Index, if_not_exists: bool = False) -> None: """Create a :class:`.Createindex` construct. :param element: a :class:`_schema.Index` that's the subject @@ -727,12 +730,12 @@ class CreateIndex(_CreateBase): super().__init__(element, if_not_exists=if_not_exists) -class DropIndex(_DropBase): +class DropIndex(_DropBase["Index"]): """Represent a DROP INDEX statement.""" __visit_name__ = "drop_index" - def __init__(self, element, if_exists=False): + def __init__(self, element: Index, if_exists: bool = False) -> None: """Create a :class:`.DropIndex` construct. :param element: a :class:`_schema.Index` that's the subject @@ -746,7 +749,7 @@ class DropIndex(_DropBase): super().__init__(element, if_exists=if_exists) -class AddConstraint(_CreateBase): +class AddConstraint(_CreateBase["Constraint"]): """Represent an ALTER TABLE ADD CONSTRAINT statement.""" __visit_name__ = "add_constraint" @@ -756,7 +759,7 @@ class AddConstraint(_CreateBase): element: Constraint, *, isolate_from_table: bool = True, - ): + ) -> None: """Construct a new :class:`.AddConstraint` construct. :param element: a :class:`.Constraint` object @@ -780,7 +783,7 @@ class AddConstraint(_CreateBase): ) -class DropConstraint(_DropBase): +class DropConstraint(_DropBase["Constraint"]): """Represent an ALTER TABLE DROP CONSTRAINT statement.""" __visit_name__ = "drop_constraint" @@ -793,7 +796,7 @@ class DropConstraint(_DropBase): if_exists: bool = False, isolate_from_table: bool = True, **kw: Any, - ): + ) -> None: """Construct a new :class:`.DropConstraint` construct. :param element: a :class:`.Constraint` object @@ -821,13 +824,13 @@ class DropConstraint(_DropBase): ) -class SetTableComment(_CreateDropBase): +class SetTableComment(_CreateDropBase["Table"]): """Represent a COMMENT ON TABLE IS statement.""" __visit_name__ = "set_table_comment" -class DropTableComment(_CreateDropBase): +class DropTableComment(_CreateDropBase["Table"]): """Represent a COMMENT ON TABLE '' statement. Note this varies a lot across database backends. @@ -837,25 +840,25 @@ class DropTableComment(_CreateDropBase): __visit_name__ = "drop_table_comment" -class SetColumnComment(_CreateDropBase): +class SetColumnComment(_CreateDropBase["Column[Any]"]): """Represent a COMMENT ON COLUMN IS statement.""" __visit_name__ = "set_column_comment" -class DropColumnComment(_CreateDropBase): +class DropColumnComment(_CreateDropBase["Column[Any]"]): """Represent a COMMENT ON COLUMN IS NULL statement.""" __visit_name__ = "drop_column_comment" -class SetConstraintComment(_CreateDropBase): +class SetConstraintComment(_CreateDropBase["Constraint"]): """Represent a COMMENT ON CONSTRAINT IS statement.""" __visit_name__ = "set_constraint_comment" -class DropConstraintComment(_CreateDropBase): +class DropConstraintComment(_CreateDropBase["Constraint"]): """Represent a COMMENT ON CONSTRAINT IS NULL statement.""" __visit_name__ = "drop_constraint_comment" diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 8d256ea377..e394f73f4f 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -2225,8 +2225,9 @@ class TypeClause(DQLDMLClauseElement): _traverse_internals: _TraverseInternalsType = [ ("type", InternalTraversal.dp_type) ] + type: TypeEngine[Any] - def __init__(self, type_): + def __init__(self, type_: TypeEngine[Any]): self.type = type_ @@ -3913,10 +3914,9 @@ class BinaryExpression(OperatorExpression[_T]): """ - modifiers: Optional[Mapping[str, Any]] - left: ColumnElement[Any] right: ColumnElement[Any] + modifiers: Mapping[str, Any] def __init__( self, diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 3fcf22ee68..131a0f2e28 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -23,6 +23,7 @@ from typing import cast from typing import Dict from typing import Generic from typing import List +from typing import Mapping from typing import Optional from typing import overload from typing import Sequence @@ -246,10 +247,14 @@ class String(Concatenable, TypeEngine[str]): return process - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[str]]: return None - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> Optional[_ResultProcessorType[str]]: return None @property @@ -426,7 +431,7 @@ class NumericCommon(HasExpressionLookup, TypeEngineMixin, Generic[_N]): if TYPE_CHECKING: @util.ro_memoized_property - def _type_affinity(self) -> Type[NumericCommon[_N]]: ... + def _type_affinity(self) -> Type[Union[Numeric[_N], Float[_N]]]: ... def __init__( self, @@ -653,8 +658,6 @@ class Float(NumericCommon[_N], TypeEngine[_N]): __visit_name__ = "float" - scale = None - @overload def __init__( self: Float[float], @@ -925,6 +928,8 @@ class Time(_RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.time]): class _Binary(TypeEngine[bytes]): """Define base behavior for binary types.""" + length: Optional[int] + def __init__(self, length: Optional[int] = None): self.length = length @@ -1249,6 +1254,9 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): return _we_are_the_impl(variant_mapping["_default"]) +_EnumTupleArg = Union[Sequence[enum.Enum], Sequence[str]] + + class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): """Generic Enum Type. @@ -1325,7 +1333,18 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): __visit_name__ = "enum" - def __init__(self, *enums: object, **kw: Any): + values_callable: Optional[Callable[[Type[enum.Enum]], Sequence[str]]] + enum_class: Optional[Type[enum.Enum]] + _valid_lookup: Dict[Union[enum.Enum, str, None], Optional[str]] + _object_lookup: Dict[Optional[str], Union[enum.Enum, str, None]] + + @overload + def __init__(self, enums: Type[enum.Enum], **kw: Any) -> None: ... + + @overload + def __init__(self, *enums: str, **kw: Any) -> None: ... + + def __init__(self, *enums: Union[str, Type[enum.Enum]], **kw: Any) -> None: r"""Construct an enum. Keyword arguments which don't apply to a specific backend are ignored @@ -1457,7 +1476,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): .. versionchanged:: 2.0 This parameter now defaults to True. """ - self._enum_init(enums, kw) + self._enum_init(enums, kw) # type: ignore[arg-type] @property def _enums_argument(self): @@ -1466,7 +1485,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): else: return self.enums - def _enum_init(self, enums, kw): + def _enum_init(self, enums: _EnumTupleArg, kw: Dict[str, Any]) -> None: """internal init for :class:`.Enum` and subclasses. friendly init helper used by subclasses to remove @@ -1525,15 +1544,19 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): _adapted_from=kw.pop("_adapted_from", None), ) - def _parse_into_values(self, enums, kw): + def _parse_into_values( + self, enums: _EnumTupleArg, kw: Any + ) -> Tuple[Sequence[str], _EnumTupleArg]: if not enums and "_enums" in kw: enums = kw.pop("_enums") if len(enums) == 1 and hasattr(enums[0], "__members__"): - self.enum_class = enums[0] + self.enum_class = enums[0] # type: ignore[assignment] + assert self.enum_class is not None _members = self.enum_class.__members__ + members: Mapping[str, enum.Enum] if self._omit_aliases is True: # remove aliases members = OrderedDict( @@ -1549,7 +1572,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): return values, objects else: self.enum_class = None - return enums, enums + return enums, enums # type: ignore[return-value] def _resolve_for_literal(self, value: Any) -> Enum: tv = type(value) @@ -1625,7 +1648,12 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): self._generic_type_affinity(_enums=enum_args, **kw), # type: ignore # noqa: E501 ) - def _setup_for_values(self, values, objects, kw): + def _setup_for_values( + self, + values: Sequence[str], + objects: _EnumTupleArg, + kw: Any, + ) -> None: self.enums = list(values) self._valid_lookup = dict(zip(reversed(objects), reversed(values))) @@ -1692,9 +1720,10 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): comparator_factory = Comparator - def _object_value_for_elem(self, elem): + def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]: try: - return self._object_lookup[elem] + # Value will not be None beacuse key is not None + return self._object_lookup[elem] # type: ignore[return-value] except KeyError as err: raise LookupError( "'%s' is not among the defined enum values. " @@ -3625,6 +3654,7 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]): __visit_name__ = "uuid" + length: Optional[int] = None collation: Optional[str] = None @overload @@ -3676,7 +3706,9 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]): else: return super().coerce_compared_value(op, value) - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[_UUID_RETURN]]: character_based_uuid = ( not dialect.supports_native_uuid or not self.native_uuid ) diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index bdc56b46ac..911071cc99 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1392,6 +1392,10 @@ class UserDefinedType( return self + if TYPE_CHECKING: + + def get_col_spec(self, **kw: Any) -> str: ... + class Emulated(TypeEngineMixin): """Mixin for base types that emulate the behavior of a DB-native type. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 9899004178..a98b51c1de 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -481,7 +481,7 @@ def surface_selectables(clause): stack.append(elem.element) -def surface_selectables_only(clause): +def surface_selectables_only(clause: ClauseElement) -> Iterator[ClauseElement]: stack = [clause] while stack: elem = stack.pop() diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 9ca5e60a20..36ca6a56a9 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -430,7 +430,9 @@ def to_column_set(x: Any) -> Set[Any]: return x -def update_copy(d, _new=None, **kw): +def update_copy( + d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any +) -> Dict[Any, Any]: """Copy the given dict and update with the given values.""" d = d.copy() diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 01569cebda..8980a85062 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -56,6 +56,7 @@ if True: # zimports removes the tailing comments from typing_extensions import TypeAliasType as TypeAliasType # 3.12 from typing_extensions import Unpack as Unpack # 3.11 from typing_extensions import Never as Never # 3.11 + from typing_extensions import LiteralString as LiteralString # 3.11 _T = TypeVar("_T", bound=Any) diff --git a/pyproject.toml b/pyproject.toml index ade402dd6b..9a9b5658c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,6 +176,8 @@ reportTypedDictNotRequiredAccess = "warning" mypy_path = "./lib/" show_error_codes = true incremental = true +# would be nice to enable this but too many error are surfaceds +# enable_error_code = "ignore-without-code" [[tool.mypy.overrides]] diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index 8ea523fb7e..1f8a23f70d 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -681,7 +681,6 @@ class CompatFlagsTest(fixtures.TestBase, AssertsCompiledSQL): dialect._get_server_version_info = server_version_info dialect.get_isolation_level = Mock() - dialect._check_unicode_returns = Mock() dialect._check_unicode_description = Mock() dialect._get_default_schema_name = Mock() dialect._detect_decimal_char = Mock()