]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-76785: Add Interpreter.prepare_main() (gh-113021)
authorEric Snow <ericsnowcurrently@gmail.com>
Tue, 12 Dec 2023 18:06:06 +0000 (11:06 -0700)
committerGitHub <noreply@github.com>
Tue, 12 Dec 2023 18:06:06 +0000 (18:06 +0000)
This is one of the last pieces to get test.support.interpreters in sync with PEP 734.

Lib/test/support/interpreters/__init__.py
Lib/test/test__xxinterpchannels.py
Lib/test/test__xxsubinterpreters.py
Lib/test/test_interpreters/test_api.py
Lib/test/test_interpreters/utils.py
Modules/_xxsubinterpretersmodule.c

index 2d6376deb5907e73813f6a695b90e6fe41d3fded..9cd1c3de0274d2733ff77c9d1b37b8d582d797fc 100644 (file)
@@ -130,7 +130,15 @@ class Interpreter:
         """
         return _interpreters.destroy(self._id)
 
-    def exec_sync(self, code, /, channels=None):
+    def prepare_main(self, ns=None, /, **kwargs):
+        """Bind the given values into the interpreter's __main__.
+
+        The values must be shareable.
+        """
+        ns = dict(ns, **kwargs) if ns is not None else kwargs
+        _interpreters.set___main___attrs(self._id, ns)
+
+    def exec_sync(self, code, /):
         """Run the given source code in the interpreter.
 
         This is essentially the same as calling the builtin "exec"
@@ -148,13 +156,13 @@ class Interpreter:
         that time, the previous interpreter is allowed to run
         in other threads.
         """
-        excinfo = _interpreters.exec(self._id, code, channels)
+        excinfo = _interpreters.exec(self._id, code)
         if excinfo is not None:
             raise ExecFailure(excinfo)
 
-    def run(self, code, /, channels=None):
+    def run(self, code, /):
         def task():
-            self.exec_sync(code, channels=channels)
+            self.exec_sync(code)
         t = threading.Thread(target=task)
         t.start()
         return t
index 13c8a10296e50246da122227a96f0778afd116d7..cc2ed7849b0c0f869bb2a14ca44799c6015f4e7b 100644 (file)
@@ -586,12 +586,12 @@ class ChannelTests(TestBase):
         cid = channels.create()
         interp = interpreters.create()
 
+        interpreters.set___main___attrs(interp, dict(cid=cid.send))
         out = _run_output(interp, dedent("""
             import _xxinterpchannels as _channels
             print(cid.end)
             _channels.send(cid, b'spam', blocking=False)
-            """),
-            dict(cid=cid.send))
+            """))
         obj = channels.recv(cid)
 
         self.assertEqual(obj, b'spam')
index 260ab64b07cb2d1afa7ed811a1fcd03b9f97600b..a76e4d0ade5b8ae46c16571328ce823363499c89 100644 (file)
@@ -33,10 +33,10 @@ def _captured_script(script):
     return wrapped, open(r, encoding="utf-8")
 
 
-def _run_output(interp, request, shared=None):
+def _run_output(interp, request):
     script, rpipe = _captured_script(request)
     with rpipe:
-        interpreters.run_string(interp, script, shared)
+        interpreters.run_string(interp, script)
         return rpipe.read()
 
 
@@ -630,10 +630,10 @@ class RunStringTests(TestBase):
         ]
         for obj in objects:
             with self.subTest(obj):
+                interpreters.set___main___attrs(interp, dict(obj=obj))
                 interpreters.run_string(
                     interp,
                     f'assert(obj == {obj!r})',
-                    shared=dict(obj=obj),
                 )
 
     def test_os_exec(self):
@@ -721,7 +721,8 @@ class RunStringTests(TestBase):
             with open({w}, 'wb') as chan:
                 pickle.dump(ns, chan)
             """)
-        interpreters.run_string(self.id, script, shared)
+        interpreters.set___main___attrs(self.id, shared)
+        interpreters.run_string(self.id, script)
         with open(r, 'rb') as chan:
             ns = pickle.load(chan)
 
@@ -742,7 +743,8 @@ class RunStringTests(TestBase):
             ns2 = dict(vars())
             del ns2['__builtins__']
         """)
-        interpreters.run_string(self.id, script, shared)
+        interpreters.set___main___attrs(self.id, shared)
+        interpreters.run_string(self.id, script)
 
         r, w = os.pipe()
         script = dedent(f"""
@@ -773,7 +775,8 @@ class RunStringTests(TestBase):
             with open({w}, 'wb') as chan:
                 pickle.dump(ns, chan)
             """)
-        interpreters.run_string(self.id, script, shared)
+        interpreters.set___main___attrs(self.id, shared)
+        interpreters.run_string(self.id, script)
         with open(r, 'rb') as chan:
             ns = pickle.load(chan)
 
@@ -1036,7 +1039,8 @@ class RunFuncTests(TestBase):
             with open(w, 'w', encoding="utf-8") as spipe:
                 with contextlib.redirect_stdout(spipe):
                     print('it worked!', end='')
-        interpreters.run_func(self.id, script, shared=dict(w=w))
+        interpreters.set___main___attrs(self.id, dict(w=w))
+        interpreters.run_func(self.id, script)
 
         with open(r, encoding="utf-8") as outfile:
             out = outfile.read()
@@ -1052,7 +1056,8 @@ class RunFuncTests(TestBase):
                 with contextlib.redirect_stdout(spipe):
                     print('it worked!', end='')
         def f():
-            interpreters.run_func(self.id, script, shared=dict(w=w))
+            interpreters.set___main___attrs(self.id, dict(w=w))
+            interpreters.run_func(self.id, script)
         t = threading.Thread(target=f)
         t.start()
         t.join()
@@ -1072,7 +1077,8 @@ class RunFuncTests(TestBase):
                 with contextlib.redirect_stdout(spipe):
                     print('it worked!', end='')
         code = script.__code__
-        interpreters.run_func(self.id, code, shared=dict(w=w))
+        interpreters.set___main___attrs(self.id, dict(w=w))
+        interpreters.run_func(self.id, code)
 
         with open(r, encoding="utf-8") as outfile:
             out = outfile.read()
index e4ae9d005b5282eaf0dd4e88025856fc38b67f21..b702338c3de1ad397426ef8a5d67050db8c127b1 100644 (file)
@@ -452,6 +452,63 @@ class TestInterpreterClose(TestBase):
         self.assertEqual(os.read(r_interp, 1), FINISHED)
 
 
+class TestInterpreterPrepareMain(TestBase):
+
+    def test_empty(self):
+        interp = interpreters.create()
+        with self.assertRaises(ValueError):
+            interp.prepare_main()
+
+    def test_dict(self):
+        values = {'spam': 42, 'eggs': 'ham'}
+        interp = interpreters.create()
+        interp.prepare_main(values)
+        out = _run_output(interp, dedent("""
+            print(spam, eggs)
+            """))
+        self.assertEqual(out.strip(), '42 ham')
+
+    def test_tuple(self):
+        values = {'spam': 42, 'eggs': 'ham'}
+        values = tuple(values.items())
+        interp = interpreters.create()
+        interp.prepare_main(values)
+        out = _run_output(interp, dedent("""
+            print(spam, eggs)
+            """))
+        self.assertEqual(out.strip(), '42 ham')
+
+    def test_kwargs(self):
+        values = {'spam': 42, 'eggs': 'ham'}
+        interp = interpreters.create()
+        interp.prepare_main(**values)
+        out = _run_output(interp, dedent("""
+            print(spam, eggs)
+            """))
+        self.assertEqual(out.strip(), '42 ham')
+
+    def test_dict_and_kwargs(self):
+        values = {'spam': 42, 'eggs': 'ham'}
+        interp = interpreters.create()
+        interp.prepare_main(values, foo='bar')
+        out = _run_output(interp, dedent("""
+            print(spam, eggs, foo)
+            """))
+        self.assertEqual(out.strip(), '42 ham bar')
+
+    def test_not_shareable(self):
+        interp = interpreters.create()
+        # XXX TypeError?
+        with self.assertRaises(ValueError):
+            interp.prepare_main(spam={'spam': 'eggs', 'foo': 'bar'})
+
+        # Make sure neither was actually bound.
+        with self.assertRaises(interpreters.ExecFailure):
+            interp.exec_sync('print(foo)')
+        with self.assertRaises(interpreters.ExecFailure):
+            interp.exec_sync('print(spam)')
+
+
 class TestInterpreterExecSync(TestBase):
 
     def test_success(self):
index 623c8737b79831c9da5bfad43430ad35836584bc..11b6f126dff0f439a09355eb43fe21258f0f9e2a 100644 (file)
@@ -29,10 +29,12 @@ def clean_up_interpreters():
             pass  # already destroyed
 
 
-def _run_output(interp, request, channels=None):
+def _run_output(interp, request, init=None):
     script, rpipe = _captured_script(request)
     with rpipe:
-        interp.exec_sync(script, channels=channels)
+        if init:
+            interp.prepare_main(init)
+        interp.exec_sync(script)
         return rpipe.read()
 
 
index 37959e953ee4f5e891fda5ba40ac3c4114cdd434..4bb54c93b0a61b8f9d8d011b145a4ad1439545af 100644 (file)
@@ -685,6 +685,60 @@ PyDoc_STRVAR(get_main_doc,
 \n\
 Return the ID of main interpreter.");
 
+static PyObject *
+interp_set___main___attrs(PyObject *self, PyObject *args)
+{
+    PyObject *id, *updates;
+    if (!PyArg_ParseTuple(args, "OO:" MODULE_NAME ".set___main___attrs",
+                          &id, &updates))
+    {
+        return NULL;
+    }
+
+    // Look up the interpreter.
+    PyInterpreterState *interp = PyInterpreterID_LookUp(id);
+    if (interp == NULL) {
+        return NULL;
+    }
+
+    // Check the updates.
+    if (updates != Py_None) {
+        Py_ssize_t size = PyObject_Size(updates);
+        if (size < 0) {
+            return NULL;
+        }
+        if (size == 0) {
+            PyErr_SetString(PyExc_ValueError,
+                            "arg 2 must be a non-empty mapping");
+            return NULL;
+        }
+    }
+
+    _PyXI_session session = {0};
+
+    // Prep and switch interpreters, including apply the updates.
+    if (_PyXI_Enter(&session, interp, updates) < 0) {
+        if (!PyErr_Occurred()) {
+            _PyXI_ApplyCapturedException(&session);
+            assert(PyErr_Occurred());
+        }
+        else {
+            assert(!_PyXI_HasCapturedException(&session));
+        }
+        return NULL;
+    }
+
+    // Clean up and switch back.
+    _PyXI_Exit(&session);
+
+    Py_RETURN_NONE;
+}
+
+PyDoc_STRVAR(set___main___attrs_doc,
+"set___main___attrs(id, ns)\n\
+\n\
+Bind the given attributes in the interpreter's __main__ module.");
+
 static PyUnicodeObject *
 convert_script_arg(PyObject *arg, const char *fname, const char *displayname,
                    const char *expected)
@@ -1033,6 +1087,8 @@ static PyMethodDef module_functions[] = {
     {"run_func",                  _PyCFunction_CAST(interp_run_func),
      METH_VARARGS | METH_KEYWORDS, run_func_doc},
 
+    {"set___main___attrs",        _PyCFunction_CAST(interp_set___main___attrs),
+     METH_VARARGS, set___main___attrs_doc},
     {"is_shareable",              _PyCFunction_CAST(object_is_shareable),
      METH_VARARGS | METH_KEYWORDS, is_shareable_doc},