msvcrt: Increase module's reference count before returning from _beginthread[ex]().

Increasing DLL's reference count from the trampoline function makes it
prone to race conditions. The thread can start executing after we have
already returned from _beginthread[ex]() and the DLL might have been
freed.

Fixes rare crash on launch with Baldur's Gate 3.

Signed-off-by: Arkadiusz Hiler <ahiler@codeweavers.com>
Signed-off-by: Piotr Caban <piotr@codeweavers.com>
Signed-off-by: Alexandre Julliard <julliard@winehq.org>
This commit is contained in:
Arkadiusz Hiler 2021-11-08 19:38:33 +02:00 committed by Alexandre Julliard
parent 2bec77cb5f
commit 04bb9112d2

View file

@ -32,6 +32,9 @@ typedef struct {
_beginthreadex_start_routine_t start_address_ex;
};
void *arglist;
#if _MSVCR_VER >= 140
HMODULE module;
#endif
} _beginthread_trampoline_t;
/*********************************************************************
@ -113,16 +116,10 @@ static DWORD CALLBACK _beginthread_trampoline(LPVOID arg)
thread_data_t *data = msvcrt_get_thread_data();
memcpy(&local_trampoline,arg,sizeof(local_trampoline));
data->handle = local_trampoline.thread;
free(arg);
data->handle = local_trampoline.thread;
#if _MSVCR_VER >= 140
if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
(void*)local_trampoline.start_address, &data->module))
{
data->module = NULL;
WARN("failed to get module for the start_address: %d\n", GetLastError());
}
data->module = local_trampoline.module;
#endif
local_trampoline.start_address(local_trampoline.arglist);
@ -162,7 +159,19 @@ uintptr_t CDECL _beginthread(
trampoline->start_address = start_address;
trampoline->arglist = arglist;
#if _MSVCR_VER >= 140
if(!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
(void*)start_address, &trampoline->module))
{
trampoline->module = NULL;
WARN("failed to get module for the start_address: %d\n", GetLastError());
}
#endif
if(ResumeThread(thread) == -1) {
#if _MSVCR_VER >= 140
FreeLibrary(trampoline->module);
#endif
free(trampoline);
*_errno() = EAGAIN;
return -1;
@ -181,19 +190,10 @@ static DWORD CALLBACK _beginthreadex_trampoline(LPVOID arg)
thread_data_t *data = msvcrt_get_thread_data();
memcpy(&local_trampoline, arg, sizeof(local_trampoline));
data->handle = local_trampoline.thread;
free(arg);
data->handle = local_trampoline.thread;
#if _MSVCR_VER >= 140
{
thread_data_t *data = msvcrt_get_thread_data();
if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
(void*)local_trampoline.start_address_ex, &data->module))
{
data->module = NULL;
WARN("failed to get module for the start_address: %d\n", GetLastError());
}
}
data->module = local_trampoline.module;
#endif
retval = local_trampoline.start_address_ex(local_trampoline.arglist);
@ -225,9 +225,21 @@ uintptr_t CDECL _beginthreadex(
trampoline->start_address_ex = start_address;
trampoline->arglist = arglist;
#if _MSVCR_VER >= 140
if(!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
(void*)start_address, &trampoline->module))
{
trampoline->module = NULL;
WARN("failed to get module for the start_address: %d\n", GetLastError());
}
#endif
thread = CreateThread(security, stack_size, _beginthreadex_trampoline,
trampoline, initflag, thrdaddr);
if(!thread) {
#if _MSVCR_VER >= 140
FreeLibrary(trampoline->module);
#endif
free(trampoline);
msvcrt_set_errno(GetLastError());
return 0;