From 7111dc0cbaae2070173e9f4a054745ee7d036dfa Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 31 Aug 2025 18:53:01 -0400 Subject: [PATCH] add RegistryEvents Added :class:`_orm.RegistryEvents` event class that allows event listeners to be established on a :class:`_orm.registry` object. The new class provides three events: :meth:`_orm.RegistryEvents.resolve_type_annotation` which allows customization of type annotation resolution that can supplement or replace the use of the :paramref:`.registry.type_annotation_map` dictionary, including that it can be helpful with custom resolution for complex types such as those of :pep:`695`, as well as :meth:`_orm.RegistryEvents.before_configured` and :meth:`_orm.RegistryEvents.after_configured`, which are registry-local forms of the mapper-wide version of these hooks. Fixes: #9832 Change-Id: I32b55de8625ec435edf916a91e65f61fda50cd51 --- doc/build/changelog/migration_21.rst | 46 ++++ doc/build/changelog/unreleased_21/9832.rst | 18 ++ doc/build/orm/declarative_tables.rst | 151 ++++++++++- doc/build/orm/events.rst | 16 ++ lib/sqlalchemy/orm/__init__.py | 2 + lib/sqlalchemy/orm/decl_api.py | 146 ++++++++++- lib/sqlalchemy/orm/events.py | 209 +++++++++++++++- lib/sqlalchemy/orm/mapper.py | 28 ++- lib/sqlalchemy/orm/properties.py | 49 ++-- lib/sqlalchemy/orm/util.py | 11 +- lib/sqlalchemy/sql/sqltypes.py | 2 +- lib/sqlalchemy/sql/type_api.py | 8 +- lib/sqlalchemy/util/typing.py | 26 +- test/orm/test_events.py | 278 +++++++++++++++++++-- 14 files changed, 915 insertions(+), 75 deletions(-) create mode 100644 doc/build/changelog/unreleased_21/9832.rst diff --git a/doc/build/changelog/migration_21.rst b/doc/build/changelog/migration_21.rst index 0e345f4e80..329cc42e77 100644 --- a/doc/build/changelog/migration_21.rst +++ b/doc/build/changelog/migration_21.rst @@ -443,6 +443,52 @@ Annotated Declarative setting from taking place. :ticket:`12570` +.. _change_9832: + +New RegistryEvents System for ORM Mapping Customization +-------------------------------------------------------- + +SQLAlchemy 2.1 introduces :class:`.RegistryEvents`, providing for event +hooks that are specific to a :class:`_orm.registry`. These events include +:meth:`_orm.RegistryEvents.before_configured` and :meth:`_orm.RegistryEvents.after_configured` +to complement the same-named events that can be established on a +:class:`_orm.Mapper`, as well as :meth:`_orm.RegistryEvents.resolve_type_annotation` +that allows programmatic access to the ORM Annotated Declarative type resolution +process. Examples are provided illustrating how to define resolution schemes +for any kind of type hierarchy in an automated fashion, including :pep:`695` +type aliases. + +E.g.:: + + from sqlalchemy import event + from sqlalchemy.orm import DeclarativeBase + + + class Base(DeclarativeBase): + pass + + + @event.listens_for(Base, "resolve_type_annotation") + def resolve_custom_type(resolve_type): + if resolve_type.primary_type is MyCustomType: + return MyCustomSQLType() + else: + return None + + + @event.listens_for(Base, "after_configured") + def after_base_configured(registry): + print(f"Registry {registry} fully configured") + +.. seealso:: + + :ref:`orm_declarative_resolve_type_event` - Complete documentation on using + the :meth:`.RegistryEvents.resolve_type_annotation` event + + :class:`.RegistryEvents` - Complete API reference for all registry events + +:ticket:`9832` + New Features and Improvements - Core ===================================== diff --git a/doc/build/changelog/unreleased_21/9832.rst b/doc/build/changelog/unreleased_21/9832.rst new file mode 100644 index 0000000000..2b894e30b7 --- /dev/null +++ b/doc/build/changelog/unreleased_21/9832.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: feature, orm + :tickets: 9832 + + Added :class:`_orm.RegistryEvents` event class that allows event listeners + to be established on a :class:`_orm.registry` object. The new class + provides three events: :meth:`_orm.RegistryEvents.resolve_type_annotation` + which allows customization of type annotation resolution that can + supplement or replace the use of the + :paramref:`.registry.type_annotation_map` dictionary, including that it can + be helpful with custom resolution for complex types such as those of + :pep:`695`, as well as :meth:`_orm.RegistryEvents.before_configured` and + :meth:`_orm.RegistryEvents.after_configured`, which are registry-local + forms of the mapper-wide version of these hooks. + + .. seealso:: + + :ref:`change_9832` diff --git a/doc/build/orm/declarative_tables.rst b/doc/build/orm/declarative_tables.rst index e56ac8a51f..1a86117605 100644 --- a/doc/build/orm/declarative_tables.rst +++ b/doc/build/orm/declarative_tables.rst @@ -956,8 +956,8 @@ no other directive for nullability is present. .. _orm_declarative_mapped_column_type_map_pep593: -Mapping Multiple Type Configurations to Python Types -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Mapping Multiple Type Configurations to Python Types with pep-593 Annotated +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ As individual Python types may be associated with :class:`_types.TypeEngine` @@ -1064,8 +1064,8 @@ more open ended. .. _orm_declarative_mapped_column_pep593: -Mapping Whole Column Declarations to Python Types -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Mapping Whole Column Declarations to Python Types using pep-593 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The previous section illustrated using :pep:`593` ``Annotated`` type @@ -1578,7 +1578,150 @@ In the above configuration, the ``my_literal`` datatype will resolve to a :class:`._sqltypes.JSON` instance. Other ``Literal`` variants will continue to resolve to :class:`_sqltypes.Enum` datatypes. +.. _orm_declarative_resolve_type_event: + +Resolving Types Programmatically with Events +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. versionadded:: 2.1 + +The :paramref:`_orm.registry.type_annotation_map` is the usual +way to customize how :func:`_orm.mapped_column` types are assigned to Python +types. But for automation of whole classes of types or other custom rules, +the type map resolution can be augmented and/or replaced using the +:meth:`.RegistryEvents.resolve_type_annotation` hook. + +This event hook allows for dynamic type resolution that goes beyond the static +mappings possible with :paramref:`_orm.registry.type_annotation_map`. It's +particularly useful when working with generic types, complex type hierarchies, +or when you need to implement custom logic for determining SQL types based +on Python type annotations. + +Basic Type Resolution with :meth:`.RegistryEvents.resolve_type_annotation` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Basic type resolution can be set up by registering the event against +a :class:`_orm.registry` or :class:`_orm.DeclarativeBase` class. The event +receives several parameters that allow inspection of the type annotation +and provide hooks for custom resolution logic. + +The following example shows how to use the hook to resolve custom type aliases +to appropriate SQL types:: + + from __future__ import annotations + + from typing import Annotated, get_origin, get_args + from sqlalchemy import String, Integer, event + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + # Define some custom type aliases + UserId = int + Username = str + LongText = Annotated[str, "long"] + + + class Base(DeclarativeBase): + pass + + + @event.listens_for(Base.registry, "resolve_type_annotation") + def resolve_custom_types(resolve_type): + # Handle our custom type aliases + if resolve_type.primary_type is UserId: + return Integer + elif resolve_type.primary_type is Username: + return String(50) + elif resolve_type.pep_593_type: + inner_type, *metadata = get_args(resolve_type.primary_type) + if inner_type is str and "long" in metadata: + return String(1000) + + # Fall back to default resolution + return None + + + class User(Base): + __tablename__ = "user" + + id: Mapped[UserId] = mapped_column(primary_key=True) + name: Mapped[Username] + description: Mapped[LongText] + +In this example, the event handler checks for specific type aliases and +returns appropriate SQL types. When the handler returns ``None``, the +default type resolution logic is used. + +Programmatic Resolution of pep-695 and NewType types +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +As detailed in :ref:`orm_declarative_type_map_pep695_types`, SQLAlchemy now +automatically resolves :pep:`695` ``type`` aliases, but does not +automatically resolve types made using ``typing.NewType`` without +these types being explicitly present in :paramref:`_orm.registry.type_annotation_map`. + +The :meth:`.RegistryEvents.resolve_type_annotation` event provides a way +to programmatically handle these types. This is particularly useful when you have +many ``NewType`` instances that would be cumbersome +to list individually in the type annotation map:: + + from __future__ import annotations + + from typing import Annotated + from typing import NewType + + from sqlalchemy import event + from sqlalchemy import String + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + # Multiple NewType instances + IntPK = NewType("IntPK", int) + UserId = NewType("UserId", int) + ProductId = NewType("ProductId", int) + CategoryName = NewType("CategoryName", str) + + # PEP 695 type alias that recursively refers to another PEP 695 type + type OrderId = Annotated[IntPK, mapped_column(primary_key=True)] + + + class Base(DeclarativeBase): + pass + + + @event.listens_for(Base.registry, "resolve_type_annotation") + def resolve_newtype_and_pep695(resolve_type): + # Handle NewType instances by checking their supertype + if hasattr(resolve_type.primary_type, "__supertype__"): + supertype = resolve_type.primary_type.__supertype__ + if supertype is int: + # return default resolution for int + return resolve_type.resolve(int) + elif supertype is str: + return String(100) + + # detect nested pep-695 IntPK type + if resolve_type.primary_type is IntPK or resolve_type.pep_593_type is IntPK: + return resolve_type.resolve(int) + + return None + + + class Order(Base): + __tablename__ = "order" + + id: Mapped[OrderId] + user_id: Mapped[UserId] + product_id: Mapped[ProductId] + category_name: Mapped[CategoryName] + +This approach allows you to handle entire categories of types programmatically +rather than having to enumerate each one in the type annotation map. + + +.. seealso:: + :meth:`.RegistryEvents.resolve_type_annotation` .. _orm_imperative_table_configuration: diff --git a/doc/build/orm/events.rst b/doc/build/orm/events.rst index 1db1137e08..37e278df32 100644 --- a/doc/build/orm/events.rst +++ b/doc/build/orm/events.rst @@ -70,6 +70,22 @@ Types of things which occur at the :class:`_orm.Mapper` level include: .. autoclass:: sqlalchemy.orm.MapperEvents :members: +Registry Events +--------------- + +Registry event hooks indicate things happening in reference to a particular +:class:`_orm.registry`. These include configurational events +:meth:`_orm.RegistryEvents.before_configured` and +:meth:`_orm.RegistryEvents.after_configured`, as well as a hook to customize +type resolution :meth:`_orm.RegistryEvents.resolve_type_annotation`. + +.. autoclass:: sqlalchemy.orm.RegistryEvents + :members: + +.. autoclass:: sqlalchemy.orm.TypeResolve + :members: + + Instance Events --------------- diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 7771de47eb..54e92927d2 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -66,6 +66,7 @@ from .decl_api import has_inherited_table as has_inherited_table from .decl_api import MappedAsDataclass as MappedAsDataclass from .decl_api import registry as registry from .decl_api import synonym_for as synonym_for +from .decl_api import TypeResolve as TypeResolve from .decl_base import MappedClassProtocol as MappedClassProtocol from .descriptor_props import Composite as Composite from .descriptor_props import CompositeProperty as CompositeProperty @@ -77,6 +78,7 @@ from .events import InstanceEvents as InstanceEvents from .events import InstrumentationEvents as InstrumentationEvents from .events import MapperEvents as MapperEvents from .events import QueryEvents as QueryEvents +from .events import RegistryEvents as RegistryEvents from .events import SessionEvents as SessionEvents from .identity import IdentityMap as IdentityMap from .instrumentation import ClassManager as ClassManager diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index b99906ed61..ec63ca84e3 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -63,6 +63,8 @@ from .state import InstanceState from .. import exc from .. import inspection from .. import util +from ..event import dispatcher +from ..event import EventTarget from ..sql import sqltypes from ..sql.base import _NoArg from ..sql.elements import SQLCoreOperations @@ -86,7 +88,7 @@ if TYPE_CHECKING: from .interfaces import MapperProperty from .state import InstanceState # noqa from ..sql._typing import _TypeEngineArgument - from ..sql.type_api import _MatchedOnType + from ..util.typing import _MatchedOnType _T = TypeVar("_T", bound=Any) @@ -1103,7 +1105,7 @@ def declarative_base( ) -class registry: +class registry(EventTarget): """Generalized registry for mapping classes. The :class:`_orm.registry` serves as the basis for maintaining a collection @@ -1144,6 +1146,7 @@ class registry: _dependents: Set[_RegistryType] _dependencies: Set[_RegistryType] _new_mappers: bool + dispatch: dispatcher["registry"] def __init__( self, @@ -1226,6 +1229,53 @@ class registry: } ) + def _resolve_type_with_events( + self, + cls: Any, + key: str, + raw_annotation: _MatchedOnType, + extracted_type: _MatchedOnType, + *, + pep_593_type: Optional[_MatchedOnType] = None, + pep_695_type: Optional[_MatchedOnType] = None, + ) -> Optional[sqltypes.TypeEngine[Any]]: + """Resolve type with event support for custom type mapping. + + This method fires the resolve_type_annotation event first to allow + custom resolution, then falls back to normal resolution. + + """ + + if self.dispatch.resolve_type_annotation: + type_resolve = TypeResolve( + self, + cls, + key, + raw_annotation, + extracted_type, + pep_593_type, + pep_695_type, + ) + + for fn in self.dispatch.resolve_type_annotation: + result = fn(type_resolve) + if result is not None: + return result # type: ignore[no-any-return] + + if pep_695_type is not None: + sqltype = self._resolve_type(pep_695_type) + if sqltype is not None: + return sqltype + + sqltype = self._resolve_type(extracted_type) + if sqltype is not None: + return sqltype + + if pep_593_type is not None: + sqltype = self._resolve_type(pep_593_type) + + return sqltype + def _resolve_type( self, python_type: _MatchedOnType ) -> Optional[sqltypes.TypeEngine[Any]]: @@ -1815,6 +1865,98 @@ if not TYPE_CHECKING: _RegistryType = registry # noqa +class TypeResolve: + """Primary argument to the :meth:`.RegistryEvents.resolve_type_annotation` + event. + + This object contains all the information needed to resolve a Python + type to a SQLAlchemy type. The :attr:`.TypeResolve.primary_type` is + typically the main type that's resolved. To resolve an arbitrary + Python type against the current type map, the :meth:`.TypeResolve.resolve` + method may be used. + + .. versionadded:: 2.1 + + """ + + __slots__ = ( + "registry", + "cls", + "key", + "raw_type", + "primary_type", + "pep_593_type", + "pep_695_type", + ) + + cls: Any + "The class being processed during declarative mapping" + + registry: "registry" + "The :class:`registry` being used" + + key: str + "String name of the ORM mapped attribute being processed" + + raw_type: _MatchedOnType + """The type annotation object directly from the attribute's annotations. + + It's recommended to look at :attr:`.TypeResolve.primary_type` or + one of :attr:`.TypeResolve.pep_593_type` or + :attr:`.TypeResolve.pep_695_type` rather than the raw type, as the raw + type will not be de-optionalized. + + """ + + primary_type: _MatchedOnType + """The primary located type annotation within the raw annotation, which + will be a de-optionalized, :pep:`695` resolved form of the original type + """ + + pep_593_type: Optional[_MatchedOnType] + """The type extracted from a :pep:`593` ``Annotated`` construct, if the + type referred to one.""" + + pep_695_type: Optional[_MatchedOnType] + "The de-optionalized :pep:`695` type, if the raw type referred to one." + + def __init__( + self, + registry: RegistryType, + cls: Any, + key: str, + raw_type: _MatchedOnType, + primary_type: _MatchedOnType, + pep_593_type: Optional[_MatchedOnType], + pep_695_type: Optional[_MatchedOnType], + ): + self.registry = registry + self.cls = cls + self.key = key + self.raw_type = raw_type + self.primary_type = primary_type + self.pep_593_type = pep_593_type + self.pep_695_type = pep_695_type + + def resolve( + self, python_type: _MatchedOnType + ) -> Optional[sqltypes.TypeEngine[Any]]: + """Resolve the given python type using the type_annotation_map of + the :class:`registry`. + + :param python_type: a Python type (e.g. ``int``, ``str``, etc.) Any + type object that's present in + :paramref:`_orm.registry_type_annotation_map` should produce a + non-``None`` result. + :return: a SQLAlchemy :class:`.TypeEngine` instance + (e.g. :class:`.Integer`, + :class:`.String`, etc.), or ``None`` to indicate no type could be + matched. + + """ + return self.registry._resolve_type(python_type) + + def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]: """ Class decorator which will adapt a given class into a diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index b915cdfec8..a8b180955f 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -23,6 +23,7 @@ from typing import TypeVar from typing import Union import weakref +from . import decl_api from . import instrumentation from . import interfaces from . import mapperlib @@ -64,6 +65,7 @@ if TYPE_CHECKING: from ..orm.context import QueryContext from ..orm.decl_api import DeclarativeAttributeIntercept from ..orm.decl_api import DeclarativeMeta + from ..orm.decl_api import registry from ..orm.mapper import Mapper from ..orm.state import InstanceState @@ -813,7 +815,14 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): "event target, use the 'sqlalchemy.orm.Mapper' class.", "2.0", ) - return mapperlib.Mapper + target = mapperlib.Mapper + + if identifier in ("before_configured", "after_configured"): + if target is mapperlib.Mapper: + return target + else: + return None + elif isinstance(target, type): if issubclass(target, mapperlib.Mapper): return target @@ -841,16 +850,6 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): event_key._listen_fn, ) - if ( - identifier in ("before_configured", "after_configured") - and target is not mapperlib.Mapper - ): - util.warn( - "'before_configured' and 'after_configured' ORM events " - "only invoke with the Mapper class " - "as the target." - ) - if not raw or not retval: if not raw: meth = getattr(cls, identifier) @@ -999,6 +998,10 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): :meth:`.MapperEvents.after_configured` + :meth:`.RegistryEvents.before_configured` + + :meth:`.RegistryEvents.after_configured` + :meth:`.MapperEvents.mapper_configured` """ @@ -1051,6 +1054,10 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): :meth:`.MapperEvents.after_configured` + :meth:`.RegistryEvents.before_configured` + + :meth:`.RegistryEvents.after_configured` + :meth:`.MapperEvents.before_mapper_configured` """ @@ -1098,6 +1105,10 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): :meth:`.MapperEvents.after_configured` + :meth:`.RegistryEvents.before_configured` + + :meth:`.RegistryEvents.after_configured` + """ @event._omit_standard_example @@ -1142,6 +1153,10 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): :meth:`.MapperEvents.before_configured` + :meth:`.RegistryEvents.before_configured` + + :meth:`.RegistryEvents.after_configured` + """ def before_insert( @@ -3184,3 +3199,175 @@ class QueryEvents(event.Events[Query[Any]]): wrap._bake_ok = bake_ok # type: ignore [attr-defined] event_key.base_listen(**kw) + + +class RegistryEvents(event.Events["registry"]): + """Define events specific to :class:`_orm.registry` lifecycle. + + The :class:`_orm.RegistryEvents` class defines events that are specific + to the lifecycle and operation of the :class:`_orm.registry` object. + + e.g.:: + + from sqlalchemy import event + from sqlalchemy.orm import registry + + reg = registry() + + + @event.listens_for(reg, "resolve_type_annotation") + def resolve_custom_type(registry, python_type, resolved_type): + if python_type is MyCustomType: + return MyCustomSQLType() + return resolved_type + + The events defined by :class:`_orm.RegistryEvents` include + :meth:`_orm.RegistryEvents.resolve_type_annotation`, + :meth:`_orm.RegistryEvents.before_configured`, and + :meth:`_orm.RegistryEvents.after_configured`.`. These events may be + applied to a :class:`_orm.registry` object as shown in the preceding + example, as well as to a declarative base class directly, which will + automtically locate the registry for the event to be applied:: + + from sqlalchemy import event + from sqlalchemy.orm import DeclarativeBase + + + class Base(DeclarativeBase): + pass + + + @event.listens_for(Base, "resolve_type_annotation") + def resolve_custom_type(resolve_type): + if resolve_type.primary_type is MyCustomType: + return MyCustomSQLType() + else: + return None + + + @event.listens_for(Base, "after_configured") + def after_base_configured(registry): + print(f"Registry {registry} fully configured") + + .. versionadded:: 2.1 + + + """ + + _target_class_doc = "SomeRegistry" + _dispatch_target = decl_api.registry + + @classmethod + def _accept_with( + cls, + target: Any, + identifier: str, + ) -> Any: + # Import here to avoid circular imports + from . import decl_api + + if isinstance(target, decl_api.registry): + return target + elif ( + isinstance(target, type) + and "_sa_registry" in target.__dict__ + and isinstance(target.__dict__["_sa_registry"], decl_api.registry) + ): + return target._sa_registry # type: ignore[attr-defined] + else: + return None + + @classmethod + def _listen( + cls, + event_key: _EventKey["registry"], + **kw: Any, + ) -> None: + identifier = event_key.identifier + + # Only resolve_type_annotation needs retval=True + if identifier == "resolve_type_annotation": + kw["retval"] = True + + event_key.base_listen(**kw) + + def resolve_type_annotation( + self, resolve_type: decl_api.TypeResolve + ) -> Optional[Any]: + """Intercept and customize type annotation resolution. + + This event is fired when the :class:`_orm.registry` attempts to + resolve a Python type annotation to a SQLAlchemy type. This is + particularly useful for handling advanced typing scenarios such as + PEP 695 type aliases. + + The :meth:`.RegistryEvents.resolve_type_annotation` event automatically + sets up ``retval=True`` when the event is set up, so that implementing + functions may return a resolved type (or ``None`` to indicate no type + was resolved). + + :param resolve_type: A :class:`_orm.TypeResolve` object which contains + all the relevant information about the type, including a link to the + registry and its resolver function. + + :return: A SQLAlchemy type to use for the given Python type. If + ``None`` is returned, the default resolution behavior will proceed + from there (equivalent to invoking ``resolver(extracted_type)``, or + None to use the default resolution behavior. + + .. versionadded:: 2.1 + + .. seealso:: + + :ref:`orm_declarative_resolve_type_event` + + """ + + def before_configured(self, registry: "registry") -> None: + """Called before a series of mappers in this registry are configured. + + This event is invoked each time the :func:`_orm.configure_mappers` + function is invoked and this registry has mappers that are part of + the configuration process. + + Compared to the :meth:`.MapperEvents.before_configured` event hook, + this event is local to the mappers within a specific + :class:`_orm.registry` and not for all :class:`.Mapper` objects + globally. :param registry: The :class:`_orm.registry` instance. + + .. versionadded:: 2.1 + + .. seealso:: + + :meth:`.RegistryEvents.after_configured` + + :meth:`.MapperEvents.before_configured` + + :meth:`.MapperEvents.after_configured` + + """ + + def after_configured(self, registry: "registry") -> None: + """Called after a series of mappers in this registry are configured. + + This event is invoked each time the :func:`_orm.configure_mappers` + function completes and this registry had mappers that were part of + the configuration process. + + Compared to the :meth:`.MapperEvents.after_configured` event hook, this + event is local to the mappers within a specific :class:`_orm.registry` + and not for all :class:`.Mapper` objects globally. + + :param registry: The :class:`_orm.registry` instance. + + .. versionadded:: 2.1 + + .. seealso:: + + :meth:`.RegistryEvents.before_configured` + + :meth:`.MapperEvents.before_configured` + + :meth:`.MapperEvents.after_configured` + + """ diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 017b829d8a..f2ab5a9820 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -4108,6 +4108,12 @@ def configure_mappers() -> None: work; this can be used to establish additional options, properties, or related mappings before the operation proceeds. + * :meth:`.RegistryEvents.before_configured` - Like + :meth:`.MapperEvents.before_configured`, but local to a specific + :class:`_orm.registry`. + + .. versionadded:: 2.1 + * :meth:`.MapperEvents.mapper_configured` - called as each individual :class:`_orm.Mapper` is configured within the process; will include all mapper state except for backrefs set up by other mappers that are still @@ -4123,6 +4129,12 @@ def configure_mappers() -> None: if they are in other :class:`_orm.registry` collections not part of the current scope of configuration. + * :meth:`.RegistryEvents.after_configured` - Like + :meth:`.MapperEvents.after_configured`, but local to a specific + :class:`_orm.registry`. + + .. versionadded:: 2.1 + """ _configure_registries(_all_registries(), cascade=True) @@ -4151,26 +4163,35 @@ def _configure_registries( return Mapper.dispatch._for_class(Mapper).before_configured() # type: ignore # noqa: E501 + # initialize properties on all mappers # note that _mapper_registry is unordered, which # may randomly conceal/reveal issues related to # the order of mapper compilation - _do_configure_registries(registries, cascade) + registries_configured = list( + _do_configure_registries(registries, cascade) + ) + finally: _already_compiling = False + for reg in registries_configured: + reg.dispatch.after_configured(reg) Mapper.dispatch._for_class(Mapper).after_configured() # type: ignore @util.preload_module("sqlalchemy.orm.decl_api") def _do_configure_registries( registries: Set[_RegistryType], cascade: bool -) -> None: +) -> Iterator[registry]: registry = util.preloaded.orm_decl_api.registry orig = set(registries) for reg in registry._recurse_with_dependencies(registries): + if reg._new_mappers: + reg.dispatch.before_configured(reg) + has_skip = False for mapper in reg._mappers_to_configure(): @@ -4205,6 +4226,9 @@ def _do_configure_registries( if not hasattr(exc, "_configure_failed"): mapper._configure_failed = exc raise + + if reg._new_mappers: + yield reg if not has_skip: reg._new_mappers = False diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 5623e61623..177464be9c 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -65,6 +65,8 @@ from ..util.typing import is_pep695 from ..util.typing import Self if TYPE_CHECKING: + from typing import ForwardRef + from ._typing import _IdentityKeyType from ._typing import _InstanceDict from ._typing import _ORMColumnExprArgument @@ -80,6 +82,7 @@ if TYPE_CHECKING: from ..sql.elements import NamedColumn from ..sql.operators import OperatorType from ..util.typing import _AnnotationScanType + from ..util.typing import _MatchedOnType from ..util.typing import RODescriptorReference _T = TypeVar("_T", bound=Any) @@ -775,25 +778,30 @@ class MappedColumn( ) -> None: sqltype = self.column.type + de_stringified_argument: _MatchedOnType + if is_fwd_ref( argument, check_generic=True, check_for_plain_string=True ): assert originating_module is not None - argument = de_stringify_annotation( + de_stringified_argument = de_stringify_annotation( cls, argument, originating_module, include_generic=True ) + else: + if TYPE_CHECKING: + assert not isinstance(argument, (str, ForwardRef)) + de_stringified_argument = argument - nullable = includes_none(argument) + nullable = includes_none(de_stringified_argument) if not self._has_nullable: self.column.nullable = nullable find_mapped_in: Tuple[Any, ...] = () - our_type_is_pep593 = False raw_pep_593_type = None raw_pep_695_type = None - our_type: Any = de_optionalize_union_types(argument) + our_type: Any = de_optionalize_union_types(de_stringified_argument) if is_pep695(our_type): raw_pep_695_type = our_type @@ -803,8 +811,6 @@ class MappedColumn( our_type = our_type[our_args] if is_pep593(our_type): - our_type_is_pep593 = True - pep_593_components = get_args(our_type) raw_pep_593_type = pep_593_components[0] if nullable: @@ -899,20 +905,23 @@ class MappedColumn( ) if sqltype._isnull and not self.column.foreign_keys: - checks: List[Any] - if our_type_is_pep593: - checks = [our_type, raw_pep_593_type] - else: - checks = [our_type] - if raw_pep_695_type is not None: - checks.insert(0, raw_pep_695_type) + new_sqltype = registry._resolve_type_with_events( + cls, + key, + de_stringified_argument, + our_type, + pep_593_type=raw_pep_593_type, + pep_695_type=raw_pep_695_type, + ) - for check_type in checks: - new_sqltype = registry._resolve_type(check_type) - if new_sqltype is not None: - break - else: + if new_sqltype is None: + checks = [] + if raw_pep_695_type: + checks.append(raw_pep_695_type) + checks.append(our_type) + if raw_pep_593_type: + checks.append(raw_pep_593_type) if isinstance(our_type, TypeEngine) or ( isinstance(our_type, type) and issubclass(our_type, TypeEngine) @@ -950,8 +959,8 @@ class MappedColumn( raise orm_exc.MappedAnnotationError( f"The object provided inside the {self.column.key!r} " "attribute Mapped annotation is not a Python type, " - f"it's the object {argument!r}. Expected a Python " - "type." + f"it's the object {de_stringified_argument!r}. " + "Expected a Python type." ) self.column._set_type(new_sqltype) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index aacf2f7362..fa63591c6d 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -92,6 +92,7 @@ from ..util.langhelpers import MemoizedSlots from ..util.typing import de_stringify_annotation as _de_stringify_annotation from ..util.typing import eval_name_only as _eval_name_only from ..util.typing import fixup_container_fwd_refs +from ..util.typing import GenericProtocol from ..util.typing import is_origin_of_cls from ..util.typing import TupleAny from ..util.typing import Unpack @@ -123,6 +124,7 @@ if typing.TYPE_CHECKING: from ..sql.selectable import Selectable from ..sql.visitors import anon_map from ..util.typing import _AnnotationScanType + from ..util.typing import _MatchedOnType _T = TypeVar("_T", bound=Any) @@ -163,7 +165,7 @@ class _DeStringifyAnnotation(Protocol): *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, include_generic: bool = False, - ) -> Type[Any]: ... + ) -> _MatchedOnType: ... de_stringify_annotation = cast( @@ -2365,15 +2367,16 @@ def _extract_mapped_subtype( else: return annotated, None - if len(annotated.__args__) != 1: + generic_annotated = cast(GenericProtocol[Any], annotated) + if len(generic_annotated.__args__) != 1: raise orm_exc.MappedAnnotationError( "Expected sub-type for Mapped[] annotation" ) return ( # fix dict/list/set args to be ForwardRef, see #11814 - fixup_container_fwd_refs(annotated.__args__[0]), - annotated.__origin__, + fixup_container_fwd_refs(generic_annotated.__args__[0]), + generic_annotated.__origin__, ) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 916e6444e5..9d62b3a960 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -80,9 +80,9 @@ if TYPE_CHECKING: from .type_api import _BindProcessorType from .type_api import _ComparatorFactory from .type_api import _LiteralProcessorType - from .type_api import _MatchedOnType from .type_api import _ResultProcessorType from ..engine.interfaces import Dialect + from ..util.typing import _MatchedOnType _T = TypeVar("_T", bound="Any") _CT = TypeVar("_CT", bound=Any) diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 2e88542c98..4f3cc4a735 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -18,7 +18,6 @@ from typing import ClassVar from typing import Dict from typing import Generic from typing import Mapping -from typing import NewType from typing import Optional from typing import overload from typing import Protocol @@ -31,6 +30,7 @@ from typing import TypeGuard from typing import TypeVar from typing import Union +from sqlalchemy.util.typing import _MatchedOnType from .base import SchemaEventTarget from .cache_key import CacheConst from .cache_key import NO_CACHE @@ -42,7 +42,6 @@ from .visitors import Visitable from .. import exc from .. import util from ..util.typing import Self -from ..util.typing import TypeAliasType # these are back-assigned by sqltypes. if typing.TYPE_CHECKING: @@ -61,7 +60,6 @@ if typing.TYPE_CHECKING: from .sqltypes import TABLEVALUE as TABLEVALUE # noqa: F401 from ..engine.interfaces import DBAPIModule from ..engine.interfaces import Dialect - from ..util.typing import GenericProtocol _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) @@ -71,10 +69,6 @@ _TE = TypeVar("_TE", bound="TypeEngine[Any]") _CT = TypeVar("_CT", bound=Any) _RT = TypeVar("_RT", bound=Any) -_MatchedOnType = Union[ - "GenericProtocol[Any]", TypeAliasType, NewType, Type[Any] -] - class _NoValueInList(Enum): NO_VALUE_IN_LIST = 0 diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 91a2380109..a4cef868d6 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -73,6 +73,10 @@ _AnnotationScanType = Union[ Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]" ] +_MatchedOnType = Union[ + "GenericProtocol[Any]", TypeAliasType, NewType, Type[Any] +] + class ArgsTypeProtocol(Protocol): """protocol for types that have ``__args__`` @@ -385,11 +389,27 @@ def pep695_values(type_: _AnnotationScanType) -> Set[Any]: return {res} +@overload +def is_fwd_ref( + type_: _AnnotationScanType, + check_generic: bool = ..., + check_for_plain_string: Literal[False] = ..., +) -> TypeGuard[ForwardRef]: ... + + +@overload +def is_fwd_ref( + type_: _AnnotationScanType, + check_generic: bool = ..., + check_for_plain_string: bool = ..., +) -> TypeGuard[Union[str, ForwardRef]]: ... + + def is_fwd_ref( type_: _AnnotationScanType, check_generic: bool = False, check_for_plain_string: bool = False, -) -> TypeGuard[ForwardRef]: +) -> TypeGuard[Union[str, ForwardRef]]: if check_for_plain_string and isinstance(type_, str): return True elif isinstance(type_, _type_instances.ForwardRef): @@ -413,6 +433,10 @@ def de_optionalize_union_types(type_: str) -> str: ... def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: ... +@overload +def de_optionalize_union_types(type_: _MatchedOnType) -> _MatchedOnType: ... + + @overload def de_optionalize_union_types( type_: _AnnotationScanType, diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 85a7d0c344..c6f913728f 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -1,3 +1,9 @@ +import re +from typing import Annotated +from typing import Any +from typing import get_args as typing_get_args +from typing import Optional +from typing import TypeVar from unittest.mock import ANY from unittest.mock import call from unittest.mock import Mock @@ -22,14 +28,18 @@ from sqlalchemy.orm import attributes from sqlalchemy.orm import class_mapper from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import deferred from sqlalchemy.orm import EXT_SKIP from sqlalchemy.orm import instrumentation from sqlalchemy.orm import joinedload from sqlalchemy.orm import lazyload +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import Mapper from sqlalchemy.orm import mapperlib from sqlalchemy.orm import query +from sqlalchemy.orm import registry from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session @@ -39,7 +49,6 @@ from sqlalchemy.orm import UserDefinedOption from sqlalchemy.sql.cache_key import NO_CACHE from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message -from sqlalchemy.testing import assert_warns_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises @@ -53,6 +62,8 @@ from sqlalchemy.testing.fixtures import RemoveORMEventsGlobally from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import gc_collect +from sqlalchemy.util import typing as util_typing +from sqlalchemy.util.typing import TypeAliasType from test.orm import _fixtures @@ -1314,33 +1325,73 @@ class MapperEventsTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): eq_(canary1, ["before_update", "after_update"]) eq_(canary2, []) - def test_before_after_configured_warn_on_non_mapper(self): + @testing.combinations( + ("before_configured",), ("after_configured",), argnames="event_name" + ) + @testing.variation( + "target_type", + [ + "mappercls", + "mapperinstance", + "registry", + "explicit_base", + "imperative_class", + "declarative_class", + ], + ) + def test_before_after_configured_only_on_mappercls_or_registry( + self, event_name, target_type: testing.Variation + ): User, users = self.classes.User, self.tables.users - m1 = Mock() + reg = registry() - self.mapper_registry.map_imperatively(User, users) - assert_warns_message( - sa.exc.SAWarning, - r"before_configured' and 'after_configured' ORM events only " - r"invoke with the Mapper class as " - r"the target.", - event.listen, - User, - "before_configured", - m1, + expect_success = ( + target_type.mappercls + or target_type.registry + or target_type.explicit_base ) - assert_warns_message( - sa.exc.SAWarning, - r"before_configured' and 'after_configured' ORM events only " - r"invoke with the Mapper class as " - r"the target.", - event.listen, - User, - "after_configured", - m1, - ) + if target_type.mappercls: + target = Mapper + elif target_type.mapperinstance: + reg.map_imperatively(User, users) + target = inspect(User) + elif target_type.registry: + target = reg + elif target_type.imperative_class: + reg.map_imperatively(User, users) + target = User + elif target_type.explicit_base: + + class Base(DeclarativeBase): + registry = reg + + target = Base + elif target_type.declarative_class: + + class Base(DeclarativeBase): + registry = reg + + class User(Base): + __table__ = users + + target = User + else: + target_type.fail() + + m1 = Mock() + if expect_success: + event.listen(target, event_name, m1) + else: + + with expect_raises_message( + sa_exc.InvalidRequestError, + re.escape( + f"No such event {event_name!r} for target '{target}'" + ), + ): + event.listen(target, event_name, m1) def test_before_after_configured(self): User, users = self.classes.User, self.tables.users @@ -3698,3 +3749,184 @@ class RefreshFlushInReturningTest(fixtures.MappedTest): eq_(t1.id, 1) eq_(t1.prefetch_val, 5) eq_(t1.returning_val, 5) + + +class RegistryEventsTest(fixtures.MappedTest): + """Test RegistryEvents functionality.""" + + @testing.variation("scenario", ["direct", "reentrant", "plain"]) + @testing.variation("include_optional", [True, False]) + @testing.variation( + "type_features", + [ + "none", + "plain_pep593", + "plain_pep695", + "generic_pep593", + "plain_pep593_pep695", + "generic_pep593_pep695", + "generic_pep593_pep695_w_compound", + ], + ) + def test_resolve_type_annotation_event( + self, scenario: testing.Variation, include_optional, type_features + ): + reg = registry(type_annotation_map={str: String(70)}) + Base = reg.generate_base() + + MyCustomType: Any + if type_features.none: + MyCustomType = type("MyCustomType", (object,), {}) + elif type_features.plain_pep593: + MyCustomType = Annotated[float, mapped_column()] + elif type_features.plain_pep695: + MyCustomType = TypeAliasType("MyCustomType", float) + elif type_features.generic_pep593: + T = TypeVar("T") + MyCustomType = Annotated[T, mapped_column()] + elif type_features.plain_pep593_pep695: + MyCustomType = TypeAliasType( # type: ignore + "MyCustomType", Annotated[float, mapped_column()] + ) + elif type_features.generic_pep593_pep695: + T = TypeVar("T") + MyCustomType = TypeAliasType( # type: ignore + "MyCustomType", Annotated[T, mapped_column()], type_params=(T,) + ) + elif type_features.generic_pep593_pep695_w_compound: + T = TypeVar("T") + MyCustomType = TypeAliasType( # type: ignore + "MyCustomType", + Annotated[T | float, mapped_column()], + type_params=(T,), + ) + else: + type_features.fail() + + @event.listens_for(reg, "resolve_type_annotation") + def resolve_custom_type(type_resolve): + assert type_resolve.cls.__name__ == "MyClass" + + if type_resolve.primary_type is int: + return None + + if type_features.none: + assert type_resolve.primary_type is MyCustomType + elif type_features.plain_pep593: + assert type_resolve.pep_593_type is float + elif type_features.plain_pep695: + assert type_resolve.pep_695_type is MyCustomType + assert type_resolve.primary_type is float + elif type_features.generic_pep593: + assert type_resolve.pep_695_type is None + assert type_resolve.pep_593_type is str + elif type_features.plain_pep593_pep695: + assert type_resolve.pep_695_type is not None + assert type_resolve.pep_593_type is float + assert util_typing.is_pep593(type_resolve.primary_type) + assert type_resolve.pep_695_type is MyCustomType + elif type_features.generic_pep593_pep695: + assert type_resolve.pep_695_type is not None + assert type_resolve.pep_593_type is str + elif type_features.generic_pep593_pep695_w_compound: + assert type_resolve.pep_695_type.__origin__ is MyCustomType + assert typing_get_args(type_resolve.pep_695_type) == (str,) + assert util_typing.is_pep593(type_resolve.primary_type) + assert type_resolve.pep_593_type == str | float + else: + type_features.fail() + + if scenario.direct: + return String(50) + elif scenario.reentrant: + return type_resolve.resolve(str) + else: + scenario.fail() + + use_type_args = ( + type_features.generic_pep593 + or type_features.generic_pep593_pep695 + or type_features.generic_pep593_pep695_w_compound + ) + + class MyClass(Base): + __tablename__ = "mytable" + id: Mapped[int] = mapped_column(primary_key=True) + + if include_optional: + if scenario.direct or scenario.reentrant: + if use_type_args: + data: Mapped[Optional[MyCustomType[str]]] + else: + data: Mapped[Optional[MyCustomType]] + else: + data: Mapped[Optional[int]] + else: + if scenario.direct or scenario.reentrant: + if use_type_args: + data: Mapped[MyCustomType[str]] + else: + data: Mapped[MyCustomType] + else: + data: Mapped[int] + + result = MyClass.data.expression.type + + if scenario.direct: + assert isinstance(result, String) + eq_(result.length, 50) + elif scenario.reentrant: + assert isinstance(result, String) + eq_(result.length, 70) + elif scenario.plain: + assert isinstance(result, Integer) + + @testing.variation( + "listen_type", ["registry", "generated_base", "explicit_base"] + ) + def test_before_after_configured_events(self, listen_type): + """Test the before_configured and after_configured events.""" + reg = registry() + + if listen_type.generated_base: + Base = reg.generate_base() + else: + + class Base(DeclarativeBase): + registry = reg + + mock = Mock() + + if listen_type.registry: + + @event.listens_for(reg, "before_configured") + def before_configured(registry_inst): + mock.before_configured(registry_inst) + + @event.listens_for(reg, "after_configured") + def after_configured(registry_inst): + mock.after_configured(registry_inst) + + else: + + @event.listens_for(Base, "before_configured") + def before_configured(registry_inst): + mock.before_configured(registry_inst) + + @event.listens_for(Base, "after_configured") + def after_configured(registry_inst): + mock.after_configured(registry_inst) + + # Create a simple mapped class to trigger configuration + class TestClass(Base): + __tablename__ = "test_table" + id = Column(Integer, primary_key=True) + + # Configure the registry + reg.configure() + + # Check that events were fired in the correct order + eq_( + mock.mock_calls, + [call.before_configured(reg), call.after_configured(reg)], + ) -- 2.47.3