diff --git a/dlls/secur32/schannel.c b/dlls/secur32/schannel.c index 4699f79ac1f..76b35c2419e 100644 --- a/dlls/secur32/schannel.c +++ b/dlls/secur32/schannel.c @@ -945,13 +945,6 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( ctx->req_ctx_attr = fContextReq; /* Perform the TLS handshake */ - if (fContextReq & ISC_REQ_ALLOCATE_MEMORY) - { - alloc_buffer.cbBuffer = extra_size; - alloc_buffer.BufferType = SECBUFFER_TOKEN; - alloc_buffer.pvBuffer = RtlAllocateHeap( GetProcessHeap(), 0, extra_size ); - } - memset(&input_desc, 0, sizeof(input_desc)); if (pInput && (idx = schan_find_sec_buffer_idx(pInput, 0, SECBUFFER_TOKEN)) != -1) { @@ -967,8 +960,13 @@ static SECURITY_STATUS SEC_ENTRY schan_InitializeSecurityContextW( { output_desc.cBuffers = 1; output_desc.pBuffers = &pOutput->pBuffers[idx]; - if (!output_desc.pBuffers->pvBuffer) + if (!output_desc.pBuffers->pvBuffer || (fContextReq & ISC_REQ_ALLOCATE_MEMORY)) + { + alloc_buffer.cbBuffer = extra_size; + alloc_buffer.BufferType = SECBUFFER_TOKEN; + alloc_buffer.pvBuffer = RtlAllocateHeap( GetProcessHeap(), 0, extra_size ); output_desc.pBuffers = &alloc_buffer; + } } params.session = ctx->session; diff --git a/dlls/secur32/tests/schannel.c b/dlls/secur32/tests/schannel.c index 9d911ac9603..c22cc0ab3c8 100644 --- a/dlls/secur32/tests/schannel.c +++ b/dlls/secur32/tests/schannel.c @@ -695,11 +695,12 @@ static void test_context_output_buffer_size(DWORD protocol, DWORD flags, ULONG c 0, 0, &in_buffers, 0, &context, &out_buffers, &attrs, NULL); ok(status == SEC_E_INSUFFICIENT_MEMORY, "%d: Expected SEC_E_INSUFFICIENT_MEMORY, got %08lx\n", i, status); - if (i) init_sec_buffer(&out_buffers.pBuffers[0], buf_size, NULL); + if (i) init_sec_buffer(&out_buffers.pBuffers[0], buf_size, (void *)0xdeadbeef); init_sec_buffer(buffer, 0, NULL); status = InitializeSecurityContextA(&cred_handle, NULL, (SEC_CHAR *)"localhost", ctxt_flags_req | ISC_REQ_ALLOCATE_MEMORY, 0, 0, &in_buffers, 0, &context, &out_buffers, &attrs, NULL); ok(status == SEC_I_CONTINUE_NEEDED, "%d: Expected SEC_I_CONTINUE_NEEDED, got %08lx\n", i, status); + ok(out_buffers.pBuffers[0].pvBuffer != (void *)0xdeadbeef, "got %p.\n", out_buffers.pBuffers[0].pvBuffer); if (i) FreeContextBuffer(out_buffers.pBuffers[0].pvBuffer); FreeContextBuffer(buffer->pvBuffer); DeleteSecurityContext(&context); @@ -1811,7 +1812,7 @@ static void test_connection_shutdown(void) ISC_REQ_CONFIDENTIALITY | ISC_REQ_STREAM, 0, 0, &buffers[0], 0, &context, &buffers[1], &attrs, NULL ); ok( status == SEC_I_CONTINUE_NEEDED, "Expected SEC_I_CONTINUE_NEEDED, got %08lx\n", status ); - todo_wine ok( !!buffers[1].pBuffers[0].pvBuffer, "Got NULL buffer.\n" ); + ok( !!buffers[1].pBuffers[0].pvBuffer, "Got NULL buffer.\n" ); FreeContextBuffer( buffers[1].pBuffers[0].pvBuffer ); buffers[1].pBuffers[0].pvBuffer = tmp; @@ -1869,9 +1870,9 @@ static void test_connection_shutdown(void) status = InitializeSecurityContextA( &cred_handle, &context, NULL, 0, 0, 0, &buffers[1], 0, &context2, &buffers[0], &attrs, NULL ); todo_wine ok( status == SEC_E_OK, "got %08lx.\n", status ); - todo_wine ok( context.dwLower == context2.dwLower, "dwLower mismatch, expected %#Ix, got %#Ix\n", + ok( context.dwLower == context2.dwLower, "dwLower mismatch, expected %#Ix, got %#Ix\n", context.dwLower, context2.dwLower ); - todo_wine ok( context.dwUpper == context2.dwUpper, "dwUpper mismatch, expected %#Ix, got %#Ix\n", + ok( context.dwUpper == context2.dwUpper, "dwUpper mismatch, expected %#Ix, got %#Ix\n", context.dwUpper, context2.dwUpper ); todo_wine ok( buf->cbBuffer == sizeof(message), "got cbBuffer %#lx.\n", buf->cbBuffer ); todo_wine ok( !memcmp( buf->pvBuffer, message, sizeof(message) ), "message data mismatch.\n" ); @@ -1881,7 +1882,7 @@ static void test_connection_shutdown(void) context2.dwLower = context2.dwUpper = 0xdeadbeef; status = InitializeSecurityContextA( &cred_handle, &context, NULL, 0, 0, 0, NULL, 0, &context2, &buffers[1], &attrs, NULL ); - todo_wine ok( status == SEC_E_INCOMPLETE_MESSAGE, "got %08lx.\n", status ); + ok( status == SEC_E_INCOMPLETE_MESSAGE, "got %08lx.\n", status ); ok( buf->cbBuffer == 1000, "got cbBuffer %#lx.\n", buf->cbBuffer ); ok( context2.dwLower == 0xdeadbeef, "dwLower mismatch, got %#Ix\n", context2.dwLower ); ok( context2.dwUpper == 0xdeadbeef, "dwUpper mismatch, got %#Ix\n", context2.dwUpper );