]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
typing: pg: type `NamedType` `create`/`drop` 12558/head
authorJustine Krejcha <justine@justinekrejcha.com>
Sun, 27 Apr 2025 17:17:58 +0000 (10:17 -0700)
committerJustine Krejcha <justine@justinekrejcha.com>
Tue, 29 Apr 2025 22:12:34 +0000 (15:12 -0700)
Also type `SchemaType` `create`/`drop` more generally

Fixes #12557

lib/sqlalchemy/dialects/postgresql/named_types.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/sqltypes.py

index e1b8e84ce858807e7377b4be7a5a9244c270b747..42a1128bbbbab6b7febdcf461e4c7c0205459c8a 100644 (file)
@@ -7,6 +7,7 @@
 # mypy: ignore-errors
 from __future__ import annotations
 
+from types import ModuleType
 from typing import Any
 from typing import Optional
 from typing import Type
@@ -25,6 +26,7 @@ from ...sql.ddl import InvokeCreateDDLBase
 from ...sql.ddl import InvokeDropDDLBase
 
 if TYPE_CHECKING:
+    from ...sql._typing import _CreateDropBind
     from ...sql._typing import _TypeEngineArgument
 
 
@@ -36,7 +38,7 @@ class NamedType(sqltypes.TypeEngine):
     DDLDropper: Type[NamedTypeDropper]
     create_type: bool
 
-    def create(self, bind, checkfirst=True, **kw):
+    def create(self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any) -> None:
         """Emit ``CREATE`` DDL for this type.
 
         :param bind: a connectable :class:`_engine.Engine`,
@@ -50,7 +52,7 @@ class NamedType(sqltypes.TypeEngine):
         """
         bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst)
 
-    def drop(self, bind, checkfirst=True, **kw):
+    def drop(self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any) -> None:
         """Emit ``DROP`` DDL for this type.
 
         :param bind: a connectable :class:`_engine.Engine`,
@@ -63,7 +65,7 @@ class NamedType(sqltypes.TypeEngine):
         """
         bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst)
 
-    def _check_for_name_in_memos(self, checkfirst, kw):
+    def _check_for_name_in_memos(self, checkfirst: bool, **kw: Any) -> None:
         """Look in the 'ddl runner' for 'memos', then
         note our name in that collection.
 
@@ -87,7 +89,7 @@ class NamedType(sqltypes.TypeEngine):
         else:
             return False
 
-    def _on_table_create(self, target, bind, checkfirst=False, **kw):
+    def _on_table_create(self, target, bind: _CreateDropBind, checkfirst: bool = False, **kw: Any) -> None:
         if (
             checkfirst
             or (
@@ -97,7 +99,7 @@ class NamedType(sqltypes.TypeEngine):
         ) and not self._check_for_name_in_memos(checkfirst, kw):
             self.create(bind=bind, checkfirst=checkfirst)
 
-    def _on_table_drop(self, target, bind, checkfirst=False, **kw):
+    def _on_table_drop(self, target, bind: _CreateDropBind, checkfirst: bool = False, **kw: Any) -> None:
         if (
             not self.metadata
             and not kw.get("_is_metadata_operation", False)
@@ -105,11 +107,11 @@ class NamedType(sqltypes.TypeEngine):
         ):
             self.drop(bind=bind, checkfirst=checkfirst)
 
-    def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
+    def _on_metadata_create(self, target, bind: _CreateDropBind, checkfirst: bool = False, **kw: Any) -> None:
         if not self._check_for_name_in_memos(checkfirst, kw):
             self.create(bind=bind, checkfirst=checkfirst)
 
-    def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
+    def _on_metadata_drop(self, target, bind: _CreateDropBind, checkfirst: bool = False, **kw: Any) -> None:
         if not self._check_for_name_in_memos(checkfirst, kw):
             self.drop(bind=bind, checkfirst=checkfirst)
 
@@ -314,7 +316,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
 
         return cls(**kw)
 
-    def create(self, bind=None, checkfirst=True):
+    def create(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
         """Emit ``CREATE TYPE`` for this
         :class:`_postgresql.ENUM`.
 
@@ -335,7 +337,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
 
         super().create(bind, checkfirst=checkfirst)
 
-    def drop(self, bind=None, checkfirst=True):
+    def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
         """Emit ``DROP TYPE`` for this
         :class:`_postgresql.ENUM`.
 
@@ -355,7 +357,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
 
         super().drop(bind, checkfirst=checkfirst)
 
-    def get_dbapi_type(self, dbapi):
+    def get_dbapi_type(self, dbapi: ModuleType) -> None:
         """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
         a different type"""
 
index 6fef1766c6df03ec3826dc8015dd3258f538f371..eb5d09ec2dac8fe56aef38fe48ac1701c353e1d8 100644 (file)
@@ -72,7 +72,10 @@ if TYPE_CHECKING:
     from .sqltypes import TableValueType
     from .sqltypes import TupleType
     from .type_api import TypeEngine
+    from ..engine import Connection
     from ..engine import Dialect
+    from ..engine import Engine
+    from ..engine.mock import MockConnection
     from ..util.typing import TypeGuard
 
 _T = TypeVar("_T", bound=Any)
@@ -304,6 +307,8 @@ _LimitOffsetType = Union[int, _ColumnExpressionArgument[int], None]
 
 _AutoIncrementType = Union[bool, Literal["auto", "ignore_fk"]]
 
+_CreateDropBind = Union["Engine", "Connection", "MockConnection"]
+
 if TYPE_CHECKING:
 
     def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: ...
index 77047f10b633de22708654b42e2c343531c6f843..5d25cdcb9f4e7a1c607f85e9ada76ee28a18b014 100644 (file)
@@ -91,6 +91,7 @@ from ..util.typing import TypeGuard
 
 if typing.TYPE_CHECKING:
     from ._typing import _AutoIncrementType
+    from ._typing import _CreateDropBind
     from ._typing import _DDLColumnArgument
     from ._typing import _DDLColumnReferenceArgument
     from ._typing import _InfoType
@@ -118,8 +119,6 @@ _SI = TypeVar("_SI", bound="SchemaItem")
 _TAB = TypeVar("_TAB", bound="Table")
 
 
-_CreateDropBind = Union["Engine", "Connection", "MockConnection"]
-
 _ConstraintNameArgument = Optional[Union[str, _NoneName]]
 
 _ServerDefaultArgument = Union[
index f71678a4ab41b4b6767fb3f50c7bec1754a8d6d0..7c656ab5a9d9fd49261da189c33dc7ace63056c1 100644 (file)
@@ -70,6 +70,7 @@ from ..util.typing import TupleAny
 
 if TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
+    from ._typing import _CreateDropBind
     from ._typing import _TypeEngineArgument
     from .elements import ColumnElement
     from .operators import OperatorType
@@ -1179,21 +1180,21 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
         kw.setdefault("_adapted_from", self)
         return super().adapt(cls, **kw)
 
-    def create(self, bind, checkfirst=False):
+    def create(self, bind: _CreateDropBind, checkfirst: bool = False) -> None:
         """Issue CREATE DDL for this type, if applicable."""
 
         t = self.dialect_impl(bind.dialect)
         if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t.create(bind, checkfirst=checkfirst)
 
-    def drop(self, bind, checkfirst=False):
+    def drop(self, bind: _CreateDropBind, checkfirst: bool = False) -> None:
         """Issue DROP DDL for this type, if applicable."""
 
         t = self.dialect_impl(bind.dialect)
         if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t.drop(bind, checkfirst=checkfirst)
 
-    def _on_table_create(self, target, bind, **kw):
+    def _on_table_create(self, target, bind: _CreateDropBind, **kw: Any) -> None:
         if not self._is_impl_for_variant(bind.dialect, kw):
             return
 
@@ -1201,7 +1202,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
         if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t._on_table_create(target, bind, **kw)
 
-    def _on_table_drop(self, target, bind, **kw):
+    def _on_table_drop(self, target, bind: _CreateDropBind, **kw: Any) -> None:
         if not self._is_impl_for_variant(bind.dialect, kw):
             return
 
@@ -1209,7 +1210,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
         if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t._on_table_drop(target, bind, **kw)
 
-    def _on_metadata_create(self, target, bind, **kw):
+    def _on_metadata_create(self, target, bind: _CreateDropBind, **kw: Any) -> None:
         if not self._is_impl_for_variant(bind.dialect, kw):
             return
 
@@ -1217,7 +1218,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
         if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t._on_metadata_create(target, bind, **kw)
 
-    def _on_metadata_drop(self, target, bind, **kw):
+    def _on_metadata_drop(self, target, bind: _CreateDropBind, **kw: Any) -> None:
         if not self._is_impl_for_variant(bind.dialect, kw):
             return
 
@@ -1225,7 +1226,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
         if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t._on_metadata_drop(target, bind, **kw)
 
-    def _is_impl_for_variant(self, dialect, kw):
+    def _is_impl_for_variant(self, dialect: Dialect, **kw: Any) -> Optional[bool]:
         variant_mapping = kw.pop("variant_mapping", None)
 
         if not variant_mapping: