secur32: Add test for schannel AcceptSecurityContext.

This commit is contained in:
Evan Tang 2022-12-12 16:39:45 -06:00 committed by Alexandre Julliard
parent 610fd134b7
commit d423d76f10

View file

@ -1671,6 +1671,146 @@ static void test_application_protocol_negotiation(void)
closesocket(sock);
}
static void test_server_protocol_negotiation(void) {
BOOL ret;
SECURITY_STATUS status;
ULONG attrs;
SCHANNEL_CRED client_cred, server_cred;
CredHandle client_cred_handle, server_cred_handle;
CtxtHandle client_context, server_context, client_context2, server_context2;
SecPkgContext_ApplicationProtocol protocol;
SecBufferDesc buffers[3];
PCCERT_CONTEXT cert;
HCRYPTPROV csp;
HCRYPTKEY key;
CRYPT_KEY_PROV_INFO keyProvInfo;
WCHAR ms_def_prov_w[MAX_PATH];
unsigned buf_size = 8192;
unsigned char *alpn_buffer;
unsigned int *extension_len;
unsigned short *list_len;
int list_start_index, offset = 0;
if (!pQueryContextAttributesA)
{
win_skip("Required secur32 functions not available\n");
return;
}
lstrcpyW(ms_def_prov_w, MS_DEF_PROV_W);
keyProvInfo.pwszContainerName = cspNameW;
keyProvInfo.pwszProvName = ms_def_prov_w;
keyProvInfo.dwProvType = PROV_RSA_FULL;
keyProvInfo.dwFlags = 0;
keyProvInfo.cProvParam = 0;
keyProvInfo.rgProvParam = NULL;
keyProvInfo.dwKeySpec = AT_SIGNATURE;
cert = CertCreateCertificateContext(X509_ASN_ENCODING, selfSignedCert, sizeof(selfSignedCert));
ret = CertSetCertificateContextProperty(cert, CERT_KEY_PROV_INFO_PROP_ID, 0, &keyProvInfo);
ok(ret, "CertSetCertificateContextProperty failed: %08lx", GetLastError());
ret = CryptAcquireContextW(&csp, cspNameW, MS_DEF_PROV_W, PROV_RSA_FULL, CRYPT_NEWKEYSET);
ok(ret, "CryptAcquireContextW failed: %08lx\n", GetLastError());
ret = CryptImportKey(csp, privKey, sizeof(privKey), 0, 0, &key);
ok(ret, "CryptImportKey failed: %08lx\n", GetLastError());
if (!ret) return;
init_cred(&client_cred);
init_cred(&server_cred);
client_cred.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT;
client_cred.dwFlags = SCH_CRED_NO_DEFAULT_CREDS|SCH_CRED_MANUAL_CRED_VALIDATION;
server_cred.grbitEnabledProtocols = SP_PROT_TLS1_SERVER;
server_cred.dwFlags = SCH_CRED_NO_DEFAULT_CREDS|SCH_CRED_MANUAL_CRED_VALIDATION;
server_cred.cCreds = 1;
server_cred.paCred = &cert;
status = AcquireCredentialsHandleA(NULL, (SEC_CHAR *)UNISP_NAME_A, SECPKG_CRED_OUTBOUND, NULL, &client_cred, NULL, NULL, &client_cred_handle, NULL);
ok(status == SEC_E_OK, "got %08lx\n", status);
if (status != SEC_E_OK) return;
status = AcquireCredentialsHandleA(NULL, (SEC_CHAR *)UNISP_NAME_A, SECPKG_CRED_INBOUND, NULL, &server_cred, NULL, NULL, &server_cred_handle, NULL);
ok(status == SEC_E_OK, "got %08lx\n", status);
if (status != SEC_E_OK) return;
init_buffers(&buffers[0], 4, buf_size);
init_buffers(&buffers[1], 4, buf_size);
init_buffers(&buffers[2], 1, 128);
alpn_buffer = buffers[2].pBuffers[0].pvBuffer;
extension_len = (unsigned int *)&alpn_buffer[offset];
offset += sizeof(*extension_len);
*(unsigned int *)&alpn_buffer[offset] = SecApplicationProtocolNegotiationExt_ALPN;
offset += sizeof(unsigned int);
list_len = (unsigned short *)&alpn_buffer[offset];
offset += sizeof(*list_len);
list_start_index = offset;
alpn_buffer[offset++] = sizeof("http/1.1") - 1;
memcpy(&alpn_buffer[offset], "http/1.1", sizeof("http/1.1") - 1);
offset += sizeof("http/1.1") - 1;
alpn_buffer[offset++] = sizeof("h2") - 1;
memcpy(&alpn_buffer[offset], "h2", sizeof("h2") - 1);
offset += sizeof("h2") - 1;
*list_len = offset - list_start_index;
*extension_len = *list_len + sizeof(*extension_len) + sizeof(*list_len);
buffers[2].pBuffers[0].BufferType = SECBUFFER_APPLICATION_PROTOCOLS;
buffers[2].pBuffers[0].cbBuffer = offset;
buffers[0].pBuffers[0].BufferType = SECBUFFER_TOKEN;
status = InitializeSecurityContextA(&client_cred_handle, NULL, (SEC_CHAR *)"localhost", ISC_REQ_CONFIDENTIALITY|ISC_REQ_STREAM, 0, 0, &buffers[2], 0, &client_context, &buffers[0], &attrs, NULL);
ok(status == SEC_I_CONTINUE_NEEDED, "got %08lx\n", status);
buffers[1].pBuffers[0].cbBuffer = buf_size;
buffers[1].pBuffers[0].BufferType = SECBUFFER_TOKEN;
buffers[0].pBuffers[1] = buffers[2].pBuffers[0];
status = AcceptSecurityContext(&server_cred_handle, NULL, &buffers[0], ASC_REQ_CONFIDENTIALITY|ASC_REQ_STREAM, 0, &server_context, &buffers[1], &attrs, NULL);
ok(status == SEC_I_CONTINUE_NEEDED, "got %08lx\n", status);
memset(&buffers[0].pBuffers[1], 0, sizeof(buffers[0].pBuffers[1]));
client_context2.dwLower = client_context2.dwUpper = 0xdeadbeef;
buffers[0].pBuffers[0].cbBuffer = buf_size;
status = InitializeSecurityContextA(&client_cred_handle, &client_context, (SEC_CHAR *)"localhost", ISC_REQ_CONFIDENTIALITY|ISC_REQ_STREAM|ISC_REQ_USE_SUPPLIED_CREDS, 0, 0, &buffers[1], 0, &client_context2, &buffers[0], &attrs, NULL);
ok(client_context.dwLower == client_context2.dwLower, "dwLower mismatch, expected %#Ix, got %#Ix\n", client_context.dwLower, client_context2.dwLower);
ok(client_context.dwUpper == client_context2.dwUpper, "dwUpper mismatch, expected %#Ix, got %#Ix\n", client_context.dwUpper, client_context2.dwUpper);
ok(status == SEC_I_CONTINUE_NEEDED, "got %08lx\n", status);
server_context2.dwLower = server_context2.dwUpper = 0xdeadbeef;
buffers[1].pBuffers[0].cbBuffer = buf_size;
status = AcceptSecurityContext(&server_cred_handle, &server_context, &buffers[0], ASC_REQ_CONFIDENTIALITY|ASC_REQ_STREAM, 0, &server_context2, &buffers[1], &attrs, NULL);
ok(server_context.dwLower == server_context2.dwLower, "dwLower mismatch, expected %#Ix, got %#Ix\n", server_context.dwLower, server_context2.dwLower);
ok(server_context.dwUpper == server_context2.dwUpper, "dwUpper mismatch, expected %#Ix, got %#Ix\n", server_context.dwUpper, server_context2.dwUpper);
ok(status == SEC_E_OK, "got %08lx\n", status);
buffers[0].pBuffers[0].cbBuffer = buf_size;
status = InitializeSecurityContextA(&client_cred_handle, &client_context, (SEC_CHAR *)"localhost", ISC_REQ_USE_SUPPLIED_CREDS, 0, 0, &buffers[1], 0, NULL, &buffers[0], &attrs, NULL);
ok(status == SEC_E_OK, "got %08lx\n", status);
memset(&protocol, 0, sizeof(protocol));
status = pQueryContextAttributesA(&client_context, SECPKG_ATTR_APPLICATION_PROTOCOL, &protocol);
ok(status == SEC_E_OK || broken(status == SEC_E_UNSUPPORTED_FUNCTION) /* win2k8 */, "got %08lx\n", status);
if (status == SEC_E_OK)
{
ok(protocol.ProtoNegoStatus == SecApplicationProtocolNegotiationStatus_Success, "got %u\n", protocol.ProtoNegoStatus);
ok(protocol.ProtoNegoExt == SecApplicationProtocolNegotiationExt_ALPN, "got %u\n", protocol.ProtoNegoExt);
ok(protocol.ProtocolIdSize == 8, "got %u\n", protocol.ProtocolIdSize);
ok(!memcmp(protocol.ProtocolId, "http/1.1", 8), "wrong protocol id\n");
}
DeleteSecurityContext(&client_context);
DeleteSecurityContext(&server_context);
FreeCredentialsHandle(&client_cred_handle);
FreeCredentialsHandle(&server_cred_handle);
free_buffers(&buffers[0]);
free_buffers(&buffers[1]);
free_buffers(&buffers[2]);
CryptDestroyKey(key);
CryptReleaseContext(csp, 0);
CryptAcquireContextW(&csp, cspNameW, MS_DEF_PROV_W, PROV_RSA_FULL, CRYPT_DELETEKEYSET);
CertFreeCertificateContext(cert);
}
static void init_dtls_output_buffer(SecBufferDesc *buffer)
{
buffer->pBuffers[0].BufferType = SECBUFFER_TOKEN;
@ -1949,6 +2089,7 @@ START_TEST(schannel)
test_InitializeSecurityContext();
test_communication();
test_application_protocol_negotiation();
test_server_protocol_negotiation();
test_dtls();
test_connection_shutdown();
}