]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-121272: set ste_coroutine during symtable construction (#121297)
authorIrit Katriel <1055913+iritkatriel@users.noreply.github.com>
Wed, 3 Jul 2024 09:18:34 +0000 (10:18 +0100)
committerGitHub <noreply@github.com>
Wed, 3 Jul 2024 09:18:34 +0000 (10:18 +0100)
compiler no longer modifies the symtable after this.

Python/compile.c
Python/symtable.c

index d33db69f4253610e3bf5528607fe70f39a639a4a..30708e1dda9d43cc5d08f2004e62496d806cd86e 100644 (file)
@@ -3059,7 +3059,7 @@ compiler_async_for(struct compiler *c, stmt_ty s)
 {
     location loc = LOC(s);
     if (IS_TOP_LEVEL_AWAIT(c)){
-        c->u->u_ste->ste_coroutine = 1;
+        assert(c->u->u_ste->ste_coroutine == 1);
     } else if (c->u->u_scope_type != COMPILER_SCOPE_ASYNC_FUNCTION) {
         return compiler_error(c, loc, "'async for' outside async function");
     }
@@ -5782,7 +5782,7 @@ compiler_comprehension(struct compiler *c, expr_ty e, int type,
     co = optimize_and_assemble(c, 1);
     compiler_exit_scope(c);
     if (is_top_level_await && is_async_generator){
-        c->u->u_ste->ste_coroutine = 1;
+        assert(c->u->u_ste->ste_coroutine == 1);
     }
     if (co == NULL) {
         goto error;
@@ -5926,7 +5926,7 @@ compiler_async_with(struct compiler *c, stmt_ty s, int pos)
 
     assert(s->kind == AsyncWith_kind);
     if (IS_TOP_LEVEL_AWAIT(c)){
-        c->u->u_ste->ste_coroutine = 1;
+        assert(c->u->u_ste->ste_coroutine == 1);
     } else if (c->u->u_scope_type != COMPILER_SCOPE_ASYNC_FUNCTION){
         return compiler_error(c, loc, "'async with' outside async function");
     }
index 61fa5c6fdf923cc3eaecde13c93a836d4e61a4f4..65677f86092b0b30ec086dcec8c2f8cbb0c3b56c 100644 (file)
@@ -1681,6 +1681,16 @@ check_import_from(struct symtable *st, stmt_ty s)
     return 1;
 }
 
+static void
+maybe_set_ste_coroutine_for_module(struct symtable *st, stmt_ty s)
+{
+    if ((st->st_future->ff_features & PyCF_ALLOW_TOP_LEVEL_AWAIT) &&
+        (st->st_cur->ste_type == ModuleBlock))
+    {
+        st->st_cur->ste_coroutine = 1;
+    }
+}
+
 static int
 symtable_visit_stmt(struct symtable *st, stmt_ty s)
 {
@@ -2074,10 +2084,12 @@ symtable_visit_stmt(struct symtable *st, stmt_ty s)
         break;
     }
     case AsyncWith_kind:
+        maybe_set_ste_coroutine_for_module(st, s);
         VISIT_SEQ(st, withitem, s->v.AsyncWith.items);
         VISIT_SEQ(st, stmt, s->v.AsyncWith.body);
         break;
     case AsyncFor_kind:
+        maybe_set_ste_coroutine_for_module(st, s);
         VISIT(st, expr, s->v.AsyncFor.target);
         VISIT(st, expr, s->v.AsyncFor.iter);
         VISIT_SEQ(st, stmt, s->v.AsyncFor.body);