]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-87138: convert SHA-3 object type to heap type (GH-127670)
authorBénédikt Tran <10796600+picnixz@users.noreply.github.com>
Sun, 8 Dec 2024 17:31:10 +0000 (18:31 +0100)
committerGitHub <noreply@github.com>
Sun, 8 Dec 2024 17:31:10 +0000 (09:31 -0800)
Modules/sha3module.c

index ca839dc55e0519f3ba5c3279cb0784e80c5aa01c..b13e6a9de10114515bfece6ab1ac3c6c3de02b3b 100644 (file)
@@ -71,13 +71,13 @@ typedef struct {
 static SHA3object *
 newSHA3object(PyTypeObject *type)
 {
-    SHA3object *newobj;
-    newobj = (SHA3object *)PyObject_New(SHA3object, type);
+    SHA3object *newobj = PyObject_GC_New(SHA3object, type);
     if (newobj == NULL) {
         return NULL;
     }
     HASHLIB_INIT_MUTEX(newobj);
 
+    PyObject_GC_Track(newobj);
     return newobj;
 }
 
@@ -166,15 +166,32 @@ py_sha3_new_impl(PyTypeObject *type, PyObject *data, int usedforsecurity)
 
 /* Internal methods for a hash object */
 
+static int
+SHA3_clear(SHA3object *self)
+{
+    if (self->hash_state != NULL) {
+        Hacl_Hash_SHA3_free(self->hash_state);
+        self->hash_state = NULL;
+    }
+    return 0;
+}
+
 static void
 SHA3_dealloc(SHA3object *self)
 {
-    Hacl_Hash_SHA3_free(self->hash_state);
     PyTypeObject *tp = Py_TYPE(self);
-    PyObject_Free(self);
+    PyObject_GC_UnTrack(self);
+    (void)SHA3_clear(self);
+    tp->tp_free(self);
     Py_DECREF(tp);
 }
 
+static int
+SHA3_traverse(PyObject *self, visitproc visit, void *arg)
+{
+    Py_VISIT(Py_TYPE(self));
+    return 0;
+}
 
 /* External methods for a hash object */
 
@@ -335,6 +352,7 @@ static PyObject *
 SHA3_get_capacity_bits(SHA3object *self, void *closure)
 {
     uint32_t rate = Hacl_Hash_SHA3_block_len(self->hash_state) * 8;
+    assert(rate <= 1600);
     int capacity = 1600 - rate;
     return PyLong_FromLong(capacity);
 }
@@ -366,12 +384,14 @@ static PyGetSetDef SHA3_getseters[] = {
 
 #define SHA3_TYPE_SLOTS(type_slots_obj, type_doc, type_methods, type_getseters) \
     static PyType_Slot type_slots_obj[] = { \
+        {Py_tp_clear, SHA3_clear}, \
         {Py_tp_dealloc, SHA3_dealloc}, \
+        {Py_tp_traverse, SHA3_traverse}, \
         {Py_tp_doc, (char*)type_doc}, \
         {Py_tp_methods, type_methods}, \
         {Py_tp_getset, type_getseters}, \
         {Py_tp_new, py_sha3_new}, \
-        {0,0} \
+        {0, NULL} \
     }
 
 // Using _PyType_GetModuleState() on these types is safe since they
@@ -380,7 +400,8 @@ static PyGetSetDef SHA3_getseters[] = {
     static PyType_Spec type_spec_obj = { \
         .name = "_sha3." type_name, \
         .basicsize = sizeof(SHA3object), \
-        .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, \
+        .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE \
+                 | Py_TPFLAGS_HAVE_GC, \
         .slots = type_slots \
     }
 
@@ -444,9 +465,7 @@ _SHAKE_digest(SHA3object *self, unsigned long digestlen, int hex)
         result = PyBytes_FromStringAndSize((const char *)digest,
                                            digestlen);
     }
-    if (digest != NULL) {
-        PyMem_Free(digest);
-    }
+    PyMem_Free(digest);
     return result;
 }
 
@@ -563,7 +582,7 @@ _sha3_clear(PyObject *module)
 static void
 _sha3_free(void *module)
 {
-    _sha3_clear((PyObject *)module);
+    (void)_sha3_clear((PyObject *)module);
 }
 
 static int