]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-122595: Add more error checks in the compiler (GH-122596)
authorSerhiy Storchaka <storchaka@gmail.com>
Tue, 6 Aug 2024 05:59:44 +0000 (08:59 +0300)
committerGitHub <noreply@github.com>
Tue, 6 Aug 2024 05:59:44 +0000 (08:59 +0300)
Python/compile.c
Python/symtable.c

index 87b2c2705474a443bcbb9144cf827ae56aa3ed7c..9695a99d20114417ef42fdfabee7eb817f99299f 100644 (file)
@@ -505,21 +505,35 @@ dictbytype(PyObject *src, int scope_type, int flag, Py_ssize_t offset)
        deterministic, then the generated bytecode is not deterministic.
     */
     sorted_keys = PyDict_Keys(src);
-    if (sorted_keys == NULL)
+    if (sorted_keys == NULL) {
+        Py_DECREF(dest);
         return NULL;
+    }
     if (PyList_Sort(sorted_keys) != 0) {
         Py_DECREF(sorted_keys);
+        Py_DECREF(dest);
         return NULL;
     }
     num_keys = PyList_GET_SIZE(sorted_keys);
 
     for (key_i = 0; key_i < num_keys; key_i++) {
-        /* XXX this should probably be a macro in symtable.h */
-        long vi;
         k = PyList_GET_ITEM(sorted_keys, key_i);
         v = PyDict_GetItemWithError(src, k);
-        assert(v && PyLong_Check(v));
-        vi = PyLong_AS_LONG(v);
+        if (!v) {
+            if (!PyErr_Occurred()) {
+                PyErr_SetObject(PyExc_KeyError, k);
+            }
+            Py_DECREF(sorted_keys);
+            Py_DECREF(dest);
+            return NULL;
+        }
+        long vi = PyLong_AsLong(v);
+        if (vi == -1 && PyErr_Occurred()) {
+            Py_DECREF(sorted_keys);
+            Py_DECREF(dest);
+            return NULL;
+        }
+        /* XXX this should probably be a macro in symtable.h */
         scope = (vi >> SCOPE_OFFSET) & SCOPE_MASK;
 
         if (scope == scope_type || vi & flag) {
@@ -631,6 +645,7 @@ compiler_set_qualname(struct compiler *c)
 
             scope = _PyST_GetScope(parent->u_ste, mangled);
             Py_DECREF(mangled);
+            RETURN_IF_ERROR(scope);
             assert(scope != GLOBAL_IMPLICIT);
             if (scope == GLOBAL_EXPLICIT)
                 force_global = 1;
@@ -1648,7 +1663,7 @@ dict_lookup_arg(PyObject *dict, PyObject *name)
     if (v == NULL) {
         return ERROR;
     }
-    return PyLong_AS_LONG(v);
+    return PyLong_AsLong(v);
 }
 
 static int
@@ -1671,7 +1686,7 @@ compiler_lookup_arg(struct compiler *c, PyCodeObject *co, PyObject *name)
     else {
         arg = dict_lookup_arg(c->u->u_metadata.u_freevars, name);
     }
-    if (arg == -1) {
+    if (arg == -1 && !PyErr_Occurred()) {
         PyObject *freevars = _PyCode_GetFreevars(co);
         if (freevars == NULL) {
             PyErr_Clear();
@@ -4085,6 +4100,8 @@ compiler_nameop(struct compiler *c, location loc,
     case GLOBAL_EXPLICIT:
         optype = OP_GLOBAL;
         break;
+    case -1:
+        goto error;
     default:
         /* scope can be 0 */
         break;
@@ -4638,6 +4655,7 @@ is_import_originated(struct compiler *c, expr_ty e)
     }
 
     long flags = _PyST_GetSymbol(SYMTABLE(c)->st_top, e->v.Name.id);
+    RETURN_IF_ERROR(flags);
     return flags & DEF_IMPORT;
 }
 
@@ -4657,10 +4675,12 @@ can_optimize_super_call(struct compiler *c, expr_ty attr)
     PyObject *super_name = e->v.Call.func->v.Name.id;
     // detect statically-visible shadowing of 'super' name
     int scope = _PyST_GetScope(SYMTABLE_ENTRY(c), super_name);
+    RETURN_IF_ERROR(scope);
     if (scope != GLOBAL_IMPLICIT) {
         return 0;
     }
     scope = _PyST_GetScope(SYMTABLE(c)->st_top, super_name);
+    RETURN_IF_ERROR(scope);
     if (scope != 0) {
         return 0;
     }
@@ -4767,7 +4787,9 @@ maybe_optimize_method_call(struct compiler *c, expr_ty e)
     }
 
     /* Check that the base object is not something that is imported */
-    if (is_import_originated(c, meth->v.Attribute.value)) {
+    int ret = is_import_originated(c, meth->v.Attribute.value);
+    RETURN_IF_ERROR(ret);
+    if (ret) {
         return 0;
     }
 
@@ -4795,7 +4817,9 @@ maybe_optimize_method_call(struct compiler *c, expr_ty e)
     /* Alright, we can optimize the code. */
     location loc = LOC(meth);
 
-    if (can_optimize_super_call(c, meth)) {
+    ret = can_optimize_super_call(c, meth);
+    RETURN_IF_ERROR(ret);
+    if (ret) {
         RETURN_IF_ERROR(load_args_for_super(c, meth->v.Attribute.value));
         int opcode = asdl_seq_LEN(meth->v.Attribute.value->v.Call.args) ?
             LOAD_SUPER_METHOD : LOAD_ZERO_SUPER_METHOD;
@@ -5367,8 +5391,10 @@ push_inlined_comprehension_state(struct compiler *c, location loc,
     PyObject *k, *v;
     Py_ssize_t pos = 0;
     while (PyDict_Next(entry->ste_symbols, &pos, &k, &v)) {
-        assert(PyLong_Check(v));
-        long symbol = PyLong_AS_LONG(v);
+        long symbol = PyLong_AsLong(v);
+        if (symbol == -1 && PyErr_Occurred()) {
+            return ERROR;
+        }
         long scope = (symbol >> SCOPE_OFFSET) & SCOPE_MASK;
         PyObject *outv = PyDict_GetItemWithError(SYMTABLE_ENTRY(c)->ste_symbols, k);
         if (outv == NULL) {
@@ -5377,8 +5403,11 @@ push_inlined_comprehension_state(struct compiler *c, location loc,
             }
             outv = _PyLong_GetZero();
         }
-        assert(PyLong_CheckExact(outv));
-        long outsc = (PyLong_AS_LONG(outv) >> SCOPE_OFFSET) & SCOPE_MASK;
+        long outsymbol = PyLong_AsLong(outv);
+        if (outsymbol == -1 && PyErr_Occurred()) {
+            return ERROR;
+        }
+        long outsc = (outsymbol >> SCOPE_OFFSET) & SCOPE_MASK;
         // If a name has different scope inside than outside the comprehension,
         // we need to temporarily handle it with the right scope while
         // compiling the comprehension. If it's free in the comprehension
@@ -6064,14 +6093,18 @@ compiler_visit_expr(struct compiler *c, expr_ty e)
         return compiler_formatted_value(c, e);
     /* The following exprs can be assignment targets. */
     case Attribute_kind:
-        if (e->v.Attribute.ctx == Load && can_optimize_super_call(c, e)) {
-            RETURN_IF_ERROR(load_args_for_super(c, e->v.Attribute.value));
-            int opcode = asdl_seq_LEN(e->v.Attribute.value->v.Call.args) ?
-                LOAD_SUPER_ATTR : LOAD_ZERO_SUPER_ATTR;
-            ADDOP_NAME(c, loc, opcode, e->v.Attribute.attr, names);
-            loc = update_start_location_to_match_attr(c, loc, e);
-            ADDOP(c, loc, NOP);
-            return SUCCESS;
+        if (e->v.Attribute.ctx == Load) {
+            int ret = can_optimize_super_call(c, e);
+            RETURN_IF_ERROR(ret);
+            if (ret) {
+                RETURN_IF_ERROR(load_args_for_super(c, e->v.Attribute.value));
+                int opcode = asdl_seq_LEN(e->v.Attribute.value->v.Call.args) ?
+                    LOAD_SUPER_ATTR : LOAD_ZERO_SUPER_ATTR;
+                ADDOP_NAME(c, loc, opcode, e->v.Attribute.attr, names);
+                loc = update_start_location_to_match_attr(c, loc, e);
+                ADDOP(c, loc, NOP);
+                return SUCCESS;
+            }
         }
         RETURN_IF_ERROR(compiler_maybe_add_static_attribute_to_class(c, e));
         VISIT(c, expr, e->v.Attribute.value);
@@ -7300,7 +7333,8 @@ consts_dict_keys_inorder(PyObject *dict)
     if (consts == NULL)
         return NULL;
     while (PyDict_Next(dict, &pos, &k, &v)) {
-        i = PyLong_AS_LONG(v);
+        assert(PyLong_CheckExact(v));
+        i = PyLong_AsLong(v);
         /* The keys of the dictionary can be tuples wrapping a constant.
          * (see dict_add_o and _PyCode_ConstantKey). In that case
          * the object we want is always second. */
index ef81a0799de3aa074f7d4be4f17095d7ecef6296..4acf762f8fca39529f240bb60eb01ecedd082d5a 100644 (file)
@@ -526,17 +526,31 @@ _PySymtable_LookupOptional(struct symtable *st, void *key,
 long
 _PyST_GetSymbol(PySTEntryObject *ste, PyObject *name)
 {
-    PyObject *v = PyDict_GetItemWithError(ste->ste_symbols, name);
-    if (!v)
+    PyObject *v;
+    if (PyDict_GetItemRef(ste->ste_symbols, name, &v) < 0) {
+        return -1;
+    }
+    if (!v) {
         return 0;
-    assert(PyLong_Check(v));
-    return PyLong_AS_LONG(v);
+    }
+    long symbol = PyLong_AsLong(v);
+    Py_DECREF(v);
+    if (symbol < 0) {
+        if (!PyErr_Occurred()) {
+            PyErr_SetString(PyExc_SystemError, "invalid symbol");
+        }
+        return -1;
+    }
+    return symbol;
 }
 
 int
 _PyST_GetScope(PySTEntryObject *ste, PyObject *name)
 {
     long symbol = _PyST_GetSymbol(ste, name);
+    if (symbol < 0) {
+        return -1;
+    }
     return (symbol >> SCOPE_OFFSET) & SCOPE_MASK;
 }
 
@@ -715,11 +729,14 @@ analyze_name(PySTEntryObject *ste, PyObject *scopes, PyObject *name, long flags,
     // global statement), we want to also treat it as a global in this scope.
     if (class_entry != NULL) {
         long class_flags = _PyST_GetSymbol(class_entry, name);
+        if (class_flags < 0) {
+            return 0;
+        }
         if (class_flags & DEF_GLOBAL) {
             SET_SCOPE(scopes, name, GLOBAL_EXPLICIT);
             return 1;
         }
-        else if (class_flags & DEF_BOUND && !(class_flags & DEF_NONLOCAL)) {
+        else if ((class_flags & DEF_BOUND) && !(class_flags & DEF_NONLOCAL)) {
             SET_SCOPE(scopes, name, GLOBAL_IMPLICIT);
             return 1;
         }
@@ -763,6 +780,9 @@ is_free_in_any_child(PySTEntryObject *entry, PyObject *key)
         PySTEntryObject *child_ste = (PySTEntryObject *)PyList_GET_ITEM(
             entry->ste_children, i);
         long scope = _PyST_GetScope(child_ste, key);
+        if (scope < 0) {
+            return -1;
+        }
         if (scope == FREE) {
             return 1;
         }
@@ -781,7 +801,10 @@ inline_comprehension(PySTEntryObject *ste, PySTEntryObject *comp,
 
     while (PyDict_Next(comp->ste_symbols, &pos, &k, &v)) {
         // skip comprehension parameter
-        long comp_flags = PyLong_AS_LONG(v);
+        long comp_flags = PyLong_AsLong(v);
+        if (comp_flags == -1 && PyErr_Occurred()) {
+            return 0;
+        }
         if (comp_flags & DEF_PARAM) {
             assert(_PyUnicode_EqualToASCIIString(k, ".0"));
             continue;
@@ -822,11 +845,19 @@ inline_comprehension(PySTEntryObject *ste, PySTEntryObject *comp,
             SET_SCOPE(scopes, k, scope);
         }
         else {
-            if (PyLong_AsLong(existing) & DEF_BOUND) {
+            long flags = PyLong_AsLong(existing);
+            if (flags == -1 && PyErr_Occurred()) {
+                return 0;
+            }
+            if ((flags & DEF_BOUND) && ste->ste_type != ClassBlock) {
                 // free vars in comprehension that are locals in outer scope can
                 // now simply be locals, unless they are free in comp children,
                 // or if the outer scope is a class block
-                if (!is_free_in_any_child(comp, k) && ste->ste_type != ClassBlock) {
+                int ok = is_free_in_any_child(comp, k);
+                if (ok < 0) {
+                    return 0;
+                }
+                if (!ok) {
                     if (PySet_Discard(comp_free, k) < 0) {
                         return 0;
                     }
@@ -861,9 +892,10 @@ analyze_cells(PyObject *scopes, PyObject *free, PyObject *inlined_cells)
     if (!v_cell)
         return 0;
     while (PyDict_Next(scopes, &pos, &name, &v)) {
-        long scope;
-        assert(PyLong_Check(v));
-        scope = PyLong_AS_LONG(v);
+        long scope = PyLong_AsLong(v);
+        if (scope == -1 && PyErr_Occurred()) {
+            goto error;
+        }
         if (scope != LOCAL)
             continue;
         int contains = PySet_Contains(free, name);
@@ -926,9 +958,10 @@ update_symbols(PyObject *symbols, PyObject *scopes,
 
     /* Update scope information for all symbols in this scope */
     while (PyDict_Next(symbols, &pos, &name, &v)) {
-        long scope, flags;
-        assert(PyLong_Check(v));
-        flags = PyLong_AS_LONG(v);
+        long flags = PyLong_AsLong(v);
+        if (flags == -1 && PyErr_Occurred()) {
+            return 0;
+        }
         int contains = PySet_Contains(inlined_cells, name);
         if (contains < 0) {
             return 0;
@@ -936,9 +969,18 @@ update_symbols(PyObject *symbols, PyObject *scopes,
         if (contains) {
             flags |= DEF_COMP_CELL;
         }
-        v_scope = PyDict_GetItemWithError(scopes, name);
-        assert(v_scope && PyLong_Check(v_scope));
-        scope = PyLong_AS_LONG(v_scope);
+        if (PyDict_GetItemRef(scopes, name, &v_scope) < 0) {
+            return 0;
+        }
+        if (!v_scope) {
+            PyErr_SetObject(PyExc_KeyError, name);
+            return 0;
+        }
+        long scope = PyLong_AsLong(v_scope);
+        Py_DECREF(v_scope);
+        if (scope == -1 && PyErr_Occurred()) {
+            return 0;
+        }
         flags |= (scope << SCOPE_OFFSET);
         v_new = PyLong_FromLong(flags);
         if (!v_new)
@@ -971,7 +1013,11 @@ update_symbols(PyObject *symbols, PyObject *scopes,
                or global in the class scope.
             */
             if  (classflag) {
-                long flags = PyLong_AS_LONG(v) | DEF_FREE_CLASS;
+                long flags = PyLong_AsLong(v);
+                if (flags == -1 && PyErr_Occurred()) {
+                    goto error;
+                }
+                flags |= DEF_FREE_CLASS;
                 v_new = PyLong_FromLong(flags);
                 if (!v_new) {
                     goto error;
@@ -1110,7 +1156,10 @@ analyze_block(PySTEntryObject *ste, PyObject *bound, PyObject *free,
     }
 
     while (PyDict_Next(ste->ste_symbols, &pos, &name, &v)) {
-        long flags = PyLong_AS_LONG(v);
+        long flags = PyLong_AsLong(v);
+        if (flags == -1 && PyErr_Occurred()) {
+            goto error;
+        }
         if (!analyze_name(ste, scopes, name, flags,
                           bound, local, free, global, type_params, class_entry))
             goto error;
@@ -1395,9 +1444,12 @@ symtable_lookup_entry(struct symtable *st, PySTEntryObject *ste, PyObject *name)
 {
     PyObject *mangled = _Py_MaybeMangle(st->st_private, ste, name);
     if (!mangled)
-        return 0;
+        return -1;
     long ret = _PyST_GetSymbol(ste, mangled);
     Py_DECREF(mangled);
+    if (ret < 0) {
+        return -1;
+    }
     return ret;
 }
 
@@ -1420,7 +1472,10 @@ symtable_add_def_helper(struct symtable *st, PyObject *name, int flag, struct _s
         return 0;
     dict = ste->ste_symbols;
     if ((o = PyDict_GetItemWithError(dict, mangled))) {
-        val = PyLong_AS_LONG(o);
+        val = PyLong_AsLong(o);
+        if (val == -1 && PyErr_Occurred()) {
+            goto error;
+        }
         if ((flag & DEF_PARAM) && (val & DEF_PARAM)) {
             /* Is it better to use 'mangled' or 'name' here? */
             PyErr_Format(PyExc_SyntaxError, DUPLICATE_ARGUMENT, name);
@@ -1466,16 +1521,20 @@ symtable_add_def_helper(struct symtable *st, PyObject *name, int flag, struct _s
     if (flag & DEF_PARAM) {
         if (PyList_Append(ste->ste_varnames, mangled) < 0)
             goto error;
-    } else      if (flag & DEF_GLOBAL) {
+    } else if (flag & DEF_GLOBAL) {
         /* XXX need to update DEF_GLOBAL for other flags too;
            perhaps only DEF_FREE_GLOBAL */
-        val = flag;
+        val = 0;
         if ((o = PyDict_GetItemWithError(st->st_global, mangled))) {
-            val |= PyLong_AS_LONG(o);
+            val = PyLong_AsLong(o);
+            if (val == -1 && PyErr_Occurred()) {
+                goto error;
+            }
         }
         else if (PyErr_Occurred()) {
             goto error;
         }
+        val |= flag;
         o = PyLong_FromLong(val);
         if (o == NULL)
             goto error;
@@ -2176,6 +2235,9 @@ symtable_extend_namedexpr_scope(struct symtable *st, expr_ty e)
          */
         if (ste->ste_comprehension) {
             long target_in_scope = symtable_lookup_entry(st, ste, target_name);
+            if (target_in_scope < 0) {
+                return 0;
+            }
             if ((target_in_scope & DEF_COMP_ITER) &&
                 (target_in_scope & DEF_LOCAL)) {
                 PyErr_Format(PyExc_SyntaxError, NAMED_EXPR_COMP_CONFLICT, target_name);
@@ -2188,6 +2250,9 @@ symtable_extend_namedexpr_scope(struct symtable *st, expr_ty e)
         /* If we find a FunctionBlock entry, add as GLOBAL/LOCAL or NONLOCAL/LOCAL */
         if (ste->ste_type == FunctionBlock) {
             long target_in_scope = symtable_lookup_entry(st, ste, target_name);
+            if (target_in_scope < 0) {
+                return 0;
+            }
             if (target_in_scope & DEF_GLOBAL) {
                 if (!symtable_add_def(st, target_name, DEF_GLOBAL, LOCATION(e)))
                     return 0;
@@ -2601,9 +2666,6 @@ symtable_visit_params(struct symtable *st, asdl_arg_seq *args)
 {
     Py_ssize_t i;
 
-    if (!args)
-        return -1;
-
     for (i = 0; i < asdl_seq_LEN(args); i++) {
         arg_ty arg = (arg_ty)asdl_seq_GET(args, i);
         if (!symtable_add_def(st, arg->arg, DEF_PARAM, LOCATION(arg)))
@@ -2650,9 +2712,6 @@ symtable_visit_argannotations(struct symtable *st, asdl_arg_seq *args)
 {
     Py_ssize_t i;
 
-    if (!args)
-        return -1;
-
     for (i = 0; i < asdl_seq_LEN(args); i++) {
         arg_ty arg = (arg_ty)asdl_seq_GET(args, i);
         if (arg->annotation) {