]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-109870: Dataclasses: batch up exec calls (gh-110851)
authorEric V. Smith <ericvsmith@users.noreply.github.com>
Mon, 25 Mar 2024 23:59:14 +0000 (19:59 -0400)
committerGitHub <noreply@github.com>
Mon, 25 Mar 2024 23:59:14 +0000 (19:59 -0400)
Instead of calling `exec()` once for each function added to a dataclass, only call `exec()` once per dataclass. This can lead to speed improvements of up to 20%.

Lib/dataclasses.py
Misc/NEWS.d/next/Core and Builtins/2023-10-14-00-05-17.gh-issue-109870.oKpJ3P.rst [new file with mode: 0644]

index 7db8a4233df8836fd6a920458674ff9c81c4800d..3acd03cd86523473172891b791ef52e6286a4af3 100644 (file)
@@ -426,32 +426,95 @@ def _tuple_str(obj_name, fields):
     return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
 
 
-def _create_fn(name, args, body, *, globals=None, locals=None,
-               return_type=MISSING):
-    # Note that we may mutate locals. Callers beware!
-    # The only callers are internal to this module, so no
-    # worries about external callers.
-    if locals is None:
-        locals = {}
-    return_annotation = ''
-    if return_type is not MISSING:
-        locals['__dataclass_return_type__'] = return_type
-        return_annotation = '->__dataclass_return_type__'
-    args = ','.join(args)
-    body = '\n'.join(f'  {b}' for b in body)
-
-    # Compute the text of the entire function.
-    txt = f' def {name}({args}){return_annotation}:\n{body}'
-
-    # Free variables in exec are resolved in the global namespace.
-    # The global namespace we have is user-provided, so we can't modify it for
-    # our purposes. So we put the things we need into locals and introduce a
-    # scope to allow the function we're creating to close over them.
-    local_vars = ', '.join(locals.keys())
-    txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
-    ns = {}
-    exec(txt, globals, ns)
-    return ns['__create_fn__'](**locals)
+class _FuncBuilder:
+    def __init__(self, globals):
+        self.names = []
+        self.src = []
+        self.globals = globals
+        self.locals = {}
+        self.overwrite_errors = {}
+        self.unconditional_adds = {}
+
+    def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
+               overwrite_error=False, unconditional_add=False, decorator=None):
+        if locals is not None:
+            self.locals.update(locals)
+
+        # Keep track if this method is allowed to be overwritten if it already
+        # exists in the class.  The error is method-specific, so keep it with
+        # the name.  We'll use this when we generate all of the functions in
+        # the add_fns_to_class call.  overwrite_error is either True, in which
+        # case we'll raise an error, or it's a string, in which case we'll
+        # raise an error and append this string.
+        if overwrite_error:
+            self.overwrite_errors[name] = overwrite_error
+
+        # Should this function always overwrite anything that's already in the
+        # class?  The default is to not overwrite a function that already
+        # exists.
+        if unconditional_add:
+            self.unconditional_adds[name] = True
+
+        self.names.append(name)
+
+        if return_type is not MISSING:
+            self.locals[f'__dataclass_{name}_return_type__'] = return_type
+            return_annotation = f'->__dataclass_{name}_return_type__'
+        else:
+            return_annotation = ''
+        args = ','.join(args)
+        body = '\n'.join(body)
+
+        # Compute the text of the entire function, add it to the text we're generating.
+        self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
+
+    def add_fns_to_class(self, cls):
+        # The source to all of the functions we're generating.
+        fns_src = '\n'.join(self.src)
+
+        # The locals they use.
+        local_vars = ','.join(self.locals.keys())
+
+        # The names of all of the functions, used for the return value of the
+        # outer function.  Need to handle the 0-tuple specially.
+        if len(self.names) == 0:
+            return_names = '()'
+        else:
+            return_names  =f'({",".join(self.names)},)'
+
+        # txt is the entire function we're going to execute, including the
+        # bodies of the functions we're defining.  Here's a greatly simplified
+        # version:
+        # def __create_fn__():
+        #  def __init__(self, x, y):
+        #   self.x = x
+        #   self.y = y
+        #  @recursive_repr
+        #  def __repr__(self):
+        #   return f"cls(x={self.x!r},y={self.y!r})"
+        # return __init__,__repr__
+
+        txt = f"def __create_fn__({local_vars}):\n{fns_src}\n return {return_names}"
+        ns = {}
+        exec(txt, self.globals, ns)
+        fns = ns['__create_fn__'](**self.locals)
+
+        # Now that we've generated the functions, assign them into cls.
+        for name, fn in zip(self.names, fns):
+            fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
+            if self.unconditional_adds.get(name, False):
+                setattr(cls, name, fn)
+            else:
+                already_exists = _set_new_attribute(cls, name, fn)
+
+                # See if it's an error to overwrite this particular function.
+                if already_exists and (msg_extra := self.overwrite_errors.get(name)):
+                    error_msg = (f'Cannot overwrite attribute {fn.__name__} '
+                                 f'in class {cls.__name__}')
+                    if not msg_extra is True:
+                        error_msg = f'{error_msg} {msg_extra}'
+
+                    raise TypeError(error_msg)
 
 
 def _field_assign(frozen, name, value, self_name):
@@ -462,8 +525,8 @@ def _field_assign(frozen, name, value, self_name):
     # self_name is what "self" is called in this function: don't
     # hard-code "self", since that might be a field name.
     if frozen:
-        return f'__dataclass_builtins_object__.__setattr__({self_name},{name!r},{value})'
-    return f'{self_name}.{name}={value}'
+        return f'  __dataclass_builtins_object__.__setattr__({self_name},{name!r},{value})'
+    return f'  {self_name}.{name}={value}'
 
 
 def _field_init(f, frozen, globals, self_name, slots):
@@ -546,7 +609,7 @@ def _init_param(f):
 
 
 def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
-             self_name, globals, slots):
+             self_name, func_builder, slots):
     # fields contains both real fields and InitVar pseudo-fields.
 
     # Make sure we don't have fields without defaults following fields
@@ -565,11 +628,11 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
                 raise TypeError(f'non-default argument {f.name!r} '
                                 f'follows default argument {seen_default.name!r}')
 
-    locals = {f'__dataclass_type_{f.name}__': f.type for f in fields}
-    locals.update({
-        '__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
-        '__dataclass_builtins_object__': object,
-    })
+    locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
+              **{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
+                 '__dataclass_builtins_object__': object,
+                 }
+              }
 
     body_lines = []
     for f in fields:
@@ -583,11 +646,11 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
     if has_post_init:
         params_str = ','.join(f.name for f in fields
                               if f._field_type is _FIELD_INITVAR)
-        body_lines.append(f'{self_name}.{_POST_INIT_NAME}({params_str})')
+        body_lines.append(f'  {self_name}.{_POST_INIT_NAME}({params_str})')
 
     # If no body lines, use 'pass'.
     if not body_lines:
-        body_lines = ['pass']
+        body_lines = ['  pass']
 
     _init_params = [_init_param(f) for f in std_fields]
     if kw_only_fields:
@@ -596,68 +659,34 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
         # (instead of just concatenting the lists together).
         _init_params += ['*']
         _init_params += [_init_param(f) for f in kw_only_fields]
-    return _create_fn('__init__',
-                      [self_name] + _init_params,
-                      body_lines,
-                      locals=locals,
-                      globals=globals,
-                      return_type=None)
-
-
-def _repr_fn(fields, globals):
-    fn = _create_fn('__repr__',
-                    ('self',),
-                    ['return f"{self.__class__.__qualname__}(' +
-                     ', '.join([f"{f.name}={{self.{f.name}!r}}"
-                                for f in fields]) +
-                     ')"'],
-                     globals=globals)
-    return recursive_repr()(fn)
-
-
-def _frozen_get_del_attr(cls, fields, globals):
+    func_builder.add_fn('__init__',
+                        [self_name] + _init_params,
+                        body_lines,
+                        locals=locals,
+                        return_type=None)
+
+
+def _frozen_get_del_attr(cls, fields, func_builder):
     locals = {'cls': cls,
               'FrozenInstanceError': FrozenInstanceError}
     condition = 'type(self) is cls'
     if fields:
         condition += ' or name in {' + ', '.join(repr(f.name) for f in fields) + '}'
-    return (_create_fn('__setattr__',
-                      ('self', 'name', 'value'),
-                      (f'if {condition}:',
-                        ' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
-                       f'super(cls, self).__setattr__(name, value)'),
-                       locals=locals,
-                       globals=globals),
-            _create_fn('__delattr__',
-                      ('self', 'name'),
-                      (f'if {condition}:',
-                        ' raise FrozenInstanceError(f"cannot delete field {name!r}")',
-                       f'super(cls, self).__delattr__(name)'),
-                       locals=locals,
-                       globals=globals),
-            )
-
-
-def _cmp_fn(name, op, self_tuple, other_tuple, globals):
-    # Create a comparison function.  If the fields in the object are
-    # named 'x' and 'y', then self_tuple is the string
-    # '(self.x,self.y)' and other_tuple is the string
-    # '(other.x,other.y)'.
-
-    return _create_fn(name,
-                      ('self', 'other'),
-                      [ 'if other.__class__ is self.__class__:',
-                       f' return {self_tuple}{op}{other_tuple}',
-                        'return NotImplemented'],
-                      globals=globals)
-
 
-def _hash_fn(fields, globals):
-    self_tuple = _tuple_str('self', fields)
-    return _create_fn('__hash__',
-                      ('self',),
-                      [f'return hash({self_tuple})'],
-                      globals=globals)
+    func_builder.add_fn('__setattr__',
+                        ('self', 'name', 'value'),
+                        (f'  if {condition}:',
+                          '   raise FrozenInstanceError(f"cannot assign to field {name!r}")',
+                         f'  super(cls, self).__setattr__(name, value)'),
+                        locals=locals,
+                        overwrite_error=True)
+    func_builder.add_fn('__delattr__',
+                        ('self', 'name'),
+                        (f'  if {condition}:',
+                          '   raise FrozenInstanceError(f"cannot delete field {name!r}")',
+                         f'  super(cls, self).__delattr__(name)'),
+                        locals=locals,
+                        overwrite_error=True)
 
 
 def _is_classvar(a_type, typing):
@@ -834,19 +863,11 @@ def _get_field(cls, a_name, a_type, default_kw_only):
 
     return f
 
-def _set_qualname(cls, value):
-    # Ensure that the functions returned from _create_fn uses the proper
-    # __qualname__ (the class they belong to).
-    if isinstance(value, FunctionType):
-        value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
-    return value
-
 def _set_new_attribute(cls, name, value):
     # Never overwrites an existing attribute.  Returns True if the
     # attribute already exists.
     if name in cls.__dict__:
         return True
-    _set_qualname(cls, value)
     setattr(cls, name, value)
     return False
 
@@ -856,14 +877,22 @@ def _set_new_attribute(cls, name, value):
 # take.  The common case is to do nothing, so instead of providing a
 # function that is a no-op, use None to signify that.
 
-def _hash_set_none(cls, fields, globals):
-    return None
+def _hash_set_none(cls, fields, func_builder):
+    # It's sort of a hack that I'm setting this here, instead of at
+    # func_builder.add_fns_to_class time, but since this is an exceptional case
+    # (it's not setting an attribute to a function, but to a scalar value),
+    # just do it directly here.  I might come to regret this.
+    cls.__hash__ = None
 
-def _hash_add(cls, fields, globals):
+def _hash_add(cls, fields, func_builder):
     flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
-    return _set_qualname(cls, _hash_fn(flds, globals))
+    self_tuple = _tuple_str('self', flds)
+    func_builder.add_fn('__hash__',
+                        ('self',),
+                        [f'  return hash({self_tuple})'],
+                        unconditional_add=True)
 
-def _hash_exception(cls, fields, globals):
+def _hash_exception(cls, fields, func_builder):
     # Raise an exception.
     raise TypeError(f'Cannot overwrite attribute __hash__ '
                     f'in class {cls.__name__}')
@@ -1041,24 +1070,26 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
     (std_init_fields,
      kw_only_init_fields) = _fields_in_init_order(all_init_fields)
 
+    func_builder = _FuncBuilder(globals)
+
     if init:
         # Does this class have a post-init function?
         has_post_init = hasattr(cls, _POST_INIT_NAME)
 
-        _set_new_attribute(cls, '__init__',
-                           _init_fn(all_init_fields,
-                                    std_init_fields,
-                                    kw_only_init_fields,
-                                    frozen,
-                                    has_post_init,
-                                    # The name to use for the "self"
-                                    # param in __init__.  Use "self"
-                                    # if possible.
-                                    '__dataclass_self__' if 'self' in fields
-                                            else 'self',
-                                    globals,
-                                    slots,
-                          ))
+        _init_fn(all_init_fields,
+                 std_init_fields,
+                 kw_only_init_fields,
+                 frozen,
+                 has_post_init,
+                 # The name to use for the "self"
+                 # param in __init__.  Use "self"
+                 # if possible.
+                 '__dataclass_self__' if 'self' in fields
+                 else 'self',
+                 func_builder,
+                 slots,
+                 )
+
     _set_new_attribute(cls, '__replace__', _replace)
 
     # Get the fields as a list, and include only real fields.  This is
@@ -1067,7 +1098,13 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
 
     if repr:
         flds = [f for f in field_list if f.repr]
-        _set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))
+        func_builder.add_fn('__repr__',
+                            ('self',),
+                            ['  return f"{self.__class__.__qualname__}(' +
+                             ', '.join([f"{f.name}={{self.{f.name}!r}}"
+                                        for f in flds]) + ')"'],
+                            locals={'__dataclasses_recursive_repr': recursive_repr},
+                            decorator="@__dataclasses_recursive_repr()")
 
     if eq:
         # Create __eq__ method.  There's no need for a __ne__ method,
@@ -1075,16 +1112,13 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
         cmp_fields = (field for field in field_list if field.compare)
         terms = [f'self.{field.name}==other.{field.name}' for field in cmp_fields]
         field_comparisons = ' and '.join(terms) or 'True'
-        body =  [f'if self is other:',
-                 f' return True',
-                 f'if other.__class__ is self.__class__:',
-                 f' return {field_comparisons}',
-                 f'return NotImplemented']
-        func = _create_fn('__eq__',
-                          ('self', 'other'),
-                          body,
-                          globals=globals)
-        _set_new_attribute(cls, '__eq__', func)
+        func_builder.add_fn('__eq__',
+                            ('self', 'other'),
+                            [ '  if self is other:',
+                              '   return True',
+                              '  if other.__class__ is self.__class__:',
+                             f'   return {field_comparisons}',
+                              '  return NotImplemented'])
 
     if order:
         # Create and set the ordering methods.
@@ -1096,18 +1130,19 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
                          ('__gt__', '>'),
                          ('__ge__', '>='),
                          ]:
-            if _set_new_attribute(cls, name,
-                                  _cmp_fn(name, op, self_tuple, other_tuple,
-                                          globals=globals)):
-                raise TypeError(f'Cannot overwrite attribute {name} '
-                                f'in class {cls.__name__}. Consider using '
-                                'functools.total_ordering')
+            # Create a comparison function.  If the fields in the object are
+            # named 'x' and 'y', then self_tuple is the string
+            # '(self.x,self.y)' and other_tuple is the string
+            # '(other.x,other.y)'.
+            func_builder.add_fn(name,
+                            ('self', 'other'),
+                            [ '  if other.__class__ is self.__class__:',
+                             f'   return {self_tuple}{op}{other_tuple}',
+                              '  return NotImplemented'],
+                            overwrite_error='Consider using functools.total_ordering')
 
     if frozen:
-        for fn in _frozen_get_del_attr(cls, field_list, globals):
-            if _set_new_attribute(cls, fn.__name__, fn):
-                raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
-                                f'in class {cls.__name__}')
+        _frozen_get_del_attr(cls, field_list, func_builder)
 
     # Decide if/how we're going to create a hash function.
     hash_action = _hash_action[bool(unsafe_hash),
@@ -1115,9 +1150,12 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
                                bool(frozen),
                                has_explicit_hash]
     if hash_action:
-        # No need to call _set_new_attribute here, since by the time
-        # we're here the overwriting is unconditional.
-        cls.__hash__ = hash_action(cls, field_list, globals)
+        cls.__hash__ = hash_action(cls, field_list, func_builder)
+
+    # Generate the methods and add them to the class.  This needs to be done
+    # before the __doc__ logic below, since inspect will look at the __init__
+    # signature.
+    func_builder.add_fns_to_class(cls)
 
     if not getattr(cls, '__doc__'):
         # Create a class doc-string.
@@ -1130,7 +1168,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
         cls.__doc__ = (cls.__name__ + text_sig)
 
     if match_args:
-        # I could probably compute this once
+        # I could probably compute this once.
         _set_new_attribute(cls, '__match_args__',
                            tuple(f.name for f in std_init_fields))
 
diff --git a/Misc/NEWS.d/next/Core and Builtins/2023-10-14-00-05-17.gh-issue-109870.oKpJ3P.rst b/Misc/NEWS.d/next/Core and Builtins/2023-10-14-00-05-17.gh-issue-109870.oKpJ3P.rst
new file mode 100644 (file)
index 0000000..390bb12
--- /dev/null
@@ -0,0 +1,3 @@
+Dataclasses now calls :func:`exec` once per dataclass, instead of once
+per method being added.  This can speed up dataclass creation by up to
+20%.