gh-111956: Add thread-safe one-time initialization. (gh-111960)

This commit is contained in:
Sam Gross 2023-11-16 14:19:54 -05:00 committed by GitHub
parent f66afa395a
commit 446f18a911
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 1061 additions and 955 deletions

View file

@ -2,6 +2,9 @@
#ifndef Py_INTERNAL_AST_STATE_H
#define Py_INTERNAL_AST_STATE_H
#include "pycore_lock.h" // _PyOnceFlag
#ifdef __cplusplus
extern "C" {
#endif
@ -11,7 +14,8 @@ extern "C" {
#endif
struct ast_state {
int initialized;
_PyOnceFlag once;
int finalized;
int recursion_depth;
int recursion_limit;
PyObject *AST_type;

View file

@ -46,6 +46,7 @@ typedef struct _PyMutex PyMutex;
#define _Py_UNLOCKED 0
#define _Py_LOCKED 1
#define _Py_HAS_PARKED 2
#define _Py_ONCE_INITIALIZED 4
// (private) slow path for locking the mutex
PyAPI_FUNC(void) _PyMutex_LockSlow(PyMutex *m);
@ -166,6 +167,35 @@ _PyRawMutex_Unlock(_PyRawMutex *m)
_PyRawMutex_UnlockSlow(m);
}
// A data structure that can be used to run initialization code once in a
// thread-safe manner. The C++11 equivalent is std::call_once.
typedef struct {
uint8_t v;
} _PyOnceFlag;
// Type signature for one-time initialization functions. The function should
// return 0 on success and -1 on failure.
typedef int _Py_once_fn_t(void *arg);
// (private) slow path for one time initialization
PyAPI_FUNC(int)
_PyOnceFlag_CallOnceSlow(_PyOnceFlag *flag, _Py_once_fn_t *fn, void *arg);
// Calls `fn` once using `flag`. The `arg` is passed to the call to `fn`.
//
// Returns 0 on success and -1 on failure.
//
// If `fn` returns 0 (success), then subsequent calls immediately return 0.
// If `fn` returns -1 (failure), then subsequent calls will retry the call.
static inline int
_PyOnceFlag_CallOnce(_PyOnceFlag *flag, _Py_once_fn_t *fn, void *arg)
{
if (_Py_atomic_load_uint8(&flag->v) == _Py_ONCE_INITIALIZED) {
return 0;
}
return _PyOnceFlag_CallOnceSlow(flag, fn, arg);
}
#ifdef __cplusplus
}
#endif

View file

@ -1,5 +1,8 @@
#ifndef Py_INTERNAL_MODSUPPORT_H
#define Py_INTERNAL_MODSUPPORT_H
#include "pycore_lock.h" // _PyOnceFlag
#ifdef __cplusplus
extern "C" {
#endif
@ -65,15 +68,16 @@ PyAPI_FUNC(void) _PyArg_BadArgument(
// --- _PyArg_Parser API ---------------------------------------------------
typedef struct _PyArg_Parser {
int initialized;
const char *format;
const char * const *keywords;
const char *fname;
const char *custom_msg;
int pos; /* number of positional-only arguments */
int min; /* minimal number of arguments */
int max; /* maximal number of positional arguments */
PyObject *kwtuple; /* tuple of keyword parameter names */
_PyOnceFlag once; /* atomic one-time initialization flag */
int is_kwtuple_owned; /* does this parser own the kwtuple object? */
int pos; /* number of positional-only arguments */
int min; /* minimal number of arguments */
int max; /* maximal number of positional arguments */
PyObject *kwtuple; /* tuple of keyword parameter names */
struct _PyArg_Parser *next;
} _PyArg_Parser;

View file

@ -27,7 +27,6 @@ extern "C" {
#include "pycore_unicodeobject.h" // struct _Py_unicode_runtime_state
struct _getargs_runtime_state {
PyThread_type_lock mutex;
struct _PyArg_Parser *static_parsers;
};

View file

@ -0,0 +1,2 @@
Add internal-only one-time initialization API: ``_PyOnceFlag`` and
``_PyOnceFlag_CallOnce``.

View file

@ -341,6 +341,37 @@ test_lock_benchmark(PyObject *module, PyObject *obj)
Py_RETURN_NONE;
}
static int
init_maybe_fail(void *arg)
{
int *counter = (int *)arg;
(*counter)++;
if (*counter < 5) {
// failure
return -1;
}
assert(*counter == 5);
return 0;
}
static PyObject *
test_lock_once(PyObject *self, PyObject *obj)
{
_PyOnceFlag once = {0};
int counter = 0;
for (int i = 0; i < 10; i++) {
int res = _PyOnceFlag_CallOnce(&once, init_maybe_fail, &counter);
if (i < 4) {
assert(res == -1);
}
else {
assert(res == 0);
assert(counter == 5);
}
}
Py_RETURN_NONE;
}
static PyMethodDef test_methods[] = {
{"test_lock_basic", test_lock_basic, METH_NOARGS},
{"test_lock_two_threads", test_lock_two_threads, METH_NOARGS},
@ -348,6 +379,7 @@ static PyMethodDef test_methods[] = {
{"test_lock_counter_slow", test_lock_counter_slow, METH_NOARGS},
_TESTINTERNALCAPI_BENCHMARK_LOCKS_METHODDEF
{"test_lock_benchmark", test_lock_benchmark, METH_NOARGS},
{"test_lock_once", test_lock_once, METH_NOARGS},
{NULL, NULL} /* sentinel */
};

View file

@ -518,7 +518,7 @@ def sumTrailer(self, name, add_label=False):
if add_label:
self.emit("failed:", 1)
self.emit("Py_XDECREF(tmp);", 1)
self.emit("return 1;", 1)
self.emit("return -1;", 1)
self.emit("}", 0)
self.emit("", 0)
@ -529,7 +529,7 @@ def simpleSum(self, sum, name):
"state->%s_type);")
self.emit(line % (t.name,), 1)
self.emit("if (isinstance == -1) {", 1)
self.emit("return 1;", 2)
self.emit("return -1;", 2)
self.emit("}", 1)
self.emit("if (isinstance) {", 1)
self.emit("*out = %s;" % t.name, 2)
@ -558,7 +558,7 @@ def complexSum(self, sum, name):
self.emit("tp = state->%s_type;" % (t.name,), 1)
self.emit("isinstance = PyObject_IsInstance(obj, tp);", 1)
self.emit("if (isinstance == -1) {", 1)
self.emit("return 1;", 2)
self.emit("return -1;", 2)
self.emit("}", 1)
self.emit("if (isinstance) {", 1)
for f in t.fields:
@ -605,7 +605,7 @@ def visitProduct(self, prod, name):
self.emit("return 0;", 1)
self.emit("failed:", 0)
self.emit("Py_XDECREF(tmp);", 1)
self.emit("return 1;", 1)
self.emit("return -1;", 1)
self.emit("}", 0)
self.emit("", 0)
@ -631,13 +631,13 @@ def visitField(self, field, name, sum=None, prod=None, depth=0):
ctype = get_c_type(field.type)
line = "if (PyObject_GetOptionalAttr(obj, state->%s, &tmp) < 0) {"
self.emit(line % field.name, depth)
self.emit("return 1;", depth+1)
self.emit("return -1;", depth+1)
self.emit("}", depth)
if field.seq:
self.emit("if (tmp == NULL) {", depth)
self.emit("tmp = PyList_New(0);", depth+1)
self.emit("if (tmp == NULL) {", depth+1)
self.emit("return 1;", depth+2)
self.emit("return -1;", depth+2)
self.emit("}", depth+1)
self.emit("}", depth)
self.emit("{", depth)
@ -647,7 +647,7 @@ def visitField(self, field, name, sum=None, prod=None, depth=0):
message = "required field \\\"%s\\\" missing from %s" % (field.name, name)
format = "PyErr_SetString(PyExc_TypeError, \"%s\");"
self.emit(format % message, depth+1, reflow=False)
self.emit("return 1;", depth+1)
self.emit("return -1;", depth+1)
else:
self.emit("if (tmp == NULL || tmp == Py_None) {", depth)
self.emit("Py_CLEAR(tmp);", depth+1)
@ -968,16 +968,16 @@ def visitModule(self, mod):
int i, result;
PyObject *s, *l = PyTuple_New(num_fields);
if (!l)
return 0;
return -1;
for (i = 0; i < num_fields; i++) {
s = PyUnicode_InternFromString(attrs[i]);
if (!s) {
Py_DECREF(l);
return 0;
return -1;
}
PyTuple_SET_ITEM(l, i, s);
}
result = PyObject_SetAttr(type, state->_attributes, l) >= 0;
result = PyObject_SetAttr(type, state->_attributes, l);
Py_DECREF(l);
return result;
}
@ -1052,7 +1052,7 @@ def visitModule(self, mod):
{
if (!PyUnicode_CheckExact(obj) && obj != Py_None) {
PyErr_SetString(PyExc_TypeError, "AST identifier must be of type str");
return 1;
return -1;
}
return obj2ast_object(state, obj, out, arena);
}
@ -1061,7 +1061,7 @@ def visitModule(self, mod):
{
if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) {
PyErr_SetString(PyExc_TypeError, "AST string must be of type str");
return 1;
return -1;
}
return obj2ast_object(state, obj, out, arena);
}
@ -1071,12 +1071,12 @@ def visitModule(self, mod):
int i;
if (!PyLong_Check(obj)) {
PyErr_Format(PyExc_ValueError, "invalid integer value: %R", obj);
return 1;
return -1;
}
i = PyLong_AsInt(obj);
if (i == -1 && PyErr_Occurred())
return 1;
return -1;
*out = i;
return 0;
}
@ -1102,22 +1102,15 @@ def visitModule(self, mod):
static int
init_types(struct ast_state *state)
{
// init_types() must not be called after _PyAST_Fini()
// has been called
assert(state->initialized >= 0);
if (state->initialized) {
return 1;
}
if (init_identifiers(state) < 0) {
return 0;
return -1;
}
state->AST_type = PyType_FromSpec(&AST_type_spec);
if (!state->AST_type) {
return 0;
return -1;
}
if (add_ast_fields(state) < 0) {
return 0;
return -1;
}
'''))
for dfn in mod.dfns:
@ -1125,8 +1118,7 @@ def visitModule(self, mod):
self.file.write(textwrap.dedent('''
state->recursion_depth = 0;
state->recursion_limit = 0;
state->initialized = 1;
return 1;
return 0;
}
'''))
@ -1138,12 +1130,12 @@ def visitProduct(self, prod, name):
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' %
(name, name, fields, len(prod.fields)), 1)
self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False)
self.emit("if (!state->%s_type) return 0;" % name, 1)
self.emit("if (!state->%s_type) return -1;" % name, 1)
if prod.attributes:
self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" %
self.emit("if (add_attributes(state, state->%s_type, %s_attributes, %d) < 0) return -1;" %
(name, name, len(prod.attributes)), 1)
else:
self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
self.emit("if (add_attributes(state, state->%s_type, NULL, 0) < 0) return -1;" % name, 1)
self.emit_defaults(name, prod.fields, 1)
self.emit_defaults(name, prod.attributes, 1)
@ -1151,12 +1143,12 @@ def visitSum(self, sum, name):
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' %
(name, name), 1)
self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False)
self.emit("if (!state->%s_type) return 0;" % name, 1)
self.emit("if (!state->%s_type) return -1;" % name, 1)
if sum.attributes:
self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" %
self.emit("if (add_attributes(state, state->%s_type, %s_attributes, %d) < 0) return -1;" %
(name, name, len(sum.attributes)), 1)
else:
self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
self.emit("if (add_attributes(state, state->%s_type, NULL, 0) < 0) return -1;" % name, 1)
self.emit_defaults(name, sum.attributes, 1)
simple = is_simple(sum)
for t in sum.types:
@ -1170,20 +1162,20 @@ def visitConstructor(self, cons, name, simple):
self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' %
(cons.name, cons.name, name, fields, len(cons.fields)), 1)
self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False)
self.emit("if (!state->%s_type) return 0;" % cons.name, 1)
self.emit("if (!state->%s_type) return -1;" % cons.name, 1)
self.emit_defaults(cons.name, cons.fields, 1)
if simple:
self.emit("state->%s_singleton = PyType_GenericNew((PyTypeObject *)"
"state->%s_type, NULL, NULL);" %
(cons.name, cons.name), 1)
self.emit("if (!state->%s_singleton) return 0;" % cons.name, 1)
self.emit("if (!state->%s_singleton) return -1;" % cons.name, 1)
def emit_defaults(self, name, fields, depth):
for field in fields:
if field.opt:
self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1)' %
(name, field.name), depth)
self.emit("return 0;", depth+1)
self.emit("return -1;", depth+1)
class ASTModuleVisitor(PickleVisitor):
@ -1279,7 +1271,7 @@ def func_begin(self, name):
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
self.emit("return 0;", 2)
self.emit("return NULL;", 2)
self.emit("}", 1)
def func_end(self):
@ -1400,7 +1392,7 @@ class PartingShots(StaticVisitor):
int COMPILER_STACK_FRAME_SCALE = 2;
PyThreadState *tstate = _PyThreadState_GET();
if (!tstate) {
return 0;
return NULL;
}
state->recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
@ -1414,7 +1406,7 @@ class PartingShots(StaticVisitor):
PyErr_Format(PyExc_SystemError,
"AST constructor recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, state->recursion_depth);
return 0;
return NULL;
}
return result;
}
@ -1481,7 +1473,8 @@ def visit(self, object):
def generate_ast_state(module_state, f):
f.write('struct ast_state {\n')
f.write(' int initialized;\n')
f.write(' _PyOnceFlag once;\n')
f.write(' int finalized;\n')
f.write(' int recursion_depth;\n')
f.write(' int recursion_limit;\n')
for s in module_state:
@ -1501,11 +1494,8 @@ def generate_ast_fini(module_state, f):
f.write(textwrap.dedent("""
Py_CLEAR(_Py_INTERP_CACHED_OBJECT(interp, str_replace_inf));
#if !defined(NDEBUG)
state->initialized = -1;
#else
state->initialized = 0;
#endif
state->finalized = 1;
state->once = (_PyOnceFlag){0};
}
"""))
@ -1544,6 +1534,7 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "pycore_ast.h"
#include "pycore_ast_state.h" // struct ast_state
#include "pycore_ceval.h" // _Py_EnterRecursiveCall
#include "pycore_lock.h" // _PyOnceFlag
#include "pycore_interp.h" // _PyInterpreterState.ast
#include "pycore_pystate.h" // _PyInterpreterState_GET()
#include <stddef.h>
@ -1556,7 +1547,8 @@ def generate_module_def(mod, metadata, f, internal_h):
{
PyInterpreterState *interp = _PyInterpreterState_GET();
struct ast_state *state = &interp->ast;
if (!init_types(state)) {
assert(!state->finalized);
if (_PyOnceFlag_CallOnce(&state->once, (_Py_once_fn_t *)&init_types, state) < 0) {
return NULL;
}
return state;
@ -1570,8 +1562,8 @@ def generate_module_def(mod, metadata, f, internal_h):
for identifier in state_strings:
f.write(' if ((state->' + identifier)
f.write(' = PyUnicode_InternFromString("')
f.write(identifier + '")) == NULL) return 0;\n')
f.write(' return 1;\n')
f.write(identifier + '")) == NULL) return -1;\n')
f.write(' return 0;\n')
f.write('};\n\n')
def write_header(mod, metadata, f):
@ -1629,6 +1621,9 @@ def write_internal_h_header(mod, f):
print(textwrap.dedent("""
#ifndef Py_INTERNAL_AST_STATE_H
#define Py_INTERNAL_AST_STATE_H
#include "pycore_lock.h" // _PyOnceFlag
#ifdef __cplusplus
extern "C" {
#endif

1735
Python/Python-ast.c generated

File diff suppressed because it is too large Load diff

View file

@ -1877,8 +1877,9 @@ new_kwtuple(const char * const *keywords, int total, int pos)
}
static int
_parser_init(struct _PyArg_Parser *parser)
_parser_init(void *arg)
{
struct _PyArg_Parser *parser = (struct _PyArg_Parser *)arg;
const char * const *keywords = parser->keywords;
assert(keywords != NULL);
assert(parser->pos == 0 &&
@ -1889,7 +1890,7 @@ _parser_init(struct _PyArg_Parser *parser)
int len, pos;
if (scan_keywords(keywords, &len, &pos) < 0) {
return 0;
return -1;
}
const char *fname, *custommsg = NULL;
@ -1898,7 +1899,7 @@ _parser_init(struct _PyArg_Parser *parser)
assert(parser->fname == NULL);
if (parse_format(parser->format, len, pos,
&fname, &custommsg, &min, &max) < 0) {
return 0;
return -1;
}
}
else {
@ -1911,7 +1912,7 @@ _parser_init(struct _PyArg_Parser *parser)
if (kwtuple == NULL) {
kwtuple = new_kwtuple(keywords, len, pos);
if (kwtuple == NULL) {
return 0;
return -1;
}
owned = 1;
}
@ -1925,40 +1926,27 @@ _parser_init(struct _PyArg_Parser *parser)
parser->min = min;
parser->max = max;
parser->kwtuple = kwtuple;
parser->initialized = owned ? 1 : -1;
parser->is_kwtuple_owned = owned;
assert(parser->next == NULL);
parser->next = _PyRuntime.getargs.static_parsers;
_PyRuntime.getargs.static_parsers = parser;
return 1;
parser->next = _Py_atomic_load_ptr(&_PyRuntime.getargs.static_parsers);
do {
// compare-exchange updates parser->next on failure
} while (_Py_atomic_compare_exchange_ptr(&_PyRuntime.getargs.static_parsers,
&parser->next, parser));
return 0;
}
static int
parser_init(struct _PyArg_Parser *parser)
{
// volatile as it can be modified by other threads
// and should not be optimized or reordered by compiler
if (*((volatile int *)&parser->initialized)) {
assert(parser->kwtuple != NULL);
return 1;
}
PyThread_acquire_lock(_PyRuntime.getargs.mutex, WAIT_LOCK);
// Check again if another thread initialized the parser
// while we were waiting for the lock.
if (*((volatile int *)&parser->initialized)) {
assert(parser->kwtuple != NULL);
PyThread_release_lock(_PyRuntime.getargs.mutex);
return 1;
}
int ret = _parser_init(parser);
PyThread_release_lock(_PyRuntime.getargs.mutex);
return ret;
return _PyOnceFlag_CallOnce(&parser->once, &_parser_init, parser);
}
static void
parser_clear(struct _PyArg_Parser *parser)
{
if (parser->initialized == 1) {
if (parser->is_kwtuple_owned) {
Py_CLEAR(parser->kwtuple);
}
}
@ -2025,7 +2013,7 @@ vgetargskeywordsfast_impl(PyObject *const *args, Py_ssize_t nargs,
return 0;
}
if (!parser_init(parser)) {
if (parser_init(parser) < 0) {
return 0;
}
@ -2258,7 +2246,7 @@ _PyArg_UnpackKeywords(PyObject *const *args, Py_ssize_t nargs,
args = buf;
}
if (!parser_init(parser)) {
if (parser_init(parser) < 0) {
return NULL;
}
@ -2435,7 +2423,7 @@ _PyArg_UnpackKeywordsWithVararg(PyObject *const *args, Py_ssize_t nargs,
args = buf;
}
if (!parser_init(parser)) {
if (parser_init(parser) < 0) {
return NULL;
}

View file

@ -295,3 +295,61 @@ PyEvent_WaitTimed(PyEvent *evt, _PyTime_t timeout_ns)
return _Py_atomic_load_uint8(&evt->v) == _Py_LOCKED;
}
}
static int
unlock_once(_PyOnceFlag *o, int res)
{
// On success (res=0), we set the state to _Py_ONCE_INITIALIZED.
// On failure (res=-1), we reset the state to _Py_UNLOCKED.
uint8_t new_value;
switch (res) {
case -1: new_value = _Py_UNLOCKED; break;
case 0: new_value = _Py_ONCE_INITIALIZED; break;
default: {
Py_FatalError("invalid result from _PyOnceFlag_CallOnce");
Py_UNREACHABLE();
break;
}
}
uint8_t old_value = _Py_atomic_exchange_uint8(&o->v, new_value);
if ((old_value & _Py_HAS_PARKED) != 0) {
// wake up anyone waiting on the once flag
_PyParkingLot_UnparkAll(&o->v);
}
return res;
}
int
_PyOnceFlag_CallOnceSlow(_PyOnceFlag *flag, _Py_once_fn_t *fn, void *arg)
{
uint8_t v = _Py_atomic_load_uint8(&flag->v);
for (;;) {
if (v == _Py_UNLOCKED) {
if (!_Py_atomic_compare_exchange_uint8(&flag->v, &v, _Py_LOCKED)) {
continue;
}
int res = fn(arg);
return unlock_once(flag, res);
}
if (v == _Py_ONCE_INITIALIZED) {
return 0;
}
// The once flag is initializing (locked).
assert((v & _Py_LOCKED));
if (!(v & _Py_HAS_PARKED)) {
// We are the first waiter. Set the _Py_HAS_PARKED flag.
uint8_t new_value = v | _Py_HAS_PARKED;
if (!_Py_atomic_compare_exchange_uint8(&flag->v, &v, new_value)) {
continue;
}
v = new_value;
}
// Wait for initialization to finish.
_PyParkingLot_Park(&flag->v, &v, sizeof(v), -1, NULL, 1);
v = _Py_atomic_load_uint8(&flag->v);
}
}

View file

@ -379,12 +379,11 @@ _Py_COMP_DIAG_IGNORE_DEPR_DECLS
static const _PyRuntimeState initial = _PyRuntimeState_INIT(_PyRuntime);
_Py_COMP_DIAG_POP
#define NUMLOCKS 9
#define NUMLOCKS 8
#define LOCKS_INIT(runtime) \
{ \
&(runtime)->interpreters.mutex, \
&(runtime)->xi.registry.mutex, \
&(runtime)->getargs.mutex, \
&(runtime)->unicode_state.ids.lock, \
&(runtime)->imports.extensions.mutex, \
&(runtime)->ceval.pending_mainthread.lock, \