self.dispatcher = singledispatch(func)
self.func = func
- import weakref # see comment in singledispatch function
- self._method_cache = weakref.WeakKeyDictionary()
-
def register(self, cls, method=None):
"""generic_method.register(cls, func) -> func
return self.dispatcher.register(cls, func=method)
def __get__(self, obj, cls=None):
- if self._method_cache is not None:
- try:
- _method = self._method_cache[obj]
- except TypeError:
- self._method_cache = None
- except KeyError:
- pass
- else:
- return _method
+ return _singledispatchmethod_get(self, obj, cls)
- dispatch = self.dispatcher.dispatch
- funcname = getattr(self.func, '__name__', 'singledispatchmethod method')
- def _method(*args, **kwargs):
- if not args:
- raise TypeError(f'{funcname} requires at least '
- '1 positional argument')
- return dispatch(args[0].__class__).__get__(obj, cls)(*args, **kwargs)
+ @property
+ def __isabstractmethod__(self):
+ return getattr(self.func, '__isabstractmethod__', False)
- _method.__isabstractmethod__ = self.__isabstractmethod__
- _method.register = self.register
- update_wrapper(_method, self.func)
- if self._method_cache is not None:
- self._method_cache[obj] = _method
+class _singledispatchmethod_get:
+ def __init__(self, unbound, obj, cls):
+ self._unbound = unbound
+ self._dispatch = unbound.dispatcher.dispatch
+ self._obj = obj
+ self._cls = cls
+ # Set instance attributes which cannot be handled in __getattr__()
+ # because they conflict with type descriptors.
+ func = unbound.func
+ try:
+ self.__module__ = func.__module__
+ except AttributeError:
+ pass
+ try:
+ self.__doc__ = func.__doc__
+ except AttributeError:
+ pass
+
+ def __call__(self, /, *args, **kwargs):
+ if not args:
+ funcname = getattr(self._unbound.func, '__name__',
+ 'singledispatchmethod method')
+ raise TypeError(f'{funcname} requires at least '
+ '1 positional argument')
+ return self._dispatch(args[0].__class__).__get__(self._obj, self._cls)(*args, **kwargs)
- return _method
+ def __getattr__(self, name):
+ # Resolve these attributes lazily to speed up creation of
+ # the _singledispatchmethod_get instance.
+ if name not in {'__name__', '__qualname__', '__isabstractmethod__',
+ '__annotations__', '__type_params__'}:
+ raise AttributeError
+ return getattr(self._unbound.func, name)
@property
- def __isabstractmethod__(self):
- return getattr(self.func, '__isabstractmethod__', False)
+ def register(self):
+ return self._unbound.register
################################################################################
"""My function docstring"""
return str(arg)
+ prefix = A.__qualname__ + '.'
for meth in (
A.func,
A().func,
A().static_func
):
with self.subTest(meth=meth):
+ self.assertEqual(meth.__module__, __name__)
+ self.assertEqual(type(meth).__module__, 'functools')
+ self.assertEqual(meth.__qualname__, prefix + meth.__name__)
self.assertEqual(meth.__doc__,
('My function docstring'
if support.HAVE_DOCSTRINGS
def _(arg: undefined):
return "forward reference"
+ def test_method_equal_instances(self):
+ # gh-127750: Reference to self was cached
+ class A:
+ def __eq__(self, other):
+ return True
+ def __hash__(self):
+ return 1
+ @functools.singledispatchmethod
+ def t(self, arg):
+ return self
+
+ a = A()
+ b = A()
+ self.assertIs(a.t(1), a)
+ self.assertIs(b.t(2), b)
+
+ def test_method_bad_hash(self):
+ class A:
+ def __eq__(self, other):
+ raise AssertionError
+ def __hash__(self):
+ raise AssertionError
+ @functools.singledispatchmethod
+ def t(self, arg):
+ pass
+
+ # Should not raise
+ A().t(1)
+ hash(A().t)
+ A().t == A().t
+
+ def test_method_no_reference_loops(self):
+ # gh-127750: Created a strong reference to self
+ class A:
+ @functools.singledispatchmethod
+ def t(self, arg):
+ return weakref.ref(self)
+
+ a = A()
+ r = a.t(1)
+ self.assertIsNotNone(r())
+ del a # delete a after a.t
+ if not support.check_impl_detail(cpython=True):
+ support.gc_collect()
+ self.assertIsNone(r())
+
+ a = A()
+ t = a.t
+ del a # delete a before a.t
+ support.gc_collect()
+ r = t(1)
+ self.assertIsNotNone(r())
+ del t
+ if not support.check_impl_detail(cpython=True):
+ support.gc_collect()
+ self.assertIsNone(r())
+
+
class CachedCostItem:
_cost = 1