self.assertEqual(len(mss), len(mss2))
self.assertEqual(list(mss), list(mss2))
+ def test_illegal_patma_flags(self):
+ with self.assertRaises(TypeError):
+ class Both(Collection):
+ __abc_tpflags__ = (Sequence.__flags__ | Mapping.__flags__)
+
+
################################################################################
### Counter
self.assertEqual(f((False, range(10, 20), True)), alts[4])
+class TestInheritance(unittest.TestCase):
+
+ def test_multiple_inheritance(self):
+ class C:
+ pass
+ class S1(collections.UserList, collections.abc.Mapping):
+ pass
+ class S2(C, collections.UserList, collections.abc.Mapping):
+ pass
+ class S3(list, C, collections.abc.Mapping):
+ pass
+ class S4(collections.UserList, dict, C):
+ pass
+ class M1(collections.UserDict, collections.abc.Sequence):
+ pass
+ class M2(C, collections.UserDict, collections.abc.Sequence):
+ pass
+ class M3(collections.UserDict, C, list):
+ pass
+ class M4(dict, collections.abc.Sequence, C):
+ pass
+ def f(x):
+ match x:
+ case []:
+ return "seq"
+ case {}:
+ return "map"
+ def g(x):
+ match x:
+ case {}:
+ return "map"
+ case []:
+ return "seq"
+ for Seq in (S1, S2, S3, S4):
+ self.assertEqual(f(Seq()), "seq")
+ self.assertEqual(g(Seq()), "seq")
+ for Map in (M1, M2, M3, M4):
+ self.assertEqual(f(Map()), "map")
+ self.assertEqual(g(Map()), "map")
+
+
class PerfPatma(TestPatma):
def assertEqual(*_, **__):
--- /dev/null
+Prevent classes being both a sequence and a mapping when pattern matching.
if (val == -1 && PyErr_Occurred()) {
return NULL;
}
+ if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
+ PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
+ return NULL;
+ }
((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS);
}
if (_PyDict_DelItemId(cls->tp_dict, &PyId___abc_tpflags__) < 0) {
/* Invalidate negative cache */
get_abc_state(module)->abc_invalidation_counter++;
- if (PyType_Check(subclass) && PyType_Check(self) &&
- !PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE))
+ /* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
+ if (PyType_Check(self) &&
+ !PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE) &&
+ ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS)
{
+ ((PyTypeObject *)subclass)->tp_flags &= ~COLLECTION_FLAGS;
((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS);
}
Py_INCREF(subclass);
if (PyType_HasFeature(base, _Py_TPFLAGS_MATCH_SELF)) {
type->tp_flags |= _Py_TPFLAGS_MATCH_SELF;
}
- if (PyType_HasFeature(base, Py_TPFLAGS_SEQUENCE)) {
- type->tp_flags |= Py_TPFLAGS_SEQUENCE;
- }
- if (PyType_HasFeature(base, Py_TPFLAGS_MAPPING)) {
- type->tp_flags |= Py_TPFLAGS_MAPPING;
- }
}
static int
static int add_operators(PyTypeObject *);
static int add_tp_new_wrapper(PyTypeObject *type);
+#define COLLECTION_FLAGS (Py_TPFLAGS_SEQUENCE | Py_TPFLAGS_MAPPING)
static int
type_ready_checks(PyTypeObject *type)
_PyObject_ASSERT((PyObject *)type, type->tp_as_async->am_send != NULL);
}
+ /* Consistency checks for pattern matching
+ * Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING are mutually exclusive */
+ _PyObject_ASSERT((PyObject *)type, (type->tp_flags & COLLECTION_FLAGS) != COLLECTION_FLAGS);
+
if (type->tp_name == NULL) {
PyErr_Format(PyExc_SystemError,
"Type does not define the tp_name field.");
}
}
+static void
+inherit_patma_flags(PyTypeObject *type, PyTypeObject *base) {
+ if ((type->tp_flags & COLLECTION_FLAGS) == 0) {
+ type->tp_flags |= base->tp_flags & COLLECTION_FLAGS;
+ }
+}
static int
type_ready_inherit(PyTypeObject *type)
if (inherit_slots(type, (PyTypeObject *)b) < 0) {
return -1;
}
+ inherit_patma_flags(type, (PyTypeObject *)b);
}
}