winhttp: Limit recursion for synchronous callback calls.

Fixes a regression in Hitman 2, Death Stranding introduced
by commit be5acd1c07.

Signed-off-by: Paul Gofman <pgofman@codeweavers.com>
Signed-off-by: Hans Leidekker <hans@codeweavers.com>
Signed-off-by: Alexandre Julliard <julliard@winehq.org>
This commit is contained in:
Paul Gofman 2021-09-20 19:17:20 +03:00 committed by Alexandre Julliard
parent 5d5f11c002
commit 34fea20cd3
4 changed files with 167 additions and 5 deletions

View file

@ -2825,6 +2825,11 @@ static DWORD query_data_ready( struct request *request )
return count;
}
static BOOL skip_async_queue( struct request *request )
{
return request->hdr.recursion_count < 3 && (end_of_read_data( request ) || query_data_ready( request ));
}
static DWORD query_data_available( struct request *request, DWORD *available, BOOL async )
{
DWORD ret = ERROR_SUCCESS, count = 0;
@ -2889,8 +2894,7 @@ BOOL WINAPI WinHttpQueryDataAvailable( HINTERNET hrequest, LPDWORD available )
return FALSE;
}
if ((async = request->connect->hdr.flags & WINHTTP_FLAG_ASYNC) && !end_of_read_data( request )
&& !query_data_ready( request ))
if ((async = request->connect->hdr.flags & WINHTTP_FLAG_ASYNC) && !skip_async_queue( request ))
{
struct query_data *q;
@ -2947,8 +2951,7 @@ BOOL WINAPI WinHttpReadData( HINTERNET hrequest, LPVOID buffer, DWORD to_read, L
return FALSE;
}
if ((async = request->connect->hdr.flags & WINHTTP_FLAG_ASYNC) && !end_of_read_data( request )
&& !query_data_ready( request ))
if ((async = request->connect->hdr.flags & WINHTTP_FLAG_ASYNC) && !skip_async_queue( request ))
{
struct read_data *r;

View file

@ -48,8 +48,10 @@ void send_callback( struct object_header *hdr, DWORD status, void *info, DWORD b
{
if (hdr->callback && (hdr->notify_mask & status))
{
TRACE("%p, 0x%08x, %p, %u\n", hdr, status, info, buflen);
TRACE("%p, 0x%08x, %p, %u, %u\n", hdr, status, info, buflen, hdr->recursion_count);
InterlockedIncrement( &hdr->recursion_count );
hdr->callback( hdr->handle, hdr->context, status, info, buflen );
InterlockedDecrement( &hdr->recursion_count );
TRACE("returning from 0x%08x callback\n", status);
}
}

View file

@ -1212,6 +1212,161 @@ static void test_persistent_connection(int port)
CloseHandle( info.wait );
}
struct test_recursion_context
{
HANDLE request;
HANDLE wait;
LONG recursion_count, max_recursion_query, max_recursion_read;
BOOL read_from_callback;
BOOL have_sync_callback;
};
/* The limit is 128 before Win7 and 3 on newer Windows. */
#define TEST_RECURSION_LIMIT 128
static void CALLBACK test_recursion_callback( HINTERNET handle, DWORD_PTR context_ptr,
DWORD status, void *buffer, DWORD buflen )
{
struct test_recursion_context *context = (struct test_recursion_context *)context_ptr;
DWORD err;
BOOL ret;
BYTE b;
switch (status)
{
case WINHTTP_CALLBACK_STATUS_SENDREQUEST_COMPLETE:
case WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE:
SetEvent( context->wait );
break;
case WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE:
if (!context->read_from_callback)
{
SetEvent( context->wait );
break;
}
if (!*(DWORD *)buffer)
{
SetEvent( context->wait );
break;
}
ok(context->recursion_count < TEST_RECURSION_LIMIT,
"Got unexpected context->recursion_count %u, thread %#x.\n",
context->recursion_count, GetCurrentThreadId());
context->max_recursion_query = max( context->max_recursion_query, context->recursion_count );
InterlockedIncrement( &context->recursion_count );
ret = WinHttpReadData( context->request, &b, 1, NULL );
err = GetLastError();
ok(ret, "Failed to read data, GetLastError() %u.\n", err);
ok(err == ERROR_SUCCESS || err == ERROR_IO_PENDING, "Got unexpected err %u.\n", err);
if (err == ERROR_SUCCESS)
context->have_sync_callback = TRUE;
InterlockedDecrement( &context->recursion_count );
break;
case WINHTTP_CALLBACK_STATUS_READ_COMPLETE:
if (!buflen)
{
SetEvent( context->wait );
break;
}
ok(context->recursion_count < TEST_RECURSION_LIMIT,
"Got unexpected context->recursion_count %u, thread %#x.\n",
context->recursion_count, GetCurrentThreadId());
context->max_recursion_read = max( context->max_recursion_read, context->recursion_count );
context->read_from_callback = TRUE;
InterlockedIncrement( &context->recursion_count );
ret = WinHttpQueryDataAvailable( context->request, NULL );
err = GetLastError();
ok(ret, "Failed to query data available, GetLastError() %u.\n", err);
ok(err == ERROR_SUCCESS || err == ERROR_IO_PENDING, "Got unexpected err %u.\n", err);
if (err == ERROR_SUCCESS)
context->have_sync_callback = TRUE;
InterlockedDecrement( &context->recursion_count );
break;
}
}
static void test_recursion(void)
{
struct test_recursion_context context;
HANDLE session, connection, request;
DWORD size, status, err;
BOOL ret;
BYTE b;
memset( &context, 0, sizeof(context) );
context.wait = CreateEventW( NULL, FALSE, FALSE, NULL );
session = WinHttpOpen( L"winetest", 0, NULL, NULL, WINHTTP_FLAG_ASYNC );
ok(!!session, "Failed to open session, GetLastError() %u.\n", GetLastError());
WinHttpSetStatusCallback( session, test_recursion_callback, WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS, 0 );
connection = WinHttpConnect( session, L"test.winehq.org", 0, 0 );
ok(!!connection, "Failed to open a connection, GetLastError() %u.\n", GetLastError());
request = WinHttpOpenRequest( connection, NULL, L"/tests/hello.html", NULL, NULL, NULL, 0 );
ok(!!request, "Failed to open a request, GetLastError() %u.\n", GetLastError());
context.request = request;
ret = WinHttpSendRequest( request, NULL, 0, NULL, 0, 0, (DWORD_PTR)&context );
err = GetLastError();
if (!ret && (err == ERROR_WINHTTP_CANNOT_CONNECT || err == ERROR_WINHTTP_TIMEOUT))
{
skip("Connection failed, skipping\n");
WinHttpSetStatusCallback( session, NULL, WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS, 0 );
WinHttpCloseHandle( request );
WinHttpCloseHandle( connection );
WinHttpCloseHandle( session );
CloseHandle( context.wait );
return;
}
ok(ret, "Failed to send request, GetLastError() %u.\n", GetLastError());
WaitForSingleObject( context.wait, INFINITE );
ret = WinHttpReceiveResponse( request, NULL );
ok(ret, "Failed to receive response, GetLastError() %u.\n", GetLastError());
WaitForSingleObject( context.wait, INFINITE );
size = sizeof(status);
ret = WinHttpQueryHeaders( request, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, NULL,
&status, &size, NULL );
ok(ret, "Request failed, GetLastError() %u.\n", GetLastError());
ok(status == 200, "Request failed unexpectedly, status %u.\n", status);
ret = WinHttpQueryDataAvailable( request, NULL );
ok(ret, "Failed to query data available, GetLastError() %u.\n", GetLastError());
WaitForSingleObject( context.wait, INFINITE );
ret = WinHttpReadData( request, &b, 1, NULL );
ok(ret, "Failed to read data, GetLastError() %u.\n", GetLastError());
WaitForSingleObject( context.wait, INFINITE );
if (context.have_sync_callback)
{
ok(context.max_recursion_query >= 2, "Got unexpected max_recursion_query %u.\n", context.max_recursion_query);
ok(context.max_recursion_read >= 2, "Got unexpected max_recursion_read %u.\n", context.max_recursion_read);
}
else
{
skip("No sync callbacks.\n");
}
WinHttpSetStatusCallback( session, NULL, WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS, 0 );
WinHttpCloseHandle( request );
WinHttpCloseHandle( connection );
WinHttpCloseHandle( session );
CloseHandle( context.wait );
}
START_TEST (notification)
{
HMODULE mod = GetModuleHandleA( "winhttp.dll" );
@ -1230,6 +1385,7 @@ START_TEST (notification)
test_redirect();
test_async();
test_websocket();
test_recursion();
si.event = CreateEventW( NULL, 0, 0, NULL );
si.port = 7533;

View file

@ -49,6 +49,7 @@ struct object_header
LONG refs;
WINHTTP_STATUS_CALLBACK callback;
DWORD notify_mask;
LONG recursion_count;
struct list entry;
};