secur32: Allocate buffer for either ISC_REQ_ALLOCATE_MEMORY or NULL output in schan_InitializeSecurityContextW().

This commit is contained in:
Paul Gofman 2022-11-07 19:55:30 -06:00 committed by Alexandre Julliard
parent 86b3fafe82
commit ac5968790a
2 changed files with 12 additions and 13 deletions

View file

@ -945,13 +945,6 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW(
ctx->req_ctx_attr = fContextReq;
/* Perform the TLS handshake */
if (fContextReq & ISC_REQ_ALLOCATE_MEMORY)
{
alloc_buffer.cbBuffer = extra_size;
alloc_buffer.BufferType = SECBUFFER_TOKEN;
alloc_buffer.pvBuffer = RtlAllocateHeap( GetProcessHeap(), 0, extra_size );
}
memset(&input_desc, 0, sizeof(input_desc));
if (pInput && (idx = schan_find_sec_buffer_idx(pInput, 0, SECBUFFER_TOKEN)) != -1)
{
@ -967,8 +960,13 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW(
{
output_desc.cBuffers = 1;
output_desc.pBuffers = &pOutput->pBuffers[idx];
if (!output_desc.pBuffers->pvBuffer)
if (!output_desc.pBuffers->pvBuffer || (fContextReq & ISC_REQ_ALLOCATE_MEMORY))
{
alloc_buffer.cbBuffer = extra_size;
alloc_buffer.BufferType = SECBUFFER_TOKEN;
alloc_buffer.pvBuffer = RtlAllocateHeap( GetProcessHeap(), 0, extra_size );
output_desc.pBuffers = &alloc_buffer;
}
}
params.session = ctx->session;

View file

@ -695,11 +695,12 @@ static void test_context_output_buffer_size(DWORD protocol, DWORD flags, ULONG c
0, 0, &in_buffers, 0, &context, &out_buffers, &attrs, NULL);
ok(status == SEC_E_INSUFFICIENT_MEMORY, "%d: Expected SEC_E_INSUFFICIENT_MEMORY, got %08lx\n", i, status);
if (i) init_sec_buffer(&out_buffers.pBuffers[0], buf_size, NULL);
if (i) init_sec_buffer(&out_buffers.pBuffers[0], buf_size, (void *)0xdeadbeef);
init_sec_buffer(buffer, 0, NULL);
status = InitializeSecurityContextA(&cred_handle, NULL, (SEC_CHAR *)"localhost",
ctxt_flags_req | ISC_REQ_ALLOCATE_MEMORY, 0, 0, &in_buffers, 0, &context, &out_buffers, &attrs, NULL);
ok(status == SEC_I_CONTINUE_NEEDED, "%d: Expected SEC_I_CONTINUE_NEEDED, got %08lx\n", i, status);
ok(out_buffers.pBuffers[0].pvBuffer != (void *)0xdeadbeef, "got %p.\n", out_buffers.pBuffers[0].pvBuffer);
if (i) FreeContextBuffer(out_buffers.pBuffers[0].pvBuffer);
FreeContextBuffer(buffer->pvBuffer);
DeleteSecurityContext(&context);
@ -1811,7 +1812,7 @@ static void test_connection_shutdown(void)
ISC_REQ_CONFIDENTIALITY | ISC_REQ_STREAM,
0, 0, &buffers[0], 0, &context, &buffers[1], &attrs, NULL );
ok( status == SEC_I_CONTINUE_NEEDED, "Expected SEC_I_CONTINUE_NEEDED, got %08lx\n", status );
todo_wine ok( !!buffers[1].pBuffers[0].pvBuffer, "Got NULL buffer.\n" );
ok( !!buffers[1].pBuffers[0].pvBuffer, "Got NULL buffer.\n" );
FreeContextBuffer( buffers[1].pBuffers[0].pvBuffer );
buffers[1].pBuffers[0].pvBuffer = tmp;
@ -1869,9 +1870,9 @@ static void test_connection_shutdown(void)
status = InitializeSecurityContextA( &cred_handle, &context, NULL, 0, 0, 0, &buffers[1], 0,
&context2, &buffers[0], &attrs, NULL );
todo_wine ok( status == SEC_E_OK, "got %08lx.\n", status );
todo_wine ok( context.dwLower == context2.dwLower, "dwLower mismatch, expected %#Ix, got %#Ix\n",
ok( context.dwLower == context2.dwLower, "dwLower mismatch, expected %#Ix, got %#Ix\n",
context.dwLower, context2.dwLower );
todo_wine ok( context.dwUpper == context2.dwUpper, "dwUpper mismatch, expected %#Ix, got %#Ix\n",
ok( context.dwUpper == context2.dwUpper, "dwUpper mismatch, expected %#Ix, got %#Ix\n",
context.dwUpper, context2.dwUpper );
todo_wine ok( buf->cbBuffer == sizeof(message), "got cbBuffer %#lx.\n", buf->cbBuffer );
todo_wine ok( !memcmp( buf->pvBuffer, message, sizeof(message) ), "message data mismatch.\n" );
@ -1881,7 +1882,7 @@ static void test_connection_shutdown(void)
context2.dwLower = context2.dwUpper = 0xdeadbeef;
status = InitializeSecurityContextA( &cred_handle, &context, NULL, 0, 0, 0, NULL, 0,
&context2, &buffers[1], &attrs, NULL );
todo_wine ok( status == SEC_E_INCOMPLETE_MESSAGE, "got %08lx.\n", status );
ok( status == SEC_E_INCOMPLETE_MESSAGE, "got %08lx.\n", status );
ok( buf->cbBuffer == 1000, "got cbBuffer %#lx.\n", buf->cbBuffer );
ok( context2.dwLower == 0xdeadbeef, "dwLower mismatch, got %#Ix\n", context2.dwLower );
ok( context2.dwUpper == 0xdeadbeef, "dwUpper mismatch, got %#Ix\n", context2.dwUpper );