From a463b1109abb60fc85f8356f30c0351a4e2ed71e Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 18 Feb 2022 10:05:12 -0500 Subject: [PATCH] implement dataclass_transforms Implement a new means of creating a mapped dataclass where instead of applying the `@dataclass` decorator distinctly, the declarative process itself can create the dataclass. MapperProperty and MappedColumn objects themselves take the place of the dataclasses.Field object when constructing the class. The overall approach is made possible at the typing level using pep-681 dataclass transforms [1]. This new approach should be able to completely supersede the previous "dataclasses" approach of embedding metadata into Field() objects, which remains a mutually exclusive declarative setup style (mixing them introduces new issues that are not worth solving). [1] https://peps.python.org/pep-0681/#transform-descriptor-types-example Fixes: #7642 Change-Id: I6ba88a87c5df38270317b4faf085904d91c8a63c --- lib/sqlalchemy/orm/__init__.py | 1 + lib/sqlalchemy/orm/_orm_constructors.py | 167 ++++- lib/sqlalchemy/orm/decl_api.py | 142 +++- lib/sqlalchemy/orm/decl_base.py | 315 ++++++-- lib/sqlalchemy/orm/descriptor_props.py | 48 +- lib/sqlalchemy/orm/instrumentation.py | 9 +- lib/sqlalchemy/orm/interfaces.py | 97 ++- lib/sqlalchemy/orm/properties.py | 63 +- lib/sqlalchemy/orm/relationships.py | 26 +- lib/sqlalchemy/orm/util.py | 45 +- lib/sqlalchemy/testing/fixtures.py | 25 +- lib/sqlalchemy/util/compat.py | 13 +- lib/sqlalchemy/util/typing.py | 8 + pyproject.toml | 1 - setup.cfg | 2 +- test/orm/declarative/test_dc_transforms.py | 816 +++++++++++++++++++++ test/orm/declarative/test_typed_mapping.py | 46 +- 17 files changed, 1661 insertions(+), 163 deletions(-) create mode 100644 test/orm/declarative/test_dc_transforms.py diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index b7d1df5322..4f19ba9460 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -60,6 +60,7 @@ from .decl_api import DeclarativeBaseNoMeta as DeclarativeBaseNoMeta from .decl_api import DeclarativeMeta as DeclarativeMeta from .decl_api import declared_attr as declared_attr from .decl_api import has_inherited_table as has_inherited_table +from .decl_api import MappedAsDataclass as MappedAsDataclass from .decl_api import registry as registry from .decl_api import synonym_for as synonym_for from .descriptor_props import Composite as Composite diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 0692cac09e..ece6a52be8 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -21,9 +21,9 @@ from typing import Union from . import mapperlib as mapperlib from ._typing import _O -from .base import Mapped from .descriptor_props import Composite from .descriptor_props import Synonym +from .interfaces import _AttributeOptions from .properties import ColumnProperty from .properties import MappedColumn from .query import AliasOption @@ -37,6 +37,8 @@ from .util import LoaderCriteriaOption from .. import sql from .. import util from ..exc import InvalidRequestError +from ..sql._typing import _no_kw +from ..sql.base import _NoArg from ..sql.base import SchemaEventTarget from ..sql.schema import SchemaConst from ..sql.selectable import FromClause @@ -105,6 +107,10 @@ def mapped_column( Union[_TypeEngineArgument[Any], SchemaEventTarget] ] = None, *args: SchemaEventTarget, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, nullable: Optional[ Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] ] = SchemaConst.NULL_UNSPECIFIED, @@ -113,7 +119,6 @@ def mapped_column( name: Optional[str] = None, type_: Optional[_TypeEngineArgument[Any]] = None, autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, doc: Optional[str] = None, key: Optional[str] = None, index: Optional[bool] = None, @@ -300,6 +305,12 @@ def mapped_column( type_=type_, autoincrement=autoincrement, default=default, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), doc=doc, key=key, index=index, @@ -325,6 +336,10 @@ def column_property( deferred: bool = False, raiseload: bool = False, comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, active_history: bool = False, expire_on_flush: bool = True, info: Optional[_InfoType] = None, @@ -416,6 +431,12 @@ def column_property( return ColumnProperty( column, *additional_columns, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), group=group, deferred=deferred, raiseload=raiseload, @@ -429,25 +450,61 @@ def column_property( @overload def composite( - class_: Type[_CC], + _class_or_attr: Type[_CC], *attrs: _CompositeAttrType[Any], - **kwargs: Any, + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, ) -> Composite[_CC]: ... @overload def composite( + _class_or_attr: _CompositeAttrType[Any], *attrs: _CompositeAttrType[Any], - **kwargs: Any, + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, ) -> Composite[Any]: ... def composite( - class_: Any = None, + _class_or_attr: Union[ + None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any] + ] = None, *attrs: _CompositeAttrType[Any], - **kwargs: Any, + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, ) -> Composite[Any]: r"""Return a composite column-based property for use with a Mapper. @@ -497,7 +554,26 @@ def composite( :attr:`.MapperProperty.info` attribute of this object. """ - return Composite(class_, *attrs, **kwargs) + if __kw: + raise _no_kw() + + return Composite( + _class_or_attr, + *attrs, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), + group=group, + deferred=deferred, + raiseload=raiseload, + comparator_factory=comparator_factory, + active_history=active_history, + info=info, + doc=doc, + ) def with_loader_criteria( @@ -700,6 +776,10 @@ def relationship( post_update: bool = False, cascade: str = "save-update, merge", viewonly: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, lazy: _LazyLoadArgumentType = "select", passive_deletes: Union[Literal["all"], bool] = False, passive_updates: bool = True, @@ -1532,6 +1612,12 @@ def relationship( post_update=post_update, cascade=cascade, viewonly=viewonly, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), lazy=lazy, passive_deletes=passive_deletes, passive_updates=passive_updates, @@ -1559,6 +1645,10 @@ def synonym( map_column: Optional[bool] = None, descriptor: Optional[Any] = None, comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, info: Optional[_InfoType] = None, doc: Optional[str] = None, ) -> Synonym[Any]: @@ -1670,6 +1760,12 @@ def synonym( map_column=map_column, descriptor=descriptor, comparator_factory=comparator_factory, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), doc=doc, info=info, ) @@ -1784,7 +1880,17 @@ def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument: def deferred( column: _ORMColumnExprArgument[_T], *additional_columns: _ORMColumnExprArgument[Any], - **kw: Any, + group: Optional[str] = None, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, ) -> ColumnProperty[_T]: r"""Indicate a column-based mapped attribute that by default will not load unless accessed. @@ -1803,21 +1909,41 @@ def deferred( :ref:`deferred_raiseload` - :param \**kw: additional keyword arguments passed to - :class:`.ColumnProperty`. + Additional arguments are the same as that of :func:`_orm.column_property`. .. seealso:: :ref:`deferred` """ - kw["deferred"] = True - return ColumnProperty(column, *additional_columns, **kw) + return ColumnProperty( + column, + *additional_columns, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), + group=group, + deferred=True, + raiseload=raiseload, + comparator_factory=comparator_factory, + active_history=active_history, + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) def query_expression( default_expr: _ORMColumnExprArgument[_T] = sql.null(), -) -> Mapped[_T]: + *, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> ColumnProperty[_T]: """Indicate an attribute that populates from a query-time SQL expression. :param default_expr: Optional SQL expression object that will be used in @@ -1840,7 +1966,18 @@ def query_expression( :ref:`mapper_querytime_expression` """ - prop = ColumnProperty(default_expr) + prop = ColumnProperty( + default_expr, + attribute_options=_AttributeOptions( + _NoArg.NO_ARG, + repr, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + ), + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) prop.strategy_key = (("query_expression", True),) return prop diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 1c343b04ce..553a50107f 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -33,6 +33,13 @@ from . import clsregistry from . import instrumentation from . import interfaces from . import mapperlib +from ._orm_constructors import column_property +from ._orm_constructors import composite +from ._orm_constructors import deferred +from ._orm_constructors import mapped_column +from ._orm_constructors import query_expression +from ._orm_constructors import relationship +from ._orm_constructors import synonym from .attributes import InstrumentedAttribute from .base import _inspect_mapped_class from .base import Mapped @@ -42,8 +49,13 @@ from .decl_base import _declarative_constructor from .decl_base import _DeferredMapperConfig from .decl_base import _del_attribute from .decl_base import _mapper +from .descriptor_props import Composite +from .descriptor_props import Synonym from .descriptor_props import Synonym as _orm_synonym from .mapper import Mapper +from .properties import ColumnProperty +from .properties import MappedColumn +from .relationships import Relationship from .state import InstanceState from .. import exc from .. import inspection @@ -60,9 +72,9 @@ from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _O from ._typing import _RegistryType - from .descriptor_props import Synonym from .instrumentation import ClassManager from .interfaces import MapperProperty + from .state import InstanceState # noqa from ..sql._typing import _TypeEngineArgument _T = TypeVar("_T", bound=Any) @@ -120,6 +132,26 @@ class DeclarativeAttributeIntercept( """ +@compat_typing.dataclass_transform( + field_descriptors=( + MappedColumn[Any], + Relationship[Any], + Composite[Any], + ColumnProperty[Any], + Synonym[Any], + mapped_column, + relationship, + composite, + column_property, + synonym, + deferred, + query_expression, + ), +) +class DCTransformDeclarative(DeclarativeAttributeIntercept): + """metaclass that includes @dataclass_transforms""" + + class DeclarativeMeta( _DynamicAttributesType, inspection.Inspectable[Mapper[Any]] ): @@ -543,12 +575,42 @@ class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]): cls._sa_registry.map_declaratively(cls) +class MappedAsDataclass(metaclass=DCTransformDeclarative): + """Mixin class to indicate when mapping this class, also convert it to be + a dataclass. + + .. seealso:: + + :meth:`_orm.registry.mapped_as_dataclass` + + .. versionadded:: 2.0 + """ + + def __init_subclass__( + cls, + init: bool = True, + repr: bool = True, # noqa: A002 + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + ) -> None: + cls._sa_apply_dc_transforms = { + "init": init, + "repr": repr, + "eq": eq, + "order": order, + "unsafe_hash": unsafe_hash, + } + super().__init_subclass__() + + class DeclarativeBase( inspection.Inspectable[InstanceState[Any]], metaclass=DeclarativeAttributeIntercept, ): """Base class used for declarative class definitions. + The :class:`_orm.DeclarativeBase` allows for the creation of new declarative bases in such a way that is compatible with type checkers:: @@ -1121,7 +1183,7 @@ class registry: bases = not isinstance(cls, tuple) and (cls,) or cls - class_dict = dict(registry=self, metadata=metadata) + class_dict: Dict[str, Any] = dict(registry=self, metadata=metadata) if isinstance(cls, type): class_dict["__doc__"] = cls.__doc__ @@ -1142,6 +1204,78 @@ class registry: return metaclass(name, bases, class_dict) + @compat_typing.dataclass_transform( + field_descriptors=( + MappedColumn[Any], + Relationship[Any], + Composite[Any], + ColumnProperty[Any], + Synonym[Any], + mapped_column, + relationship, + composite, + column_property, + synonym, + deferred, + query_expression, + ), + ) + @overload + def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]: + ... + + @overload + def mapped_as_dataclass( + self, + __cls: Literal[None] = ..., + *, + init: bool = True, + repr: bool = True, # noqa: A002 + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + ) -> Callable[[Type[_O]], Type[_O]]: + ... + + def mapped_as_dataclass( + self, + __cls: Optional[Type[_O]] = None, + *, + init: bool = True, + repr: bool = True, # noqa: A002 + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]: + """Class decorator that will apply the Declarative mapping process + to a given class, and additionally convert the class to be a + Python dataclass. + + .. seealso:: + + :meth:`_orm.registry.mapped` + + .. versionadded:: 2.0 + + + """ + + def decorate(cls: Type[_O]) -> Type[_O]: + cls._sa_apply_dc_transforms = { + "init": init, + "repr": repr, + "eq": eq, + "order": order, + "unsafe_hash": unsafe_hash, + } + _as_declarative(self, cls, cls.__dict__) + return cls + + if __cls: + return decorate(__cls) + else: + return decorate + def mapped(self, cls: Type[_O]) -> Type[_O]: """Class decorator that will apply the Declarative mapping process to a given class. @@ -1174,6 +1308,10 @@ class registry: that will apply Declarative mapping to subclasses automatically using a Python metaclass. + .. seealso:: + + :meth:`_orm.registry.mapped_as_dataclass` + """ _as_declarative(self, cls, cls.__dict__) return cls diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index a66421e225..54a272f86e 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -10,6 +10,8 @@ from __future__ import annotations import collections +import dataclasses +import re from typing import Any from typing import Callable from typing import cast @@ -40,6 +42,7 @@ from .base import _is_mapped_class from .base import InspectionAttr from .descriptor_props import Composite from .descriptor_props import Synonym +from .interfaces import _AttributeOptions from .interfaces import _IntrospectsAnnotations from .interfaces import _MappedAttribute from .interfaces import _MapsColumns @@ -48,15 +51,18 @@ from .mapper import Mapper as mapper from .mapper import Mapper from .properties import ColumnProperty from .properties import MappedColumn +from .util import _extract_mapped_subtype from .util import _is_mapped_annotation from .util import class_mapper from .. import event from .. import exc from .. import util from ..sql import expression +from ..sql.base import _NoArg from ..sql.schema import Column from ..sql.schema import Table from ..util import topological +from ..util.typing import _AnnotationScanType from ..util.typing import Protocol if TYPE_CHECKING: @@ -392,11 +398,13 @@ class _ClassScanMapperConfig(_MapperConfig): "mapper_args", "mapper_args_fn", "inherits", + "allow_dataclass_fields", + "dataclass_setup_arguments", ) registry: _RegistryType clsdict_view: _ClassDict - collected_annotations: Dict[str, Tuple[Any, bool]] + collected_annotations: Dict[str, Tuple[Any, Any, bool]] collected_attributes: Dict[str, Any] local_table: Optional[FromClause] persist_selectable: Optional[FromClause] @@ -411,6 +419,17 @@ class _ClassScanMapperConfig(_MapperConfig): mapper_args_fn: Optional[Callable[[], Dict[str, Any]]] inherits: Optional[Type[Any]] + dataclass_setup_arguments: Optional[Dict[str, Any]] + """if the class has SQLAlchemy native dataclass parameters, where + we will create a SQLAlchemy dataclass (not a real dataclass). + + """ + + allow_dataclass_fields: bool + """if true, look for dataclass-processed Field objects on the target + class as well as superclasses and extract ORM mapping directives from + the "metadata" attribute of each Field""" + def __init__( self, registry: _RegistryType, @@ -434,10 +453,37 @@ class _ClassScanMapperConfig(_MapperConfig): self.declared_columns = util.OrderedSet() self.column_copies = {} + self.dataclass_setup_arguments = dca = getattr( + self.cls, "_sa_apply_dc_transforms", None + ) + + cld = dataclasses.is_dataclass(cls_) + + sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__") + + # we don't want to consume Field objects from a not-already-dataclass. + # the Field objects won't have their "name" or "type" populated, + # and while it seems like we could just set these on Field as we + # read them, Field is documented as "user read only" and we need to + # stay far away from any off-label use of dataclasses APIs. + if (not cld or dca) and sdk: + raise exc.InvalidRequestError( + "SQLAlchemy mapped dataclasses can't consume mapping " + "information from dataclass.Field() objects if the immediate " + "class is not already a dataclass." + ) + + # if already a dataclass, and __sa_dataclass_metadata_key__ present, + # then also look inside of dataclass.Field() objects yielded by + # dataclasses.get_fields(cls) when scanning for attributes + self.allow_dataclass_fields = bool(sdk and cld) + self._setup_declared_events() self._scan_attributes() + self._setup_dataclasses_transforms() + with mapperlib._CONFIGURE_MUTEX: clsregistry.add_class( self.classname, self.cls, registry._class_registry @@ -477,11 +523,15 @@ class _ClassScanMapperConfig(_MapperConfig): attribute, taking SQLAlchemy-enabled dataclass fields into account. """ - sa_dataclass_metadata_key = _get_immediate_cls_attr( - cls, "__sa_dataclass_metadata_key__" - ) - if sa_dataclass_metadata_key is None: + if self.allow_dataclass_fields: + sa_dataclass_metadata_key = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" + ) + else: + sa_dataclass_metadata_key = None + + if not sa_dataclass_metadata_key: def attribute_is_overridden(key: str, obj: Any) -> bool: return getattr(cls, key) is not obj @@ -551,6 +601,7 @@ class _ClassScanMapperConfig(_MapperConfig): "__dict__", "__weakref__", "_sa_class_manager", + "_sa_apply_dc_transforms", "__dict__", "__weakref__", ] @@ -563,10 +614,6 @@ class _ClassScanMapperConfig(_MapperConfig): adjusting for SQLAlchemy fields embedded in dataclass fields. """ - sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( - cls, "__sa_dataclass_metadata_key__" - ) - cls_annotations = util.get_annotations(cls) cls_vars = vars(cls) @@ -576,7 +623,15 @@ class _ClassScanMapperConfig(_MapperConfig): names = util.merge_lists_w_ordering( [n for n in cls_vars if n not in skip], list(cls_annotations) ) - if sa_dataclass_metadata_key is None: + + if self.allow_dataclass_fields: + sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" + ) + else: + sa_dataclass_metadata_key = None + + if not sa_dataclass_metadata_key: def local_attributes_for_class() -> Iterable[ Tuple[str, Any, Any, bool] @@ -652,45 +707,51 @@ class _ClassScanMapperConfig(_MapperConfig): name, obj, annotation, - is_dataclass, + is_dataclass_field, ) in local_attributes_for_class(): - if name == "__mapper_args__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - if not mapper_args_fn and (not class_mapped or check_decl): - # don't even invoke __mapper_args__ until - # after we've determined everything about the - # mapped table. - # make a copy of it so a class-level dictionary - # is not overwritten when we update column-based - # arguments. - def _mapper_args_fn() -> Dict[str, Any]: - return dict(cls_as_Decl.__mapper_args__) - - mapper_args_fn = _mapper_args_fn - - elif name == "__tablename__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - if not tablename and (not class_mapped or check_decl): - tablename = cls_as_Decl.__tablename__ - elif name == "__table_args__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - if not table_args and (not class_mapped or check_decl): - table_args = cls_as_Decl.__table_args__ - if not isinstance( - table_args, (tuple, dict, type(None)) + if re.match(r"^__.+__$", name): + if name == "__mapper_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not mapper_args_fn and ( + not class_mapped or check_decl ): - raise exc.ArgumentError( - "__table_args__ value must be a tuple, " - "dict, or None" - ) - if base is not cls: - inherited_table_args = True + # don't even invoke __mapper_args__ until + # after we've determined everything about the + # mapped table. + # make a copy of it so a class-level dictionary + # is not overwritten when we update column-based + # arguments. + def _mapper_args_fn() -> Dict[str, Any]: + return dict(cls_as_Decl.__mapper_args__) + + mapper_args_fn = _mapper_args_fn + + elif name == "__tablename__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not tablename and (not class_mapped or check_decl): + tablename = cls_as_Decl.__tablename__ + elif name == "__table_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not table_args and (not class_mapped or check_decl): + table_args = cls_as_Decl.__table_args__ + if not isinstance( + table_args, (tuple, dict, type(None)) + ): + raise exc.ArgumentError( + "__table_args__ value must be a tuple, " + "dict, or None" + ) + if base is not cls: + inherited_table_args = True + else: + # skip all other dunder names + continue elif class_mapped: if _is_declarative_props(obj): util.warn( @@ -706,9 +767,8 @@ class _ClassScanMapperConfig(_MapperConfig): # acting like that for now. if isinstance(obj, (Column, MappedColumn)): - self.collected_annotations[name] = ( - annotation, - False, + self._collect_annotation( + name, annotation, is_dataclass_field, True, obj ) # already copied columns to the mapped class. continue @@ -745,7 +805,7 @@ class _ClassScanMapperConfig(_MapperConfig): ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) else: - if is_dataclass: + if is_dataclass_field: # access attribute using normal class access # first, to see if it's been mapped on a # superclass. note if the dataclasses.field() @@ -789,14 +849,16 @@ class _ClassScanMapperConfig(_MapperConfig): ): ret.doc = obj.__doc__ - self.collected_annotations[name] = ( + self._collect_annotation( + name, obj._collect_return_annotation(), False, + True, + obj, ) elif _is_mapped_annotation(annotation, cls): - self.collected_annotations[name] = ( - annotation, - is_dataclass, + self._collect_annotation( + name, annotation, is_dataclass_field, True, obj ) if obj is None: if not fixed_table: @@ -809,7 +871,7 @@ class _ClassScanMapperConfig(_MapperConfig): # declarative mapping. however, check for some # more common mistakes self._warn_for_decl_attributes(base, name, obj) - elif is_dataclass and ( + elif is_dataclass_field and ( name not in clsdict_view or clsdict_view[name] is not obj ): # here, we are definitely looking at the target class @@ -826,14 +888,12 @@ class _ClassScanMapperConfig(_MapperConfig): obj = obj.fget() collected_attributes[name] = obj - self.collected_annotations[name] = ( - annotation, - True, + self._collect_annotation( + name, annotation, True, False, obj ) else: - self.collected_annotations[name] = ( - annotation, - False, + self._collect_annotation( + name, annotation, False, None, obj ) if ( obj is None @@ -843,6 +903,10 @@ class _ClassScanMapperConfig(_MapperConfig): collected_attributes[name] = MappedColumn() elif name in clsdict_view: collected_attributes[name] = obj + # else if the name is not in the cls.__dict__, + # don't collect it as an attribute. + # we will see the annotation only, which is meaningful + # both for mapping and dataclasses setup if inherited_table_args and not tablename: table_args = None @@ -851,6 +915,77 @@ class _ClassScanMapperConfig(_MapperConfig): self.tablename = tablename self.mapper_args_fn = mapper_args_fn + def _setup_dataclasses_transforms(self) -> None: + + dataclass_setup_arguments = self.dataclass_setup_arguments + if not dataclass_setup_arguments: + return + + manager = instrumentation.manager_of_class(self.cls) + assert manager is not None + + field_list = [ + _AttributeOptions._get_arguments_for_make_dataclass( + key, + anno, + self.collected_attributes.get(key, _NoArg.NO_ARG), + ) + for key, anno in ( + (key, mapped_anno if mapped_anno else raw_anno) + for key, ( + raw_anno, + mapped_anno, + is_dc, + ) in self.collected_annotations.items() + ) + ] + + annotations = {} + defaults = {} + for item in field_list: + if len(item) == 2: + name, tp = item # type: ignore + elif len(item) == 3: + name, tp, spec = item # type: ignore + defaults[name] = spec + else: + assert False + annotations[name] = tp + + for k, v in defaults.items(): + setattr(self.cls, k, v) + self.cls.__annotations__ = annotations + + dataclasses.dataclass(self.cls, **dataclass_setup_arguments) + + def _collect_annotation( + self, + name: str, + raw_annotation: _AnnotationScanType, + is_dataclass: bool, + expect_mapped: Optional[bool], + attr_value: Any, + ) -> None: + + if expect_mapped is None: + expect_mapped = isinstance(attr_value, _MappedAttribute) + + extracted_mapped_annotation = _extract_mapped_subtype( + raw_annotation, + self.cls, + name, + type(attr_value), + required=False, + is_dataclass_field=False, + expect_mapped=expect_mapped and not self.allow_dataclass_fields, + ) + + self.collected_annotations[name] = ( + raw_annotation, + extracted_mapped_annotation, + is_dataclass, + ) + def _warn_for_decl_attributes( self, cls: Type[Any], key: str, c: Any ) -> None: @@ -982,13 +1117,53 @@ class _ClassScanMapperConfig(_MapperConfig): _undefer_column_name( k, self.column_copies.get(value, value) # type: ignore ) - elif isinstance(value, _IntrospectsAnnotations): - annotation, is_dataclass = self.collected_annotations.get( - k, (None, False) - ) - value.declarative_scan( - self.registry, cls, k, annotation, is_dataclass - ) + else: + if isinstance(value, _IntrospectsAnnotations): + ( + annotation, + extracted_mapped_annotation, + is_dataclass, + ) = self.collected_annotations.get(k, (None, None, False)) + value.declarative_scan( + self.registry, + cls, + k, + annotation, + extracted_mapped_annotation, + is_dataclass, + ) + + if ( + isinstance(value, (MapperProperty, _MapsColumns)) + and value._has_dataclass_arguments + and not self.dataclass_setup_arguments + ): + if isinstance(value, MapperProperty): + argnames = [ + "init", + "default_factory", + "repr", + "default", + ] + else: + argnames = ["init", "default_factory", "repr"] + + args = { + a + for a in argnames + if getattr( + value._attribute_options, f"dataclasses_{a}" + ) + is not _NoArg.NO_ARG + } + raise exc.ArgumentError( + f"Attribute '{k}' on class {cls} includes dataclasses " + f"argument(s): " + f"{', '.join(sorted(repr(a) for a in args))} but " + f"class does not specify " + "SQLAlchemy native dataclass configuration." + ) + our_stuff[k] = value def _extract_declared_columns(self) -> None: @@ -997,6 +1172,7 @@ class _ClassScanMapperConfig(_MapperConfig): # extract columns from the class dict declared_columns = self.declared_columns name_to_prop_key = collections.defaultdict(set) + for key, c in list(our_stuff.items()): if isinstance(c, _MapsColumns): @@ -1019,7 +1195,6 @@ class _ClassScanMapperConfig(_MapperConfig): # otherwise, Mapper will map it under the column key. if mp_to_assign is None and key != col.key: our_stuff[key] = col - elif isinstance(c, Column): # undefer previously occurred here, and now occurs earlier. # ensure every column we get here has been named diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 8c89f96aa9..a366a9534f 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -35,11 +35,11 @@ from .base import LoaderCallableStatus from .base import Mapped from .base import PassiveFlag from .base import SQLORMOperations +from .interfaces import _AttributeOptions from .interfaces import _IntrospectsAnnotations from .interfaces import _MapsColumns from .interfaces import MapperProperty from .interfaces import PropComparator -from .util import _extract_mapped_subtype from .util import _none_set from .. import event from .. import exc as sa_exc @@ -200,24 +200,26 @@ class Composite( def __init__( self, - class_: Union[ + _class_or_attr: Union[ None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any] ] = None, *attrs: _CompositeAttrType[Any], + attribute_options: Optional[_AttributeOptions] = None, active_history: bool = False, deferred: bool = False, group: Optional[str] = None, comparator_factory: Optional[Type[Comparator[_CC]]] = None, info: Optional[_InfoType] = None, + **kwargs: Any, ): - super().__init__() + super().__init__(attribute_options=attribute_options) - if isinstance(class_, (Mapped, str, sql.ColumnElement)): - self.attrs = (class_,) + attrs + if isinstance(_class_or_attr, (Mapped, str, sql.ColumnElement)): + self.attrs = (_class_or_attr,) + attrs # will initialize within declarative_scan self.composite_class = None # type: ignore else: - self.composite_class = class_ # type: ignore + self.composite_class = _class_or_attr # type: ignore self.attrs = attrs self.active_history = active_history @@ -332,19 +334,15 @@ class Composite( cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: - MappedColumn = util.preloaded.orm_properties.MappedColumn - - argument = _extract_mapped_subtype( - annotation, - cls, - key, - MappedColumn, - self.composite_class is None, - is_dataclass_field, - ) - + if ( + self.composite_class is None + and extracted_mapped_annotation is None + ): + self._raise_for_required(key, cls) + argument = extracted_mapped_annotation if argument and self.composite_class is None: if isinstance(argument, str) or hasattr( argument, "__forward_arg__" @@ -371,11 +369,18 @@ class Composite( for param, attr in itertools.zip_longest( insp.parameters.values(), self.attrs ): - if param is None or attr is None: + if param is None: raise sa_exc.ArgumentError( - f"number of arguments to {self.composite_class.__name__} " - f"class and number of attributes don't match" + f"number of composite attributes " + f"{len(self.attrs)} exceeds " + f"that of the number of attributes in class " + f"{self.composite_class.__name__} {len(insp.parameters)}" ) + if attr is None: + # fill in missing attr spots with empty MappedColumn + attr = MappedColumn() + self.attrs += (attr,) + if isinstance(attr, MappedColumn): attr.declarative_scan_for_composite( registry, cls, key, param.name, param.annotation @@ -800,10 +805,11 @@ class Synonym(DescriptorProperty[_T]): map_column: Optional[bool] = None, descriptor: Optional[Any] = None, comparator_factory: Optional[Type[PropComparator[_T]]] = None, + attribute_options: Optional[_AttributeOptions] = None, info: Optional[_InfoType] = None, doc: Optional[str] = None, ): - super().__init__() + super().__init__(attribute_options=attribute_options) self.name = name self.map_column = map_column diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 4fa61b7cee..33de2aee90 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -113,6 +113,7 @@ class ClassManager( "previously known as deferred_scalar_loader" init_method: Optional[Callable[..., None]] + original_init: Optional[Callable[..., None]] = None factory: Optional[_ManagerFactory] @@ -229,7 +230,7 @@ class ClassManager( if finalize and not self._finalized: self._finalize() - def _finalize(self): + def _finalize(self) -> None: if self._finalized: return self._finalized = True @@ -238,14 +239,14 @@ class ClassManager( _instrumentation_factory.dispatch.class_instrument(self.class_) - def __hash__(self): + def __hash__(self) -> int: # type: ignore[override] return id(self) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return other is self @property - def is_mapped(self): + def is_mapped(self) -> bool: return "mapper" in self.__dict__ @HasMemoized.memoized_attribute diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index b5569ce063..e0034061d4 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -19,6 +19,7 @@ are exposed when inspecting mappings. from __future__ import annotations import collections +import dataclasses import typing from typing import Any from typing import Callable @@ -27,6 +28,8 @@ from typing import ClassVar from typing import Dict from typing import Iterator from typing import List +from typing import NamedTuple +from typing import NoReturn from typing import Optional from typing import Sequence from typing import Set @@ -51,11 +54,13 @@ from .base import ONETOMANY as ONETOMANY # noqa: F401 from .base import RelationshipDirection as RelationshipDirection # noqa: F401 from .base import SQLORMOperations from .. import ColumnElement +from .. import exc as sa_exc from .. import inspection from .. import util from ..sql import operators from ..sql import roles from ..sql import visitors +from ..sql.base import _NoArg from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey from ..sql.schema import Column @@ -141,6 +146,7 @@ class _IntrospectsAnnotations: cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: """Perform class-specific initializaton at early declarative scanning @@ -150,6 +156,70 @@ class _IntrospectsAnnotations: """ + def _raise_for_required(self, key: str, cls: Type[Any]) -> NoReturn: + raise sa_exc.ArgumentError( + f"Python typing annotation is required for attribute " + f'"{cls.__name__}.{key}" when primary argument(s) for ' + f'"{self.__class__.__name__}" construct are None or not present' + ) + + +class _AttributeOptions(NamedTuple): + """define Python-local attribute behavior options common to all + :class:`.MapperProperty` objects. + + Currently this includes dataclass-generation arguments. + + .. versionadded:: 2.0 + + """ + + dataclasses_init: Union[_NoArg, bool] + dataclasses_repr: Union[_NoArg, bool] + dataclasses_default: Union[_NoArg, Any] + dataclasses_default_factory: Union[_NoArg, Callable[[], Any]] + + def _as_dataclass_field(self) -> Any: + """Return a ``dataclasses.Field`` object given these arguments.""" + + kw: Dict[str, Any] = {} + if self.dataclasses_default_factory is not _NoArg.NO_ARG: + kw["default_factory"] = self.dataclasses_default_factory + if self.dataclasses_default is not _NoArg.NO_ARG: + kw["default"] = self.dataclasses_default + if self.dataclasses_init is not _NoArg.NO_ARG: + kw["init"] = self.dataclasses_init + if self.dataclasses_repr is not _NoArg.NO_ARG: + kw["repr"] = self.dataclasses_repr + + return dataclasses.field(**kw) + + @classmethod + def _get_arguments_for_make_dataclass( + cls, key: str, annotation: Type[Any], elem: _T + ) -> Union[ + Tuple[str, Type[Any]], Tuple[str, Type[Any], dataclasses.Field[Any]] + ]: + """given attribute key, annotation, and value from a class, return + the argument tuple we would pass to dataclasses.make_dataclass() + for this attribute. + + """ + if isinstance(elem, (MapperProperty, _MapsColumns)): + dc_field = elem._attribute_options._as_dataclass_field() + + return (key, annotation, dc_field) + elif elem is not _NoArg.NO_ARG: + # why is typing not erroring on this? + return (key, annotation, elem) + else: + return (key, annotation) + + +_DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions( + _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG +) + class _MapsColumns(_MappedAttribute[_T]): """interface for declarative-capable construct that delivers one or more @@ -158,6 +228,9 @@ class _MapsColumns(_MappedAttribute[_T]): __slots__ = () + _attribute_options: _AttributeOptions + _has_dataclass_arguments: bool + @property def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: """return a MapperProperty to be assigned to the declarative mapping""" @@ -199,6 +272,8 @@ class MapperProperty( __slots__ = ( "_configure_started", "_configure_finished", + "_attribute_options", + "_has_dataclass_arguments", "parent", "key", "info", @@ -241,6 +316,15 @@ class MapperProperty( doc: Optional[str] """optional documentation string""" + _attribute_options: _AttributeOptions + """behavioral options for ORM-enabled Python attributes + + .. versionadded:: 2.0 + + """ + + _has_dataclass_arguments: bool + def _memoized_attr_info(self) -> _InfoType: """Info dictionary associated with the object, allowing user-defined data to be associated with this :class:`.InspectionAttr`. @@ -349,9 +433,20 @@ class MapperProperty( """ - def __init__(self) -> None: + def __init__( + self, attribute_options: Optional[_AttributeOptions] = None + ) -> None: self._configure_started = False self._configure_finished = False + if ( + attribute_options + and attribute_options != _DEFAULT_ATTRIBUTE_OPTIONS + ): + self._has_dataclass_arguments = True + self._attribute_options = attribute_options + else: + self._has_dataclass_arguments = False + self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS def init(self) -> None: """Called after all mappers are created to assemble diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index ad3e9f248d..7655f3ae2f 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -30,13 +30,14 @@ from . import strategy_options from .descriptor_props import Composite from .descriptor_props import ConcreteInheritedProperty from .descriptor_props import Synonym +from .interfaces import _AttributeOptions +from .interfaces import _DEFAULT_ATTRIBUTE_OPTIONS from .interfaces import _IntrospectsAnnotations from .interfaces import _MapsColumns from .interfaces import MapperProperty from .interfaces import PropComparator from .interfaces import StrategizedProperty from .relationships import Relationship -from .util import _extract_mapped_subtype from .util import _orm_full_deannotate from .. import exc as sa_exc from .. import ForeignKey @@ -45,6 +46,7 @@ from .. import util from ..sql import coercions from ..sql import roles from ..sql import sqltypes +from ..sql.base import _NoArg from ..sql.elements import SQLCoreOperations from ..sql.schema import Column from ..sql.schema import SchemaConst @@ -131,6 +133,7 @@ class ColumnProperty( self, column: _ORMColumnExprArgument[_T], *additional_columns: _ORMColumnExprArgument[Any], + attribute_options: Optional[_AttributeOptions] = None, group: Optional[str] = None, deferred: bool = False, raiseload: bool = False, @@ -141,7 +144,9 @@ class ColumnProperty( doc: Optional[str] = None, _instrument: bool = True, ): - super(ColumnProperty, self).__init__() + super(ColumnProperty, self).__init__( + attribute_options=attribute_options + ) columns = (column,) + additional_columns self._orig_columns = [ coercions.expect(roles.LabeledColumnExprRole, c) for c in columns @@ -193,6 +198,7 @@ class ColumnProperty( cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: column = self.columns[0] @@ -487,13 +493,38 @@ class MappedColumn( "foreign_keys", "_has_nullable", "deferred", + "_attribute_options", + "_has_dataclass_arguments", ) deferred: bool column: Column[_T] foreign_keys: Optional[Set[ForeignKey]] + _attribute_options: _AttributeOptions def __init__(self, *arg: Any, **kw: Any): + self._attribute_options = attr_opts = kw.pop( + "attribute_options", _DEFAULT_ATTRIBUTE_OPTIONS + ) + + self._has_dataclass_arguments = False + + if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS: + if attr_opts.dataclasses_default_factory is not _NoArg.NO_ARG: + self._has_dataclass_arguments = True + kw["default"] = attr_opts.dataclasses_default_factory + elif attr_opts.dataclasses_default is not _NoArg.NO_ARG: + kw["default"] = attr_opts.dataclasses_default + + if ( + attr_opts.dataclasses_init is not _NoArg.NO_ARG + or attr_opts.dataclasses_repr is not _NoArg.NO_ARG + ): + self._has_dataclass_arguments = True + + if "default" in kw and kw["default"] is _NoArg.NO_ARG: + kw.pop("default") + self.deferred = kw.pop("deferred", False) self.column = cast("Column[_T]", Column(*arg, **kw)) self.foreign_keys = self.column.foreign_keys @@ -509,13 +540,19 @@ class MappedColumn( new.deferred = self.deferred new.foreign_keys = new.column.foreign_keys new._has_nullable = self._has_nullable + new._attribute_options = self._attribute_options + new._has_dataclass_arguments = self._has_dataclass_arguments util.set_creation_order(new) return new @property def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: if self.deferred: - return ColumnProperty(self.column, deferred=True) + return ColumnProperty( + self.column, + deferred=True, + attribute_options=self._attribute_options, + ) else: return None @@ -543,6 +580,7 @@ class MappedColumn( cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: column = self.column @@ -553,18 +591,15 @@ class MappedColumn( sqltype = column.type - argument = _extract_mapped_subtype( - annotation, - cls, - key, - MappedColumn, - sqltype._isnull and not self.column.foreign_keys, - is_dataclass_field, - ) - if argument is None: - return + if extracted_mapped_annotation is None: + if sqltype._isnull and not self.column.foreign_keys: + self._raise_for_required(key, cls) + else: + return - self._init_column_for_annotation(cls, registry, argument) + self._init_column_for_annotation( + cls, registry, extracted_mapped_annotation + ) @util.preload_module("sqlalchemy.orm.decl_base") def declarative_scan_for_composite( diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 1186f0f541..deaf521472 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -49,6 +49,7 @@ from .base import class_mapper from .base import LoaderCallableStatus from .base import PassiveFlag from .base import state_str +from .interfaces import _AttributeOptions from .interfaces import _IntrospectsAnnotations from .interfaces import MANYTOMANY from .interfaces import MANYTOONE @@ -56,7 +57,6 @@ from .interfaces import ONETOMANY from .interfaces import PropComparator from .interfaces import RelationshipDirection from .interfaces import StrategizedProperty -from .util import _extract_mapped_subtype from .util import _orm_annotate from .util import _orm_deannotate from .util import CascadeOptions @@ -355,6 +355,7 @@ class Relationship( post_update: bool = False, cascade: str = "save-update, merge", viewonly: bool = False, + attribute_options: Optional[_AttributeOptions] = None, lazy: _LazyLoadArgumentType = "select", passive_deletes: Union[Literal["all"], bool] = False, passive_updates: bool = True, @@ -380,7 +381,7 @@ class Relationship( _local_remote_pairs: Optional[_ColumnPairs] = None, _legacy_inactive_history_style: bool = False, ): - super(Relationship, self).__init__() + super(Relationship, self).__init__(attribute_options=attribute_options) self.uselist = uselist self.argument = argument @@ -1701,18 +1702,19 @@ class Relationship( cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: - argument = _extract_mapped_subtype( - annotation, - cls, - key, - Relationship, - self.argument is None, - is_dataclass_field, - ) - if argument is None: - return + argument = extracted_mapped_annotation + + if extracted_mapped_annotation is None: + + if self.argument is None: + self._raise_for_required(key, cls) + else: + return + + argument = extracted_mapped_annotation if hasattr(argument, "__origin__"): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index c50cc5bac8..520c95672f 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1927,7 +1927,7 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any: def _is_mapped_annotation( - raw_annotation: Union[type, str], cls: Type[Any] + raw_annotation: _AnnotationScanType, cls: Type[Any] ) -> bool: annotated = de_stringify_annotation(cls, raw_annotation) return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") @@ -1969,9 +1969,14 @@ def _extract_mapped_subtype( attr_cls: Type[Any], required: bool, is_dataclass_field: bool, - superclasses: Optional[Tuple[Type[Any], ...]] = None, + expect_mapped: bool = True, ) -> Optional[Union[type, str]]: + """given an annotation, figure out if it's ``Mapped[something]`` and if + so, return the ``something`` part. + Includes error raise scenarios and other options. + + """ if raw_annotation is None: if required: @@ -1989,25 +1994,29 @@ def _extract_mapped_subtype( if is_dataclass_field: return annotated else: - # TODO: there don't seem to be tests for the failure - # conditions here - if not hasattr(annotated, "__origin__") or ( - not issubclass( - annotated.__origin__, # type: ignore - superclasses if superclasses else attr_cls, - ) - and not issubclass(attr_cls, annotated.__origin__) # type: ignore + if not hasattr(annotated, "__origin__") or not is_origin_of( + annotated, "Mapped", module="sqlalchemy.orm" ): - our_annotated_str = ( - annotated.__name__ + anno_name = ( + getattr(annotated, "__name__", None) if not isinstance(annotated, str) - else repr(annotated) - ) - raise sa_exc.ArgumentError( - f'Type annotation for "{cls.__name__}.{key}" should use the ' - f'syntax "Mapped[{our_annotated_str}]" or ' - f'"{attr_cls.__name__}[{our_annotated_str}]".' + else None ) + if anno_name is None: + our_annotated_str = repr(annotated) + else: + our_annotated_str = anno_name + + if expect_mapped: + raise sa_exc.ArgumentError( + f'Type annotation for "{cls.__name__}.{key}" ' + "should use the " + f'syntax "Mapped[{our_annotated_str}]" or ' + f'"{attr_cls.__name__}[{our_annotated_str}]".' + ) + + else: + return annotated if len(annotated.__args__) != 1: # type: ignore raise sa_exc.ArgumentError( diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 53f76f3ce2..d4e4d2dcad 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -25,6 +25,7 @@ from .. import event from .. import util from ..orm import declarative_base from ..orm import DeclarativeBase +from ..orm import MappedAsDataclass from ..orm import registry from ..schema import sort_tables_and_constraints @@ -90,7 +91,14 @@ class TestBase: @config.fixture() def registry(self, metadata): - reg = registry(metadata=metadata) + reg = registry( + metadata=metadata, + type_annotation_map={ + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb" + ) + }, + ) yield reg reg.dispose() @@ -109,6 +117,21 @@ class TestBase: yield Base Base.registry.dispose() + @config.fixture + def dc_decl_base(self, metadata): + _md = metadata + + class Base(MappedAsDataclass, DeclarativeBase): + metadata = _md + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb" + ) + } + + yield Base + Base.registry.dispose() + @config.fixture() def future_connection(self, future_engine, connection): # integrate the future_engine and connection fixtures so diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index adbbf143f9..4ce1e7ff32 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -230,7 +230,11 @@ def inspect_formatargspec( def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: """Return a sequence of all dataclasses.Field objects associated - with a class.""" + with a class as an already processed dataclass. + + The class must **already be a dataclass** for Field objects to be returned. + + """ if dataclasses.is_dataclass(cls): return dataclasses.fields(cls) @@ -240,7 +244,12 @@ def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: """Return a sequence of all dataclasses.Field objects associated with - a class, excluding those that originate from a superclass.""" + an already processed dataclass, excluding those that originate from a + superclass. + + The class must **already be a dataclass** for Field objects to be returned. + + """ if dataclasses.is_dataclass(cls): super_fields: Set[dataclasses.Field[Any]] = set() diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 44e26f6094..454de100bd 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -23,6 +23,14 @@ from typing_extensions import NotRequired as NotRequired # noqa: F401 from . import compat + +# more zimports issues +if True: + from typing_extensions import ( # noqa: F401 + dataclass_transform as dataclass_transform, + ) + + _T = TypeVar("_T", bound=Any) _KT = TypeVar("_KT") _KT_co = TypeVar("_KT_co", covariant=True) diff --git a/pyproject.toml b/pyproject.toml index 29d59ea698..812d60e915 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ markers = [ [tool.pyright] - reportPrivateUsage = "none" reportUnusedClass = "none" reportUnusedFunction = "none" diff --git a/setup.cfg b/setup.cfg index 5ef2c6f22c..0df41dc7bd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ package_dir = install_requires = importlib-metadata;python_version<"3.8" greenlet != 0.4.17;(platform_machine=='aarch64' or (platform_machine=='ppc64le' or (platform_machine=='x86_64' or (platform_machine=='amd64' or (platform_machine=='AMD64' or (platform_machine=='win32' or platform_machine=='WIN32')))))) - typing-extensions >= 4 + typing-extensions >= 4.1.0 [options.extras_require] asyncio = diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py new file mode 100644 index 0000000000..aac8737232 --- /dev/null +++ b/test/orm/declarative/test_dc_transforms.py @@ -0,0 +1,816 @@ +import dataclasses +import inspect as pyinspect +from typing import Any +from typing import List +from typing import Optional +from typing import Set +from typing import Type +from unittest import mock + +from sqlalchemy import Column +from sqlalchemy import exc +from sqlalchemy import ForeignKey +from sqlalchemy import inspect +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import testing +from sqlalchemy.orm import column_property +from sqlalchemy.orm import composite +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import deferred +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import MappedAsDataclass +from sqlalchemy.orm import MappedColumn +from sqlalchemy.orm import registry as _RegistryType +from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session +from sqlalchemy.orm import synonym +from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import eq_regex +from sqlalchemy.testing import expect_raises +from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_true +from sqlalchemy.testing import ne_ + + +class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): + def test_basic_constructor_repr_base_cls( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + class B(dc_decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column(default=None) + + A.__qualname__ = "some_module.A" + B.__qualname__ = "some_module.B" + + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x", "bs"], + varargs=None, + varkw=None, + defaults=(None, mock.ANY), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + eq_( + pyinspect.getfullargspec(B.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x"], + varargs=None, + varkw=None, + defaults=(None,), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)]) + eq_( + repr(a2), + "some_module.A(id=None, data='10', x=5, " + "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), " + "some_module.B(id=None, data='data2', a_id=None, x=12)])", + ) + + a3 = A("data") + eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + + def test_basic_constructor_repr_cls_decorator( + self, registry: _RegistryType + ): + @registry.mapped_as_dataclass() + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + @registry.mapped_as_dataclass() + class B: + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column(default=None) + + A.__qualname__ = "some_module.A" + B.__qualname__ = "some_module.B" + + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x", "bs"], + varargs=None, + varkw=None, + defaults=(None, mock.ANY), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + eq_( + pyinspect.getfullargspec(B.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x"], + varargs=None, + varkw=None, + defaults=(None,), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)]) + eq_( + repr(a2), + "some_module.A(id=None, data='10', x=5, " + "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), " + "some_module.B(id=None, data='data2', a_id=None, x=12)])", + ) + + a3 = A("data") + eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + + def test_default_fn(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column(default="d1") + data2: Mapped[str] = mapped_column(default_factory=lambda: "d2") + + a1 = A() + eq_(a1.data, "d1") + eq_(a1.data2, "d2") + + def test_default_factory_vs_collection_class( + self, dc_decl_base: Type[MappedAsDataclass] + ): + # this is currently the error raised by dataclasses. We can instead + # do this validation ourselves, but overall I don't know that we + # can hit every validation and rule that's in dataclasses + with expect_raises_message( + ValueError, "cannot specify both default and default_factory" + ): + + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column( + default="d1", default_factory=lambda: "d2" + ) + + def test_inheritance(self, dc_decl_base: Type[MappedAsDataclass]): + class Person(dc_decl_base): + __tablename__ = "person" + person_id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + name: Mapped[str] + type: Mapped[str] = mapped_column(init=False) + + __mapper_args__ = {"polymorphic_on": type} + + class Engineer(Person): + __tablename__ = "engineer" + + person_id: Mapped[int] = mapped_column( + ForeignKey("person.person_id"), primary_key=True, init=False + ) + + status: Mapped[str] = mapped_column(String(30)) + engineer_name: Mapped[str] + primary_language: Mapped[str] + + e1 = Engineer("nm", "st", "en", "pl") + eq_(e1.name, "nm") + eq_(e1.status, "st") + eq_(e1.engineer_name, "en") + eq_(e1.primary_language, "pl") + + def test_integrated_dc(self, dc_decl_base: Type[MappedAsDataclass]): + """We will be telling users "this is a dataclass that is also + mapped". Therefore, they will want *any* kind of attribute to do what + it would normally do in a dataclass, including normal types without any + field and explicit use of dataclasses.field(). additionally, we'd like + ``Mapped`` to mean "persist this attribute". So the absence of + ``Mapped`` should also mean something too. + + """ + + class A(dc_decl_base): + __tablename__ = "a" + + ctrl_one: str + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + some_field: int = dataclasses.field(default=5) + + some_none_field: Optional[str] = None + + a1 = A("ctrlone", "datafield") + eq_(a1.some_field, 5) + eq_(a1.some_none_field, None) + + # only Mapped[] is mapped + self.assert_compile(select(A), "SELECT a.id, a.data FROM a") + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=[ + "self", + "ctrl_one", + "data", + "some_field", + "some_none_field", + ], + varargs=None, + varkw=None, + defaults=(5, None), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + def test_dc_on_top_of_non_dc(self, decl_base: Type[DeclarativeBase]): + class Person(decl_base): + __tablename__ = "person" + person_id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + type: Mapped[str] = mapped_column() + + __mapper_args__ = {"polymorphic_on": type} + + class Engineer(MappedAsDataclass, Person): + __tablename__ = "engineer" + + person_id: Mapped[int] = mapped_column( + ForeignKey("person.person_id"), primary_key=True, init=False + ) + + status: Mapped[str] = mapped_column(String(30)) + engineer_name: Mapped[str] + primary_language: Mapped[str] + + e1 = Engineer("st", "en", "pl") + eq_(e1.status, "st") + eq_(e1.engineer_name, "en") + eq_(e1.primary_language, "pl") + + eq_( + pyinspect.getfullargspec(Person.__init__), + # the boring **kw __init__ + pyinspect.FullArgSpec( + args=["self"], + varargs=None, + varkw="kwargs", + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + eq_( + pyinspect.getfullargspec(Engineer.__init__), + # the exciting dataclasses __init__ + pyinspect.FullArgSpec( + args=["self", "status", "engineer_name", "primary_language"], + varargs=None, + varkw=None, + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + +class RelationshipDefaultFactoryTest(fixtures.TestBase): + def test_list(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=lambda: [B(data="hi")] + ) + + class B(dc_decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + + a1 = A() + eq_(a1.bs[0].data, "hi") + + def test_set(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + bs: Mapped[Set["B"]] = relationship( # noqa: F821 + default_factory=lambda: {B(data="hi")} + ) + + class B(dc_decl_base, unsafe_hash=True): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + + a1 = A() + eq_(a1.bs.pop().data, "hi") + + def test_oh_no_mismatch(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + bs: Mapped[Set["B"]] = relationship( # noqa: F821 + default_factory=lambda: [B(data="hi")] + ) + + class B(dc_decl_base, unsafe_hash=True): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + + # old school collection mismatch error FTW + with expect_raises_message( + TypeError, "Incompatible collection type: list is not set-like" + ): + A() + + def test_replace_operation_works_w_history_etc( + self, registry: _RegistryType + ): + @registry.mapped_as_dataclass + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + @registry.mapped_as_dataclass + class B: + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column(default=None) + + registry.metadata.create_all(testing.db) + + with Session(testing.db) as sess: + a1 = A("data", 10, [B("b1"), B("b2", x=5), B("b3")]) + sess.add(a1) + sess.commit() + + a2 = dataclasses.replace(a1, x=12, bs=[B("b4")]) + + assert a1 in sess + assert not sess.is_modified(a1, include_collections=True) + assert a2 not in sess + eq_(inspect(a2).attrs.x.history, ([12], (), ())) + sess.add(a2) + sess.commit() + + eq_(sess.scalars(select(A.x).order_by(A.id)).all(), [10, 12]) + eq_( + sess.scalars(select(B.data).order_by(B.id)).all(), + ["b1", "b2", "b3", "b4"], + ) + + def test_post_init(self, registry: _RegistryType): + @registry.mapped_as_dataclass + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column(init=False) + + def __post_init__(self): + self.data = "some data" + + a1 = A() + eq_(a1.data, "some data") + + def test_no_field_args_w_new_style(self, registry: _RegistryType): + with expect_raises_message( + exc.InvalidRequestError, + "SQLAlchemy mapped dataclasses can't consume mapping information", + ): + + @registry.mapped_as_dataclass() + class A: + __tablename__ = "a" + __sa_dataclass_metadata_key__ = "sa" + + account_id: int = dataclasses.field( + init=False, + metadata={"sa": Column(Integer, primary_key=True)}, + ) + + def test_no_field_args_w_new_style_two(self, registry: _RegistryType): + @dataclasses.dataclass + class Base: + pass + + with expect_raises_message( + exc.InvalidRequestError, + "SQLAlchemy mapped dataclasses can't consume mapping information", + ): + + @registry.mapped_as_dataclass() + class A(Base): + __tablename__ = "a" + __sa_dataclass_metadata_key__ = "sa" + + account_id: int = dataclasses.field( + init=False, + metadata={"sa": Column(Integer, primary_key=True)}, + ) + + +class DataclassArgsTest(fixtures.TestBase): + dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash") + + @testing.fixture(params=dc_arg_names) + def dc_argument_fixture(self, request: Any, registry: _RegistryType): + name = request.param + + args = {n: n == name for n in self.dc_arg_names} + if args["order"]: + args["eq"] = True + yield args + + @testing.fixture( + params=["mapped_column", "synonym", "deferred", "column_property"] + ) + def mapped_expr_constructor(self, request): + name = request.param + + if name == "mapped_column": + yield mapped_column(default=7, init=True) + elif name == "synonym": + yield synonym("some_int", default=7, init=True) + elif name == "deferred": + yield deferred(Column(Integer), default=7, init=True) + elif name == "column_property": + yield column_property(Column(Integer), default=7, init=True) + + def test_attrs_rejected_if_not_a_dc( + self, mapped_expr_constructor, decl_base: Type[DeclarativeBase] + ): + if isinstance(mapped_expr_constructor, MappedColumn): + unwanted_args = "'init'" + else: + unwanted_args = "'default', 'init'" + with expect_raises_message( + exc.ArgumentError, + r"Attribute 'x' on class .*A.* includes dataclasses " + r"argument\(s\): " + rf"{unwanted_args} but class does not specify SQLAlchemy native " + "dataclass configuration", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + + x: Mapped[int] = mapped_expr_constructor + + def _assert_cls(self, cls, dc_arguments): + + if dc_arguments["init"]: + + def create(data, x): + return cls(data, x) + + else: + + def create(data, x): + a1 = cls() + a1.data = data + a1.x = x + return a1 + + for n in self.dc_arg_names: + if dc_arguments[n]: + getattr(self, f"_assert_{n}")(cls, create, dc_arguments) + else: + getattr(self, f"_assert_not_{n}")(cls, create, dc_arguments) + + if dc_arguments["init"]: + a1 = cls("some data") + eq_(a1.x, 7) + + a1 = create("some data", 15) + some_int = a1.some_int + eq_( + dataclasses.asdict(a1), + {"data": "some data", "id": None, "some_int": some_int, "x": 15}, + ) + eq_(dataclasses.astuple(a1), (None, "some data", some_int, 15)) + + def _assert_unsafe_hash(self, cls, create, dc_arguments): + a1 = create("d1", 5) + hash(a1) + + def _assert_not_unsafe_hash(self, cls, create, dc_arguments): + a1 = create("d1", 5) + + if dc_arguments["eq"]: + with expect_raises(TypeError): + hash(a1) + else: + hash(a1) + + def _assert_eq(self, cls, create, dc_arguments): + a1 = create("d1", 5) + a2 = create("d2", 10) + a3 = create("d1", 5) + + eq_(a1, a3) + ne_(a1, a2) + + def _assert_not_eq(self, cls, create, dc_arguments): + a1 = create("d1", 5) + a2 = create("d2", 10) + a3 = create("d1", 5) + + eq_(a1, a1) + ne_(a1, a3) + ne_(a1, a2) + + def _assert_order(self, cls, create, dc_arguments): + is_false(create("g", 10) < create("b", 7)) + + is_true(create("g", 10) > create("b", 7)) + + is_false(create("g", 10) <= create("b", 7)) + + is_true(create("g", 10) >= create("b", 7)) + + eq_( + list(sorted([create("g", 10), create("g", 5), create("b", 7)])), + [ + create("b", 7), + create("g", 5), + create("g", 10), + ], + ) + + def _assert_not_order(self, cls, create, dc_arguments): + with expect_raises(TypeError): + create("g", 10) < create("b", 7) + + with expect_raises(TypeError): + create("g", 10) > create("b", 7) + + with expect_raises(TypeError): + create("g", 10) <= create("b", 7) + + with expect_raises(TypeError): + create("g", 10) >= create("b", 7) + + def _assert_repr(self, cls, create, dc_arguments): + a1 = create("some data", 12) + eq_regex(repr(a1), r".*A\(id=None, data='some data', x=12\)") + + def _assert_not_repr(self, cls, create, dc_arguments): + a1 = create("some data", 12) + eq_regex(repr(a1), r"<.*A object at 0x.*>") + + def _assert_init(self, cls, create, dc_arguments): + a1 = cls("some data", 5) + + eq_(a1.data, "some data") + eq_(a1.x, 5) + + a2 = cls(data="some data", x=5) + eq_(a2.data, "some data") + eq_(a2.x, 5) + + a3 = cls(data="some data") + eq_(a3.data, "some data") + eq_(a3.x, 7) + + def _assert_not_init(self, cls, create, dc_arguments): + + with expect_raises(TypeError): + cls("Some data", 5) + + # we run real "dataclasses" on the class. so with init=False, it + # doesn't touch what was there, and the SQLA default constructor + # gets put on. + a1 = cls(data="some data") + eq_(a1.data, "some data") + eq_(a1.x, None) + + a1 = cls() + eq_(a1.data, None) + + # no constructor, it sets None for x...ok + eq_(a1.x, None) + + def test_dc_arguments_decorator( + self, + dc_argument_fixture, + mapped_expr_constructor, + registry: _RegistryType, + ): + @registry.mapped_as_dataclass(**dc_argument_fixture) + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_expr_constructor + + self._assert_cls(A, dc_argument_fixture) + + def test_dc_arguments_base( + self, + dc_argument_fixture, + mapped_expr_constructor, + registry: _RegistryType, + ): + reg = registry + + class Base(MappedAsDataclass, DeclarativeBase, **dc_argument_fixture): + registry = reg + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_expr_constructor + + self.A = A + + def test_dc_arguments_perclass( + self, + dc_argument_fixture, + mapped_expr_constructor, + decl_base: Type[DeclarativeBase], + ): + class A(MappedAsDataclass, decl_base, **dc_argument_fixture): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_expr_constructor + + self.A = A + + +class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + def test_composite_setup(self, dc_decl_base: Type[MappedAsDataclass]): + @dataclasses.dataclass + class Point: + x: int + y: int + + class Edge(dc_decl_base): + __tablename__ = "edge" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + graph_id: Mapped[int] = mapped_column( + ForeignKey("graph.id"), init=False + ) + + start: Mapped[Point] = composite( + Point, mapped_column("x1"), mapped_column("y1"), default=None + ) + + end: Mapped[Point] = composite( + Point, mapped_column("x2"), mapped_column("y2"), default=None + ) + + class Graph(dc_decl_base): + __tablename__ = "graph" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + edges: Mapped[List[Edge]] = relationship() + + Point.__qualname__ = "mymodel.Point" + Edge.__qualname__ = "mymodel.Edge" + Graph.__qualname__ = "mymodel.Graph" + g = Graph( + edges=[ + Edge(start=Point(1, 2), end=Point(3, 4)), + Edge(start=Point(7, 8), end=Point(5, 6)), + ] + ) + eq_( + repr(g), + "mymodel.Graph(id=None, edges=[mymodel.Edge(id=None, " + "graph_id=None, start=mymodel.Point(x=1, y=2), " + "end=mymodel.Point(x=3, y=4)), " + "mymodel.Edge(id=None, graph_id=None, " + "start=mymodel.Point(x=7, y=8), end=mymodel.Point(x=5, y=6))])", + ) + + def test_named_setup(self, dc_decl_base: Type[MappedAsDataclass]): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(dc_decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column( + primary_key=True, init=False, repr=False + ) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite( + Address, + mapped_column(), + mapped_column(), + mapped_column("zip"), + default=None, + ) + + Address.__qualname__ = "mymodule.Address" + User.__qualname__ = "mymodule.User" + u = User( + name="user 1", + address=Address("123 anywhere street", "NY", "12345"), + ) + u2 = User("u2") + eq_( + repr(u), + "mymodule.User(name='user 1', " + "address=mymodule.Address(street='123 anywhere street', " + "state='NY', zip_='12345'))", + ) + eq_(repr(u2), "mymodule.User(name='u2', address=None)") diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index d7d19821c6..8657354397 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -190,6 +190,18 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(User.__table__.c.data.nullable) assert isinstance(User.__table__.c.created_at.type, DateTime) + def test_column_default(self, decl_base): + class MyClass(decl_base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(default="some default") + + mc = MyClass() + assert "data" not in mc.__dict__ + + eq_(MyClass.__table__.c.data.default.arg, "some default") + def test_anno_w_fixed_table(self, decl_base): users = Table( "users", @@ -959,7 +971,7 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): with expect_raises_message( ArgumentError, r"Type annotation for \"User.address\" should use the syntax " - r"\"Mapped\['Address'\]\" or \"MappedColumn\['Address'\]\"", + r"\"Mapped\['Address'\]\"", ): class User(decl_base): @@ -1068,6 +1080,38 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): # round trip! eq_(u1.address, Address("123 anywhere street", "NY", "12345")) + def test_cls_annotated_no_mapped_cols_setup(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite() + + decl_base.metadata.create_all(testing.db) + + with fixture_session() as sess: + sess.add( + User( + name="user 1", + address=Address("123 anywhere street", "NY", "12345"), + ) + ) + sess.commit() + + with fixture_session() as sess: + u1 = sess.scalar(select(User)) + + # round trip! + eq_(u1.address, Address("123 anywhere street", "NY", "12345")) + def test_one_col_setup(self, decl_base): @dataclasses.dataclass class Address: -- 2.47.2