]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-47167: Allow overriding a future compliance check in asyncio.Task (GH-32197)
authorAndrew Svetlov <andrew.svetlov@gmail.com>
Fri, 1 Apr 2022 01:25:15 +0000 (04:25 +0300)
committerGitHub <noreply@github.com>
Fri, 1 Apr 2022 01:25:15 +0000 (04:25 +0300)
Doc/library/asyncio-extending.rst
Lib/asyncio/tasks.py
Lib/test/test_asyncio/test_tasks.py
Misc/NEWS.d/next/Library/2022-03-30-18-35-50.bpo-47167.nCNHsB.rst [new file with mode: 0644]
Modules/_asynciomodule.c
Modules/clinic/_asynciomodule.c.h

index 619723e61b5f9e54c7f15a6404196b5c763d22b3..215d215bb14fe7503b7d3bddc11c04833680420a 100644 (file)
@@ -48,16 +48,27 @@ For this purpose the following, *private* constructors are listed:
 
 .. method:: Future.__init__(*, loop=None)
 
-Create a built-in future instance.
+   Create a built-in future instance.
 
-*loop* is an optional event loop instance.
+   *loop* is an optional event loop instance.
 
 .. method:: Task.__init__(coro, *, loop=None, name=None, context=None)
 
-Create a built-in task instance.
+   Create a built-in task instance.
 
-*loop* is an optional event loop instance. The rest of arguments are described in
-:meth:`loop.create_task` description.
+   *loop* is an optional event loop instance. The rest of arguments are described in
+   :meth:`loop.create_task` description.
+
+   .. versionchanged:: 3.11
+
+      *context* argument is added.
+
+.. method:: Tasl._check_future(future)
+
+   Return ``True`` if *future* is attached to the same loop as the task, ``False``
+   otherwise.
+
+   .. versionadded:: 3.11
 
 
 Task lifetime support
index 27fe58da15136a0a04e52a23dab22942e593bd31..3952b5f2a7743defd288196063f7d2ade1cea59d 100644 (file)
@@ -252,6 +252,10 @@ class Task(futures._PyFuture):  # Inherit Python Task implementation
             self._num_cancels_requested -= 1
         return self._num_cancels_requested
 
+    def _check_future(self, future):
+        """Return False if task and future loops are not compatible."""
+        return futures._get_loop(future) is self._loop
+
     def __step(self, exc=None):
         if self.done():
             raise exceptions.InvalidStateError(
@@ -292,7 +296,7 @@ class Task(futures._PyFuture):  # Inherit Python Task implementation
             blocking = getattr(result, '_asyncio_future_blocking', None)
             if blocking is not None:
                 # Yielded Future must come from Future.__iter__().
-                if futures._get_loop(result) is not self._loop:
+                if not self._check_future(result):
                     new_exc = RuntimeError(
                         f'Task {self!r} got Future '
                         f'{result!r} attached to a different loop')
index 8df1957bbe9e7a3d14dcfc2b6d946cb2da68af21..80afb27351362017cb464c8e96f2faa05dc30f19 100644 (file)
@@ -2383,7 +2383,13 @@ def add_subclass_tests(cls):
             return super().add_done_callback(*args, **kwargs)
 
     class Task(CommonFuture, BaseTask):
-        pass
+        def __init__(self, *args, **kwargs):
+            self._check_future_called = 0
+            super().__init__(*args, **kwargs)
+
+        def _check_future(self, future):
+            self._check_future_called += 1
+            return super()._check_future(future)
 
     class Future(CommonFuture, BaseFuture):
         pass
@@ -2409,6 +2415,8 @@ def add_subclass_tests(cls):
             dict(fut.calls),
             {'add_done_callback': 1})
 
+        self.assertEqual(1, task._check_future_called)
+
     # Add patched Task & Future back to the test case
     cls.Task = Task
     cls.Future = Future
diff --git a/Misc/NEWS.d/next/Library/2022-03-30-18-35-50.bpo-47167.nCNHsB.rst b/Misc/NEWS.d/next/Library/2022-03-30-18-35-50.bpo-47167.nCNHsB.rst
new file mode 100644 (file)
index 0000000..a37dd25
--- /dev/null
@@ -0,0 +1 @@
+Allow overriding a future compliance check in :class:`asyncio.Task`.
index 632a4465c224ab99f30922367031568f1539556e..d8d3da91cdd8e0a293f000427c19db949bfde68f 100644 (file)
@@ -23,6 +23,7 @@ _Py_IDENTIFIER(call_soon);
 _Py_IDENTIFIER(cancel);
 _Py_IDENTIFIER(get_event_loop);
 _Py_IDENTIFIER(throw);
+_Py_IDENTIFIER(_check_future);
 
 
 /* State of the _asyncio module */
@@ -1795,6 +1796,8 @@ class _asyncio.Task "TaskObj *" "&Task_Type"
 static int task_call_step_soon(TaskObj *, PyObject *);
 static PyObject * task_wakeup(TaskObj *, PyObject *);
 static PyObject * task_step(TaskObj *, PyObject *);
+static int task_check_future(TaskObj *, PyObject *);
+static int task_check_future_exact(TaskObj *, PyObject *);
 
 /* ----- Task._step wrapper */
 
@@ -2269,7 +2272,6 @@ Returns the remaining number of cancellation requests.
 static PyObject *
 _asyncio_Task_uncancel_impl(TaskObj *self)
 /*[clinic end generated code: output=58184d236a817d3c input=68f81a4b90b46be2]*/
-/*[clinic end generated code]*/
 {
     if (self->task_num_cancels_requested > 0) {
         self->task_num_cancels_requested -= 1;
@@ -2277,6 +2279,21 @@ _asyncio_Task_uncancel_impl(TaskObj *self)
     return PyLong_FromLong(self->task_num_cancels_requested);
 }
 
+/*[clinic input]
+_asyncio.Task._check_future -> bool
+
+    future: object
+
+Return False if task and future loops are not compatible.
+[clinic start generated code]*/
+
+static int
+_asyncio_Task__check_future_impl(TaskObj *self, PyObject *future)
+/*[clinic end generated code: output=a3bfba79295c8d57 input=3b1d6dfd6fe90aa5]*/
+{
+    return task_check_future_exact(self, future);
+}
+
 /*[clinic input]
 _asyncio.Task.get_stack
 
@@ -2502,6 +2519,7 @@ static PyMethodDef TaskType_methods[] = {
     _ASYNCIO_TASK_CANCEL_METHODDEF
     _ASYNCIO_TASK_CANCELLING_METHODDEF
     _ASYNCIO_TASK_UNCANCEL_METHODDEF
+    _ASYNCIO_TASK__CHECK_FUTURE_METHODDEF
     _ASYNCIO_TASK_GET_STACK_METHODDEF
     _ASYNCIO_TASK_PRINT_STACK_METHODDEF
     _ASYNCIO_TASK__MAKE_CANCELLED_ERROR_METHODDEF
@@ -2569,6 +2587,43 @@ TaskObj_dealloc(PyObject *self)
     Py_TYPE(task)->tp_free(task);
 }
 
+static int
+task_check_future_exact(TaskObj *task, PyObject *future)
+{
+    int res;
+    if (Future_CheckExact(future) || Task_CheckExact(future)) {
+        FutureObj *fut = (FutureObj *)future;
+        res = (fut->fut_loop == task->task_loop);
+    } else {
+        PyObject *oloop = get_future_loop(future);
+        if (oloop == NULL) {
+            return -1;
+        }
+        res = (oloop == task->task_loop);
+        Py_DECREF(oloop);
+    }
+    return res;
+}
+
+
+static int
+task_check_future(TaskObj *task, PyObject *future)
+{
+    if (Task_CheckExact(task)) {
+        return task_check_future_exact(task, future);
+    } else {
+        PyObject * ret = _PyObject_CallMethodIdOneArg((PyObject *)task,
+                                                      &PyId__check_future,
+                                                      future);
+        if (ret == NULL) {
+            return -1;
+        }
+        int is_true = PyObject_IsTrue(ret);
+        Py_DECREF(ret);
+        return is_true;
+    }
+}
+
 static int
 task_call_step_soon(TaskObj *task, PyObject *arg)
 {
@@ -2790,7 +2845,11 @@ task_step_impl(TaskObj *task, PyObject *exc)
         FutureObj *fut = (FutureObj*)result;
 
         /* Check if `result` future is attached to a different loop */
-        if (fut->fut_loop != task->task_loop) {
+        res = task_check_future(task, result);
+        if (res == -1) {
+            goto fail;
+        }
+        if (res == 0) {
             goto different_loop;
         }
 
@@ -2862,15 +2921,13 @@ task_step_impl(TaskObj *task, PyObject *exc)
         }
 
         /* Check if `result` future is attached to a different loop */
-        PyObject *oloop = get_future_loop(result);
-        if (oloop == NULL) {
+        res = task_check_future(task, result);
+        if (res == -1) {
             goto fail;
         }
-        if (oloop != task->task_loop) {
-            Py_DECREF(oloop);
+        if (res == 0) {
             goto different_loop;
         }
-        Py_DECREF(oloop);
 
         if (!blocking) {
             goto yield_insteadof_yf;
index 4b64367a3f6312e55aa28375eea3e23c9cb3543f..163b0f95691b31262435ef3809ee7bca716be3c8 100644 (file)
@@ -466,6 +466,43 @@ _asyncio_Task_uncancel(TaskObj *self, PyObject *Py_UNUSED(ignored))
     return _asyncio_Task_uncancel_impl(self);
 }
 
+PyDoc_STRVAR(_asyncio_Task__check_future__doc__,
+"_check_future($self, /, future)\n"
+"--\n"
+"\n"
+"Return False if task and future loops are not compatible.");
+
+#define _ASYNCIO_TASK__CHECK_FUTURE_METHODDEF    \
+    {"_check_future", (PyCFunction)(void(*)(void))_asyncio_Task__check_future, METH_FASTCALL|METH_KEYWORDS, _asyncio_Task__check_future__doc__},
+
+static int
+_asyncio_Task__check_future_impl(TaskObj *self, PyObject *future);
+
+static PyObject *
+_asyncio_Task__check_future(TaskObj *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
+{
+    PyObject *return_value = NULL;
+    static const char * const _keywords[] = {"future", NULL};
+    static _PyArg_Parser _parser = {NULL, _keywords, "_check_future", 0};
+    PyObject *argsbuf[1];
+    PyObject *future;
+    int _return_value;
+
+    args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 1, 1, 0, argsbuf);
+    if (!args) {
+        goto exit;
+    }
+    future = args[0];
+    _return_value = _asyncio_Task__check_future_impl(self, future);
+    if ((_return_value == -1) && PyErr_Occurred()) {
+        goto exit;
+    }
+    return_value = PyBool_FromLong((long)_return_value);
+
+exit:
+    return return_value;
+}
+
 PyDoc_STRVAR(_asyncio_Task_get_stack__doc__,
 "get_stack($self, /, *, limit=None)\n"
 "--\n"
@@ -890,4 +927,4 @@ _asyncio__leave_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs,
 exit:
     return return_value;
 }
-/*[clinic end generated code: output=64b3836574e8a18c input=a9049054013a1b77]*/
+/*[clinic end generated code: output=fdb7129263a8712e input=a9049054013a1b77]*/