From: Justine Krejcha Date: Sun, 27 Apr 2025 17:17:58 +0000 (-0700) Subject: typing: pg: type `NamedType` `create`/`drop` X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=75c8d81bfb68f45299a9448d45dda446532205d3;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git typing: pg: type `NamedType` `create`/`drop` Also type `SchemaType` `create`/`drop` more generally Fixes #12557 --- diff --git a/lib/sqlalchemy/dialects/postgresql/named_types.py b/lib/sqlalchemy/dialects/postgresql/named_types.py index e1b8e84ce8..42a1128bbb 100644 --- a/lib/sqlalchemy/dialects/postgresql/named_types.py +++ b/lib/sqlalchemy/dialects/postgresql/named_types.py @@ -7,6 +7,7 @@ # mypy: ignore-errors from __future__ import annotations +from types import ModuleType from typing import Any from typing import Optional from typing import Type @@ -25,6 +26,7 @@ from ...sql.ddl import InvokeCreateDDLBase from ...sql.ddl import InvokeDropDDLBase if TYPE_CHECKING: + from ...sql._typing import _CreateDropBind from ...sql._typing import _TypeEngineArgument @@ -36,7 +38,7 @@ class NamedType(sqltypes.TypeEngine): DDLDropper: Type[NamedTypeDropper] create_type: bool - def create(self, bind, checkfirst=True, **kw): + def create(self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any) -> None: """Emit ``CREATE`` DDL for this type. :param bind: a connectable :class:`_engine.Engine`, @@ -50,7 +52,7 @@ class NamedType(sqltypes.TypeEngine): """ bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst) - def drop(self, bind, checkfirst=True, **kw): + def drop(self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any) -> None: """Emit ``DROP`` DDL for this type. :param bind: a connectable :class:`_engine.Engine`, @@ -63,7 +65,7 @@ class NamedType(sqltypes.TypeEngine): """ bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst) - def _check_for_name_in_memos(self, checkfirst, kw): + def _check_for_name_in_memos(self, checkfirst: bool, **kw: Any) -> None: """Look in the 'ddl runner' for 'memos', then note our name in that collection. @@ -87,7 +89,7 @@ class NamedType(sqltypes.TypeEngine): else: return False - def _on_table_create(self, target, bind, checkfirst=False, **kw): + def _on_table_create(self, target, bind: _CreateDropBind, checkfirst: bool = False, **kw: Any) -> None: if ( checkfirst or ( @@ -97,7 +99,7 @@ class NamedType(sqltypes.TypeEngine): ) and not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) - def _on_table_drop(self, target, bind, checkfirst=False, **kw): + def _on_table_drop(self, target, bind: _CreateDropBind, checkfirst: bool = False, **kw: Any) -> None: if ( not self.metadata and not kw.get("_is_metadata_operation", False) @@ -105,11 +107,11 @@ class NamedType(sqltypes.TypeEngine): ): self.drop(bind=bind, checkfirst=checkfirst) - def _on_metadata_create(self, target, bind, checkfirst=False, **kw): + def _on_metadata_create(self, target, bind: _CreateDropBind, checkfirst: bool = False, **kw: Any) -> None: if not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) - def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): + def _on_metadata_drop(self, target, bind: _CreateDropBind, checkfirst: bool = False, **kw: Any) -> None: if not self._check_for_name_in_memos(checkfirst, kw): self.drop(bind=bind, checkfirst=checkfirst) @@ -314,7 +316,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): return cls(**kw) - def create(self, bind=None, checkfirst=True): + def create(self, bind: _CreateDropBind, checkfirst: bool = True) -> None: """Emit ``CREATE TYPE`` for this :class:`_postgresql.ENUM`. @@ -335,7 +337,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): super().create(bind, checkfirst=checkfirst) - def drop(self, bind=None, checkfirst=True): + def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None: """Emit ``DROP TYPE`` for this :class:`_postgresql.ENUM`. @@ -355,7 +357,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): super().drop(bind, checkfirst=checkfirst) - def get_dbapi_type(self, dbapi): + def get_dbapi_type(self, dbapi: ModuleType) -> None: """dont return dbapi.STRING for ENUM in PostgreSQL, since that's a different type""" diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 6fef1766c6..eb5d09ec2d 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -72,7 +72,10 @@ if TYPE_CHECKING: from .sqltypes import TableValueType from .sqltypes import TupleType from .type_api import TypeEngine + from ..engine import Connection from ..engine import Dialect + from ..engine import Engine + from ..engine.mock import MockConnection from ..util.typing import TypeGuard _T = TypeVar("_T", bound=Any) @@ -304,6 +307,8 @@ _LimitOffsetType = Union[int, _ColumnExpressionArgument[int], None] _AutoIncrementType = Union[bool, Literal["auto", "ignore_fk"]] +_CreateDropBind = Union["Engine", "Connection", "MockConnection"] + if TYPE_CHECKING: def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: ... diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 77047f10b6..5d25cdcb9f 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -91,6 +91,7 @@ from ..util.typing import TypeGuard if typing.TYPE_CHECKING: from ._typing import _AutoIncrementType + from ._typing import _CreateDropBind from ._typing import _DDLColumnArgument from ._typing import _DDLColumnReferenceArgument from ._typing import _InfoType @@ -118,8 +119,6 @@ _SI = TypeVar("_SI", bound="SchemaItem") _TAB = TypeVar("_TAB", bound="Table") -_CreateDropBind = Union["Engine", "Connection", "MockConnection"] - _ConstraintNameArgument = Optional[Union[str, _NoneName]] _ServerDefaultArgument = Union[ diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index f71678a4ab..7c656ab5a9 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -70,6 +70,7 @@ from ..util.typing import TupleAny if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument + from ._typing import _CreateDropBind from ._typing import _TypeEngineArgument from .elements import ColumnElement from .operators import OperatorType @@ -1179,21 +1180,21 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): kw.setdefault("_adapted_from", self) return super().adapt(cls, **kw) - def create(self, bind, checkfirst=False): + def create(self, bind: _CreateDropBind, checkfirst: bool = False) -> None: """Issue CREATE DDL for this type, if applicable.""" t = self.dialect_impl(bind.dialect) if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t.create(bind, checkfirst=checkfirst) - def drop(self, bind, checkfirst=False): + def drop(self, bind: _CreateDropBind, checkfirst: bool = False) -> None: """Issue DROP DDL for this type, if applicable.""" t = self.dialect_impl(bind.dialect) if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t.drop(bind, checkfirst=checkfirst) - def _on_table_create(self, target, bind, **kw): + def _on_table_create(self, target, bind: _CreateDropBind, **kw: Any) -> None: if not self._is_impl_for_variant(bind.dialect, kw): return @@ -1201,7 +1202,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_table_create(target, bind, **kw) - def _on_table_drop(self, target, bind, **kw): + def _on_table_drop(self, target, bind: _CreateDropBind, **kw: Any) -> None: if not self._is_impl_for_variant(bind.dialect, kw): return @@ -1209,7 +1210,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_table_drop(target, bind, **kw) - def _on_metadata_create(self, target, bind, **kw): + def _on_metadata_create(self, target, bind: _CreateDropBind, **kw: Any) -> None: if not self._is_impl_for_variant(bind.dialect, kw): return @@ -1217,7 +1218,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_metadata_create(target, bind, **kw) - def _on_metadata_drop(self, target, bind, **kw): + def _on_metadata_drop(self, target, bind: _CreateDropBind, **kw: Any) -> None: if not self._is_impl_for_variant(bind.dialect, kw): return @@ -1225,7 +1226,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_metadata_drop(target, bind, **kw) - def _is_impl_for_variant(self, dialect, kw): + def _is_impl_for_variant(self, dialect: Dialect, **kw: Any) -> Optional[bool]: variant_mapping = kw.pop("variant_mapping", None) if not variant_mapping: