]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
typing: pg: type NamedType create/drops (fixes #12557)
authorJustine Krejcha <justine@justinekrejcha.com>
Tue, 6 May 2025 19:18:02 +0000 (15:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 May 2025 01:40:18 +0000 (21:40 -0400)
Type the `create` and `drop` functions for `NamedType`s

Also partially type the SchemaType create/drop functions more generally

One change to this is that the default parameter of `None` is removed. It doesn't work and will fail with a `AttributeError` at runtime since it immediately tries to access a property of `None` which doesn't exist.

Fixes #12557

This pull request is:

- [X] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed
- [X] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #12558
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12558
Pull-request-sha: 75c8d81bfb68f45299a9448d45dda446532205d3

Change-Id: I173771d365f34f54ab474b9661e1cdc70cc4de84

lib/sqlalchemy/dialects/postgresql/named_types.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/mock.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/ddl.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/sqltypes.py
test/sql/test_types.py

index e1b8e84ce858807e7377b4be7a5a9244c270b747..c9d6e5844cf117ad1c346a2b7540f539d40d0235 100644 (file)
@@ -7,7 +7,9 @@
 # mypy: ignore-errors
 from __future__ import annotations
 
+from types import ModuleType
 from typing import Any
+from typing import Dict
 from typing import Optional
 from typing import Type
 from typing import TYPE_CHECKING
@@ -25,10 +27,11 @@ from ...sql.ddl import InvokeCreateDDLBase
 from ...sql.ddl import InvokeDropDDLBase
 
 if TYPE_CHECKING:
+    from ...sql._typing import _CreateDropBind
     from ...sql._typing import _TypeEngineArgument
 
 
-class NamedType(sqltypes.TypeEngine):
+class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
     """Base for named types."""
 
     __abstract__ = True
@@ -36,7 +39,9 @@ class NamedType(sqltypes.TypeEngine):
     DDLDropper: Type[NamedTypeDropper]
     create_type: bool
 
-    def create(self, bind, checkfirst=True, **kw):
+    def create(
+        self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any
+    ) -> None:
         """Emit ``CREATE`` DDL for this type.
 
         :param bind: a connectable :class:`_engine.Engine`,
@@ -50,7 +55,9 @@ class NamedType(sqltypes.TypeEngine):
         """
         bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst)
 
-    def drop(self, bind, checkfirst=True, **kw):
+    def drop(
+        self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any
+    ) -> None:
         """Emit ``DROP`` DDL for this type.
 
         :param bind: a connectable :class:`_engine.Engine`,
@@ -63,7 +70,9 @@ class NamedType(sqltypes.TypeEngine):
         """
         bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst)
 
-    def _check_for_name_in_memos(self, checkfirst, kw):
+    def _check_for_name_in_memos(
+        self, checkfirst: bool, kw: Dict[str, Any]
+    ) -> bool:
         """Look in the 'ddl runner' for 'memos', then
         note our name in that collection.
 
@@ -87,7 +96,13 @@ class NamedType(sqltypes.TypeEngine):
         else:
             return False
 
-    def _on_table_create(self, target, bind, checkfirst=False, **kw):
+    def _on_table_create(
+        self,
+        target: Any,
+        bind: _CreateDropBind,
+        checkfirst: bool = False,
+        **kw: Any,
+    ) -> None:
         if (
             checkfirst
             or (
@@ -97,7 +112,13 @@ class NamedType(sqltypes.TypeEngine):
         ) and not self._check_for_name_in_memos(checkfirst, kw):
             self.create(bind=bind, checkfirst=checkfirst)
 
-    def _on_table_drop(self, target, bind, checkfirst=False, **kw):
+    def _on_table_drop(
+        self,
+        target: Any,
+        bind: _CreateDropBind,
+        checkfirst: bool = False,
+        **kw: Any,
+    ) -> None:
         if (
             not self.metadata
             and not kw.get("_is_metadata_operation", False)
@@ -105,11 +126,23 @@ class NamedType(sqltypes.TypeEngine):
         ):
             self.drop(bind=bind, checkfirst=checkfirst)
 
-    def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
+    def _on_metadata_create(
+        self,
+        target: Any,
+        bind: _CreateDropBind,
+        checkfirst: bool = False,
+        **kw: Any,
+    ) -> None:
         if not self._check_for_name_in_memos(checkfirst, kw):
             self.create(bind=bind, checkfirst=checkfirst)
 
-    def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
+    def _on_metadata_drop(
+        self,
+        target: Any,
+        bind: _CreateDropBind,
+        checkfirst: bool = False,
+        **kw: Any,
+    ) -> None:
         if not self._check_for_name_in_memos(checkfirst, kw):
             self.drop(bind=bind, checkfirst=checkfirst)
 
@@ -314,7 +347,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
 
         return cls(**kw)
 
-    def create(self, bind=None, checkfirst=True):
+    def create(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
         """Emit ``CREATE TYPE`` for this
         :class:`_postgresql.ENUM`.
 
@@ -335,7 +368,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
 
         super().create(bind, checkfirst=checkfirst)
 
-    def drop(self, bind=None, checkfirst=True):
+    def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
         """Emit ``DROP TYPE`` for this
         :class:`_postgresql.ENUM`.
 
@@ -355,7 +388,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
 
         super().drop(bind, checkfirst=checkfirst)
 
-    def get_dbapi_type(self, dbapi):
+    def get_dbapi_type(self, dbapi: ModuleType) -> None:
         """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
         a different type"""
 
index 5b5339036bb9353a77edc7589fca413465cf57b7..5e562bcb1384bfa8419d7c315e2a0404c093aae5 100644 (file)
@@ -73,12 +73,11 @@ if typing.TYPE_CHECKING:
     from ..sql._typing import _InfoType
     from ..sql.compiler import Compiled
     from ..sql.ddl import ExecutableDDLElement
-    from ..sql.ddl import SchemaDropper
-    from ..sql.ddl import SchemaGenerator
+    from ..sql.ddl import InvokeDDLBase
     from ..sql.functions import FunctionElement
     from ..sql.schema import DefaultGenerator
     from ..sql.schema import HasSchemaAttr
-    from ..sql.schema import SchemaItem
+    from ..sql.schema import SchemaVisitable
     from ..sql.selectable import TypedReturnsRows
 
 
@@ -2450,8 +2449,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
 
     def _run_ddl_visitor(
         self,
-        visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
-        element: SchemaItem,
+        visitorcallable: Type[InvokeDDLBase],
+        element: SchemaVisitable,
         **kwargs: Any,
     ) -> None:
         """run a DDL visitor.
@@ -2460,7 +2459,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         options given to the visitor so that "checkfirst" is skipped.
 
         """
-        visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
+        visitorcallable(
+            dialect=self.dialect, connection=self, **kwargs
+        ).traverse_single(element)
 
 
 class ExceptionContextImpl(ExceptionContext):
@@ -3246,8 +3247,8 @@ class Engine(
 
     def _run_ddl_visitor(
         self,
-        visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
-        element: SchemaItem,
+        visitorcallable: Type[InvokeDDLBase],
+        element: SchemaVisitable,
         **kwargs: Any,
     ) -> None:
         with self.begin() as conn:
index 08dba5a6456de92677e477f66c063dde94995a2b..a96af36ccda5b530c688bef08d04e9c9713c053f 100644 (file)
@@ -27,10 +27,9 @@ if typing.TYPE_CHECKING:
     from .interfaces import Dialect
     from .url import URL
     from ..sql.base import Executable
-    from ..sql.ddl import SchemaDropper
-    from ..sql.ddl import SchemaGenerator
+    from ..sql.ddl import InvokeDDLBase
     from ..sql.schema import HasSchemaAttr
-    from ..sql.schema import SchemaItem
+    from ..sql.visitors import Visitable
 
 
 class MockConnection:
@@ -53,12 +52,14 @@ class MockConnection:
 
     def _run_ddl_visitor(
         self,
-        visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
-        element: SchemaItem,
+        visitorcallable: Type[InvokeDDLBase],
+        element: Visitable,
         **kwargs: Any,
     ) -> None:
         kwargs["checkfirst"] = False
-        visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
+        visitorcallable(
+            dialect=self.dialect, connection=self, **kwargs
+        ).traverse_single(element)
 
     def execute(
         self,
index 32adc9bb218286fed3688c5ea0be8bbcfbf3e0de..16f7ec37b3c539aa6ef8a9760658f59f78874eea 100644 (file)
@@ -65,6 +65,7 @@ from .sql.schema import MetaData as MetaData
 from .sql.schema import PrimaryKeyConstraint as PrimaryKeyConstraint
 from .sql.schema import SchemaConst as SchemaConst
 from .sql.schema import SchemaItem as SchemaItem
+from .sql.schema import SchemaVisitable as SchemaVisitable
 from .sql.schema import Sequence as Sequence
 from .sql.schema import Table as Table
 from .sql.schema import UniqueConstraint as UniqueConstraint
index 6fef1766c6df03ec3826dc8015dd3258f538f371..eb5d09ec2dac8fe56aef38fe48ac1701c353e1d8 100644 (file)
@@ -72,7 +72,10 @@ if TYPE_CHECKING:
     from .sqltypes import TableValueType
     from .sqltypes import TupleType
     from .type_api import TypeEngine
+    from ..engine import Connection
     from ..engine import Dialect
+    from ..engine import Engine
+    from ..engine.mock import MockConnection
     from ..util.typing import TypeGuard
 
 _T = TypeVar("_T", bound=Any)
@@ -304,6 +307,8 @@ _LimitOffsetType = Union[int, _ColumnExpressionArgument[int], None]
 
 _AutoIncrementType = Union[bool, Literal["auto", "ignore_fk"]]
 
+_CreateDropBind = Union["Engine", "Connection", "MockConnection"]
+
 if TYPE_CHECKING:
 
     def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: ...
index 38eea2d772df5dabfcd88ca59d5b976d90e448f2..e4279964a05ad5305f03cdacd4de8c378419816c 100644 (file)
@@ -1540,8 +1540,19 @@ class SchemaEventTarget(event.EventTarget):
         self.dispatch.after_parent_attach(self, parent)
 
 
+class SchemaVisitable(SchemaEventTarget, visitors.Visitable):
+    """Base class for elements that are targets of a :class:`.SchemaVisitor`.
+
+    .. versionadded:: 2.0.41
+
+    """
+
+
 class SchemaVisitor(ClauseVisitor):
-    """Define the visiting for ``SchemaItem`` objects."""
+    """Define the visiting for ``SchemaItem`` and more
+    generally ``SchemaVisitable`` objects.
+
+    """
 
     __traverse_options__ = {"schema_visitor": True}
 
index e96dfea2bab6d96b3311ee8cdc24fa6b9201cc78..8748c7c7be818bd157fac7b0176dbba3358add30 100644 (file)
@@ -865,8 +865,9 @@ class DropConstraintComment(_CreateDropBase["Constraint"]):
 
 
 class InvokeDDLBase(SchemaVisitor):
-    def __init__(self, connection):
+    def __init__(self, connection, **kw):
         self.connection = connection
+        assert not kw, f"Unexpected keywords: {kw.keys()}"
 
     @contextlib.contextmanager
     def with_ddl_events(self, target, **kw):
index 77047f10b633de22708654b42e2c343531c6f843..7f5f5e346ec4eec3668c370b1e6b37f62e9fe51f 100644 (file)
@@ -71,6 +71,7 @@ from .base import DedupeColumnCollection
 from .base import DialectKWArgs
 from .base import Executable
 from .base import SchemaEventTarget as SchemaEventTarget
+from .base import SchemaVisitable as SchemaVisitable
 from .coercions import _document_text_coercion
 from .elements import ClauseElement
 from .elements import ColumnClause
@@ -91,6 +92,7 @@ from ..util.typing import TypeGuard
 
 if typing.TYPE_CHECKING:
     from ._typing import _AutoIncrementType
+    from ._typing import _CreateDropBind
     from ._typing import _DDLColumnArgument
     from ._typing import _DDLColumnReferenceArgument
     from ._typing import _InfoType
@@ -109,7 +111,6 @@ if typing.TYPE_CHECKING:
     from ..engine.interfaces import _CoreMultiExecuteParams
     from ..engine.interfaces import CoreExecuteOptionsParameter
     from ..engine.interfaces import ExecutionContext
-    from ..engine.mock import MockConnection
     from ..engine.reflection import _ReflectionInfo
     from ..sql.selectable import FromClause
 
@@ -118,8 +119,6 @@ _SI = TypeVar("_SI", bound="SchemaItem")
 _TAB = TypeVar("_TAB", bound="Table")
 
 
-_CreateDropBind = Union["Engine", "Connection", "MockConnection"]
-
 _ConstraintNameArgument = Optional[Union[str, _NoneName]]
 
 _ServerDefaultArgument = Union[
@@ -213,7 +212,7 @@ def _copy_expression(
 
 
 @inspection._self_inspects
-class SchemaItem(SchemaEventTarget, visitors.Visitable):
+class SchemaItem(SchemaVisitable):
     """Base class for items that define a database schema."""
 
     __visit_name__ = "schema_item"
index f71678a4ab41b4b6767fb3f50c7bec1754a8d6d0..90c93bcef1bb69f1789ebfeccd44cfddc01ecb37 100644 (file)
@@ -70,6 +70,7 @@ from ..util.typing import TupleAny
 
 if TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
+    from ._typing import _CreateDropBind
     from ._typing import _TypeEngineArgument
     from .elements import ColumnElement
     from .operators import OperatorType
@@ -1179,21 +1180,23 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
         kw.setdefault("_adapted_from", self)
         return super().adapt(cls, **kw)
 
-    def create(self, bind, checkfirst=False):
+    def create(self, bind: _CreateDropBind, checkfirst: bool = False) -> None:
         """Issue CREATE DDL for this type, if applicable."""
 
         t = self.dialect_impl(bind.dialect)
         if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t.create(bind, checkfirst=checkfirst)
 
-    def drop(self, bind, checkfirst=False):
+    def drop(self, bind: _CreateDropBind, checkfirst: bool = False) -> None:
         """Issue DROP DDL for this type, if applicable."""
 
         t = self.dialect_impl(bind.dialect)
         if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t.drop(bind, checkfirst=checkfirst)
 
-    def _on_table_create(self, target, bind, **kw):
+    def _on_table_create(
+        self, target: Any, bind: _CreateDropBind, **kw: Any
+    ) -> None:
         if not self._is_impl_for_variant(bind.dialect, kw):
             return
 
@@ -1201,7 +1204,9 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
         if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t._on_table_create(target, bind, **kw)
 
-    def _on_table_drop(self, target, bind, **kw):
+    def _on_table_drop(
+        self, target: Any, bind: _CreateDropBind, **kw: Any
+    ) -> None:
         if not self._is_impl_for_variant(bind.dialect, kw):
             return
 
@@ -1209,7 +1214,9 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
         if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t._on_table_drop(target, bind, **kw)
 
-    def _on_metadata_create(self, target, bind, **kw):
+    def _on_metadata_create(
+        self, target: Any, bind: _CreateDropBind, **kw: Any
+    ) -> None:
         if not self._is_impl_for_variant(bind.dialect, kw):
             return
 
@@ -1217,7 +1224,9 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
         if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t._on_metadata_create(target, bind, **kw)
 
-    def _on_metadata_drop(self, target, bind, **kw):
+    def _on_metadata_drop(
+        self, target: Any, bind: _CreateDropBind, **kw: Any
+    ) -> None:
         if not self._is_impl_for_variant(bind.dialect, kw):
             return
 
@@ -1225,7 +1234,9 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
         if isinstance(t, SchemaType) and t.__class__ is not self.__class__:
             t._on_metadata_drop(target, bind, **kw)
 
-    def _is_impl_for_variant(self, dialect, kw):
+    def _is_impl_for_variant(
+        self, dialect: Dialect, kw: Dict[str, Any]
+    ) -> Optional[bool]:
         variant_mapping = kw.pop("variant_mapping", None)
 
         if not variant_mapping:
@@ -1242,7 +1253,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
 
         # since PostgreSQL is the only DB that has ARRAY this can only
         # be integration tested by PG-specific tests
-        def _we_are_the_impl(typ):
+        def _we_are_the_impl(typ: SchemaType) -> bool:
             return (
                 typ is self
                 or isinstance(typ, ARRAY)
@@ -1255,6 +1266,8 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
             return True
         elif dialect.name not in variant_mapping:
             return _we_are_the_impl(variant_mapping["_default"])
+        else:
+            return None
 
 
 _EnumTupleArg = Union[Sequence[enum.Enum], Sequence[str]]
index e6e2a18f160d7f7e9050b3ad7fcf01d2ab4aba4e..eb4b420129f63a7644bb883f56e966cfdcfc13fe 100644 (file)
@@ -298,6 +298,7 @@ class AdaptTest(fixtures.TestBase):
                     "schema",
                     "metadata",
                     "name",
+                    "dispatch",
                 ):
                     continue
                 # assert each value was copied, or that