gh-106905: Use separate structs to track recursion depth in each PyAST_mod2obj call. (GH-113035)

Co-authored-by: Gregory P. Smith [Google LLC] <greg@krypto.org>
This commit is contained in:
Yilei Yang 2023-12-25 09:36:59 -08:00 committed by GitHub
parent 3f5eb3e6c7
commit 48c49739f5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 412 additions and 339 deletions

View file

@ -16,8 +16,6 @@ extern "C" {
struct ast_state {
_PyOnceFlag once;
int finalized;
int recursion_depth;
int recursion_limit;
PyObject *AST_type;
PyObject *Add_singleton;
PyObject *Add_type;

View file

@ -0,0 +1,7 @@
Use per AST-parser state rather than global state to track recursion depth
within the AST parser to prevent potential race condition due to
simultaneous parsing.
The issue primarily showed up in 3.11 by multithreaded users of
:func:`ast.parse`. In 3.12 a change to when garbage collection can be
triggered prevented the race condition from occurring.

View file

@ -731,7 +731,7 @@ def emit_sequence_constructor(self, name, type):
class PyTypesDeclareVisitor(PickleVisitor):
def visitProduct(self, prod, name):
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0)
if prod.attributes:
self.emit("static const char * const %s_attributes[] = {" % name, 0)
for a in prod.attributes:
@ -752,7 +752,7 @@ def visitSum(self, sum, name):
ptype = "void*"
if is_simple(sum):
ptype = get_c_type(name)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0)
for t in sum.types:
self.visitConstructor(t, name)
@ -984,7 +984,8 @@ def visitModule(self, mod):
/* Conversion AST -> Python */
static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
{
Py_ssize_t i, n = asdl_seq_LEN(seq);
PyObject *result = PyList_New(n);
@ -992,7 +993,7 @@ def visitModule(self, mod):
if (!result)
return NULL;
for (i = 0; i < n; i++) {
value = func(state, asdl_seq_GET_UNTYPED(seq, i));
value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
if (!value) {
Py_DECREF(result);
return NULL;
@ -1002,7 +1003,7 @@ def visitModule(self, mod):
return result;
}
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
{
PyObject *op = (PyObject*)o;
if (!op) {
@ -1014,7 +1015,7 @@ def visitModule(self, mod):
#define ast2obj_identifier ast2obj_object
#define ast2obj_string ast2obj_object
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
{
return PyLong_FromLong(b);
}
@ -1116,8 +1117,6 @@ def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
self.file.write(textwrap.dedent('''
state->recursion_depth = 0;
state->recursion_limit = 0;
return 0;
}
'''))
@ -1260,7 +1259,7 @@ class ObjVisitor(PickleVisitor):
def func_begin(self, name):
ctype = get_c_type(name)
self.emit("PyObject*", 0)
self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0)
self.emit("{", 0)
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
self.emit("PyObject *result = NULL, *value = NULL;", 1)
@ -1268,17 +1267,17 @@ def func_begin(self, name):
self.emit('if (!o) {', 1)
self.emit("Py_RETURN_NONE;", 2)
self.emit("}", 1)
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1)
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
self.emit("return NULL;", 2)
self.emit("}", 1)
def func_end(self):
self.emit("state->recursion_depth--;", 1)
self.emit("vstate->recursion_depth--;", 1)
self.emit("return result;", 1)
self.emit("failed:", 0)
self.emit("state->recursion_depth--;", 1)
self.emit("vstate->recursion_depth--;", 1)
self.emit("Py_XDECREF(value);", 1)
self.emit("Py_XDECREF(result);", 1)
self.emit("return NULL;", 1)
@ -1296,7 +1295,7 @@ def visitSum(self, sum, name):
self.visitConstructor(t, i + 1, name)
self.emit("}", 1)
for a in sum.attributes:
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1)
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
self.emit('goto failed;', 2)
@ -1304,7 +1303,7 @@ def visitSum(self, sum, name):
self.func_end()
def simpleSum(self, sum, name):
self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0)
self.emit("{", 0)
self.emit("switch(o) {", 1)
for t in sum.types:
@ -1322,7 +1321,7 @@ def visitProduct(self, prod, name):
for field in prod.fields:
self.visitField(field, name, 1, True)
for a in prod.attributes:
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1)
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
self.emit('goto failed;', 2)
@ -1363,7 +1362,7 @@ def set(self, field, value, depth):
self.emit("for(i = 0; i < n; i++)", depth+1)
# This cannot fail, so no need for error handling
self.emit(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format(
field.type,
value
),
@ -1372,9 +1371,9 @@ def set(self, field, value, depth):
)
self.emit("}", depth)
else:
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
else:
self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)
self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False)
class PartingShots(StaticVisitor):
@ -1394,18 +1393,19 @@ class PartingShots(StaticVisitor):
if (!tstate) {
return NULL;
}
state->recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
struct validator vstate;
vstate.recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state->recursion_depth = starting_recursion_depth;
vstate.recursion_depth = starting_recursion_depth;
PyObject *result = ast2obj_mod(state, t);
PyObject *result = ast2obj_mod(state, &vstate, t);
/* Check that the recursion depth counting balanced correctly */
if (result && state->recursion_depth != starting_recursion_depth) {
if (result && vstate.recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError,
"AST constructor recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, state->recursion_depth);
starting_recursion_depth, vstate.recursion_depth);
return NULL;
}
return result;
@ -1475,8 +1475,6 @@ def generate_ast_state(module_state, f):
f.write('struct ast_state {\n')
f.write(' _PyOnceFlag once;\n')
f.write(' int finalized;\n')
f.write(' int recursion_depth;\n')
f.write(' int recursion_limit;\n')
for s in module_state:
f.write(' PyObject *' + s + ';\n')
f.write('};')
@ -1539,6 +1537,11 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "pycore_pystate.h" // _PyInterpreterState_GET()
#include <stddef.h>
struct validator {
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
};
// Forward declaration
static int init_types(struct ast_state *state);

689
Python/Python-ast.c generated

File diff suppressed because it is too large Load diff