From: Federico Caselli Date: Sat, 3 Jun 2023 10:40:00 +0000 (+0200) Subject: Support dataclass default with init=False X-Git-Tag: rel_2_0_16~13^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8f720e1529af2b4810b0ae5379e13b343155eded;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support dataclass default with init=False Fixed an issue where generating dataclasses fields that specified a default`` value and set ``init=False`` would not work. The dataclasses behavior in this case is to set the default value on the class, that's not compatible with the descriptors used by SQLAlchemy. To support this case the default is transformed to a ``default_factory`` when generating the dataclass. Fixes: #9879 Change-Id: I5151d388232eacd506a100ba18ce26970bf83cf3 --- diff --git a/doc/build/changelog/unreleased_20/9879.rst b/doc/build/changelog/unreleased_20/9879.rst new file mode 100644 index 0000000000..d1112305ba --- /dev/null +++ b/doc/build/changelog/unreleased_20/9879.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, orm, dataclasses + :tickets: 9879 + + Fixed an issue where generating dataclasses fields that specified a + ``default`` value and set ``init=False`` would not work. + The dataclasses behavior in this case is to set the default + value on the class, that's not compatible with the descriptors used + by SQLAlchemy. To support this case the default is transformed to + a ``default_factory`` when generating the dataclass. diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 4da8f63e68..d9df2b3b19 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -217,6 +217,16 @@ class _AttributeOptions(NamedTuple): if self.dataclasses_kw_only is not _NoArg.NO_ARG: kw["kw_only"] = self.dataclasses_kw_only + if ( + "init" in kw + and not kw["init"] + and "default" in kw + and "default_factory" not in kw # illegal but let field raise + ): + # fix for #9879 + default = kw.pop("default") + kw["default_factory"] = lambda: default + return dataclasses.field(**kw) @classmethod diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index 576ee7fbfe..678dc51a27 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -758,6 +758,38 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): is_true(isinstance(ec.error.__cause__, TypeError)) + def test_dataclass_default(self, dc_decl_base): + """test for #9879""" + + def c10(): + return 10 + + def c20(): + return 20 + + class A(dc_decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + def_init: Mapped[int] = mapped_column(default=42) + call_init: Mapped[int] = mapped_column(default_factory=c10) + def_no_init: Mapped[int] = mapped_column(default=13, init=False) + call_no_init: Mapped[int] = mapped_column( + default_factory=c20, init=False + ) + + a = A(id=100) + eq_(a.def_init, 42) + eq_(a.call_init, 10) + eq_(a.def_no_init, 13) + eq_(a.call_no_init, 20) + + fields = {f.name: f for f in dataclasses.fields(A)} + eq_(fields["def_init"].default, 42) + eq_(fields["call_init"].default_factory, c10) + eq_(fields["def_no_init"].default, dataclasses.MISSING) + ne_(fields["def_no_init"].default_factory, dataclasses.MISSING) + eq_(fields["call_no_init"].default_factory, c20) + class RelationshipDefaultFactoryTest(fixtures.TestBase): def test_list(self, dc_decl_base: Type[MappedAsDataclass]):