]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-117953: Other Cleanups in the Extensions Machinery (gh-118206)
authorEric Snow <ericsnowcurrently@gmail.com>
Fri, 3 May 2024 00:51:43 +0000 (18:51 -0600)
committerGitHub <noreply@github.com>
Fri, 3 May 2024 00:51:43 +0000 (00:51 +0000)
This change will make some later changes simpler.

Lib/test/test_import/__init__.py
Modules/_testsinglephase.c
Python/import.c

index f88b51f916cc8188e083ffe4194c6be13be02364..b2aae774a1220569ae0268c86e20b3d40762338f 100644 (file)
@@ -2285,6 +2285,107 @@ class SubinterpImportTests(unittest.TestCase):
 
 
 class TestSinglePhaseSnapshot(ModuleSnapshot):
+    """A representation of a single-phase init module for testing.
+
+    Fields from ModuleSnapshot:
+
+    * id - id(mod)
+    * module - mod or a SimpleNamespace with __file__ & __spec__
+    * ns - a shallow copy of mod.__dict__
+    * ns_id - id(mod.__dict__)
+    * cached - sys.modules[name] (or None if not there or not snapshotable)
+    * cached_id - id(sys.modules[name]) (or None if not there)
+
+    Extra fields:
+
+    * summed - the result of calling "mod.sum(1, 2)"
+    * lookedup - the result of calling "mod.look_up_self()"
+    * lookedup_id - the object ID of self.lookedup
+    * state_initialized - the result of calling "mod.state_initialized()"
+    * init_count - (optional) the result of calling "mod.initialized_count()"
+
+    Overridden methods from ModuleSnapshot:
+
+    * from_module()
+    * parse()
+
+    Other methods from ModuleSnapshot:
+
+    * build_script()
+    * from_subinterp()
+
+    ----
+
+    There are 5 modules in Modules/_testsinglephase.c:
+
+    * _testsinglephase
+       * has global state
+       * extra loads skip the init function, copy def.m_base.m_copy
+       * counts calls to init function
+    * _testsinglephase_basic_wrapper
+       * _testsinglephase by another name (and separate init function symbol)
+    * _testsinglephase_basic_copy
+       * same as _testsinglephase but with own def (and init func)
+    * _testsinglephase_with_reinit
+       * has no global or module state
+       * mod.state_initialized returns None
+       * an extra load in the main interpreter calls the cached init func
+       * an extra load in legacy subinterpreters does a full load
+    * _testsinglephase_with_state
+       * has module state
+       * an extra load in the main interpreter calls the cached init func
+       * an extra load in legacy subinterpreters does a full load
+
+    (See Modules/_testsinglephase.c for more info.)
+
+    For all those modules, the snapshot after the initial load (not in
+    the global extensions cache) would look like the following:
+
+    * initial load
+       * id: ID of nww module object
+       * ns: exactly what the module init put there
+       * ns_id: ID of new module's __dict__
+       * cached_id: same as self.id
+       * summed: 3  (never changes)
+       * lookedup_id: same as self.id
+       * state_initialized: a timestamp between the time of the load
+         and the time of the snapshot
+       * init_count: 1  (None for _testsinglephase_with_reinit)
+
+    For the other scenarios it varies.
+
+    For the _testsinglephase, _testsinglephase_basic_wrapper, and
+    _testsinglephase_basic_copy modules, the snapshot should look
+    like the following:
+
+    * reloaded
+       * id: no change
+       * ns: matches what the module init function put there,
+         including the IDs of all contained objects,
+         plus any extra attributes added before the reload
+       * ns_id: no change
+       * cached_id: no change
+       * lookedup_id: no change
+       * state_initialized: no change
+       * init_count: no change
+    * already loaded
+       * (same as initial load except for ns and state_initialized)
+       * ns: matches the initial load, incl. IDs of contained objects
+       * state_initialized: no change from initial load
+
+    For _testsinglephase_with_reinit:
+
+    * reloaded: same as initial load (old module & ns is discarded)
+    * already loaded: same as initial load (old module & ns is discarded)
+
+    For _testsinglephase_with_state:
+
+    * reloaded
+       * (same as initial load (old module & ns is discarded),
+         except init_count)
+       * init_count: increase by 1
+    * already loaded: same as reloaded
+    """
 
     @classmethod
     def from_module(cls, mod):
@@ -2901,17 +3002,18 @@ class SinglephaseInitTests(unittest.TestCase):
         #  * module's global state was initialized but cleared
 
         # Start with an interpreter that gets destroyed right away.
-        base = self.import_in_subinterp(postscript='''
-            # Attrs set after loading are not in m_copy.
-            mod.spam = 'spam, spam, mash, spam, eggs, and spam'
-        ''')
+        base = self.import_in_subinterp(
+            postscript='''
+                # Attrs set after loading are not in m_copy.
+                mod.spam = 'spam, spam, mash, spam, eggs, and spam'
+                ''')
         self.check_common(base)
         self.check_fresh(base)
 
         # At this point:
         #  * alive in 0 interpreters
         #  * module def in _PyRuntime.imports.extensions
-        #  * mod init func ran again
+        #  * mod init func ran for the first time (since reset)
         #  * m_copy is NULL (claered when the interpreter was destroyed)
         #  * module's global state was initialized, not reset
 
@@ -2923,7 +3025,7 @@ class SinglephaseInitTests(unittest.TestCase):
         # At this point:
         #  * alive in 1 interpreter (interp1)
         #  * module def still in _PyRuntime.imports.extensions
-        #  * mod init func ran again
+        #  * mod init func ran for the second time (since reset)
         #  * m_copy was copied from interp1 (was NULL)
         #  * module's global state was updated, not reset
 
@@ -2935,7 +3037,7 @@ class SinglephaseInitTests(unittest.TestCase):
         # At this point:
         #  * alive in 2 interpreters (interp1, interp2)
         #  * module def still in _PyRuntime.imports.extensions
-        #  * mod init func ran again
+        #  * mod init func did not run again
         #  * m_copy was copied from interp2 (was from interp1)
         #  * module's global state was updated, not reset
 
index 092673a9ea43e110125ac3d2fdfc193cbc3707b9..ff533e44a82730cf70ccb9393338dacf0a2d296e 100644 (file)
@@ -1,6 +1,172 @@
 
 /* Testing module for single-phase initialization of extension modules
- */
+
+This file contains 5 distinct modules, meaning each as its own name
+and its own init function (PyInit_...).  The default import system will
+only find the one matching the filename: _testsinglephase.  To load the
+others you must do so manually.  For example:
+
+```python
+name = '_testsinglephase_base_wrapper'
+filename = _testsinglephase.__file__
+loader = importlib.machinery.ExtensionFileLoader(name, filename)
+spec = importlib.util.spec_from_file_location(name, filename, loader=loader)
+mod = importlib._bootstrap._load(spec)
+```
+
+Here are the 5 modules:
+
+* _testsinglephase
+   * def: _testsinglephase_basic,
+      * m_name: "_testsinglephase"
+      * m_size: -1
+   * state
+      * process-global
+         * <int> initialized_count  (default to -1; will never be 0)
+         * <module_state> module  (see module state below)
+      * module state: no
+      * initial __dict__: see common initial __dict__ below
+   * init function
+      1. create module
+      2. clear <global>.module
+      3. initialize <global>.module: see module state below
+      4. initialize module: set initial __dict__
+      5. increment <global>.initialized_count
+   * functions
+      * (3 common, see below)
+      * initialized_count() - return <global>.module.initialized_count
+   * import system
+      * caches
+         * global extensions cache: yes
+         * def.m_base.m_copy: yes
+         * def.m_base.m_init: no
+         * per-interpreter cache: yes  (all single-phase init modules)
+      * load in main interpreter
+         * initial  (not already in global cache)
+            1. get init function from shared object file
+            2. run init function
+            3. copy __dict__ into def.m_base.m_copy
+            4. set entry in global cache
+            5. set entry in per-interpreter cache
+            6. set entry in sys.modules
+         * reload  (already in sys.modules)
+            1. get def from global cache
+            2. get module from sys.modules
+            3. update module with contents of def.m_base.m_copy
+         * already loaded in other interpreter  (already in global cache)
+            * same as reload, but create new module and update *it*
+         * not in any sys.modules, still in global cache
+            * same as already loaded
+      * load in legacy (non-isolated) interpreter
+         * same as main interpreter
+      * unload: never  (all single-phase init modules)
+* _testsinglephase_basic_wrapper
+   * identical to _testsinglephase except module name
+* _testsinglephase_basic_copy
+   * def: static local variable in init function
+      * m_name: "_testsinglephase_basic_copy"
+      * m_size: -1
+   * state: same as _testsinglephase
+   * init function: same as _testsinglephase
+   * functions: same as _testsinglephase
+   * import system: same as _testsinglephase
+* _testsinglephase_with_reinit
+   * def: _testsinglephase_with_reinit,
+      * m_name: "_testsinglephase_with_reinit"
+      * m_size: 0
+   * state
+      * process-global state: no
+      * module state: no
+      * initial __dict__: see common initial __dict__ below
+   * init function
+      1. create module
+      2. initialize temporary module state (local var): see module state below
+      3. initialize module: set initial __dict__
+   * functions: see common functions below
+   * import system
+      * caches
+         * global extensions cache: only if loaded in main interpreter
+         * def.m_base.m_copy: no
+         * def.m_base.m_init: only if loaded in the main interpreter
+         * per-interpreter cache: yes  (all single-phase init modules)
+      * load in main interpreter
+         * initial  (not already in global cache)
+            * (same as _testsinglephase except step 3)
+            1. get init function from shared object file
+            2. run init function
+            3. set def.m_base.m_init to the init function
+            4. set entry in global cache
+            5. set entry in per-interpreter cache
+            6. set entry in sys.modules
+         * reload  (already in sys.modules)
+            1. get def from global cache
+            2. call def->m_base.m_init to get a new module object
+            3. replace the existing module in sys.modules
+         * already loaded in other interpreter  (already in global cache)
+            * same as reload (since will only be in cache for main interp)
+         * not in any sys.modules, still in global cache
+            * same as already loaded
+      * load in legacy (non-isolated) interpreter
+         * initial  (not already in global cache)
+            * (same as main interpreter except skip steps 3 & 4 there)
+            1. get init function from shared object file
+            2. run init function
+            ...
+            5. set entry in per-interpreter cache
+            6. set entry in sys.modules
+         * reload  (already in sys.modules)
+            * same as initial  (load from scratch)
+         * already loaded in other interpreter  (already in global cache)
+            * same as initial  (load from scratch)
+         * not in any sys.modules, still in global cache
+            * same as initial  (load from scratch)
+      * unload: never  (all single-phase init modules)
+* _testsinglephase_with_state
+   * def: _testsinglephase_with_state,
+      * m_name: "_testsinglephase_with_state"
+      * m_size: sizeof(module_state)
+   * state
+      * process-global: no
+      * module state: see module state below
+      * initial __dict__: see common initial __dict__ below
+   * init function
+      1. create module
+      3. initialize module state: see module state below
+      4. initialize module: set initial __dict__
+      5. increment <global>.initialized_count
+   * functions: see common functions below
+   * import system: same as _testsinglephase_basic_copy
+
+Module state:
+
+* fields
+   * <PyTime_t> initialized - when the module was first initialized
+   * <PyObject> *error
+   * <PyObject> *int_const
+   * <PyObject> *str_const
+* initialization
+   1. set state.initialized to the current time
+   2. set state.error to a new exception class
+   3. set state->int_const to int(1969)
+   4. set state->str_const to "something different"
+
+Common initial __dict__:
+
+* error: state.error
+* int_const: state.int_const
+* str_const: state.str_const
+* _module_initialized: state.initialized
+
+Common functions:
+
+* look_up_self() - return the module from the per-interpreter "by-index" cache
+* sum() - return a + b
+* state_initialized() - return state->initialized (or None if m_size == 0)
+
+See Python/import.c, especially the long comments, for more about
+single-phase init modules.
+*/
+
 #ifndef Py_BUILD_CORE_BUILTIN
 #  define Py_BUILD_CORE_MODULE 1
 #endif
index 0c51ffc6285a9c618150b7006f50f97c422a3061..f120a3841b495da4034f142fdc78e0b2cc2a676c 100644 (file)
@@ -645,33 +645,33 @@ _PyImport_ClearModulesByIndex(PyInterpreterState *interp)
        K.         PyModule_CreateInitialized() -> PyModule_SetDocString()
        L.       PyModule_CreateInitialized():  set mod->md_def
        M.       <module init func>:  initialize the module, etc.
-       N.     _PyImport_RunModInitFunc():  set def->m_base.m_init
-       O.   import_run_extension()
+       N.   import_run_extension()
                 -> _PyImport_CheckSubinterpIncompatibleExtensionAllowed()
-       P.   import_run_extension():  set __file__
-       Q.   import_run_extension() -> update_global_state_for_extension()
-       R.     update_global_state_for_extension():
+       O.   import_run_extension():  set __file__
+       P.   import_run_extension() -> update_global_state_for_extension()
+       Q.     update_global_state_for_extension():
                       copy __dict__ into def->m_base.m_copy
-       S.     update_global_state_for_extension():
+       R.     update_global_state_for_extension():
                       add it to _PyRuntime.imports.extensions
-       T.   import_run_extension() -> finish_singlephase_extension()
-       U.     finish_singlephase_extension():
+       S.   import_run_extension() -> finish_singlephase_extension()
+       T.     finish_singlephase_extension():
                       add it to interp->imports.modules_by_index
-       V.     finish_singlephase_extension():  add it to sys.modules
+       U.     finish_singlephase_extension():  add it to sys.modules
 
        Step (Q) is skipped for core modules (sys/builtins).
 
     (6). subsequent times  (found in _PyRuntime.imports.extensions):
        A. _imp_create_dynamic_impl() -> import_find_extension()
-       B.   import_find_extension()
-                -> _PyImport_CheckSubinterpIncompatibleExtensionAllowed()
-       C.   import_find_extension() -> import_add_module()
-       D.     if name in sys.modules:  use that module
-       E.     else:
-                1. import_add_module() -> PyModule_NewObject()
-                2. import_add_module():  set it on sys.modules
-       F.   import_find_extension():  copy the "m_copy" dict into __dict__
-       G.   import_find_extension():  add to modules_by_index
+       B.   import_find_extension() -> reload_singlephase_extension()
+       C.     reload_singlephase_extension()
+                  -> _PyImport_CheckSubinterpIncompatibleExtensionAllowed()
+       D.     reload_singlephase_extension() -> import_add_module()
+       E.       if name in sys.modules:  use that module
+       F.       else:
+                  1. import_add_module() -> PyModule_NewObject()
+                  2. import_add_module():  set it on sys.modules
+       G.     reload_singlephase_extension():  copy the "m_copy" dict into __dict__
+       H.     reload_singlephase_extension():  add to modules_by_index
 
     (10). (every time):
        A. noop
@@ -681,21 +681,23 @@ _PyImport_ClearModulesByIndex(PyInterpreterState *interp)
 
     (6). not main interpreter and never loaded there - every time  (not found in _PyRuntime.imports.extensions):
        A-P. (same as for m_size == -1)
-       Q-S. (skipped)
-       T-V. (same as for m_size == -1)
+       Q.     _PyImport_RunModInitFunc():  set def->m_base.m_init
+       R. (skipped)
+       S-U. (same as for m_size == -1)
 
     (6). main interpreter - first time  (not found in _PyRuntime.imports.extensions):
-       A-R. (same as for m_size == -1)
-       S. (skipped)
-       T-V. (same as for m_size == -1)
+       A-P. (same as for m_size == -1)
+       Q.     _PyImport_RunModInitFunc():  set def->m_base.m_init
+       R-U. (same as for m_size == -1)
 
     (6). subsequent times  (found in _PyRuntime.imports.extensions):
        A. _imp_create_dynamic_impl() -> import_find_extension()
-       B.   import_find_extension()
-                -> _PyImport_CheckSubinterpIncompatibleExtensionAllowed()
-       C.   import_find_extension():  call def->m_base.m_init  (see above)
-       D.   import_find_extension():  add the module to sys.modules
-       E.   import_find_extension():  add to modules_by_index
+       B.   import_find_extension() -> reload_singlephase_extension()
+       C.     reload_singlephase_extension()
+                  -> _PyImport_CheckSubinterpIncompatibleExtensionAllowed()
+       D.     reload_singlephase_extension():  call def->m_base.m_init  (see above)
+       E.     reload_singlephase_extension():  add the module to sys.modules
+       F.     reload_singlephase_extension():  add to modules_by_index
 
     (10). every time:
        A. noop
@@ -984,84 +986,103 @@ hashtable_destroy_str(void *ptr)
 
 #define HTSEP ':'
 
-static PyModuleDef *
-_extensions_cache_get(PyObject *filename, PyObject *name)
-{
-    PyModuleDef *def = NULL;
-    void *key = NULL;
-    extensions_lock_acquire();
-
+static int
+_extensions_cache_init(void)
+{
+    _Py_hashtable_allocator_t alloc = {PyMem_RawMalloc, PyMem_RawFree};
+    EXTENSIONS.hashtable = _Py_hashtable_new_full(
+        hashtable_hash_str,
+        hashtable_compare_str,
+        hashtable_destroy_str,  // key
+        /* There's no need to decref the def since it's immortal. */
+        NULL,  // value
+        &alloc
+    );
     if (EXTENSIONS.hashtable == NULL) {
-        goto finally;
+        PyErr_NoMemory();
+        return -1;
     }
+    return 0;
+}
 
-    key = hashtable_key_from_2_strings(filename, name, HTSEP);
+static _Py_hashtable_entry_t *
+_extensions_cache_find_unlocked(PyObject *path, PyObject *name,
+                                void **p_key)
+{
+    if (EXTENSIONS.hashtable == NULL) {
+        return NULL;
+    }
+    void *key = hashtable_key_from_2_strings(path, name, HTSEP);
     if (key == NULL) {
-        goto finally;
+        return NULL;
+    }
+    _Py_hashtable_entry_t *entry =
+            _Py_hashtable_get_entry(EXTENSIONS.hashtable, key);
+    if (p_key != NULL) {
+        *p_key = key;
     }
-    _Py_hashtable_entry_t *entry = _Py_hashtable_get_entry(
-            EXTENSIONS.hashtable, key);
+    else {
+        hashtable_destroy_str(key);
+    }
+    return entry;
+}
+
+static PyModuleDef *
+_extensions_cache_get(PyObject *path, PyObject *name)
+{
+    PyModuleDef *def = NULL;
+    extensions_lock_acquire();
+
+    _Py_hashtable_entry_t *entry =
+            _extensions_cache_find_unlocked(path, name, NULL);
     if (entry == NULL) {
+        /* It was never added. */
         goto finally;
     }
     def = (PyModuleDef *)entry->value;
 
 finally:
     extensions_lock_release();
-    if (key != NULL) {
-        PyMem_RawFree(key);
-    }
     return def;
 }
 
 static int
-_extensions_cache_set(PyObject *filename, PyObject *name, PyModuleDef *def)
+_extensions_cache_set(PyObject *path, PyObject *name, PyModuleDef *def)
 {
     int res = -1;
+    assert(def != NULL);
     extensions_lock_acquire();
 
     if (EXTENSIONS.hashtable == NULL) {
-        _Py_hashtable_allocator_t alloc = {PyMem_RawMalloc, PyMem_RawFree};
-        EXTENSIONS.hashtable = _Py_hashtable_new_full(
-            hashtable_hash_str,
-            hashtable_compare_str,
-            hashtable_destroy_str,  // key
-            /* There's no need to decref the def since it's immortal. */
-            NULL,  // value
-            &alloc
-        );
-        if (EXTENSIONS.hashtable == NULL) {
-            PyErr_NoMemory();
+        if (_extensions_cache_init() < 0) {
             goto finally;
         }
     }
 
-    void *key = hashtable_key_from_2_strings(filename, name, HTSEP);
-    if (key == NULL) {
-        goto finally;
-    }
-
     int already_set = 0;
-    _Py_hashtable_entry_t *entry = _Py_hashtable_get_entry(
-            EXTENSIONS.hashtable, key);
+    void *key = NULL;
+    _Py_hashtable_entry_t *entry =
+            _extensions_cache_find_unlocked(path, name, &key);
     if (entry == NULL) {
+        /* It was never added. */
         if (_Py_hashtable_set(EXTENSIONS.hashtable, key, def) < 0) {
-            PyMem_RawFree(key);
             PyErr_NoMemory();
             goto finally;
         }
+        /* The hashtable owns the key now. */
+        key = NULL;
+    }
+    else if (entry->value == NULL) {
+        /* It was previously deleted. */
+        entry->value = def;
     }
     else {
-        if (entry->value == NULL) {
-            entry->value = def;
-        }
-        else {
-            /* We expect it to be static, so it must be the same pointer. */
-            assert((PyModuleDef *)entry->value == def);
-            already_set = 1;
-        }
-        PyMem_RawFree(key);
+        /* We expect it to be static, so it must be the same pointer. */
+        assert((PyModuleDef *)entry->value == def);
+        /* It was already added. */
+        already_set = 1;
     }
+
     if (!already_set) {
         /* We assume that all module defs are statically allocated
            and will never be freed.  Otherwise, we would incref here. */
@@ -1071,13 +1092,15 @@ _extensions_cache_set(PyObject *filename, PyObject *name, PyModuleDef *def)
 
 finally:
     extensions_lock_release();
+    if (key != NULL) {
+        hashtable_destroy_str(key);
+    }
     return res;
 }
 
 static void
-_extensions_cache_delete(PyObject *filename, PyObject *name)
+_extensions_cache_delete(PyObject *path, PyObject *name)
 {
-    void *key = NULL;
     extensions_lock_acquire();
 
     if (EXTENSIONS.hashtable == NULL) {
@@ -1085,13 +1108,8 @@ _extensions_cache_delete(PyObject *filename, PyObject *name)
         goto finally;
     }
 
-    key = hashtable_key_from_2_strings(filename, name, HTSEP);
-    if (key == NULL) {
-        goto finally;
-    }
-
-    _Py_hashtable_entry_t *entry = _Py_hashtable_get_entry(
-            EXTENSIONS.hashtable, key);
+    _Py_hashtable_entry_t *entry =
+            _extensions_cache_find_unlocked(path, name, NULL);
     if (entry == NULL) {
         /* It was never added. */
         goto finally;
@@ -1109,9 +1127,6 @@ _extensions_cache_delete(PyObject *filename, PyObject *name)
 
 finally:
     extensions_lock_release();
-    if (key != NULL) {
-        PyMem_RawFree(key);
-    }
 }
 
 static void
@@ -1359,15 +1374,11 @@ finish_singlephase_extension(PyThreadState *tstate,
 
 
 static PyObject *
-import_find_extension(PyThreadState *tstate,
-                      struct _Py_ext_module_loader_info *info)
+reload_singlephase_extension(PyThreadState *tstate, PyModuleDef *def,
+                             struct _Py_ext_module_loader_info *info)
 {
-    /* Only single-phase init modules will be in the cache. */
-    PyModuleDef *def = _extensions_cache_get(info->path, info->name);
-    if (def == NULL) {
-        return NULL;
-    }
     assert_singlephase(def);
+    PyObject *mod = NULL;
 
     /* It may have been successfully imported previously
        in an interpreter that allows legacy modules
@@ -1378,9 +1389,7 @@ import_find_extension(PyThreadState *tstate,
         return NULL;
     }
 
-    PyObject *mod, *mdict;
     PyObject *modules = get_modules_dict(tstate, true);
-
     if (def->m_size == -1) {
         PyObject *m_copy = def->m_base.m_copy;
         /* Module does not support repeated initialization */
@@ -1390,6 +1399,7 @@ import_find_extension(PyThreadState *tstate,
             m_copy = get_core_module_dict(
                     tstate->interp, info->name, info->path);
             if (m_copy == NULL) {
+                assert(!PyErr_Occurred());
                 return NULL;
             }
         }
@@ -1397,7 +1407,7 @@ import_find_extension(PyThreadState *tstate,
         if (mod == NULL) {
             return NULL;
         }
-        mdict = PyModule_GetDict(mod);
+        PyObject *mdict = PyModule_GetDict(mod);
         if (mdict == NULL) {
             Py_DECREF(mod);
             return NULL;
@@ -1416,6 +1426,7 @@ import_find_extension(PyThreadState *tstate,
     }
     else {
         if (def->m_base.m_init == NULL) {
+            assert(!PyErr_Occurred());
             return NULL;
         }
         struct _Py_ext_module_loader_result res;
@@ -1445,12 +1456,41 @@ import_find_extension(PyThreadState *tstate,
             return NULL;
         }
     }
+
     if (_modules_by_index_set(tstate->interp, def, mod) < 0) {
         PyMapping_DelItem(modules, info->name);
         Py_DECREF(mod);
         return NULL;
     }
 
+    return mod;
+}
+
+static PyObject *
+import_find_extension(PyThreadState *tstate,
+                      struct _Py_ext_module_loader_info *info)
+{
+    /* Only single-phase init modules will be in the cache. */
+    PyModuleDef *def = _extensions_cache_get(info->path, info->name);
+    if (def == NULL) {
+        return NULL;
+    }
+    assert_singlephase(def);
+
+    /* It may have been successfully imported previously
+       in an interpreter that allows legacy modules
+       but is not allowed in the current interpreter. */
+    const char *name_buf = PyUnicode_AsUTF8(info->name);
+    assert(name_buf != NULL);
+    if (_PyImport_CheckSubinterpIncompatibleExtensionAllowed(name_buf) < 0) {
+        return NULL;
+    }
+
+    PyObject *mod = reload_singlephase_extension(tstate, def, info);
+    if (mod == NULL) {
+        return NULL;
+    }
+
     int verbose = _PyInterpreterState_GetConfig(tstate->interp)->verbose;
     if (verbose) {
         PySys_FormatStderr("import %U # previously loaded (%R)\n",