]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement dataclass_transforms
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 Feb 2022 15:05:12 +0000 (10:05 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 20 May 2022 18:19:02 +0000 (14:19 -0400)
Implement a new means of creating a mapped dataclass where
instead of applying the `@dataclass` decorator distinctly,
the declarative process itself can create the dataclass.

MapperProperty and MappedColumn objects themselves take
the place of the dataclasses.Field object when constructing
the class.

The overall approach is made possible at the typing level
using pep-681 dataclass transforms [1].

This new approach should be able to completely supersede the
previous "dataclasses" approach of embedding metadata into
Field() objects, which remains a mutually exclusive declarative
setup style (mixing them introduces new issues that are not worth
solving).

[1] https://peps.python.org/pep-0681/#transform-descriptor-types-example

Fixes: #7642
Change-Id: I6ba88a87c5df38270317b4faf085904d91c8a63c

17 files changed:
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/instrumentation.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/util/compat.py
lib/sqlalchemy/util/typing.py
pyproject.toml
setup.cfg
test/orm/declarative/test_dc_transforms.py [new file with mode: 0644]
test/orm/declarative/test_typed_mapping.py

index b7d1df532234f9de926672b0157b1183da92b37d..4f19ba946013dcd952c83c808b48303c3026ea2a 100644 (file)
@@ -60,6 +60,7 @@ from .decl_api import DeclarativeBaseNoMeta as DeclarativeBaseNoMeta
 from .decl_api import DeclarativeMeta as DeclarativeMeta
 from .decl_api import declared_attr as declared_attr
 from .decl_api import has_inherited_table as has_inherited_table
+from .decl_api import MappedAsDataclass as MappedAsDataclass
 from .decl_api import registry as registry
 from .decl_api import synonym_for as synonym_for
 from .descriptor_props import Composite as Composite
index 0692cac09e5bc47de9c5ca6c86b0c34e87064afd..ece6a52be82807aea4846aa76b03ee26cbd77856 100644 (file)
@@ -21,9 +21,9 @@ from typing import Union
 
 from . import mapperlib as mapperlib
 from ._typing import _O
-from .base import Mapped
 from .descriptor_props import Composite
 from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
 from .properties import ColumnProperty
 from .properties import MappedColumn
 from .query import AliasOption
@@ -37,6 +37,8 @@ from .util import LoaderCriteriaOption
 from .. import sql
 from .. import util
 from ..exc import InvalidRequestError
+from ..sql._typing import _no_kw
+from ..sql.base import _NoArg
 from ..sql.base import SchemaEventTarget
 from ..sql.schema import SchemaConst
 from ..sql.selectable import FromClause
@@ -105,6 +107,10 @@ def mapped_column(
         Union[_TypeEngineArgument[Any], SchemaEventTarget]
     ] = None,
     *args: SchemaEventTarget,
+    init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+    repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
+    default: Optional[Any] = _NoArg.NO_ARG,
+    default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
     nullable: Optional[
         Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
     ] = SchemaConst.NULL_UNSPECIFIED,
@@ -113,7 +119,6 @@ def mapped_column(
     name: Optional[str] = None,
     type_: Optional[_TypeEngineArgument[Any]] = None,
     autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
-    default: Optional[Any] = None,
     doc: Optional[str] = None,
     key: Optional[str] = None,
     index: Optional[bool] = None,
@@ -300,6 +305,12 @@ def mapped_column(
         type_=type_,
         autoincrement=autoincrement,
         default=default,
+        attribute_options=_AttributeOptions(
+            init,
+            repr,
+            default,
+            default_factory,
+        ),
         doc=doc,
         key=key,
         index=index,
@@ -325,6 +336,10 @@ def column_property(
     deferred: bool = False,
     raiseload: bool = False,
     comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+    init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+    repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
+    default: Optional[Any] = _NoArg.NO_ARG,
+    default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
     active_history: bool = False,
     expire_on_flush: bool = True,
     info: Optional[_InfoType] = None,
@@ -416,6 +431,12 @@ def column_property(
     return ColumnProperty(
         column,
         *additional_columns,
+        attribute_options=_AttributeOptions(
+            init,
+            repr,
+            default,
+            default_factory,
+        ),
         group=group,
         deferred=deferred,
         raiseload=raiseload,
@@ -429,25 +450,61 @@ def column_property(
 
 @overload
 def composite(
-    class_: Type[_CC],
+    _class_or_attr: Type[_CC],
     *attrs: _CompositeAttrType[Any],
-    **kwargs: Any,
+    group: Optional[str] = None,
+    deferred: bool = False,
+    raiseload: bool = False,
+    comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+    active_history: bool = False,
+    init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+    repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
+    default: Optional[Any] = _NoArg.NO_ARG,
+    default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+    info: Optional[_InfoType] = None,
+    doc: Optional[str] = None,
+    **__kw: Any,
 ) -> Composite[_CC]:
     ...
 
 
 @overload
 def composite(
+    _class_or_attr: _CompositeAttrType[Any],
     *attrs: _CompositeAttrType[Any],
-    **kwargs: Any,
+    group: Optional[str] = None,
+    deferred: bool = False,
+    raiseload: bool = False,
+    comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+    active_history: bool = False,
+    init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+    repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
+    default: Optional[Any] = _NoArg.NO_ARG,
+    default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+    info: Optional[_InfoType] = None,
+    doc: Optional[str] = None,
+    **__kw: Any,
 ) -> Composite[Any]:
     ...
 
 
 def composite(
-    class_: Any = None,
+    _class_or_attr: Union[
+        None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
+    ] = None,
     *attrs: _CompositeAttrType[Any],
-    **kwargs: Any,
+    group: Optional[str] = None,
+    deferred: bool = False,
+    raiseload: bool = False,
+    comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+    active_history: bool = False,
+    init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+    repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
+    default: Optional[Any] = _NoArg.NO_ARG,
+    default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+    info: Optional[_InfoType] = None,
+    doc: Optional[str] = None,
+    **__kw: Any,
 ) -> Composite[Any]:
     r"""Return a composite column-based property for use with a Mapper.
 
@@ -497,7 +554,26 @@ def composite(
         :attr:`.MapperProperty.info` attribute of this object.
 
     """
-    return Composite(class_, *attrs, **kwargs)
+    if __kw:
+        raise _no_kw()
+
+    return Composite(
+        _class_or_attr,
+        *attrs,
+        attribute_options=_AttributeOptions(
+            init,
+            repr,
+            default,
+            default_factory,
+        ),
+        group=group,
+        deferred=deferred,
+        raiseload=raiseload,
+        comparator_factory=comparator_factory,
+        active_history=active_history,
+        info=info,
+        doc=doc,
+    )
 
 
 def with_loader_criteria(
@@ -700,6 +776,10 @@ def relationship(
     post_update: bool = False,
     cascade: str = "save-update, merge",
     viewonly: bool = False,
+    init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+    repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
+    default: Union[_NoArg, _T] = _NoArg.NO_ARG,
+    default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
     lazy: _LazyLoadArgumentType = "select",
     passive_deletes: Union[Literal["all"], bool] = False,
     passive_updates: bool = True,
@@ -1532,6 +1612,12 @@ def relationship(
         post_update=post_update,
         cascade=cascade,
         viewonly=viewonly,
+        attribute_options=_AttributeOptions(
+            init,
+            repr,
+            default,
+            default_factory,
+        ),
         lazy=lazy,
         passive_deletes=passive_deletes,
         passive_updates=passive_updates,
@@ -1559,6 +1645,10 @@ def synonym(
     map_column: Optional[bool] = None,
     descriptor: Optional[Any] = None,
     comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+    init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+    repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
+    default: Union[_NoArg, _T] = _NoArg.NO_ARG,
+    default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
     info: Optional[_InfoType] = None,
     doc: Optional[str] = None,
 ) -> Synonym[Any]:
@@ -1670,6 +1760,12 @@ def synonym(
         map_column=map_column,
         descriptor=descriptor,
         comparator_factory=comparator_factory,
+        attribute_options=_AttributeOptions(
+            init,
+            repr,
+            default,
+            default_factory,
+        ),
         doc=doc,
         info=info,
     )
@@ -1784,7 +1880,17 @@ def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument:
 def deferred(
     column: _ORMColumnExprArgument[_T],
     *additional_columns: _ORMColumnExprArgument[Any],
-    **kw: Any,
+    group: Optional[str] = None,
+    raiseload: bool = False,
+    comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+    init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+    repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
+    default: Optional[Any] = _NoArg.NO_ARG,
+    default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+    active_history: bool = False,
+    expire_on_flush: bool = True,
+    info: Optional[_InfoType] = None,
+    doc: Optional[str] = None,
 ) -> ColumnProperty[_T]:
     r"""Indicate a column-based mapped attribute that by default will
     not load unless accessed.
@@ -1803,21 +1909,41 @@ def deferred(
 
         :ref:`deferred_raiseload`
 
-    :param \**kw: additional keyword arguments passed to
-     :class:`.ColumnProperty`.
+    Additional arguments are the same as that of :func:`_orm.column_property`.
 
     .. seealso::
 
         :ref:`deferred`
 
     """
-    kw["deferred"] = True
-    return ColumnProperty(column, *additional_columns, **kw)
+    return ColumnProperty(
+        column,
+        *additional_columns,
+        attribute_options=_AttributeOptions(
+            init,
+            repr,
+            default,
+            default_factory,
+        ),
+        group=group,
+        deferred=True,
+        raiseload=raiseload,
+        comparator_factory=comparator_factory,
+        active_history=active_history,
+        expire_on_flush=expire_on_flush,
+        info=info,
+        doc=doc,
+    )
 
 
 def query_expression(
     default_expr: _ORMColumnExprArgument[_T] = sql.null(),
-) -> Mapped[_T]:
+    *,
+    repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
+    expire_on_flush: bool = True,
+    info: Optional[_InfoType] = None,
+    doc: Optional[str] = None,
+) -> ColumnProperty[_T]:
     """Indicate an attribute that populates from a query-time SQL expression.
 
     :param default_expr: Optional SQL expression object that will be used in
@@ -1840,7 +1966,18 @@ def query_expression(
         :ref:`mapper_querytime_expression`
 
     """
-    prop = ColumnProperty(default_expr)
+    prop = ColumnProperty(
+        default_expr,
+        attribute_options=_AttributeOptions(
+            _NoArg.NO_ARG,
+            repr,
+            _NoArg.NO_ARG,
+            _NoArg.NO_ARG,
+        ),
+        expire_on_flush=expire_on_flush,
+        info=info,
+        doc=doc,
+    )
     prop.strategy_key = (("query_expression", True),)
     return prop
 
index 1c343b04ce794777006cd08b0bf25a4d87e075c3..553a50107f950ca8b7a09b5f5176cfd270b15668 100644 (file)
@@ -33,6 +33,13 @@ from . import clsregistry
 from . import instrumentation
 from . import interfaces
 from . import mapperlib
+from ._orm_constructors import column_property
+from ._orm_constructors import composite
+from ._orm_constructors import deferred
+from ._orm_constructors import mapped_column
+from ._orm_constructors import query_expression
+from ._orm_constructors import relationship
+from ._orm_constructors import synonym
 from .attributes import InstrumentedAttribute
 from .base import _inspect_mapped_class
 from .base import Mapped
@@ -42,8 +49,13 @@ from .decl_base import _declarative_constructor
 from .decl_base import _DeferredMapperConfig
 from .decl_base import _del_attribute
 from .decl_base import _mapper
+from .descriptor_props import Composite
+from .descriptor_props import Synonym
 from .descriptor_props import Synonym as _orm_synonym
 from .mapper import Mapper
+from .properties import ColumnProperty
+from .properties import MappedColumn
+from .relationships import Relationship
 from .state import InstanceState
 from .. import exc
 from .. import inspection
@@ -60,9 +72,9 @@ from ..util.typing import Literal
 if TYPE_CHECKING:
     from ._typing import _O
     from ._typing import _RegistryType
-    from .descriptor_props import Synonym
     from .instrumentation import ClassManager
     from .interfaces import MapperProperty
+    from .state import InstanceState  # noqa
     from ..sql._typing import _TypeEngineArgument
 
 _T = TypeVar("_T", bound=Any)
@@ -120,6 +132,26 @@ class DeclarativeAttributeIntercept(
     """
 
 
+@compat_typing.dataclass_transform(
+    field_descriptors=(
+        MappedColumn[Any],
+        Relationship[Any],
+        Composite[Any],
+        ColumnProperty[Any],
+        Synonym[Any],
+        mapped_column,
+        relationship,
+        composite,
+        column_property,
+        synonym,
+        deferred,
+        query_expression,
+    ),
+)
+class DCTransformDeclarative(DeclarativeAttributeIntercept):
+    """metaclass that includes @dataclass_transforms"""
+
+
 class DeclarativeMeta(
     _DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
 ):
@@ -543,12 +575,42 @@ class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]):
             cls._sa_registry.map_declaratively(cls)
 
 
+class MappedAsDataclass(metaclass=DCTransformDeclarative):
+    """Mixin class to indicate when mapping this class, also convert it to be
+    a dataclass.
+
+    .. seealso::
+
+        :meth:`_orm.registry.mapped_as_dataclass`
+
+    .. versionadded:: 2.0
+    """
+
+    def __init_subclass__(
+        cls,
+        init: bool = True,
+        repr: bool = True,  # noqa: A002
+        eq: bool = True,
+        order: bool = False,
+        unsafe_hash: bool = False,
+    ) -> None:
+        cls._sa_apply_dc_transforms = {
+            "init": init,
+            "repr": repr,
+            "eq": eq,
+            "order": order,
+            "unsafe_hash": unsafe_hash,
+        }
+        super().__init_subclass__()
+
+
 class DeclarativeBase(
     inspection.Inspectable[InstanceState[Any]],
     metaclass=DeclarativeAttributeIntercept,
 ):
     """Base class used for declarative class definitions.
 
+
     The :class:`_orm.DeclarativeBase` allows for the creation of new
     declarative bases in such a way that is compatible with type checkers::
 
@@ -1121,7 +1183,7 @@ class registry:
 
         bases = not isinstance(cls, tuple) and (cls,) or cls
 
-        class_dict = dict(registry=self, metadata=metadata)
+        class_dict: Dict[str, Any] = dict(registry=self, metadata=metadata)
         if isinstance(cls, type):
             class_dict["__doc__"] = cls.__doc__
 
@@ -1142,6 +1204,78 @@ class registry:
 
         return metaclass(name, bases, class_dict)
 
+    @compat_typing.dataclass_transform(
+        field_descriptors=(
+            MappedColumn[Any],
+            Relationship[Any],
+            Composite[Any],
+            ColumnProperty[Any],
+            Synonym[Any],
+            mapped_column,
+            relationship,
+            composite,
+            column_property,
+            synonym,
+            deferred,
+            query_expression,
+        ),
+    )
+    @overload
+    def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]:
+        ...
+
+    @overload
+    def mapped_as_dataclass(
+        self,
+        __cls: Literal[None] = ...,
+        *,
+        init: bool = True,
+        repr: bool = True,  # noqa: A002
+        eq: bool = True,
+        order: bool = False,
+        unsafe_hash: bool = False,
+    ) -> Callable[[Type[_O]], Type[_O]]:
+        ...
+
+    def mapped_as_dataclass(
+        self,
+        __cls: Optional[Type[_O]] = None,
+        *,
+        init: bool = True,
+        repr: bool = True,  # noqa: A002
+        eq: bool = True,
+        order: bool = False,
+        unsafe_hash: bool = False,
+    ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]:
+        """Class decorator that will apply the Declarative mapping process
+        to a given class, and additionally convert the class to be a
+        Python dataclass.
+
+        .. seealso::
+
+            :meth:`_orm.registry.mapped`
+
+        .. versionadded:: 2.0
+
+
+        """
+
+        def decorate(cls: Type[_O]) -> Type[_O]:
+            cls._sa_apply_dc_transforms = {
+                "init": init,
+                "repr": repr,
+                "eq": eq,
+                "order": order,
+                "unsafe_hash": unsafe_hash,
+            }
+            _as_declarative(self, cls, cls.__dict__)
+            return cls
+
+        if __cls:
+            return decorate(__cls)
+        else:
+            return decorate
+
     def mapped(self, cls: Type[_O]) -> Type[_O]:
         """Class decorator that will apply the Declarative mapping process
         to a given class.
@@ -1174,6 +1308,10 @@ class registry:
             that will apply Declarative mapping to subclasses automatically
             using a Python metaclass.
 
+        .. seealso::
+
+            :meth:`_orm.registry.mapped_as_dataclass`
+
         """
         _as_declarative(self, cls, cls.__dict__)
         return cls
index a66421e2250c599c76660aacb8a315f1977ccafb..54a272f86e9920c79668fc17b784d74fe3c5247c 100644 (file)
@@ -10,6 +10,8 @@
 from __future__ import annotations
 
 import collections
+import dataclasses
+import re
 from typing import Any
 from typing import Callable
 from typing import cast
@@ -40,6 +42,7 @@ from .base import _is_mapped_class
 from .base import InspectionAttr
 from .descriptor_props import Composite
 from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
 from .interfaces import _IntrospectsAnnotations
 from .interfaces import _MappedAttribute
 from .interfaces import _MapsColumns
@@ -48,15 +51,18 @@ from .mapper import Mapper as mapper
 from .mapper import Mapper
 from .properties import ColumnProperty
 from .properties import MappedColumn
+from .util import _extract_mapped_subtype
 from .util import _is_mapped_annotation
 from .util import class_mapper
 from .. import event
 from .. import exc
 from .. import util
 from ..sql import expression
+from ..sql.base import _NoArg
 from ..sql.schema import Column
 from ..sql.schema import Table
 from ..util import topological
+from ..util.typing import _AnnotationScanType
 from ..util.typing import Protocol
 
 if TYPE_CHECKING:
@@ -392,11 +398,13 @@ class _ClassScanMapperConfig(_MapperConfig):
         "mapper_args",
         "mapper_args_fn",
         "inherits",
+        "allow_dataclass_fields",
+        "dataclass_setup_arguments",
     )
 
     registry: _RegistryType
     clsdict_view: _ClassDict
-    collected_annotations: Dict[str, Tuple[Any, bool]]
+    collected_annotations: Dict[str, Tuple[Any, Any, bool]]
     collected_attributes: Dict[str, Any]
     local_table: Optional[FromClause]
     persist_selectable: Optional[FromClause]
@@ -411,6 +419,17 @@ class _ClassScanMapperConfig(_MapperConfig):
     mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
     inherits: Optional[Type[Any]]
 
+    dataclass_setup_arguments: Optional[Dict[str, Any]]
+    """if the class has SQLAlchemy native dataclass parameters, where
+    we will create a SQLAlchemy dataclass (not a real dataclass).
+
+    """
+
+    allow_dataclass_fields: bool
+    """if true, look for dataclass-processed Field objects on the target
+    class as well as superclasses and extract ORM mapping directives from
+    the "metadata" attribute of each Field"""
+
     def __init__(
         self,
         registry: _RegistryType,
@@ -434,10 +453,37 @@ class _ClassScanMapperConfig(_MapperConfig):
         self.declared_columns = util.OrderedSet()
         self.column_copies = {}
 
+        self.dataclass_setup_arguments = dca = getattr(
+            self.cls, "_sa_apply_dc_transforms", None
+        )
+
+        cld = dataclasses.is_dataclass(cls_)
+
+        sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__")
+
+        # we don't want to consume Field objects from a not-already-dataclass.
+        # the Field objects won't have their "name" or "type" populated,
+        # and while it seems like we could just set these on Field as we
+        # read them, Field is documented as "user read only" and we need to
+        # stay far away from any off-label use of dataclasses APIs.
+        if (not cld or dca) and sdk:
+            raise exc.InvalidRequestError(
+                "SQLAlchemy mapped dataclasses can't consume mapping "
+                "information from dataclass.Field() objects if the immediate "
+                "class is not already a dataclass."
+            )
+
+        # if already a dataclass, and __sa_dataclass_metadata_key__ present,
+        # then also look inside of dataclass.Field() objects yielded by
+        # dataclasses.get_fields(cls) when scanning for attributes
+        self.allow_dataclass_fields = bool(sdk and cld)
+
         self._setup_declared_events()
 
         self._scan_attributes()
 
+        self._setup_dataclasses_transforms()
+
         with mapperlib._CONFIGURE_MUTEX:
             clsregistry.add_class(
                 self.classname, self.cls, registry._class_registry
@@ -477,11 +523,15 @@ class _ClassScanMapperConfig(_MapperConfig):
         attribute, taking SQLAlchemy-enabled dataclass fields into account.
 
         """
-        sa_dataclass_metadata_key = _get_immediate_cls_attr(
-            cls, "__sa_dataclass_metadata_key__"
-        )
 
-        if sa_dataclass_metadata_key is None:
+        if self.allow_dataclass_fields:
+            sa_dataclass_metadata_key = _get_immediate_cls_attr(
+                cls, "__sa_dataclass_metadata_key__"
+            )
+        else:
+            sa_dataclass_metadata_key = None
+
+        if not sa_dataclass_metadata_key:
 
             def attribute_is_overridden(key: str, obj: Any) -> bool:
                 return getattr(cls, key) is not obj
@@ -551,6 +601,7 @@ class _ClassScanMapperConfig(_MapperConfig):
             "__dict__",
             "__weakref__",
             "_sa_class_manager",
+            "_sa_apply_dc_transforms",
             "__dict__",
             "__weakref__",
         ]
@@ -563,10 +614,6 @@ class _ClassScanMapperConfig(_MapperConfig):
         adjusting for SQLAlchemy fields embedded in dataclass fields.
 
         """
-        sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
-            cls, "__sa_dataclass_metadata_key__"
-        )
-
         cls_annotations = util.get_annotations(cls)
 
         cls_vars = vars(cls)
@@ -576,7 +623,15 @@ class _ClassScanMapperConfig(_MapperConfig):
         names = util.merge_lists_w_ordering(
             [n for n in cls_vars if n not in skip], list(cls_annotations)
         )
-        if sa_dataclass_metadata_key is None:
+
+        if self.allow_dataclass_fields:
+            sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
+                cls, "__sa_dataclass_metadata_key__"
+            )
+        else:
+            sa_dataclass_metadata_key = None
+
+        if not sa_dataclass_metadata_key:
 
             def local_attributes_for_class() -> Iterable[
                 Tuple[str, Any, Any, bool]
@@ -652,45 +707,51 @@ class _ClassScanMapperConfig(_MapperConfig):
                 name,
                 obj,
                 annotation,
-                is_dataclass,
+                is_dataclass_field,
             ) in local_attributes_for_class():
-                if name == "__mapper_args__":
-                    check_decl = _check_declared_props_nocascade(
-                        obj, name, cls
-                    )
-                    if not mapper_args_fn and (not class_mapped or check_decl):
-                        # don't even invoke __mapper_args__ until
-                        # after we've determined everything about the
-                        # mapped table.
-                        # make a copy of it so a class-level dictionary
-                        # is not overwritten when we update column-based
-                        # arguments.
-                        def _mapper_args_fn() -> Dict[str, Any]:
-                            return dict(cls_as_Decl.__mapper_args__)
-
-                        mapper_args_fn = _mapper_args_fn
-
-                elif name == "__tablename__":
-                    check_decl = _check_declared_props_nocascade(
-                        obj, name, cls
-                    )
-                    if not tablename and (not class_mapped or check_decl):
-                        tablename = cls_as_Decl.__tablename__
-                elif name == "__table_args__":
-                    check_decl = _check_declared_props_nocascade(
-                        obj, name, cls
-                    )
-                    if not table_args and (not class_mapped or check_decl):
-                        table_args = cls_as_Decl.__table_args__
-                        if not isinstance(
-                            table_args, (tuple, dict, type(None))
+                if re.match(r"^__.+__$", name):
+                    if name == "__mapper_args__":
+                        check_decl = _check_declared_props_nocascade(
+                            obj, name, cls
+                        )
+                        if not mapper_args_fn and (
+                            not class_mapped or check_decl
                         ):
-                            raise exc.ArgumentError(
-                                "__table_args__ value must be a tuple, "
-                                "dict, or None"
-                            )
-                        if base is not cls:
-                            inherited_table_args = True
+                            # don't even invoke __mapper_args__ until
+                            # after we've determined everything about the
+                            # mapped table.
+                            # make a copy of it so a class-level dictionary
+                            # is not overwritten when we update column-based
+                            # arguments.
+                            def _mapper_args_fn() -> Dict[str, Any]:
+                                return dict(cls_as_Decl.__mapper_args__)
+
+                            mapper_args_fn = _mapper_args_fn
+
+                    elif name == "__tablename__":
+                        check_decl = _check_declared_props_nocascade(
+                            obj, name, cls
+                        )
+                        if not tablename and (not class_mapped or check_decl):
+                            tablename = cls_as_Decl.__tablename__
+                    elif name == "__table_args__":
+                        check_decl = _check_declared_props_nocascade(
+                            obj, name, cls
+                        )
+                        if not table_args and (not class_mapped or check_decl):
+                            table_args = cls_as_Decl.__table_args__
+                            if not isinstance(
+                                table_args, (tuple, dict, type(None))
+                            ):
+                                raise exc.ArgumentError(
+                                    "__table_args__ value must be a tuple, "
+                                    "dict, or None"
+                                )
+                            if base is not cls:
+                                inherited_table_args = True
+                    else:
+                        # skip all other dunder names
+                        continue
                 elif class_mapped:
                     if _is_declarative_props(obj):
                         util.warn(
@@ -706,9 +767,8 @@ class _ClassScanMapperConfig(_MapperConfig):
                     # acting like that for now.
 
                     if isinstance(obj, (Column, MappedColumn)):
-                        self.collected_annotations[name] = (
-                            annotation,
-                            False,
+                        self._collect_annotation(
+                            name, annotation, is_dataclass_field, True, obj
                         )
                         # already copied columns to the mapped class.
                         continue
@@ -745,7 +805,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                             ] = ret = obj.__get__(obj, cls)
                             setattr(cls, name, ret)
                         else:
-                            if is_dataclass:
+                            if is_dataclass_field:
                                 # access attribute using normal class access
                                 # first, to see if it's been mapped on a
                                 # superclass.   note if the dataclasses.field()
@@ -789,14 +849,16 @@ class _ClassScanMapperConfig(_MapperConfig):
                         ):
                             ret.doc = obj.__doc__
 
-                        self.collected_annotations[name] = (
+                        self._collect_annotation(
+                            name,
                             obj._collect_return_annotation(),
                             False,
+                            True,
+                            obj,
                         )
                     elif _is_mapped_annotation(annotation, cls):
-                        self.collected_annotations[name] = (
-                            annotation,
-                            is_dataclass,
+                        self._collect_annotation(
+                            name, annotation, is_dataclass_field, True, obj
                         )
                         if obj is None:
                             if not fixed_table:
@@ -809,7 +871,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                         # declarative mapping.  however, check for some
                         # more common mistakes
                         self._warn_for_decl_attributes(base, name, obj)
-                elif is_dataclass and (
+                elif is_dataclass_field and (
                     name not in clsdict_view or clsdict_view[name] is not obj
                 ):
                     # here, we are definitely looking at the target class
@@ -826,14 +888,12 @@ class _ClassScanMapperConfig(_MapperConfig):
                         obj = obj.fget()
 
                     collected_attributes[name] = obj
-                    self.collected_annotations[name] = (
-                        annotation,
-                        True,
+                    self._collect_annotation(
+                        name, annotation, True, False, obj
                     )
                 else:
-                    self.collected_annotations[name] = (
-                        annotation,
-                        False,
+                    self._collect_annotation(
+                        name, annotation, False, None, obj
                     )
                     if (
                         obj is None
@@ -843,6 +903,10 @@ class _ClassScanMapperConfig(_MapperConfig):
                         collected_attributes[name] = MappedColumn()
                     elif name in clsdict_view:
                         collected_attributes[name] = obj
+                    # else if the name is not in the cls.__dict__,
+                    # don't collect it as an attribute.
+                    # we will see the annotation only, which is meaningful
+                    # both for mapping and dataclasses setup
 
         if inherited_table_args and not tablename:
             table_args = None
@@ -851,6 +915,77 @@ class _ClassScanMapperConfig(_MapperConfig):
         self.tablename = tablename
         self.mapper_args_fn = mapper_args_fn
 
+    def _setup_dataclasses_transforms(self) -> None:
+
+        dataclass_setup_arguments = self.dataclass_setup_arguments
+        if not dataclass_setup_arguments:
+            return
+
+        manager = instrumentation.manager_of_class(self.cls)
+        assert manager is not None
+
+        field_list = [
+            _AttributeOptions._get_arguments_for_make_dataclass(
+                key,
+                anno,
+                self.collected_attributes.get(key, _NoArg.NO_ARG),
+            )
+            for key, anno in (
+                (key, mapped_anno if mapped_anno else raw_anno)
+                for key, (
+                    raw_anno,
+                    mapped_anno,
+                    is_dc,
+                ) in self.collected_annotations.items()
+            )
+        ]
+
+        annotations = {}
+        defaults = {}
+        for item in field_list:
+            if len(item) == 2:
+                name, tp = item  # type: ignore
+            elif len(item) == 3:
+                name, tp, spec = item  # type: ignore
+                defaults[name] = spec
+            else:
+                assert False
+            annotations[name] = tp
+
+        for k, v in defaults.items():
+            setattr(self.cls, k, v)
+        self.cls.__annotations__ = annotations
+
+        dataclasses.dataclass(self.cls, **dataclass_setup_arguments)
+
+    def _collect_annotation(
+        self,
+        name: str,
+        raw_annotation: _AnnotationScanType,
+        is_dataclass: bool,
+        expect_mapped: Optional[bool],
+        attr_value: Any,
+    ) -> None:
+
+        if expect_mapped is None:
+            expect_mapped = isinstance(attr_value, _MappedAttribute)
+
+        extracted_mapped_annotation = _extract_mapped_subtype(
+            raw_annotation,
+            self.cls,
+            name,
+            type(attr_value),
+            required=False,
+            is_dataclass_field=False,
+            expect_mapped=expect_mapped and not self.allow_dataclass_fields,
+        )
+
+        self.collected_annotations[name] = (
+            raw_annotation,
+            extracted_mapped_annotation,
+            is_dataclass,
+        )
+
     def _warn_for_decl_attributes(
         self, cls: Type[Any], key: str, c: Any
     ) -> None:
@@ -982,13 +1117,53 @@ class _ClassScanMapperConfig(_MapperConfig):
                 _undefer_column_name(
                     k, self.column_copies.get(value, value)  # type: ignore
                 )
-            elif isinstance(value, _IntrospectsAnnotations):
-                annotation, is_dataclass = self.collected_annotations.get(
-                    k, (None, False)
-                )
-                value.declarative_scan(
-                    self.registry, cls, k, annotation, is_dataclass
-                )
+            else:
+                if isinstance(value, _IntrospectsAnnotations):
+                    (
+                        annotation,
+                        extracted_mapped_annotation,
+                        is_dataclass,
+                    ) = self.collected_annotations.get(k, (None, None, False))
+                    value.declarative_scan(
+                        self.registry,
+                        cls,
+                        k,
+                        annotation,
+                        extracted_mapped_annotation,
+                        is_dataclass,
+                    )
+
+                if (
+                    isinstance(value, (MapperProperty, _MapsColumns))
+                    and value._has_dataclass_arguments
+                    and not self.dataclass_setup_arguments
+                ):
+                    if isinstance(value, MapperProperty):
+                        argnames = [
+                            "init",
+                            "default_factory",
+                            "repr",
+                            "default",
+                        ]
+                    else:
+                        argnames = ["init", "default_factory", "repr"]
+
+                    args = {
+                        a
+                        for a in argnames
+                        if getattr(
+                            value._attribute_options, f"dataclasses_{a}"
+                        )
+                        is not _NoArg.NO_ARG
+                    }
+                    raise exc.ArgumentError(
+                        f"Attribute '{k}' on class {cls} includes dataclasses "
+                        f"argument(s): "
+                        f"{', '.join(sorted(repr(a) for a in args))} but "
+                        f"class does not specify "
+                        "SQLAlchemy native dataclass configuration."
+                    )
+
             our_stuff[k] = value
 
     def _extract_declared_columns(self) -> None:
@@ -997,6 +1172,7 @@ class _ClassScanMapperConfig(_MapperConfig):
         # extract columns from the class dict
         declared_columns = self.declared_columns
         name_to_prop_key = collections.defaultdict(set)
+
         for key, c in list(our_stuff.items()):
             if isinstance(c, _MapsColumns):
 
@@ -1019,7 +1195,6 @@ class _ClassScanMapperConfig(_MapperConfig):
                     # otherwise, Mapper will map it under the column key.
                     if mp_to_assign is None and key != col.key:
                         our_stuff[key] = col
-
             elif isinstance(c, Column):
                 # undefer previously occurred here, and now occurs earlier.
                 # ensure every column we get here has been named
index 8c89f96aa950b7cde53ef4bdf9a84dc711d55fa3..a366a9534f4c87d7265436b528d71b4eaf1b8b1c 100644 (file)
@@ -35,11 +35,11 @@ from .base import LoaderCallableStatus
 from .base import Mapped
 from .base import PassiveFlag
 from .base import SQLORMOperations
+from .interfaces import _AttributeOptions
 from .interfaces import _IntrospectsAnnotations
 from .interfaces import _MapsColumns
 from .interfaces import MapperProperty
 from .interfaces import PropComparator
-from .util import _extract_mapped_subtype
 from .util import _none_set
 from .. import event
 from .. import exc as sa_exc
@@ -200,24 +200,26 @@ class Composite(
 
     def __init__(
         self,
-        class_: Union[
+        _class_or_attr: Union[
             None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
         ] = None,
         *attrs: _CompositeAttrType[Any],
+        attribute_options: Optional[_AttributeOptions] = None,
         active_history: bool = False,
         deferred: bool = False,
         group: Optional[str] = None,
         comparator_factory: Optional[Type[Comparator[_CC]]] = None,
         info: Optional[_InfoType] = None,
+        **kwargs: Any,
     ):
-        super().__init__()
+        super().__init__(attribute_options=attribute_options)
 
-        if isinstance(class_, (Mapped, str, sql.ColumnElement)):
-            self.attrs = (class_,) + attrs
+        if isinstance(_class_or_attr, (Mapped, str, sql.ColumnElement)):
+            self.attrs = (_class_or_attr,) + attrs
             # will initialize within declarative_scan
             self.composite_class = None  # type: ignore
         else:
-            self.composite_class = class_  # type: ignore
+            self.composite_class = _class_or_attr  # type: ignore
             self.attrs = attrs
 
         self.active_history = active_history
@@ -332,19 +334,15 @@ class Composite(
         cls: Type[Any],
         key: str,
         annotation: Optional[_AnnotationScanType],
+        extracted_mapped_annotation: Optional[_AnnotationScanType],
         is_dataclass_field: bool,
     ) -> None:
-        MappedColumn = util.preloaded.orm_properties.MappedColumn
-
-        argument = _extract_mapped_subtype(
-            annotation,
-            cls,
-            key,
-            MappedColumn,
-            self.composite_class is None,
-            is_dataclass_field,
-        )
-
+        if (
+            self.composite_class is None
+            and extracted_mapped_annotation is None
+        ):
+            self._raise_for_required(key, cls)
+        argument = extracted_mapped_annotation
         if argument and self.composite_class is None:
             if isinstance(argument, str) or hasattr(
                 argument, "__forward_arg__"
@@ -371,11 +369,18 @@ class Composite(
         for param, attr in itertools.zip_longest(
             insp.parameters.values(), self.attrs
         ):
-            if param is None or attr is None:
+            if param is None:
                 raise sa_exc.ArgumentError(
-                    f"number of arguments to {self.composite_class.__name__} "
-                    f"class and number of attributes don't match"
+                    f"number of composite attributes "
+                    f"{len(self.attrs)} exceeds "
+                    f"that of the number of attributes in class "
+                    f"{self.composite_class.__name__} {len(insp.parameters)}"
                 )
+            if attr is None:
+                # fill in missing attr spots with empty MappedColumn
+                attr = MappedColumn()
+                self.attrs += (attr,)
+
             if isinstance(attr, MappedColumn):
                 attr.declarative_scan_for_composite(
                     registry, cls, key, param.name, param.annotation
@@ -800,10 +805,11 @@ class Synonym(DescriptorProperty[_T]):
         map_column: Optional[bool] = None,
         descriptor: Optional[Any] = None,
         comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+        attribute_options: Optional[_AttributeOptions] = None,
         info: Optional[_InfoType] = None,
         doc: Optional[str] = None,
     ):
-        super().__init__()
+        super().__init__(attribute_options=attribute_options)
 
         self.name = name
         self.map_column = map_column
index 4fa61b7ceef56cb8451b99c3d42419ca799b0314..33de2aee9084007808162c587c11e8b73973ad82 100644 (file)
@@ -113,6 +113,7 @@ class ClassManager(
     "previously known as deferred_scalar_loader"
 
     init_method: Optional[Callable[..., None]]
+    original_init: Optional[Callable[..., None]] = None
 
     factory: Optional[_ManagerFactory]
 
@@ -229,7 +230,7 @@ class ClassManager(
         if finalize and not self._finalized:
             self._finalize()
 
-    def _finalize(self):
+    def _finalize(self) -> None:
         if self._finalized:
             return
         self._finalized = True
@@ -238,14 +239,14 @@ class ClassManager(
 
         _instrumentation_factory.dispatch.class_instrument(self.class_)
 
-    def __hash__(self):
+    def __hash__(self) -> int:  # type: ignore[override]
         return id(self)
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         return other is self
 
     @property
-    def is_mapped(self):
+    def is_mapped(self) -> bool:
         return "mapper" in self.__dict__
 
     @HasMemoized.memoized_attribute
index b5569ce063e988726fe8e9b171a618bb6532b1b4..e0034061d4e25e33e580afcb18899566e666c01b 100644 (file)
@@ -19,6 +19,7 @@ are exposed when inspecting mappings.
 from __future__ import annotations
 
 import collections
+import dataclasses
 import typing
 from typing import Any
 from typing import Callable
@@ -27,6 +28,8 @@ from typing import ClassVar
 from typing import Dict
 from typing import Iterator
 from typing import List
+from typing import NamedTuple
+from typing import NoReturn
 from typing import Optional
 from typing import Sequence
 from typing import Set
@@ -51,11 +54,13 @@ from .base import ONETOMANY as ONETOMANY  # noqa: F401
 from .base import RelationshipDirection as RelationshipDirection  # noqa: F401
 from .base import SQLORMOperations
 from .. import ColumnElement
+from .. import exc as sa_exc
 from .. import inspection
 from .. import util
 from ..sql import operators
 from ..sql import roles
 from ..sql import visitors
+from ..sql.base import _NoArg
 from ..sql.base import ExecutableOption
 from ..sql.cache_key import HasCacheKey
 from ..sql.schema import Column
@@ -141,6 +146,7 @@ class _IntrospectsAnnotations:
         cls: Type[Any],
         key: str,
         annotation: Optional[_AnnotationScanType],
+        extracted_mapped_annotation: Optional[_AnnotationScanType],
         is_dataclass_field: bool,
     ) -> None:
         """Perform class-specific initializaton at early declarative scanning
@@ -150,6 +156,70 @@ class _IntrospectsAnnotations:
 
         """
 
+    def _raise_for_required(self, key: str, cls: Type[Any]) -> NoReturn:
+        raise sa_exc.ArgumentError(
+            f"Python typing annotation is required for attribute "
+            f'"{cls.__name__}.{key}" when primary argument(s) for '
+            f'"{self.__class__.__name__}" construct are None or not present'
+        )
+
+
+class _AttributeOptions(NamedTuple):
+    """define Python-local attribute behavior options common to all
+    :class:`.MapperProperty` objects.
+
+    Currently this includes dataclass-generation arguments.
+
+    .. versionadded:: 2.0
+
+    """
+
+    dataclasses_init: Union[_NoArg, bool]
+    dataclasses_repr: Union[_NoArg, bool]
+    dataclasses_default: Union[_NoArg, Any]
+    dataclasses_default_factory: Union[_NoArg, Callable[[], Any]]
+
+    def _as_dataclass_field(self) -> Any:
+        """Return a ``dataclasses.Field`` object given these arguments."""
+
+        kw: Dict[str, Any] = {}
+        if self.dataclasses_default_factory is not _NoArg.NO_ARG:
+            kw["default_factory"] = self.dataclasses_default_factory
+        if self.dataclasses_default is not _NoArg.NO_ARG:
+            kw["default"] = self.dataclasses_default
+        if self.dataclasses_init is not _NoArg.NO_ARG:
+            kw["init"] = self.dataclasses_init
+        if self.dataclasses_repr is not _NoArg.NO_ARG:
+            kw["repr"] = self.dataclasses_repr
+
+        return dataclasses.field(**kw)
+
+    @classmethod
+    def _get_arguments_for_make_dataclass(
+        cls, key: str, annotation: Type[Any], elem: _T
+    ) -> Union[
+        Tuple[str, Type[Any]], Tuple[str, Type[Any], dataclasses.Field[Any]]
+    ]:
+        """given attribute key, annotation, and value from a class, return
+        the argument tuple we would pass to dataclasses.make_dataclass()
+        for this attribute.
+
+        """
+        if isinstance(elem, (MapperProperty, _MapsColumns)):
+            dc_field = elem._attribute_options._as_dataclass_field()
+
+            return (key, annotation, dc_field)
+        elif elem is not _NoArg.NO_ARG:
+            # why is typing not erroring on this?
+            return (key, annotation, elem)
+        else:
+            return (key, annotation)
+
+
+_DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions(
+    _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG
+)
+
 
 class _MapsColumns(_MappedAttribute[_T]):
     """interface for declarative-capable construct that delivers one or more
@@ -158,6 +228,9 @@ class _MapsColumns(_MappedAttribute[_T]):
 
     __slots__ = ()
 
+    _attribute_options: _AttributeOptions
+    _has_dataclass_arguments: bool
+
     @property
     def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]:
         """return a MapperProperty to be assigned to the declarative mapping"""
@@ -199,6 +272,8 @@ class MapperProperty(
     __slots__ = (
         "_configure_started",
         "_configure_finished",
+        "_attribute_options",
+        "_has_dataclass_arguments",
         "parent",
         "key",
         "info",
@@ -241,6 +316,15 @@ class MapperProperty(
     doc: Optional[str]
     """optional documentation string"""
 
+    _attribute_options: _AttributeOptions
+    """behavioral options for ORM-enabled Python attributes
+
+    .. versionadded:: 2.0
+
+    """
+
+    _has_dataclass_arguments: bool
+
     def _memoized_attr_info(self) -> _InfoType:
         """Info dictionary associated with the object, allowing user-defined
         data to be associated with this :class:`.InspectionAttr`.
@@ -349,9 +433,20 @@ class MapperProperty(
 
         """
 
-    def __init__(self) -> None:
+    def __init__(
+        self, attribute_options: Optional[_AttributeOptions] = None
+    ) -> None:
         self._configure_started = False
         self._configure_finished = False
+        if (
+            attribute_options
+            and attribute_options != _DEFAULT_ATTRIBUTE_OPTIONS
+        ):
+            self._has_dataclass_arguments = True
+            self._attribute_options = attribute_options
+        else:
+            self._has_dataclass_arguments = False
+            self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS
 
     def init(self) -> None:
         """Called after all mappers are created to assemble
index ad3e9f248ddfbb37a444d7b5af3d687cc3255fa9..7655f3ae2f79dea7c259c8c76db31bd9538b44d4 100644 (file)
@@ -30,13 +30,14 @@ from . import strategy_options
 from .descriptor_props import Composite
 from .descriptor_props import ConcreteInheritedProperty
 from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
+from .interfaces import _DEFAULT_ATTRIBUTE_OPTIONS
 from .interfaces import _IntrospectsAnnotations
 from .interfaces import _MapsColumns
 from .interfaces import MapperProperty
 from .interfaces import PropComparator
 from .interfaces import StrategizedProperty
 from .relationships import Relationship
-from .util import _extract_mapped_subtype
 from .util import _orm_full_deannotate
 from .. import exc as sa_exc
 from .. import ForeignKey
@@ -45,6 +46,7 @@ from .. import util
 from ..sql import coercions
 from ..sql import roles
 from ..sql import sqltypes
+from ..sql.base import _NoArg
 from ..sql.elements import SQLCoreOperations
 from ..sql.schema import Column
 from ..sql.schema import SchemaConst
@@ -131,6 +133,7 @@ class ColumnProperty(
         self,
         column: _ORMColumnExprArgument[_T],
         *additional_columns: _ORMColumnExprArgument[Any],
+        attribute_options: Optional[_AttributeOptions] = None,
         group: Optional[str] = None,
         deferred: bool = False,
         raiseload: bool = False,
@@ -141,7 +144,9 @@ class ColumnProperty(
         doc: Optional[str] = None,
         _instrument: bool = True,
     ):
-        super(ColumnProperty, self).__init__()
+        super(ColumnProperty, self).__init__(
+            attribute_options=attribute_options
+        )
         columns = (column,) + additional_columns
         self._orig_columns = [
             coercions.expect(roles.LabeledColumnExprRole, c) for c in columns
@@ -193,6 +198,7 @@ class ColumnProperty(
         cls: Type[Any],
         key: str,
         annotation: Optional[_AnnotationScanType],
+        extracted_mapped_annotation: Optional[_AnnotationScanType],
         is_dataclass_field: bool,
     ) -> None:
         column = self.columns[0]
@@ -487,13 +493,38 @@ class MappedColumn(
         "foreign_keys",
         "_has_nullable",
         "deferred",
+        "_attribute_options",
+        "_has_dataclass_arguments",
     )
 
     deferred: bool
     column: Column[_T]
     foreign_keys: Optional[Set[ForeignKey]]
+    _attribute_options: _AttributeOptions
 
     def __init__(self, *arg: Any, **kw: Any):
+        self._attribute_options = attr_opts = kw.pop(
+            "attribute_options", _DEFAULT_ATTRIBUTE_OPTIONS
+        )
+
+        self._has_dataclass_arguments = False
+
+        if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS:
+            if attr_opts.dataclasses_default_factory is not _NoArg.NO_ARG:
+                self._has_dataclass_arguments = True
+                kw["default"] = attr_opts.dataclasses_default_factory
+            elif attr_opts.dataclasses_default is not _NoArg.NO_ARG:
+                kw["default"] = attr_opts.dataclasses_default
+
+            if (
+                attr_opts.dataclasses_init is not _NoArg.NO_ARG
+                or attr_opts.dataclasses_repr is not _NoArg.NO_ARG
+            ):
+                self._has_dataclass_arguments = True
+
+        if "default" in kw and kw["default"] is _NoArg.NO_ARG:
+            kw.pop("default")
+
         self.deferred = kw.pop("deferred", False)
         self.column = cast("Column[_T]", Column(*arg, **kw))
         self.foreign_keys = self.column.foreign_keys
@@ -509,13 +540,19 @@ class MappedColumn(
         new.deferred = self.deferred
         new.foreign_keys = new.column.foreign_keys
         new._has_nullable = self._has_nullable
+        new._attribute_options = self._attribute_options
+        new._has_dataclass_arguments = self._has_dataclass_arguments
         util.set_creation_order(new)
         return new
 
     @property
     def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
         if self.deferred:
-            return ColumnProperty(self.column, deferred=True)
+            return ColumnProperty(
+                self.column,
+                deferred=True,
+                attribute_options=self._attribute_options,
+            )
         else:
             return None
 
@@ -543,6 +580,7 @@ class MappedColumn(
         cls: Type[Any],
         key: str,
         annotation: Optional[_AnnotationScanType],
+        extracted_mapped_annotation: Optional[_AnnotationScanType],
         is_dataclass_field: bool,
     ) -> None:
         column = self.column
@@ -553,18 +591,15 @@ class MappedColumn(
 
         sqltype = column.type
 
-        argument = _extract_mapped_subtype(
-            annotation,
-            cls,
-            key,
-            MappedColumn,
-            sqltype._isnull and not self.column.foreign_keys,
-            is_dataclass_field,
-        )
-        if argument is None:
-            return
+        if extracted_mapped_annotation is None:
+            if sqltype._isnull and not self.column.foreign_keys:
+                self._raise_for_required(key, cls)
+            else:
+                return
 
-        self._init_column_for_annotation(cls, registry, argument)
+        self._init_column_for_annotation(
+            cls, registry, extracted_mapped_annotation
+        )
 
     @util.preload_module("sqlalchemy.orm.decl_base")
     def declarative_scan_for_composite(
index 1186f0f541965e6b50131475cc2ec394b70050ad..deaf5214720bbf7080844d5bbb804c646d9ccef7 100644 (file)
@@ -49,6 +49,7 @@ from .base import class_mapper
 from .base import LoaderCallableStatus
 from .base import PassiveFlag
 from .base import state_str
+from .interfaces import _AttributeOptions
 from .interfaces import _IntrospectsAnnotations
 from .interfaces import MANYTOMANY
 from .interfaces import MANYTOONE
@@ -56,7 +57,6 @@ from .interfaces import ONETOMANY
 from .interfaces import PropComparator
 from .interfaces import RelationshipDirection
 from .interfaces import StrategizedProperty
-from .util import _extract_mapped_subtype
 from .util import _orm_annotate
 from .util import _orm_deannotate
 from .util import CascadeOptions
@@ -355,6 +355,7 @@ class Relationship(
         post_update: bool = False,
         cascade: str = "save-update, merge",
         viewonly: bool = False,
+        attribute_options: Optional[_AttributeOptions] = None,
         lazy: _LazyLoadArgumentType = "select",
         passive_deletes: Union[Literal["all"], bool] = False,
         passive_updates: bool = True,
@@ -380,7 +381,7 @@ class Relationship(
         _local_remote_pairs: Optional[_ColumnPairs] = None,
         _legacy_inactive_history_style: bool = False,
     ):
-        super(Relationship, self).__init__()
+        super(Relationship, self).__init__(attribute_options=attribute_options)
 
         self.uselist = uselist
         self.argument = argument
@@ -1701,18 +1702,19 @@ class Relationship(
         cls: Type[Any],
         key: str,
         annotation: Optional[_AnnotationScanType],
+        extracted_mapped_annotation: Optional[_AnnotationScanType],
         is_dataclass_field: bool,
     ) -> None:
-        argument = _extract_mapped_subtype(
-            annotation,
-            cls,
-            key,
-            Relationship,
-            self.argument is None,
-            is_dataclass_field,
-        )
-        if argument is None:
-            return
+        argument = extracted_mapped_annotation
+
+        if extracted_mapped_annotation is None:
+
+            if self.argument is None:
+                self._raise_for_required(key, cls)
+            else:
+                return
+
+        argument = extracted_mapped_annotation
 
         if hasattr(argument, "__origin__"):
 
index c50cc5bac84f6bb995b49c0eede3c8accb9872db..520c95672f42c97c9a94ccd21b95b76513e9edd1 100644 (file)
@@ -1927,7 +1927,7 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any:
 
 
 def _is_mapped_annotation(
-    raw_annotation: Union[type, str], cls: Type[Any]
+    raw_annotation: _AnnotationScanType, cls: Type[Any]
 ) -> bool:
     annotated = de_stringify_annotation(cls, raw_annotation)
     return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
@@ -1969,9 +1969,14 @@ def _extract_mapped_subtype(
     attr_cls: Type[Any],
     required: bool,
     is_dataclass_field: bool,
-    superclasses: Optional[Tuple[Type[Any], ...]] = None,
+    expect_mapped: bool = True,
 ) -> Optional[Union[type, str]]:
+    """given an annotation, figure out if it's ``Mapped[something]`` and if
+    so, return the ``something`` part.
 
+    Includes error raise scenarios and other options.
+
+    """
     if raw_annotation is None:
 
         if required:
@@ -1989,25 +1994,29 @@ def _extract_mapped_subtype(
     if is_dataclass_field:
         return annotated
     else:
-        # TODO: there don't seem to be tests for the failure
-        # conditions here
-        if not hasattr(annotated, "__origin__") or (
-            not issubclass(
-                annotated.__origin__,  # type: ignore
-                superclasses if superclasses else attr_cls,
-            )
-            and not issubclass(attr_cls, annotated.__origin__)  # type: ignore
+        if not hasattr(annotated, "__origin__") or not is_origin_of(
+            annotated, "Mapped", module="sqlalchemy.orm"
         ):
-            our_annotated_str = (
-                annotated.__name__
+            anno_name = (
+                getattr(annotated, "__name__", None)
                 if not isinstance(annotated, str)
-                else repr(annotated)
-            )
-            raise sa_exc.ArgumentError(
-                f'Type annotation for "{cls.__name__}.{key}" should use the '
-                f'syntax "Mapped[{our_annotated_str}]" or '
-                f'"{attr_cls.__name__}[{our_annotated_str}]".'
+                else None
             )
+            if anno_name is None:
+                our_annotated_str = repr(annotated)
+            else:
+                our_annotated_str = anno_name
+
+            if expect_mapped:
+                raise sa_exc.ArgumentError(
+                    f'Type annotation for "{cls.__name__}.{key}" '
+                    "should use the "
+                    f'syntax "Mapped[{our_annotated_str}]" or '
+                    f'"{attr_cls.__name__}[{our_annotated_str}]".'
+                )
+
+            else:
+                return annotated
 
         if len(annotated.__args__) != 1:  # type: ignore
             raise sa_exc.ArgumentError(
index 53f76f3ce2f84456fac0d48dfe2445fb67b0695e..d4e4d2dcad78b4c56212b56f7812ecc191dcfa90 100644 (file)
@@ -25,6 +25,7 @@ from .. import event
 from .. import util
 from ..orm import declarative_base
 from ..orm import DeclarativeBase
+from ..orm import MappedAsDataclass
 from ..orm import registry
 from ..schema import sort_tables_and_constraints
 
@@ -90,7 +91,14 @@ class TestBase:
 
     @config.fixture()
     def registry(self, metadata):
-        reg = registry(metadata=metadata)
+        reg = registry(
+            metadata=metadata,
+            type_annotation_map={
+                str: sa.String().with_variant(
+                    sa.String(50), "mysql", "mariadb"
+                )
+            },
+        )
         yield reg
         reg.dispose()
 
@@ -109,6 +117,21 @@ class TestBase:
         yield Base
         Base.registry.dispose()
 
+    @config.fixture
+    def dc_decl_base(self, metadata):
+        _md = metadata
+
+        class Base(MappedAsDataclass, DeclarativeBase):
+            metadata = _md
+            type_annotation_map = {
+                str: sa.String().with_variant(
+                    sa.String(50), "mysql", "mariadb"
+                )
+            }
+
+        yield Base
+        Base.registry.dispose()
+
     @config.fixture()
     def future_connection(self, future_engine, connection):
         # integrate the future_engine and connection fixtures so
index adbbf143f9d7e9cbdba4696501a59d4994428d73..4ce1e7ff32c063f3c1370144297b8a65ba0c8a1a 100644 (file)
@@ -230,7 +230,11 @@ def inspect_formatargspec(
 
 def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
     """Return a sequence of all dataclasses.Field objects associated
-    with a class."""
+    with a class as an already processed dataclass.
+
+    The class must **already be a dataclass** for Field objects to be returned.
+
+    """
 
     if dataclasses.is_dataclass(cls):
         return dataclasses.fields(cls)
@@ -240,7 +244,12 @@ def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
 
 def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
     """Return a sequence of all dataclasses.Field objects associated with
-    a class, excluding those that originate from a superclass."""
+    an already processed dataclass, excluding those that originate from a
+    superclass.
+
+    The class must **already be a dataclass** for Field objects to be returned.
+
+    """
 
     if dataclasses.is_dataclass(cls):
         super_fields: Set[dataclasses.Field[Any]] = set()
index 44e26f60940cdcd3c9bc33af24545f692ad9ad20..454de100bd2d02959ac6c768da3772df6045a66a 100644 (file)
@@ -23,6 +23,14 @@ from typing_extensions import NotRequired as NotRequired  # noqa: F401
 
 from . import compat
 
+
+# more zimports issues
+if True:
+    from typing_extensions import (  # noqa: F401
+        dataclass_transform as dataclass_transform,
+    )
+
+
 _T = TypeVar("_T", bound=Any)
 _KT = TypeVar("_KT")
 _KT_co = TypeVar("_KT_co", covariant=True)
index 29d59ea698bd38560e50006cfc5d598a04d6842a..812d60e91524301b34ebb60d25a31937b4b31b88 100644 (file)
@@ -40,7 +40,6 @@ markers = [
 
 [tool.pyright]
 
-
 reportPrivateUsage = "none"
 reportUnusedClass = "none"
 reportUnusedFunction = "none"
index 5ef2c6f22c8b8772cf309790c644a0509060bc24..0df41dc7bda582e81c56917590bcaab4ef7a021e 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -37,7 +37,7 @@ package_dir =
 install_requires =
     importlib-metadata;python_version<"3.8"
     greenlet != 0.4.17;(platform_machine=='aarch64' or (platform_machine=='ppc64le' or (platform_machine=='x86_64' or (platform_machine=='amd64' or (platform_machine=='AMD64' or (platform_machine=='win32' or platform_machine=='WIN32'))))))
-    typing-extensions >= 4
+    typing-extensions >= 4.1.0
 
 [options.extras_require]
 asyncio =
diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py
new file mode 100644 (file)
index 0000000..aac8737
--- /dev/null
@@ -0,0 +1,816 @@
+import dataclasses
+import inspect as pyinspect
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Type
+from unittest import mock
+
+from sqlalchemy import Column
+from sqlalchemy import exc
+from sqlalchemy import ForeignKey
+from sqlalchemy import inspect
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import testing
+from sqlalchemy.orm import column_property
+from sqlalchemy.orm import composite
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import deferred
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import MappedAsDataclass
+from sqlalchemy.orm import MappedColumn
+from sqlalchemy.orm import registry as _RegistryType
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
+from sqlalchemy.orm import synonym
+from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import eq_regex
+from sqlalchemy.testing import expect_raises
+from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
+from sqlalchemy.testing import ne_
+
+
+class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
+    def test_basic_constructor_repr_base_cls(
+        self, dc_decl_base: Type[MappedAsDataclass]
+    ):
+        class A(dc_decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            data: Mapped[str]
+
+            x: Mapped[Optional[int]] = mapped_column(default=None)
+
+            bs: Mapped[List["B"]] = relationship(  # noqa: F821
+                default_factory=list
+            )
+
+        class B(dc_decl_base):
+            __tablename__ = "b"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            a_id = mapped_column(ForeignKey("a.id"), init=False)
+            data: Mapped[str]
+            x: Mapped[Optional[int]] = mapped_column(default=None)
+
+        A.__qualname__ = "some_module.A"
+        B.__qualname__ = "some_module.B"
+
+        eq_(
+            pyinspect.getfullargspec(A.__init__),
+            pyinspect.FullArgSpec(
+                args=["self", "data", "x", "bs"],
+                varargs=None,
+                varkw=None,
+                defaults=(None, mock.ANY),
+                kwonlyargs=[],
+                kwonlydefaults=None,
+                annotations={},
+            ),
+        )
+        eq_(
+            pyinspect.getfullargspec(B.__init__),
+            pyinspect.FullArgSpec(
+                args=["self", "data", "x"],
+                varargs=None,
+                varkw=None,
+                defaults=(None,),
+                kwonlyargs=[],
+                kwonlydefaults=None,
+                annotations={},
+            ),
+        )
+
+        a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)])
+        eq_(
+            repr(a2),
+            "some_module.A(id=None, data='10', x=5, "
+            "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), "
+            "some_module.B(id=None, data='data2', a_id=None, x=12)])",
+        )
+
+        a3 = A("data")
+        eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])")
+
+    def test_basic_constructor_repr_cls_decorator(
+        self, registry: _RegistryType
+    ):
+        @registry.mapped_as_dataclass()
+        class A:
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            data: Mapped[str]
+
+            x: Mapped[Optional[int]] = mapped_column(default=None)
+
+            bs: Mapped[List["B"]] = relationship(  # noqa: F821
+                default_factory=list
+            )
+
+        @registry.mapped_as_dataclass()
+        class B:
+            __tablename__ = "b"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            a_id = mapped_column(ForeignKey("a.id"), init=False)
+            data: Mapped[str]
+            x: Mapped[Optional[int]] = mapped_column(default=None)
+
+        A.__qualname__ = "some_module.A"
+        B.__qualname__ = "some_module.B"
+
+        eq_(
+            pyinspect.getfullargspec(A.__init__),
+            pyinspect.FullArgSpec(
+                args=["self", "data", "x", "bs"],
+                varargs=None,
+                varkw=None,
+                defaults=(None, mock.ANY),
+                kwonlyargs=[],
+                kwonlydefaults=None,
+                annotations={},
+            ),
+        )
+        eq_(
+            pyinspect.getfullargspec(B.__init__),
+            pyinspect.FullArgSpec(
+                args=["self", "data", "x"],
+                varargs=None,
+                varkw=None,
+                defaults=(None,),
+                kwonlyargs=[],
+                kwonlydefaults=None,
+                annotations={},
+            ),
+        )
+
+        a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)])
+        eq_(
+            repr(a2),
+            "some_module.A(id=None, data='10', x=5, "
+            "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), "
+            "some_module.B(id=None, data='data2', a_id=None, x=12)])",
+        )
+
+        a3 = A("data")
+        eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])")
+
+    def test_default_fn(self, dc_decl_base: Type[MappedAsDataclass]):
+        class A(dc_decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            data: Mapped[str] = mapped_column(default="d1")
+            data2: Mapped[str] = mapped_column(default_factory=lambda: "d2")
+
+        a1 = A()
+        eq_(a1.data, "d1")
+        eq_(a1.data2, "d2")
+
+    def test_default_factory_vs_collection_class(
+        self, dc_decl_base: Type[MappedAsDataclass]
+    ):
+        # this is currently the error raised by dataclasses.  We can instead
+        # do this validation ourselves, but overall I don't know that we
+        # can hit every validation and rule that's in dataclasses
+        with expect_raises_message(
+            ValueError, "cannot specify both default and default_factory"
+        ):
+
+            class A(dc_decl_base):
+                __tablename__ = "a"
+
+                id: Mapped[int] = mapped_column(primary_key=True, init=False)
+                data: Mapped[str] = mapped_column(
+                    default="d1", default_factory=lambda: "d2"
+                )
+
+    def test_inheritance(self, dc_decl_base: Type[MappedAsDataclass]):
+        class Person(dc_decl_base):
+            __tablename__ = "person"
+            person_id: Mapped[int] = mapped_column(
+                primary_key=True, init=False
+            )
+            name: Mapped[str]
+            type: Mapped[str] = mapped_column(init=False)
+
+            __mapper_args__ = {"polymorphic_on": type}
+
+        class Engineer(Person):
+            __tablename__ = "engineer"
+
+            person_id: Mapped[int] = mapped_column(
+                ForeignKey("person.person_id"), primary_key=True, init=False
+            )
+
+            status: Mapped[str] = mapped_column(String(30))
+            engineer_name: Mapped[str]
+            primary_language: Mapped[str]
+
+        e1 = Engineer("nm", "st", "en", "pl")
+        eq_(e1.name, "nm")
+        eq_(e1.status, "st")
+        eq_(e1.engineer_name, "en")
+        eq_(e1.primary_language, "pl")
+
+    def test_integrated_dc(self, dc_decl_base: Type[MappedAsDataclass]):
+        """We will be telling users "this is a dataclass that is also
+        mapped". Therefore, they will want *any* kind of attribute to do what
+        it would normally do in a dataclass, including normal types without any
+        field and explicit use of dataclasses.field(). additionally, we'd like
+        ``Mapped`` to mean "persist this attribute". So the absence of
+        ``Mapped`` should also mean something too.
+
+        """
+
+        class A(dc_decl_base):
+            __tablename__ = "a"
+
+            ctrl_one: str
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            data: Mapped[str]
+            some_field: int = dataclasses.field(default=5)
+
+            some_none_field: Optional[str] = None
+
+        a1 = A("ctrlone", "datafield")
+        eq_(a1.some_field, 5)
+        eq_(a1.some_none_field, None)
+
+        # only Mapped[] is mapped
+        self.assert_compile(select(A), "SELECT a.id, a.data FROM a")
+        eq_(
+            pyinspect.getfullargspec(A.__init__),
+            pyinspect.FullArgSpec(
+                args=[
+                    "self",
+                    "ctrl_one",
+                    "data",
+                    "some_field",
+                    "some_none_field",
+                ],
+                varargs=None,
+                varkw=None,
+                defaults=(5, None),
+                kwonlyargs=[],
+                kwonlydefaults=None,
+                annotations={},
+            ),
+        )
+
+    def test_dc_on_top_of_non_dc(self, decl_base: Type[DeclarativeBase]):
+        class Person(decl_base):
+            __tablename__ = "person"
+            person_id: Mapped[int] = mapped_column(primary_key=True)
+            name: Mapped[str]
+            type: Mapped[str] = mapped_column()
+
+            __mapper_args__ = {"polymorphic_on": type}
+
+        class Engineer(MappedAsDataclass, Person):
+            __tablename__ = "engineer"
+
+            person_id: Mapped[int] = mapped_column(
+                ForeignKey("person.person_id"), primary_key=True, init=False
+            )
+
+            status: Mapped[str] = mapped_column(String(30))
+            engineer_name: Mapped[str]
+            primary_language: Mapped[str]
+
+        e1 = Engineer("st", "en", "pl")
+        eq_(e1.status, "st")
+        eq_(e1.engineer_name, "en")
+        eq_(e1.primary_language, "pl")
+
+        eq_(
+            pyinspect.getfullargspec(Person.__init__),
+            # the boring **kw __init__
+            pyinspect.FullArgSpec(
+                args=["self"],
+                varargs=None,
+                varkw="kwargs",
+                defaults=None,
+                kwonlyargs=[],
+                kwonlydefaults=None,
+                annotations={},
+            ),
+        )
+
+        eq_(
+            pyinspect.getfullargspec(Engineer.__init__),
+            # the exciting dataclasses __init__
+            pyinspect.FullArgSpec(
+                args=["self", "status", "engineer_name", "primary_language"],
+                varargs=None,
+                varkw=None,
+                defaults=None,
+                kwonlyargs=[],
+                kwonlydefaults=None,
+                annotations={},
+            ),
+        )
+
+
+class RelationshipDefaultFactoryTest(fixtures.TestBase):
+    def test_list(self, dc_decl_base: Type[MappedAsDataclass]):
+        class A(dc_decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+            bs: Mapped[List["B"]] = relationship(  # noqa: F821
+                default_factory=lambda: [B(data="hi")]
+            )
+
+        class B(dc_decl_base):
+            __tablename__ = "b"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            a_id = mapped_column(ForeignKey("a.id"), init=False)
+            data: Mapped[str]
+
+        a1 = A()
+        eq_(a1.bs[0].data, "hi")
+
+    def test_set(self, dc_decl_base: Type[MappedAsDataclass]):
+        class A(dc_decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+            bs: Mapped[Set["B"]] = relationship(  # noqa: F821
+                default_factory=lambda: {B(data="hi")}
+            )
+
+        class B(dc_decl_base, unsafe_hash=True):
+            __tablename__ = "b"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            a_id = mapped_column(ForeignKey("a.id"), init=False)
+            data: Mapped[str]
+
+        a1 = A()
+        eq_(a1.bs.pop().data, "hi")
+
+    def test_oh_no_mismatch(self, dc_decl_base: Type[MappedAsDataclass]):
+        class A(dc_decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+            bs: Mapped[Set["B"]] = relationship(  # noqa: F821
+                default_factory=lambda: [B(data="hi")]
+            )
+
+        class B(dc_decl_base, unsafe_hash=True):
+            __tablename__ = "b"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            a_id = mapped_column(ForeignKey("a.id"), init=False)
+            data: Mapped[str]
+
+        # old school collection mismatch error FTW
+        with expect_raises_message(
+            TypeError, "Incompatible collection type: list is not set-like"
+        ):
+            A()
+
+    def test_replace_operation_works_w_history_etc(
+        self, registry: _RegistryType
+    ):
+        @registry.mapped_as_dataclass
+        class A:
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            data: Mapped[str]
+
+            x: Mapped[Optional[int]] = mapped_column(default=None)
+
+            bs: Mapped[List["B"]] = relationship(  # noqa: F821
+                default_factory=list
+            )
+
+        @registry.mapped_as_dataclass
+        class B:
+            __tablename__ = "b"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            a_id = mapped_column(ForeignKey("a.id"), init=False)
+            data: Mapped[str]
+            x: Mapped[Optional[int]] = mapped_column(default=None)
+
+        registry.metadata.create_all(testing.db)
+
+        with Session(testing.db) as sess:
+            a1 = A("data", 10, [B("b1"), B("b2", x=5), B("b3")])
+            sess.add(a1)
+            sess.commit()
+
+            a2 = dataclasses.replace(a1, x=12, bs=[B("b4")])
+
+            assert a1 in sess
+            assert not sess.is_modified(a1, include_collections=True)
+            assert a2 not in sess
+            eq_(inspect(a2).attrs.x.history, ([12], (), ()))
+            sess.add(a2)
+            sess.commit()
+
+            eq_(sess.scalars(select(A.x).order_by(A.id)).all(), [10, 12])
+            eq_(
+                sess.scalars(select(B.data).order_by(B.id)).all(),
+                ["b1", "b2", "b3", "b4"],
+            )
+
+    def test_post_init(self, registry: _RegistryType):
+        @registry.mapped_as_dataclass
+        class A:
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            data: Mapped[str] = mapped_column(init=False)
+
+            def __post_init__(self):
+                self.data = "some data"
+
+        a1 = A()
+        eq_(a1.data, "some data")
+
+    def test_no_field_args_w_new_style(self, registry: _RegistryType):
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            "SQLAlchemy mapped dataclasses can't consume mapping information",
+        ):
+
+            @registry.mapped_as_dataclass()
+            class A:
+                __tablename__ = "a"
+                __sa_dataclass_metadata_key__ = "sa"
+
+                account_id: int = dataclasses.field(
+                    init=False,
+                    metadata={"sa": Column(Integer, primary_key=True)},
+                )
+
+    def test_no_field_args_w_new_style_two(self, registry: _RegistryType):
+        @dataclasses.dataclass
+        class Base:
+            pass
+
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            "SQLAlchemy mapped dataclasses can't consume mapping information",
+        ):
+
+            @registry.mapped_as_dataclass()
+            class A(Base):
+                __tablename__ = "a"
+                __sa_dataclass_metadata_key__ = "sa"
+
+                account_id: int = dataclasses.field(
+                    init=False,
+                    metadata={"sa": Column(Integer, primary_key=True)},
+                )
+
+
+class DataclassArgsTest(fixtures.TestBase):
+    dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash")
+
+    @testing.fixture(params=dc_arg_names)
+    def dc_argument_fixture(self, request: Any, registry: _RegistryType):
+        name = request.param
+
+        args = {n: n == name for n in self.dc_arg_names}
+        if args["order"]:
+            args["eq"] = True
+        yield args
+
+    @testing.fixture(
+        params=["mapped_column", "synonym", "deferred", "column_property"]
+    )
+    def mapped_expr_constructor(self, request):
+        name = request.param
+
+        if name == "mapped_column":
+            yield mapped_column(default=7, init=True)
+        elif name == "synonym":
+            yield synonym("some_int", default=7, init=True)
+        elif name == "deferred":
+            yield deferred(Column(Integer), default=7, init=True)
+        elif name == "column_property":
+            yield column_property(Column(Integer), default=7, init=True)
+
+    def test_attrs_rejected_if_not_a_dc(
+        self, mapped_expr_constructor, decl_base: Type[DeclarativeBase]
+    ):
+        if isinstance(mapped_expr_constructor, MappedColumn):
+            unwanted_args = "'init'"
+        else:
+            unwanted_args = "'default', 'init'"
+        with expect_raises_message(
+            exc.ArgumentError,
+            r"Attribute 'x' on class .*A.* includes dataclasses "
+            r"argument\(s\): "
+            rf"{unwanted_args} but class does not specify SQLAlchemy native "
+            "dataclass configuration",
+        ):
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+
+                x: Mapped[int] = mapped_expr_constructor
+
+    def _assert_cls(self, cls, dc_arguments):
+
+        if dc_arguments["init"]:
+
+            def create(data, x):
+                return cls(data, x)
+
+        else:
+
+            def create(data, x):
+                a1 = cls()
+                a1.data = data
+                a1.x = x
+                return a1
+
+        for n in self.dc_arg_names:
+            if dc_arguments[n]:
+                getattr(self, f"_assert_{n}")(cls, create, dc_arguments)
+            else:
+                getattr(self, f"_assert_not_{n}")(cls, create, dc_arguments)
+
+            if dc_arguments["init"]:
+                a1 = cls("some data")
+                eq_(a1.x, 7)
+
+        a1 = create("some data", 15)
+        some_int = a1.some_int
+        eq_(
+            dataclasses.asdict(a1),
+            {"data": "some data", "id": None, "some_int": some_int, "x": 15},
+        )
+        eq_(dataclasses.astuple(a1), (None, "some data", some_int, 15))
+
+    def _assert_unsafe_hash(self, cls, create, dc_arguments):
+        a1 = create("d1", 5)
+        hash(a1)
+
+    def _assert_not_unsafe_hash(self, cls, create, dc_arguments):
+        a1 = create("d1", 5)
+
+        if dc_arguments["eq"]:
+            with expect_raises(TypeError):
+                hash(a1)
+        else:
+            hash(a1)
+
+    def _assert_eq(self, cls, create, dc_arguments):
+        a1 = create("d1", 5)
+        a2 = create("d2", 10)
+        a3 = create("d1", 5)
+
+        eq_(a1, a3)
+        ne_(a1, a2)
+
+    def _assert_not_eq(self, cls, create, dc_arguments):
+        a1 = create("d1", 5)
+        a2 = create("d2", 10)
+        a3 = create("d1", 5)
+
+        eq_(a1, a1)
+        ne_(a1, a3)
+        ne_(a1, a2)
+
+    def _assert_order(self, cls, create, dc_arguments):
+        is_false(create("g", 10) < create("b", 7))
+
+        is_true(create("g", 10) > create("b", 7))
+
+        is_false(create("g", 10) <= create("b", 7))
+
+        is_true(create("g", 10) >= create("b", 7))
+
+        eq_(
+            list(sorted([create("g", 10), create("g", 5), create("b", 7)])),
+            [
+                create("b", 7),
+                create("g", 5),
+                create("g", 10),
+            ],
+        )
+
+    def _assert_not_order(self, cls, create, dc_arguments):
+        with expect_raises(TypeError):
+            create("g", 10) < create("b", 7)
+
+        with expect_raises(TypeError):
+            create("g", 10) > create("b", 7)
+
+        with expect_raises(TypeError):
+            create("g", 10) <= create("b", 7)
+
+        with expect_raises(TypeError):
+            create("g", 10) >= create("b", 7)
+
+    def _assert_repr(self, cls, create, dc_arguments):
+        a1 = create("some data", 12)
+        eq_regex(repr(a1), r".*A\(id=None, data='some data', x=12\)")
+
+    def _assert_not_repr(self, cls, create, dc_arguments):
+        a1 = create("some data", 12)
+        eq_regex(repr(a1), r"<.*A object at 0x.*>")
+
+    def _assert_init(self, cls, create, dc_arguments):
+        a1 = cls("some data", 5)
+
+        eq_(a1.data, "some data")
+        eq_(a1.x, 5)
+
+        a2 = cls(data="some data", x=5)
+        eq_(a2.data, "some data")
+        eq_(a2.x, 5)
+
+        a3 = cls(data="some data")
+        eq_(a3.data, "some data")
+        eq_(a3.x, 7)
+
+    def _assert_not_init(self, cls, create, dc_arguments):
+
+        with expect_raises(TypeError):
+            cls("Some data", 5)
+
+        # we run real "dataclasses" on the class.  so with init=False, it
+        # doesn't touch what was there, and the SQLA default constructor
+        # gets put on.
+        a1 = cls(data="some data")
+        eq_(a1.data, "some data")
+        eq_(a1.x, None)
+
+        a1 = cls()
+        eq_(a1.data, None)
+
+        # no constructor, it sets None for x...ok
+        eq_(a1.x, None)
+
+    def test_dc_arguments_decorator(
+        self,
+        dc_argument_fixture,
+        mapped_expr_constructor,
+        registry: _RegistryType,
+    ):
+        @registry.mapped_as_dataclass(**dc_argument_fixture)
+        class A:
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            data: Mapped[str]
+
+            some_int: Mapped[int] = mapped_column(init=False, repr=False)
+
+            x: Mapped[Optional[int]] = mapped_expr_constructor
+
+        self._assert_cls(A, dc_argument_fixture)
+
+    def test_dc_arguments_base(
+        self,
+        dc_argument_fixture,
+        mapped_expr_constructor,
+        registry: _RegistryType,
+    ):
+        reg = registry
+
+        class Base(MappedAsDataclass, DeclarativeBase, **dc_argument_fixture):
+            registry = reg
+
+        class A(Base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            data: Mapped[str]
+
+            some_int: Mapped[int] = mapped_column(init=False, repr=False)
+
+            x: Mapped[Optional[int]] = mapped_expr_constructor
+
+        self.A = A
+
+    def test_dc_arguments_perclass(
+        self,
+        dc_argument_fixture,
+        mapped_expr_constructor,
+        decl_base: Type[DeclarativeBase],
+    ):
+        class A(MappedAsDataclass, decl_base, **dc_argument_fixture):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            data: Mapped[str]
+
+            some_int: Mapped[int] = mapped_column(init=False, repr=False)
+
+            x: Mapped[Optional[int]] = mapped_expr_constructor
+
+        self.A = A
+
+
+class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
+    __dialect__ = "default"
+
+    def test_composite_setup(self, dc_decl_base: Type[MappedAsDataclass]):
+        @dataclasses.dataclass
+        class Point:
+            x: int
+            y: int
+
+        class Edge(dc_decl_base):
+            __tablename__ = "edge"
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            graph_id: Mapped[int] = mapped_column(
+                ForeignKey("graph.id"), init=False
+            )
+
+            start: Mapped[Point] = composite(
+                Point, mapped_column("x1"), mapped_column("y1"), default=None
+            )
+
+            end: Mapped[Point] = composite(
+                Point, mapped_column("x2"), mapped_column("y2"), default=None
+            )
+
+        class Graph(dc_decl_base):
+            __tablename__ = "graph"
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+            edges: Mapped[List[Edge]] = relationship()
+
+        Point.__qualname__ = "mymodel.Point"
+        Edge.__qualname__ = "mymodel.Edge"
+        Graph.__qualname__ = "mymodel.Graph"
+        g = Graph(
+            edges=[
+                Edge(start=Point(1, 2), end=Point(3, 4)),
+                Edge(start=Point(7, 8), end=Point(5, 6)),
+            ]
+        )
+        eq_(
+            repr(g),
+            "mymodel.Graph(id=None, edges=[mymodel.Edge(id=None, "
+            "graph_id=None, start=mymodel.Point(x=1, y=2), "
+            "end=mymodel.Point(x=3, y=4)), "
+            "mymodel.Edge(id=None, graph_id=None, "
+            "start=mymodel.Point(x=7, y=8), end=mymodel.Point(x=5, y=6))])",
+        )
+
+    def test_named_setup(self, dc_decl_base: Type[MappedAsDataclass]):
+        @dataclasses.dataclass
+        class Address:
+            street: str
+            state: str
+            zip_: str
+
+        class User(dc_decl_base):
+            __tablename__ = "user"
+
+            id: Mapped[int] = mapped_column(
+                primary_key=True, init=False, repr=False
+            )
+            name: Mapped[str] = mapped_column()
+
+            address: Mapped[Address] = composite(
+                Address,
+                mapped_column(),
+                mapped_column(),
+                mapped_column("zip"),
+                default=None,
+            )
+
+        Address.__qualname__ = "mymodule.Address"
+        User.__qualname__ = "mymodule.User"
+        u = User(
+            name="user 1",
+            address=Address("123 anywhere street", "NY", "12345"),
+        )
+        u2 = User("u2")
+        eq_(
+            repr(u),
+            "mymodule.User(name='user 1', "
+            "address=mymodule.Address(street='123 anywhere street', "
+            "state='NY', zip_='12345'))",
+        )
+        eq_(repr(u2), "mymodule.User(name='u2', address=None)")
index d7d19821c616c092bfb61fc5e6937c98b35ec3c7..8657354397c5f2463e6431de31a90ee6fc9e97a7 100644 (file)
@@ -190,6 +190,18 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(User.__table__.c.data.nullable)
         assert isinstance(User.__table__.c.created_at.type, DateTime)
 
+    def test_column_default(self, decl_base):
+        class MyClass(decl_base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column(default="some default")
+
+        mc = MyClass()
+        assert "data" not in mc.__dict__
+
+        eq_(MyClass.__table__.c.data.default.arg, "some default")
+
     def test_anno_w_fixed_table(self, decl_base):
         users = Table(
             "users",
@@ -959,7 +971,7 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         with expect_raises_message(
             ArgumentError,
             r"Type annotation for \"User.address\" should use the syntax "
-            r"\"Mapped\['Address'\]\" or \"MappedColumn\['Address'\]\"",
+            r"\"Mapped\['Address'\]\"",
         ):
 
             class User(decl_base):
@@ -1068,6 +1080,38 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             # round trip!
             eq_(u1.address, Address("123 anywhere street", "NY", "12345"))
 
+    def test_cls_annotated_no_mapped_cols_setup(self, decl_base):
+        @dataclasses.dataclass
+        class Address:
+            street: str
+            state: str
+            zip_: str
+
+        class User(decl_base):
+            __tablename__ = "user"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            name: Mapped[str] = mapped_column()
+
+            address: Mapped[Address] = composite()
+
+        decl_base.metadata.create_all(testing.db)
+
+        with fixture_session() as sess:
+            sess.add(
+                User(
+                    name="user 1",
+                    address=Address("123 anywhere street", "NY", "12345"),
+                )
+            )
+            sess.commit()
+
+        with fixture_session() as sess:
+            u1 = sess.scalar(select(User))
+
+            # round trip!
+            eq_(u1.address, Address("123 anywhere street", "NY", "12345"))
+
     def test_one_col_setup(self, decl_base):
         @dataclasses.dataclass
         class Address: