]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-102103: add `module` argument to `dataclasses.make_dataclass` (#102104)
authorNikita Sobolev <mail@sobolevn.me>
Sat, 11 Mar 2023 00:26:46 +0000 (03:26 +0300)
committerGitHub <noreply@github.com>
Sat, 11 Mar 2023 00:26:46 +0000 (17:26 -0700)
Doc/library/dataclasses.rst
Lib/dataclasses.py
Lib/test/test_dataclasses.py
Misc/NEWS.d/next/Library/2023-02-21-11-56-16.gh-issue-102103.Dj0WEj.rst [new file with mode: 0644]

index 82faa7b77450fb6dbef00dc36f64b24dce0282b8..5f4dc25bfd7877199ff73a43bfab8cf40eb76abd 100644 (file)
@@ -389,7 +389,7 @@ Module contents
    :func:`astuple` raises :exc:`TypeError` if ``obj`` is not a dataclass
    instance.
 
-.. function:: make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)
+.. function:: make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False, module=None)
 
    Creates a new dataclass with name ``cls_name``, fields as defined
    in ``fields``, base classes as given in ``bases``, and initialized
@@ -401,6 +401,10 @@ Module contents
    ``match_args``, ``kw_only``, ``slots``, and ``weakref_slot`` have
    the same meaning as they do in :func:`dataclass`.
 
+   If ``module`` is defined, the ``__module__`` attribute
+   of the dataclass is set to that value.
+   By default, it is set to the module name of the caller.
+
    This function is not strictly required, because any Python
    mechanism for creating a new class with ``__annotations__`` can
    then apply the :func:`dataclass` function to convert that class to
index 78a126f051e2e76e584b4eead4eef9d682e3db05..24f3779ebb8ec8e7ae2a563d92e45cbaeec5e319 100644 (file)
@@ -1391,7 +1391,7 @@ def _astuple_inner(obj, tuple_factory):
 def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
                    repr=True, eq=True, order=False, unsafe_hash=False,
                    frozen=False, match_args=True, kw_only=False, slots=False,
-                   weakref_slot=False):
+                   weakref_slot=False, module=None):
     """Return a new dynamically created dataclass.
 
     The dataclass name will be 'cls_name'.  'fields' is an iterable
@@ -1455,6 +1455,19 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
     # of generic dataclasses.
     cls = types.new_class(cls_name, bases, {}, exec_body_callback)
 
+    # For pickling to work, the __module__ variable needs to be set to the frame
+    # where the dataclass is created.
+    if module is None:
+        try:
+            module = sys._getframemodulename(1) or '__main__'
+        except AttributeError:
+            try:
+                module = sys._getframe(1).f_globals.get('__name__', '__main__')
+            except (AttributeError, ValueError):
+                pass
+    if module is not None:
+        cls.__module__ = module
+
     # Apply the normal decorator.
     return dataclass(cls, init=init, repr=repr, eq=eq, order=order,
                      unsafe_hash=unsafe_hash, frozen=frozen,
index 5486b2ef3f47e5183ebc521c6cd888097edb0d3d..76bed0c33146734acdeafdfc5610d0e9a7839ba8 100644 (file)
@@ -3606,6 +3606,15 @@ class TestStringAnnotations(unittest.TestCase):
              'return': type(None)})
 
 
+ByMakeDataClass = make_dataclass('ByMakeDataClass', [('x', int)])
+ManualModuleMakeDataClass = make_dataclass('ManualModuleMakeDataClass',
+                                           [('x', int)],
+                                           module='test.test_dataclasses')
+WrongNameMakeDataclass = make_dataclass('Wrong', [('x', int)])
+WrongModuleMakeDataclass = make_dataclass('WrongModuleMakeDataclass',
+                                          [('x', int)],
+                                          module='custom')
+
 class TestMakeDataclass(unittest.TestCase):
     def test_simple(self):
         C = make_dataclass('C',
@@ -3715,6 +3724,36 @@ class TestMakeDataclass(unittest.TestCase):
                                              'y': int,
                                              'z': 'typing.Any'})
 
+    def test_module_attr(self):
+        self.assertEqual(ByMakeDataClass.__module__, __name__)
+        self.assertEqual(ByMakeDataClass(1).__module__, __name__)
+        self.assertEqual(WrongModuleMakeDataclass.__module__, "custom")
+        Nested = make_dataclass('Nested', [])
+        self.assertEqual(Nested.__module__, __name__)
+        self.assertEqual(Nested().__module__, __name__)
+
+    def test_pickle_support(self):
+        for klass in [ByMakeDataClass, ManualModuleMakeDataClass]:
+            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+                with self.subTest(proto=proto):
+                    self.assertEqual(
+                        pickle.loads(pickle.dumps(klass, proto)),
+                        klass,
+                    )
+                    self.assertEqual(
+                        pickle.loads(pickle.dumps(klass(1), proto)),
+                        klass(1),
+                    )
+
+    def test_cannot_be_pickled(self):
+        for klass in [WrongNameMakeDataclass, WrongModuleMakeDataclass]:
+            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+                with self.subTest(proto=proto):
+                    with self.assertRaises(pickle.PickleError):
+                        pickle.dumps(klass, proto)
+                    with self.assertRaises(pickle.PickleError):
+                        pickle.dumps(klass(1), proto)
+
     def test_invalid_type_specification(self):
         for bad_field in [(),
                           (1, 2, 3, 4),
diff --git a/Misc/NEWS.d/next/Library/2023-02-21-11-56-16.gh-issue-102103.Dj0WEj.rst b/Misc/NEWS.d/next/Library/2023-02-21-11-56-16.gh-issue-102103.Dj0WEj.rst
new file mode 100644 (file)
index 0000000..feba433
--- /dev/null
@@ -0,0 +1,2 @@
+Add ``module`` argument to :func:`dataclasses.make_dataclass` and make
+classes produced by it pickleable.