]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support kw_only and match_args in dataclass mapping
authorFederico Caselli <cfederico87@gmail.com>
Wed, 3 Aug 2022 21:50:19 +0000 (23:50 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 4 Aug 2022 16:28:57 +0000 (18:28 +0200)
Fixes: #8346
Change-Id: I964629e3bd25221bf6df6ab31c59b3ce1983cd9a

doc/build/orm/dataclasses.rst
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/interfaces.py
test/orm/declarative/test_dc_transforms.py

index 4f152a0092802af97939f39092da3e1665b98a09..18817d2fe480792dc2fa632de3e1c538d837ea60 100644 (file)
@@ -108,9 +108,9 @@ Class level feature configuration
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
 Support for dataclasses features is partial.  Currently **supported** are
-the ``init``, ``repr``, ``eq``, ``order`` and ``unsafe_hash`` features.
-Currently **not supported** are the ``frozen``, ``slots``, ``match_args``,
-and ``kw_only`` features.
+the ``init``, ``repr``, ``eq``, ``order`` and ``unsafe_hash`` features,
+``match_args`` and ``kw_only`` are supported on Python 3.10+.
+Currently **not supported** are the ``frozen`` and ``slots`` features.
 
 When using the mixin class form with :class:`_orm.MappedAsDataclass`,
 class configuration arguments are passed as class-level parameters::
index 2f1128f94a86e541b4b15963dfa138b52cfaaeb1..8dfec0fb1c2c299f3b148103c28e07abd7063c5f 100644 (file)
@@ -111,6 +111,7 @@ def mapped_column(
     repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
     default: Optional[Any] = _NoArg.NO_ARG,
     default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+    kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
     nullable: Optional[
         Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
     ] = SchemaConst.NULL_UNSPECIFIED,
@@ -242,6 +243,9 @@ def mapped_column(
      specifies a default-value generation function that will take place
      as part of the ``__init__()``
      method as generated by the dataclass process.
+    :param kw_only: Specific to
+     :ref:`orm_declarative_native_dataclasses`, indicates if this field
+     should be marked as keyword-only when generating the ``__init__()``.
 
     :param \**kw: All remaining keyword argments are passed through to the
      constructor for the :class:`_schema.Column`.
@@ -257,10 +261,7 @@ def mapped_column(
         autoincrement=autoincrement,
         insert_default=insert_default,
         attribute_options=_AttributeOptions(
-            init,
-            repr,
-            default,
-            default_factory,
+            init, repr, default, default_factory, kw_only
         ),
         doc=doc,
         key=key,
@@ -293,6 +294,7 @@ def column_property(
     repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
     default: Optional[Any] = _NoArg.NO_ARG,
     default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+    kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
     active_history: bool = False,
     expire_on_flush: bool = True,
     info: Optional[_InfoType] = None,
@@ -385,10 +387,7 @@ def column_property(
         column,
         *additional_columns,
         attribute_options=_AttributeOptions(
-            init,
-            repr,
-            default,
-            default_factory,
+            init, repr, default, default_factory, kw_only
         ),
         group=group,
         deferred=deferred,
@@ -414,6 +413,7 @@ def composite(
     repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
     default: Optional[Any] = _NoArg.NO_ARG,
     default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+    kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
     info: Optional[_InfoType] = None,
     doc: Optional[str] = None,
     **__kw: Any,
@@ -434,6 +434,7 @@ def composite(
     repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
     default: Optional[Any] = _NoArg.NO_ARG,
     default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+    kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
     info: Optional[_InfoType] = None,
     doc: Optional[str] = None,
     **__kw: Any,
@@ -455,6 +456,7 @@ def composite(
     repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
     default: Optional[Any] = _NoArg.NO_ARG,
     default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+    kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
     info: Optional[_InfoType] = None,
     doc: Optional[str] = None,
     **__kw: Any,
@@ -517,6 +519,9 @@ def composite(
      specifies a default-value generation function that will take place
      as part of the ``__init__()``
      method as generated by the dataclass process.
+    :param kw_only: Specific to
+     :ref:`orm_declarative_native_dataclasses`, indicates if this field
+     should be marked as keyword-only when generating the ``__init__()``.
 
     """
     if __kw:
@@ -526,10 +531,7 @@ def composite(
         _class_or_attr,
         *attrs,
         attribute_options=_AttributeOptions(
-            init,
-            repr,
-            default,
-            default_factory,
+            init, repr, default, default_factory, kw_only
         ),
         group=group,
         deferred=deferred,
@@ -752,6 +754,7 @@ def relationship(
     repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
     default: Union[_NoArg, _T] = _NoArg.NO_ARG,
     default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+    kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
     lazy: _LazyLoadArgumentType = "select",
     passive_deletes: Union[Literal["all"], bool] = False,
     passive_updates: bool = True,
@@ -1549,7 +1552,9 @@ def relationship(
      specifies a default-value generation function that will take place
      as part of the ``__init__()``
      method as generated by the dataclass process.
-
+    :param kw_only: Specific to
+     :ref:`orm_declarative_native_dataclasses`, indicates if this field
+     should be marked as keyword-only when generating the ``__init__()``.
 
 
     """
@@ -1568,10 +1573,7 @@ def relationship(
         cascade=cascade,
         viewonly=viewonly,
         attribute_options=_AttributeOptions(
-            init,
-            repr,
-            default,
-            default_factory,
+            init, repr, default, default_factory, kw_only
         ),
         lazy=lazy,
         passive_deletes=passive_deletes,
@@ -1604,6 +1606,7 @@ def synonym(
     repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
     default: Union[_NoArg, _T] = _NoArg.NO_ARG,
     default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+    kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
     info: Optional[_InfoType] = None,
     doc: Optional[str] = None,
 ) -> Synonym[Any]:
@@ -1716,10 +1719,7 @@ def synonym(
         descriptor=descriptor,
         comparator_factory=comparator_factory,
         attribute_options=_AttributeOptions(
-            init,
-            repr,
-            default,
-            default_factory,
+            init, repr, default, default_factory, kw_only
         ),
         doc=doc,
         info=info,
@@ -1848,6 +1848,7 @@ def deferred(
     repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
     default: Optional[Any] = _NoArg.NO_ARG,
     default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+    kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
     active_history: bool = False,
     expire_on_flush: bool = True,
     info: Optional[_InfoType] = None,
@@ -1884,10 +1885,7 @@ def deferred(
         column,
         *additional_columns,
         attribute_options=_AttributeOptions(
-            init,
-            repr,
-            default,
-            default_factory,
+            init, repr, default, default_factory, kw_only
         ),
         group=group,
         deferred=True,
@@ -1937,6 +1935,7 @@ def query_expression(
             repr,
             _NoArg.NO_ARG,
             _NoArg.NO_ARG,
+            _NoArg.NO_ARG,
         ),
         expire_on_flush=expire_on_flush,
         info=info,
index 500f2786e3867abd67ada266aa675852eaee3432..d34ec8c93e8994a68c0000725a4ed44695ba39d2 100644 (file)
@@ -578,6 +578,8 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
         eq: Union[_NoArg, bool] = _NoArg.NO_ARG,
         order: Union[_NoArg, bool] = _NoArg.NO_ARG,
         unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        match_args: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
     ) -> None:
 
         apply_dc_transforms: _DataclassArguments = {
@@ -586,6 +588,8 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
             "eq": eq,
             "order": order,
             "unsafe_hash": unsafe_hash,
+            "match_args": match_args,
+            "kw_only": kw_only,
         }
 
         if hasattr(cls, "_sa_apply_dc_transforms"):
@@ -1313,6 +1317,8 @@ class registry:
         eq: Union[_NoArg, bool] = ...,
         order: Union[_NoArg, bool] = ...,
         unsafe_hash: Union[_NoArg, bool] = ...,
+        match_args: Union[_NoArg, bool] = ...,
+        kw_only: Union[_NoArg, bool] = ...,
     ) -> Callable[[Type[_O]], Type[_O]]:
         ...
 
@@ -1325,6 +1331,8 @@ class registry:
         eq: Union[_NoArg, bool] = _NoArg.NO_ARG,
         order: Union[_NoArg, bool] = _NoArg.NO_ARG,
         unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        match_args: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
     ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]:
         """Class decorator that will apply the Declarative mapping process
         to a given class, and additionally convert the class to be a
@@ -1348,6 +1356,8 @@ class registry:
                 "eq": eq,
                 "order": order,
                 "unsafe_hash": unsafe_hash,
+                "match_args": match_args,
+                "kw_only": kw_only,
             }
             _as_declarative(self, cls, cls.__dict__)
             return cls
index 108027dd50fd980e1dd13abaae350de932ec76b7..e8d6e4c1b1ffb0df709ae23b7710d42f2632f469 100644 (file)
@@ -106,6 +106,8 @@ class _DataclassArguments(TypedDict):
     eq: Union[_NoArg, bool]
     order: Union[_NoArg, bool]
     unsafe_hash: Union[_NoArg, bool]
+    match_args: Union[_NoArg, bool]
+    kw_only: Union[_NoArg, bool]
 
 
 def _declared_mapping_info(
@@ -1030,22 +1032,20 @@ class _ClassScanMapperConfig(_MapperConfig):
 
     @classmethod
     def _assert_dc_arguments(cls, arguments: _DataclassArguments) -> None:
-        disallowed_args = set(arguments).difference(
-            {
-                "init",
-                "repr",
-                "order",
-                "eq",
-                "unsafe_hash",
-            }
-        )
+        allowed = {
+            "init",
+            "repr",
+            "order",
+            "eq",
+            "unsafe_hash",
+            "kw_only",
+            "match_args",
+        }
+        disallowed_args = set(arguments).difference(allowed)
         if disallowed_args:
+            msg = ", ".join(f"{arg!r}" for arg in sorted(disallowed_args))
             raise exc.ArgumentError(
-                f"Dataclass argument(s) "
-                f"""{
-                    ', '.join(f'{arg!r}'
-                    for arg in sorted(disallowed_args))
-                } are not accepted"""
+                f"Dataclass argument(s) {msg} are not accepted"
             )
 
     def _collect_annotation(
index a9ae4436f17bba5cfde29ee2b53cea7de76de137..0f66566b0fb1eaf7ba5707f612ca04ddaea69919 100644 (file)
@@ -187,6 +187,7 @@ class _AttributeOptions(NamedTuple):
     dataclasses_repr: Union[_NoArg, bool]
     dataclasses_default: Union[_NoArg, Any]
     dataclasses_default_factory: Union[_NoArg, Callable[[], Any]]
+    dataclasses_kw_only: Union[_NoArg, bool]
 
     def _as_dataclass_field(self) -> Any:
         """Return a ``dataclasses.Field`` object given these arguments."""
@@ -200,6 +201,8 @@ class _AttributeOptions(NamedTuple):
             kw["init"] = self.dataclasses_init
         if self.dataclasses_repr is not _NoArg.NO_ARG:
             kw["repr"] = self.dataclasses_repr
+        if self.dataclasses_kw_only is not _NoArg.NO_ARG:
+            kw["kw_only"] = self.dataclasses_kw_only
 
         return dataclasses.field(**kw)
 
@@ -226,7 +229,7 @@ class _AttributeOptions(NamedTuple):
 
 
 _DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions(
-    _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG
+    _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG
 )
 
 
index f5111bfc79915fc1125d136d9fa9e4ac5243871c..ca656a8682ea6129a4fd24a6aeb9762d3e5febb4 100644 (file)
@@ -24,6 +24,7 @@ from sqlalchemy.orm import column_property
 from sqlalchemy.orm import composite
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import deferred
+from sqlalchemy.orm import interfaces
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import MappedAsDataclass
@@ -42,6 +43,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import ne_
+from sqlalchemy.util import compat
 
 
 class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
@@ -483,6 +485,18 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
             ),
         )
 
+    @testing.only_if(lambda: compat.py310, "python 3.10 is required")
+    def test_kw_only(self, dc_decl_base: Type[MappedAsDataclass]):
+        class A(dc_decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column(kw_only=True)
+
+        fas = pyinspect.getfullargspec(A.__init__)
+        eq_(fas.args, ["self", "id"])
+        eq_(fas.kwonlyargs, ["data"])
+
 
 class RelationshipDefaultFactoryTest(fixtures.TestBase):
     def test_list(self, dc_decl_base: Type[MappedAsDataclass]):
@@ -679,6 +693,8 @@ class RelationshipDefaultFactoryTest(fixtures.TestBase):
 
 class DataclassArgsTest(fixtures.TestBase):
     dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash")
+    if compat.py310:
+        dc_arg_names += ("match_args", "kw_only")
 
     @testing.fixture(params=product(dc_arg_names, (True, False)))
     def dc_argument_fixture(self, request: Any, registry: _RegistryType):
@@ -695,6 +711,8 @@ class DataclassArgsTest(fixtures.TestBase):
                 "order": False,
                 "unsafe_hash": False,
             }
+            if compat.py310:
+                default |= {"match_args": True, "kw_only": False}
             to_apply = {k: v for k, v in args.items() if v}
             effective = {**default, **to_apply}
             return to_apply, effective
@@ -743,7 +761,10 @@ class DataclassArgsTest(fixtures.TestBase):
         if dc_arguments["init"]:
 
             def create(data, x):
-                return cls(data, x)
+                if dc_arguments.get("kw_only"):
+                    return cls(data=data, x=x)
+                else:
+                    return cls(data, x)
 
         else:
 
@@ -760,7 +781,7 @@ class DataclassArgsTest(fixtures.TestBase):
                 getattr(self, f"_assert_not_{n}")(cls, create, dc_arguments)
 
             if dc_arguments["init"]:
-                a1 = cls("some data")
+                a1 = cls(data="some data")
                 eq_(a1.x, 7)
 
         a1 = create("some data", 15)
@@ -841,10 +862,11 @@ class DataclassArgsTest(fixtures.TestBase):
         eq_regex(repr(a1), r"<.*A object at 0x.*>")
 
     def _assert_init(self, cls, create, dc_arguments):
-        a1 = cls("some data", 5)
+        if not dc_arguments.get("kw_only", False):
+            a1 = cls("some data", 5)
 
-        eq_(a1.data, "some data")
-        eq_(a1.x, 5)
+            eq_(a1.data, "some data")
+            eq_(a1.x, 5)
 
         a2 = cls(data="some data", x=5)
         eq_(a2.data, "some data")
@@ -872,6 +894,31 @@ class DataclassArgsTest(fixtures.TestBase):
         # no constructor, it sets None for x...ok
         eq_(a1.x, None)
 
+    def _assert_match_args(self, cls, create, dc_arguments):
+        if not dc_arguments["kw_only"]:
+            is_true(len(cls.__match_args__) > 0)
+
+    def _assert_not_match_args(self, cls, create, dc_arguments):
+        is_false(hasattr(cls, "__match_args__"))
+
+    def _assert_kw_only(self, cls, create, dc_arguments):
+        if dc_arguments["init"]:
+            fas = pyinspect.getfullargspec(cls.__init__)
+            eq_(fas.args, ["self"])
+            eq_(
+                len(fas.kwonlyargs),
+                len(pyinspect.signature(cls.__init__).parameters) - 1,
+            )
+
+    def _assert_not_kw_only(self, cls, create, dc_arguments):
+        if dc_arguments["init"]:
+            fas = pyinspect.getfullargspec(cls.__init__)
+            eq_(
+                len(fas.args),
+                len(pyinspect.signature(cls.__init__).parameters),
+            )
+            eq_(fas.kwonlyargs, [])
+
     def test_dc_arguments_decorator(
         self,
         dc_argument_fixture,
@@ -957,6 +1004,8 @@ class DataclassArgsTest(fixtures.TestBase):
             "order": True,
             "unsafe_hash": False,
         }
+        if compat.py310:
+            effective |= {"match_args": True, "kw_only": False}
         self._assert_cls(A, effective)
 
     def test_dc_base_unsupported_argument(self, registry: _RegistryType):
@@ -1033,6 +1082,31 @@ class DataclassArgsTest(fixtures.TestBase):
 
                 id: Mapped[int] = mapped_column(primary_key=True, init=False)
 
+    @testing.combinations(True, False)
+    def test_attribute_options(self, args):
+        if args:
+            kw = {
+                "init": True,
+                "repr": True,
+                "default": True,
+                "default_factory": list,
+                "kw_only": True,
+            }
+            exp = interfaces._AttributeOptions(True, True, True, list, True)
+        else:
+            kw = {}
+            exp = interfaces._DEFAULT_ATTRIBUTE_OPTIONS
+
+        for prop in [
+            mapped_column(**kw),
+            synonym("some_int", **kw),
+            column_property(Column(Integer), **kw),
+            deferred(Column(Integer), **kw),
+            composite("foo", **kw),
+            relationship("Foo", **kw),
+        ]:
+            eq_(prop._attribute_options, exp)
+
 
 class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"