"""
def __init__(self, func):
+ import weakref # see comment in singledispatch function
if not callable(func) and not hasattr(func, "__get__"):
raise TypeError(f"{func!r} is not callable or a descriptor")
self.dispatcher = singledispatch(func)
self.func = func
+ self._method_cache = weakref.WeakKeyDictionary()
+ self._all_weakrefable_instances = True
def register(self, cls, method=None):
"""generic_method.register(cls, func) -> func
return self.dispatcher.register(cls, func=method)
def __get__(self, obj, cls=None):
+ if self._all_weakrefable_instances:
+ try:
+ _method = self._method_cache[obj]
+ except TypeError:
+ self._all_weakrefable_instances = False
+ except KeyError:
+ pass
+ else:
+ return _method
+
+ dispatch = self.dispatcher.dispatch
def _method(*args, **kwargs):
- method = self.dispatcher.dispatch(args[0].__class__)
- return method.__get__(obj, cls)(*args, **kwargs)
+ return dispatch(args[0].__class__).__get__(obj, cls)(*args, **kwargs)
_method.__isabstractmethod__ = self.__isabstractmethod__
_method.register = self.register
update_wrapper(_method, self.func)
+
+ if self._all_weakrefable_instances:
+ self._method_cache[obj] = _method
+
return _method
@property
self.assertTrue(A.t(''))
self.assertEqual(A.t(0.0), 0.0)
+ def test_slotted_class(self):
+ class Slot:
+ __slots__ = ('a', 'b')
+ @functools.singledispatchmethod
+ def go(self, item, arg):
+ pass
+
+ @go.register
+ def _(self, item: int, arg):
+ return item + arg
+
+ s = Slot()
+ self.assertEqual(s.go(1, 1), 2)
+
+ def test_classmethod_slotted_class(self):
+ class Slot:
+ __slots__ = ('a', 'b')
+ @functools.singledispatchmethod
+ @classmethod
+ def go(cls, item, arg):
+ pass
+
+ @go.register
+ @classmethod
+ def _(cls, item: int, arg):
+ return item + arg
+
+ s = Slot()
+ self.assertEqual(s.go(1, 1), 2)
+ self.assertEqual(Slot.go(1, 1), 2)
+
+ def test_staticmethod_slotted_class(self):
+ class A:
+ __slots__ = ['a']
+ @functools.singledispatchmethod
+ @staticmethod
+ def t(arg):
+ return arg
+ @t.register(int)
+ @staticmethod
+ def _(arg):
+ return isinstance(arg, int)
+ @t.register(str)
+ @staticmethod
+ def _(arg):
+ return isinstance(arg, str)
+ a = A()
+
+ self.assertTrue(A.t(0))
+ self.assertTrue(A.t(''))
+ self.assertEqual(A.t(0.0), 0.0)
+ self.assertTrue(a.t(0))
+ self.assertTrue(a.t(''))
+ self.assertEqual(a.t(0.0), 0.0)
+
+ def test_assignment_behavior(self):
+ # see gh-106448
+ class A:
+ @functools.singledispatchmethod
+ def t(arg):
+ return arg
+
+ a = A()
+ a.t.foo = 'bar'
+ a2 = A()
+ with self.assertRaises(AttributeError):
+ a2.t.foo
+
def test_classmethod_register(self):
class A:
def __init__(self, arg):