From: Justine Krejcha Date: Tue, 6 May 2025 19:18:02 +0000 (-0400) Subject: typing: pg: type NamedType create/drops (fixes #12557) X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b4d7bf7a2f74db73e12f47ca4cb45666bf08439e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git typing: pg: type NamedType create/drops (fixes #12557) Type the `create` and `drop` functions for `NamedType`s Also partially type the SchemaType create/drop functions more generally One change to this is that the default parameter of `None` is removed. It doesn't work and will fail with a `AttributeError` at runtime since it immediately tries to access a property of `None` which doesn't exist. Fixes #12557 This pull request is: - [X] A documentation / typographical / small typing error fix - Good to go, no issue or tests are needed - [X] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. **Have a nice day!** Closes: #12558 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12558 Pull-request-sha: 75c8d81bfb68f45299a9448d45dda446532205d3 Change-Id: I173771d365f34f54ab474b9661e1cdc70cc4de84 --- diff --git a/lib/sqlalchemy/dialects/postgresql/named_types.py b/lib/sqlalchemy/dialects/postgresql/named_types.py index e1b8e84ce8..c9d6e5844c 100644 --- a/lib/sqlalchemy/dialects/postgresql/named_types.py +++ b/lib/sqlalchemy/dialects/postgresql/named_types.py @@ -7,7 +7,9 @@ # mypy: ignore-errors from __future__ import annotations +from types import ModuleType from typing import Any +from typing import Dict from typing import Optional from typing import Type from typing import TYPE_CHECKING @@ -25,10 +27,11 @@ from ...sql.ddl import InvokeCreateDDLBase from ...sql.ddl import InvokeDropDDLBase if TYPE_CHECKING: + from ...sql._typing import _CreateDropBind from ...sql._typing import _TypeEngineArgument -class NamedType(sqltypes.TypeEngine): +class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine): """Base for named types.""" __abstract__ = True @@ -36,7 +39,9 @@ class NamedType(sqltypes.TypeEngine): DDLDropper: Type[NamedTypeDropper] create_type: bool - def create(self, bind, checkfirst=True, **kw): + def create( + self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any + ) -> None: """Emit ``CREATE`` DDL for this type. :param bind: a connectable :class:`_engine.Engine`, @@ -50,7 +55,9 @@ class NamedType(sqltypes.TypeEngine): """ bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst) - def drop(self, bind, checkfirst=True, **kw): + def drop( + self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any + ) -> None: """Emit ``DROP`` DDL for this type. :param bind: a connectable :class:`_engine.Engine`, @@ -63,7 +70,9 @@ class NamedType(sqltypes.TypeEngine): """ bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst) - def _check_for_name_in_memos(self, checkfirst, kw): + def _check_for_name_in_memos( + self, checkfirst: bool, kw: Dict[str, Any] + ) -> bool: """Look in the 'ddl runner' for 'memos', then note our name in that collection. @@ -87,7 +96,13 @@ class NamedType(sqltypes.TypeEngine): else: return False - def _on_table_create(self, target, bind, checkfirst=False, **kw): + def _on_table_create( + self, + target: Any, + bind: _CreateDropBind, + checkfirst: bool = False, + **kw: Any, + ) -> None: if ( checkfirst or ( @@ -97,7 +112,13 @@ class NamedType(sqltypes.TypeEngine): ) and not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) - def _on_table_drop(self, target, bind, checkfirst=False, **kw): + def _on_table_drop( + self, + target: Any, + bind: _CreateDropBind, + checkfirst: bool = False, + **kw: Any, + ) -> None: if ( not self.metadata and not kw.get("_is_metadata_operation", False) @@ -105,11 +126,23 @@ class NamedType(sqltypes.TypeEngine): ): self.drop(bind=bind, checkfirst=checkfirst) - def _on_metadata_create(self, target, bind, checkfirst=False, **kw): + def _on_metadata_create( + self, + target: Any, + bind: _CreateDropBind, + checkfirst: bool = False, + **kw: Any, + ) -> None: if not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) - def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): + def _on_metadata_drop( + self, + target: Any, + bind: _CreateDropBind, + checkfirst: bool = False, + **kw: Any, + ) -> None: if not self._check_for_name_in_memos(checkfirst, kw): self.drop(bind=bind, checkfirst=checkfirst) @@ -314,7 +347,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): return cls(**kw) - def create(self, bind=None, checkfirst=True): + def create(self, bind: _CreateDropBind, checkfirst: bool = True) -> None: """Emit ``CREATE TYPE`` for this :class:`_postgresql.ENUM`. @@ -335,7 +368,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): super().create(bind, checkfirst=checkfirst) - def drop(self, bind=None, checkfirst=True): + def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None: """Emit ``DROP TYPE`` for this :class:`_postgresql.ENUM`. @@ -355,7 +388,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): super().drop(bind, checkfirst=checkfirst) - def get_dbapi_type(self, dbapi): + def get_dbapi_type(self, dbapi: ModuleType) -> None: """dont return dbapi.STRING for ENUM in PostgreSQL, since that's a different type""" diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 5b5339036b..5e562bcb13 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -73,12 +73,11 @@ if typing.TYPE_CHECKING: from ..sql._typing import _InfoType from ..sql.compiler import Compiled from ..sql.ddl import ExecutableDDLElement - from ..sql.ddl import SchemaDropper - from ..sql.ddl import SchemaGenerator + from ..sql.ddl import InvokeDDLBase from ..sql.functions import FunctionElement from ..sql.schema import DefaultGenerator from ..sql.schema import HasSchemaAttr - from ..sql.schema import SchemaItem + from ..sql.schema import SchemaVisitable from ..sql.selectable import TypedReturnsRows @@ -2450,8 +2449,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def _run_ddl_visitor( self, - visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], - element: SchemaItem, + visitorcallable: Type[InvokeDDLBase], + element: SchemaVisitable, **kwargs: Any, ) -> None: """run a DDL visitor. @@ -2460,7 +2459,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): options given to the visitor so that "checkfirst" is skipped. """ - visitorcallable(self.dialect, self, **kwargs).traverse_single(element) + visitorcallable( + dialect=self.dialect, connection=self, **kwargs + ).traverse_single(element) class ExceptionContextImpl(ExceptionContext): @@ -3246,8 +3247,8 @@ class Engine( def _run_ddl_visitor( self, - visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], - element: SchemaItem, + visitorcallable: Type[InvokeDDLBase], + element: SchemaVisitable, **kwargs: Any, ) -> None: with self.begin() as conn: diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py index 08dba5a645..a96af36ccd 100644 --- a/lib/sqlalchemy/engine/mock.py +++ b/lib/sqlalchemy/engine/mock.py @@ -27,10 +27,9 @@ if typing.TYPE_CHECKING: from .interfaces import Dialect from .url import URL from ..sql.base import Executable - from ..sql.ddl import SchemaDropper - from ..sql.ddl import SchemaGenerator + from ..sql.ddl import InvokeDDLBase from ..sql.schema import HasSchemaAttr - from ..sql.schema import SchemaItem + from ..sql.visitors import Visitable class MockConnection: @@ -53,12 +52,14 @@ class MockConnection: def _run_ddl_visitor( self, - visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], - element: SchemaItem, + visitorcallable: Type[InvokeDDLBase], + element: Visitable, **kwargs: Any, ) -> None: kwargs["checkfirst"] = False - visitorcallable(self.dialect, self, **kwargs).traverse_single(element) + visitorcallable( + dialect=self.dialect, connection=self, **kwargs + ).traverse_single(element) def execute( self, diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 32adc9bb21..16f7ec37b3 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -65,6 +65,7 @@ from .sql.schema import MetaData as MetaData from .sql.schema import PrimaryKeyConstraint as PrimaryKeyConstraint from .sql.schema import SchemaConst as SchemaConst from .sql.schema import SchemaItem as SchemaItem +from .sql.schema import SchemaVisitable as SchemaVisitable from .sql.schema import Sequence as Sequence from .sql.schema import Table as Table from .sql.schema import UniqueConstraint as UniqueConstraint diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 6fef1766c6..eb5d09ec2d 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -72,7 +72,10 @@ if TYPE_CHECKING: from .sqltypes import TableValueType from .sqltypes import TupleType from .type_api import TypeEngine + from ..engine import Connection from ..engine import Dialect + from ..engine import Engine + from ..engine.mock import MockConnection from ..util.typing import TypeGuard _T = TypeVar("_T", bound=Any) @@ -304,6 +307,8 @@ _LimitOffsetType = Union[int, _ColumnExpressionArgument[int], None] _AutoIncrementType = Union[bool, Literal["auto", "ignore_fk"]] +_CreateDropBind = Union["Engine", "Connection", "MockConnection"] + if TYPE_CHECKING: def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: ... diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 38eea2d772..e4279964a0 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -1540,8 +1540,19 @@ class SchemaEventTarget(event.EventTarget): self.dispatch.after_parent_attach(self, parent) +class SchemaVisitable(SchemaEventTarget, visitors.Visitable): + """Base class for elements that are targets of a :class:`.SchemaVisitor`. + + .. versionadded:: 2.0.41 + + """ + + class SchemaVisitor(ClauseVisitor): - """Define the visiting for ``SchemaItem`` objects.""" + """Define the visiting for ``SchemaItem`` and more + generally ``SchemaVisitable`` objects. + + """ __traverse_options__ = {"schema_visitor": True} diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index e96dfea2ba..8748c7c7be 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -865,8 +865,9 @@ class DropConstraintComment(_CreateDropBase["Constraint"]): class InvokeDDLBase(SchemaVisitor): - def __init__(self, connection): + def __init__(self, connection, **kw): self.connection = connection + assert not kw, f"Unexpected keywords: {kw.keys()}" @contextlib.contextmanager def with_ddl_events(self, target, **kw): diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 77047f10b6..7f5f5e346e 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -71,6 +71,7 @@ from .base import DedupeColumnCollection from .base import DialectKWArgs from .base import Executable from .base import SchemaEventTarget as SchemaEventTarget +from .base import SchemaVisitable as SchemaVisitable from .coercions import _document_text_coercion from .elements import ClauseElement from .elements import ColumnClause @@ -91,6 +92,7 @@ from ..util.typing import TypeGuard if typing.TYPE_CHECKING: from ._typing import _AutoIncrementType + from ._typing import _CreateDropBind from ._typing import _DDLColumnArgument from ._typing import _DDLColumnReferenceArgument from ._typing import _InfoType @@ -109,7 +111,6 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import CoreExecuteOptionsParameter from ..engine.interfaces import ExecutionContext - from ..engine.mock import MockConnection from ..engine.reflection import _ReflectionInfo from ..sql.selectable import FromClause @@ -118,8 +119,6 @@ _SI = TypeVar("_SI", bound="SchemaItem") _TAB = TypeVar("_TAB", bound="Table") -_CreateDropBind = Union["Engine", "Connection", "MockConnection"] - _ConstraintNameArgument = Optional[Union[str, _NoneName]] _ServerDefaultArgument = Union[ @@ -213,7 +212,7 @@ def _copy_expression( @inspection._self_inspects -class SchemaItem(SchemaEventTarget, visitors.Visitable): +class SchemaItem(SchemaVisitable): """Base class for items that define a database schema.""" __visit_name__ = "schema_item" diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index f71678a4ab..90c93bcef1 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -70,6 +70,7 @@ from ..util.typing import TupleAny if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument + from ._typing import _CreateDropBind from ._typing import _TypeEngineArgument from .elements import ColumnElement from .operators import OperatorType @@ -1179,21 +1180,23 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): kw.setdefault("_adapted_from", self) return super().adapt(cls, **kw) - def create(self, bind, checkfirst=False): + def create(self, bind: _CreateDropBind, checkfirst: bool = False) -> None: """Issue CREATE DDL for this type, if applicable.""" t = self.dialect_impl(bind.dialect) if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t.create(bind, checkfirst=checkfirst) - def drop(self, bind, checkfirst=False): + def drop(self, bind: _CreateDropBind, checkfirst: bool = False) -> None: """Issue DROP DDL for this type, if applicable.""" t = self.dialect_impl(bind.dialect) if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t.drop(bind, checkfirst=checkfirst) - def _on_table_create(self, target, bind, **kw): + def _on_table_create( + self, target: Any, bind: _CreateDropBind, **kw: Any + ) -> None: if not self._is_impl_for_variant(bind.dialect, kw): return @@ -1201,7 +1204,9 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_table_create(target, bind, **kw) - def _on_table_drop(self, target, bind, **kw): + def _on_table_drop( + self, target: Any, bind: _CreateDropBind, **kw: Any + ) -> None: if not self._is_impl_for_variant(bind.dialect, kw): return @@ -1209,7 +1214,9 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_table_drop(target, bind, **kw) - def _on_metadata_create(self, target, bind, **kw): + def _on_metadata_create( + self, target: Any, bind: _CreateDropBind, **kw: Any + ) -> None: if not self._is_impl_for_variant(bind.dialect, kw): return @@ -1217,7 +1224,9 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_metadata_create(target, bind, **kw) - def _on_metadata_drop(self, target, bind, **kw): + def _on_metadata_drop( + self, target: Any, bind: _CreateDropBind, **kw: Any + ) -> None: if not self._is_impl_for_variant(bind.dialect, kw): return @@ -1225,7 +1234,9 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_metadata_drop(target, bind, **kw) - def _is_impl_for_variant(self, dialect, kw): + def _is_impl_for_variant( + self, dialect: Dialect, kw: Dict[str, Any] + ) -> Optional[bool]: variant_mapping = kw.pop("variant_mapping", None) if not variant_mapping: @@ -1242,7 +1253,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): # since PostgreSQL is the only DB that has ARRAY this can only # be integration tested by PG-specific tests - def _we_are_the_impl(typ): + def _we_are_the_impl(typ: SchemaType) -> bool: return ( typ is self or isinstance(typ, ARRAY) @@ -1255,6 +1266,8 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): return True elif dialect.name not in variant_mapping: return _we_are_the_impl(variant_mapping["_default"]) + else: + return None _EnumTupleArg = Union[Sequence[enum.Enum], Sequence[str]] diff --git a/test/sql/test_types.py b/test/sql/test_types.py index e6e2a18f16..eb4b420129 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -298,6 +298,7 @@ class AdaptTest(fixtures.TestBase): "schema", "metadata", "name", + "dispatch", ): continue # assert each value was copied, or that