From: Federico Caselli Date: Wed, 3 Aug 2022 21:50:19 +0000 (+0200) Subject: Support kw_only and match_args in dataclass mapping X-Git-Tag: rel_2_0_0b1~130^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d2887d03a28b09e9be7db17d7603b6b0a4715df3;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support kw_only and match_args in dataclass mapping Fixes: #8346 Change-Id: I964629e3bd25221bf6df6ab31c59b3ce1983cd9a --- diff --git a/doc/build/orm/dataclasses.rst b/doc/build/orm/dataclasses.rst index 4f152a0092..18817d2fe4 100644 --- a/doc/build/orm/dataclasses.rst +++ b/doc/build/orm/dataclasses.rst @@ -108,9 +108,9 @@ Class level feature configuration ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Support for dataclasses features is partial. Currently **supported** are -the ``init``, ``repr``, ``eq``, ``order`` and ``unsafe_hash`` features. -Currently **not supported** are the ``frozen``, ``slots``, ``match_args``, -and ``kw_only`` features. +the ``init``, ``repr``, ``eq``, ``order`` and ``unsafe_hash`` features, +``match_args`` and ``kw_only`` are supported on Python 3.10+. +Currently **not supported** are the ``frozen`` and ``slots`` features. When using the mixin class form with :class:`_orm.MappedAsDataclass`, class configuration arguments are passed as class-level parameters:: diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 2f1128f94a..8dfec0fb1c 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -111,6 +111,7 @@ def mapped_column( repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 default: Optional[Any] = _NoArg.NO_ARG, default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, nullable: Optional[ Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] ] = SchemaConst.NULL_UNSPECIFIED, @@ -242,6 +243,9 @@ def mapped_column( specifies a default-value generation function that will take place as part of the ``__init__()`` method as generated by the dataclass process. + :param kw_only: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be marked as keyword-only when generating the ``__init__()``. :param \**kw: All remaining keyword argments are passed through to the constructor for the :class:`_schema.Column`. @@ -257,10 +261,7 @@ def mapped_column( autoincrement=autoincrement, insert_default=insert_default, attribute_options=_AttributeOptions( - init, - repr, - default, - default_factory, + init, repr, default, default_factory, kw_only ), doc=doc, key=key, @@ -293,6 +294,7 @@ def column_property( repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 default: Optional[Any] = _NoArg.NO_ARG, default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, active_history: bool = False, expire_on_flush: bool = True, info: Optional[_InfoType] = None, @@ -385,10 +387,7 @@ def column_property( column, *additional_columns, attribute_options=_AttributeOptions( - init, - repr, - default, - default_factory, + init, repr, default, default_factory, kw_only ), group=group, deferred=deferred, @@ -414,6 +413,7 @@ def composite( repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 default: Optional[Any] = _NoArg.NO_ARG, default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, info: Optional[_InfoType] = None, doc: Optional[str] = None, **__kw: Any, @@ -434,6 +434,7 @@ def composite( repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 default: Optional[Any] = _NoArg.NO_ARG, default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, info: Optional[_InfoType] = None, doc: Optional[str] = None, **__kw: Any, @@ -455,6 +456,7 @@ def composite( repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 default: Optional[Any] = _NoArg.NO_ARG, default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, info: Optional[_InfoType] = None, doc: Optional[str] = None, **__kw: Any, @@ -517,6 +519,9 @@ def composite( specifies a default-value generation function that will take place as part of the ``__init__()`` method as generated by the dataclass process. + :param kw_only: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be marked as keyword-only when generating the ``__init__()``. """ if __kw: @@ -526,10 +531,7 @@ def composite( _class_or_attr, *attrs, attribute_options=_AttributeOptions( - init, - repr, - default, - default_factory, + init, repr, default, default_factory, kw_only ), group=group, deferred=deferred, @@ -752,6 +754,7 @@ def relationship( repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 default: Union[_NoArg, _T] = _NoArg.NO_ARG, default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, lazy: _LazyLoadArgumentType = "select", passive_deletes: Union[Literal["all"], bool] = False, passive_updates: bool = True, @@ -1549,7 +1552,9 @@ def relationship( specifies a default-value generation function that will take place as part of the ``__init__()`` method as generated by the dataclass process. - + :param kw_only: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be marked as keyword-only when generating the ``__init__()``. """ @@ -1568,10 +1573,7 @@ def relationship( cascade=cascade, viewonly=viewonly, attribute_options=_AttributeOptions( - init, - repr, - default, - default_factory, + init, repr, default, default_factory, kw_only ), lazy=lazy, passive_deletes=passive_deletes, @@ -1604,6 +1606,7 @@ def synonym( repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 default: Union[_NoArg, _T] = _NoArg.NO_ARG, default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, info: Optional[_InfoType] = None, doc: Optional[str] = None, ) -> Synonym[Any]: @@ -1716,10 +1719,7 @@ def synonym( descriptor=descriptor, comparator_factory=comparator_factory, attribute_options=_AttributeOptions( - init, - repr, - default, - default_factory, + init, repr, default, default_factory, kw_only ), doc=doc, info=info, @@ -1848,6 +1848,7 @@ def deferred( repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 default: Optional[Any] = _NoArg.NO_ARG, default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, active_history: bool = False, expire_on_flush: bool = True, info: Optional[_InfoType] = None, @@ -1884,10 +1885,7 @@ def deferred( column, *additional_columns, attribute_options=_AttributeOptions( - init, - repr, - default, - default_factory, + init, repr, default, default_factory, kw_only ), group=group, deferred=True, @@ -1937,6 +1935,7 @@ def query_expression( repr, _NoArg.NO_ARG, _NoArg.NO_ARG, + _NoArg.NO_ARG, ), expire_on_flush=expire_on_flush, info=info, diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 500f2786e3..d34ec8c93e 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -578,6 +578,8 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): eq: Union[_NoArg, bool] = _NoArg.NO_ARG, order: Union[_NoArg, bool] = _NoArg.NO_ARG, unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG, + match_args: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, ) -> None: apply_dc_transforms: _DataclassArguments = { @@ -586,6 +588,8 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): "eq": eq, "order": order, "unsafe_hash": unsafe_hash, + "match_args": match_args, + "kw_only": kw_only, } if hasattr(cls, "_sa_apply_dc_transforms"): @@ -1313,6 +1317,8 @@ class registry: eq: Union[_NoArg, bool] = ..., order: Union[_NoArg, bool] = ..., unsafe_hash: Union[_NoArg, bool] = ..., + match_args: Union[_NoArg, bool] = ..., + kw_only: Union[_NoArg, bool] = ..., ) -> Callable[[Type[_O]], Type[_O]]: ... @@ -1325,6 +1331,8 @@ class registry: eq: Union[_NoArg, bool] = _NoArg.NO_ARG, order: Union[_NoArg, bool] = _NoArg.NO_ARG, unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG, + match_args: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]: """Class decorator that will apply the Declarative mapping process to a given class, and additionally convert the class to be a @@ -1348,6 +1356,8 @@ class registry: "eq": eq, "order": order, "unsafe_hash": unsafe_hash, + "match_args": match_args, + "kw_only": kw_only, } _as_declarative(self, cls, cls.__dict__) return cls diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 108027dd50..e8d6e4c1b1 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -106,6 +106,8 @@ class _DataclassArguments(TypedDict): eq: Union[_NoArg, bool] order: Union[_NoArg, bool] unsafe_hash: Union[_NoArg, bool] + match_args: Union[_NoArg, bool] + kw_only: Union[_NoArg, bool] def _declared_mapping_info( @@ -1030,22 +1032,20 @@ class _ClassScanMapperConfig(_MapperConfig): @classmethod def _assert_dc_arguments(cls, arguments: _DataclassArguments) -> None: - disallowed_args = set(arguments).difference( - { - "init", - "repr", - "order", - "eq", - "unsafe_hash", - } - ) + allowed = { + "init", + "repr", + "order", + "eq", + "unsafe_hash", + "kw_only", + "match_args", + } + disallowed_args = set(arguments).difference(allowed) if disallowed_args: + msg = ", ".join(f"{arg!r}" for arg in sorted(disallowed_args)) raise exc.ArgumentError( - f"Dataclass argument(s) " - f"""{ - ', '.join(f'{arg!r}' - for arg in sorted(disallowed_args)) - } are not accepted""" + f"Dataclass argument(s) {msg} are not accepted" ) def _collect_annotation( diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index a9ae4436f1..0f66566b0f 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -187,6 +187,7 @@ class _AttributeOptions(NamedTuple): dataclasses_repr: Union[_NoArg, bool] dataclasses_default: Union[_NoArg, Any] dataclasses_default_factory: Union[_NoArg, Callable[[], Any]] + dataclasses_kw_only: Union[_NoArg, bool] def _as_dataclass_field(self) -> Any: """Return a ``dataclasses.Field`` object given these arguments.""" @@ -200,6 +201,8 @@ class _AttributeOptions(NamedTuple): kw["init"] = self.dataclasses_init if self.dataclasses_repr is not _NoArg.NO_ARG: kw["repr"] = self.dataclasses_repr + if self.dataclasses_kw_only is not _NoArg.NO_ARG: + kw["kw_only"] = self.dataclasses_kw_only return dataclasses.field(**kw) @@ -226,7 +229,7 @@ class _AttributeOptions(NamedTuple): _DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions( - _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG + _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG ) diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index f5111bfc79..ca656a8682 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -24,6 +24,7 @@ from sqlalchemy.orm import column_property from sqlalchemy.orm import composite from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import deferred +from sqlalchemy.orm import interfaces from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedAsDataclass @@ -42,6 +43,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing import ne_ +from sqlalchemy.util import compat class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): @@ -483,6 +485,18 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): ), ) + @testing.only_if(lambda: compat.py310, "python 3.10 is required") + def test_kw_only(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(kw_only=True) + + fas = pyinspect.getfullargspec(A.__init__) + eq_(fas.args, ["self", "id"]) + eq_(fas.kwonlyargs, ["data"]) + class RelationshipDefaultFactoryTest(fixtures.TestBase): def test_list(self, dc_decl_base: Type[MappedAsDataclass]): @@ -679,6 +693,8 @@ class RelationshipDefaultFactoryTest(fixtures.TestBase): class DataclassArgsTest(fixtures.TestBase): dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash") + if compat.py310: + dc_arg_names += ("match_args", "kw_only") @testing.fixture(params=product(dc_arg_names, (True, False))) def dc_argument_fixture(self, request: Any, registry: _RegistryType): @@ -695,6 +711,8 @@ class DataclassArgsTest(fixtures.TestBase): "order": False, "unsafe_hash": False, } + if compat.py310: + default |= {"match_args": True, "kw_only": False} to_apply = {k: v for k, v in args.items() if v} effective = {**default, **to_apply} return to_apply, effective @@ -743,7 +761,10 @@ class DataclassArgsTest(fixtures.TestBase): if dc_arguments["init"]: def create(data, x): - return cls(data, x) + if dc_arguments.get("kw_only"): + return cls(data=data, x=x) + else: + return cls(data, x) else: @@ -760,7 +781,7 @@ class DataclassArgsTest(fixtures.TestBase): getattr(self, f"_assert_not_{n}")(cls, create, dc_arguments) if dc_arguments["init"]: - a1 = cls("some data") + a1 = cls(data="some data") eq_(a1.x, 7) a1 = create("some data", 15) @@ -841,10 +862,11 @@ class DataclassArgsTest(fixtures.TestBase): eq_regex(repr(a1), r"<.*A object at 0x.*>") def _assert_init(self, cls, create, dc_arguments): - a1 = cls("some data", 5) + if not dc_arguments.get("kw_only", False): + a1 = cls("some data", 5) - eq_(a1.data, "some data") - eq_(a1.x, 5) + eq_(a1.data, "some data") + eq_(a1.x, 5) a2 = cls(data="some data", x=5) eq_(a2.data, "some data") @@ -872,6 +894,31 @@ class DataclassArgsTest(fixtures.TestBase): # no constructor, it sets None for x...ok eq_(a1.x, None) + def _assert_match_args(self, cls, create, dc_arguments): + if not dc_arguments["kw_only"]: + is_true(len(cls.__match_args__) > 0) + + def _assert_not_match_args(self, cls, create, dc_arguments): + is_false(hasattr(cls, "__match_args__")) + + def _assert_kw_only(self, cls, create, dc_arguments): + if dc_arguments["init"]: + fas = pyinspect.getfullargspec(cls.__init__) + eq_(fas.args, ["self"]) + eq_( + len(fas.kwonlyargs), + len(pyinspect.signature(cls.__init__).parameters) - 1, + ) + + def _assert_not_kw_only(self, cls, create, dc_arguments): + if dc_arguments["init"]: + fas = pyinspect.getfullargspec(cls.__init__) + eq_( + len(fas.args), + len(pyinspect.signature(cls.__init__).parameters), + ) + eq_(fas.kwonlyargs, []) + def test_dc_arguments_decorator( self, dc_argument_fixture, @@ -957,6 +1004,8 @@ class DataclassArgsTest(fixtures.TestBase): "order": True, "unsafe_hash": False, } + if compat.py310: + effective |= {"match_args": True, "kw_only": False} self._assert_cls(A, effective) def test_dc_base_unsupported_argument(self, registry: _RegistryType): @@ -1033,6 +1082,31 @@ class DataclassArgsTest(fixtures.TestBase): id: Mapped[int] = mapped_column(primary_key=True, init=False) + @testing.combinations(True, False) + def test_attribute_options(self, args): + if args: + kw = { + "init": True, + "repr": True, + "default": True, + "default_factory": list, + "kw_only": True, + } + exp = interfaces._AttributeOptions(True, True, True, list, True) + else: + kw = {} + exp = interfaces._DEFAULT_ATTRIBUTE_OPTIONS + + for prop in [ + mapped_column(**kw), + synonym("some_int", **kw), + column_property(Column(Integer), **kw), + deferred(Column(Integer), **kw), + composite("foo", **kw), + relationship("Foo", **kw), + ]: + eq_(prop._attribute_options, exp) + class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default"