except CustomError:
pass
+ def test_weak_cache_descriptor_use_after_free(self):
+ class BombDescriptor:
+ def __get__(self, obj, owner):
+ return {}
+
+ class EvilZoneInfo(self.klass):
+ pass
+
+ # Must be set after the class creation.
+ EvilZoneInfo._weak_cache = BombDescriptor()
+
+ key = "America/Los_Angeles"
+ zone1 = EvilZoneInfo(key)
+ self.assertEqual(str(zone1), key)
+
+ EvilZoneInfo.clear_cache()
+ zone2 = EvilZoneInfo(key)
+ self.assertEqual(str(zone2), key)
+ self.assertIsNot(zone2, zone1)
+
class CZoneInfoCacheTest(ZoneInfoCacheTest):
module = c_zoneinfo
get_weak_cache(zoneinfo_state *state, PyTypeObject *type)
{
if (type == state->ZoneInfoType) {
+ Py_INCREF(state->ZONEINFO_WEAK_CACHE);
return state->ZONEINFO_WEAK_CACHE;
}
else {
- PyObject *cache =
- PyObject_GetAttrString((PyObject *)type, "_weak_cache");
- // We are assuming that the type lives at least as long as the function
- // that calls get_weak_cache, and that it holds a reference to the
- // cache, so we'll return a "borrowed reference".
- Py_XDECREF(cache);
- return cache;
+ return PyObject_GetAttrString((PyObject *)type, "_weak_cache");
}
}
PyObject *weak_cache = get_weak_cache(state, type);
instance = PyObject_CallMethod(weak_cache, "get", "O", key, Py_None);
if (instance == NULL) {
+ Py_DECREF(weak_cache);
return NULL;
}
Py_DECREF(instance);
PyObject *tmp = zoneinfo_new_instance(state, type, key);
if (tmp == NULL) {
+ Py_DECREF(weak_cache);
return NULL;
}
PyObject_CallMethod(weak_cache, "setdefault", "OO", key, tmp);
Py_DECREF(tmp);
if (instance == NULL) {
+ Py_DECREF(weak_cache);
return NULL;
}
((PyZoneInfo_ZoneInfo *)instance)->source = SOURCE_CACHE;
}
update_strong_cache(state, type, key, instance);
+ Py_DECREF(weak_cache);
return instance;
}
PyObject *item = NULL;
PyObject *pop = PyUnicode_FromString("pop");
if (pop == NULL) {
+ Py_DECREF(weak_cache);
return NULL;
}
PyObject *iter = PyObject_GetIter(only_keys);
if (iter == NULL) {
Py_DECREF(pop);
+ Py_DECREF(weak_cache);
return NULL;
}
Py_DECREF(pop);
}
+ Py_DECREF(weak_cache);
if (PyErr_Occurred()) {
return NULL;
}