]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-121272: move async for/with validation from compiler to symtable (#121361)
authorIrit Katriel <1055913+iritkatriel@users.noreply.github.com>
Thu, 4 Jul 2024 13:47:21 +0000 (14:47 +0100)
committerGitHub <noreply@github.com>
Thu, 4 Jul 2024 13:47:21 +0000 (14:47 +0100)
Python/compile.c
Python/symtable.c

index 30708e1dda9d43cc5d08f2004e62496d806cd86e..1d6b54d411daf16b8cc49f5b512c6d25132a0748 100644 (file)
@@ -3058,11 +3058,6 @@ static int
 compiler_async_for(struct compiler *c, stmt_ty s)
 {
     location loc = LOC(s);
-    if (IS_TOP_LEVEL_AWAIT(c)){
-        assert(c->u->u_ste->ste_coroutine == 1);
-    } else if (c->u->u_scope_type != COMPILER_SCOPE_ASYNC_FUNCTION) {
-        return compiler_error(c, loc, "'async for' outside async function");
-    }
 
     NEW_JUMP_TARGET_LABEL(c, start);
     NEW_JUMP_TARGET_LABEL(c, except);
@@ -5781,9 +5776,6 @@ compiler_comprehension(struct compiler *c, expr_ty e, int type,
 
     co = optimize_and_assemble(c, 1);
     compiler_exit_scope(c);
-    if (is_top_level_await && is_async_generator){
-        assert(c->u->u_ste->ste_coroutine == 1);
-    }
     if (co == NULL) {
         goto error;
     }
@@ -5925,11 +5917,6 @@ compiler_async_with(struct compiler *c, stmt_ty s, int pos)
     withitem_ty item = asdl_seq_GET(s->v.AsyncWith.items, pos);
 
     assert(s->kind == AsyncWith_kind);
-    if (IS_TOP_LEVEL_AWAIT(c)){
-        assert(c->u->u_ste->ste_coroutine == 1);
-    } else if (c->u->u_scope_type != COMPILER_SCOPE_ASYNC_FUNCTION){
-        return compiler_error(c, loc, "'async with' outside async function");
-    }
 
     NEW_JUMP_TARGET_LABEL(c, block);
     NEW_JUMP_TARGET_LABEL(c, final);
index 6ff07077d4d0ed93a9835063fa914cdfc14db954..10103dbc2582a2bafde64eceb6c3fbcf8087c99b 100644 (file)
 #define DUPLICATE_TYPE_PARAM \
 "duplicate type parameter '%U'"
 
+#define ASYNC_WITH_OUTISDE_ASYNC_FUNC \
+"'async with' outside async function"
+
+#define ASYNC_FOR_OUTISDE_ASYNC_FUNC \
+"'async for' outside async function"
 
 #define LOCATION(x) SRC_LOCATION_FROM_AST(x)
 
@@ -251,6 +256,7 @@ static int symtable_visit_withitem(struct symtable *st, withitem_ty item);
 static int symtable_visit_match_case(struct symtable *st, match_case_ty m);
 static int symtable_visit_pattern(struct symtable *st, pattern_ty s);
 static int symtable_raise_if_annotation_block(struct symtable *st, const char *, expr_ty);
+static int symtable_raise_if_not_coroutine(struct symtable *st, const char *msg, _Py_SourceLocation loc);
 static int symtable_raise_if_comprehension_block(struct symtable *st, expr_ty);
 
 /* For debugging purposes only */
@@ -2048,11 +2054,17 @@ symtable_visit_stmt(struct symtable *st, stmt_ty s)
     }
     case AsyncWith_kind:
         maybe_set_ste_coroutine_for_module(st, s);
+        if (!symtable_raise_if_not_coroutine(st, ASYNC_WITH_OUTISDE_ASYNC_FUNC, LOCATION(s))) {
+            VISIT_QUIT(st, 0);
+        }
         VISIT_SEQ(st, withitem, s->v.AsyncWith.items);
         VISIT_SEQ(st, stmt, s->v.AsyncWith.body);
         break;
     case AsyncFor_kind:
         maybe_set_ste_coroutine_for_module(st, s);
+        if (!symtable_raise_if_not_coroutine(st, ASYNC_FOR_OUTISDE_ASYNC_FUNC, LOCATION(s))) {
+            VISIT_QUIT(st, 0);
+        }
         VISIT(st, expr, s->v.AsyncFor.target);
         VISIT(st, expr, s->v.AsyncFor.iter);
         VISIT_SEQ(st, stmt, s->v.AsyncFor.body);
@@ -2865,6 +2877,16 @@ symtable_raise_if_comprehension_block(struct symtable *st, expr_ty e) {
     VISIT_QUIT(st, 0);
 }
 
+static int
+symtable_raise_if_not_coroutine(struct symtable *st, const char *msg, _Py_SourceLocation loc) {
+    if (!st->st_cur->ste_coroutine) {
+        PyErr_SetString(PyExc_SyntaxError, msg);
+        SET_ERROR_LOCATION(st->st_filename, loc);
+        return 0;
+    }
+    return 1;
+}
+
 struct symtable *
 _Py_SymtableStringObjectFlags(const char *str, PyObject *filename,
                               int start, PyCompilerFlags *flags)