diff --git a/dlls/secur32/tests/schannel.c b/dlls/secur32/tests/schannel.c index 455fcb97aa7..f0b0b20b08a 100644 --- a/dlls/secur32/tests/schannel.c +++ b/dlls/secur32/tests/schannel.c @@ -1671,6 +1671,146 @@ static void test_application_protocol_negotiation(void) closesocket(sock); } +static void test_server_protocol_negotiation(void) { + BOOL ret; + SECURITY_STATUS status; + ULONG attrs; + SCHANNEL_CRED client_cred, server_cred; + CredHandle client_cred_handle, server_cred_handle; + CtxtHandle client_context, server_context, client_context2, server_context2; + SecPkgContext_ApplicationProtocol protocol; + SecBufferDesc buffers[3]; + PCCERT_CONTEXT cert; + HCRYPTPROV csp; + HCRYPTKEY key; + CRYPT_KEY_PROV_INFO keyProvInfo; + WCHAR ms_def_prov_w[MAX_PATH]; + unsigned buf_size = 8192; + unsigned char *alpn_buffer; + unsigned int *extension_len; + unsigned short *list_len; + int list_start_index, offset = 0; + + if (!pQueryContextAttributesA) + { + win_skip("Required secur32 functions not available\n"); + return; + } + + lstrcpyW(ms_def_prov_w, MS_DEF_PROV_W); + keyProvInfo.pwszContainerName = cspNameW; + keyProvInfo.pwszProvName = ms_def_prov_w; + keyProvInfo.dwProvType = PROV_RSA_FULL; + keyProvInfo.dwFlags = 0; + keyProvInfo.cProvParam = 0; + keyProvInfo.rgProvParam = NULL; + keyProvInfo.dwKeySpec = AT_SIGNATURE; + + cert = CertCreateCertificateContext(X509_ASN_ENCODING, selfSignedCert, sizeof(selfSignedCert)); + ret = CertSetCertificateContextProperty(cert, CERT_KEY_PROV_INFO_PROP_ID, 0, &keyProvInfo); + ok(ret, "CertSetCertificateContextProperty failed: %08lx", GetLastError()); + ret = CryptAcquireContextW(&csp, cspNameW, MS_DEF_PROV_W, PROV_RSA_FULL, CRYPT_NEWKEYSET); + ok(ret, "CryptAcquireContextW failed: %08lx\n", GetLastError()); + ret = CryptImportKey(csp, privKey, sizeof(privKey), 0, 0, &key); + ok(ret, "CryptImportKey failed: %08lx\n", GetLastError()); + if (!ret) return; + + init_cred(&client_cred); + init_cred(&server_cred); + client_cred.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT; + client_cred.dwFlags = SCH_CRED_NO_DEFAULT_CREDS|SCH_CRED_MANUAL_CRED_VALIDATION; + server_cred.grbitEnabledProtocols = SP_PROT_TLS1_SERVER; + server_cred.dwFlags = SCH_CRED_NO_DEFAULT_CREDS|SCH_CRED_MANUAL_CRED_VALIDATION; + server_cred.cCreds = 1; + server_cred.paCred = &cert; + + status = AcquireCredentialsHandleA(NULL, (SEC_CHAR *)UNISP_NAME_A, SECPKG_CRED_OUTBOUND, NULL, &client_cred, NULL, NULL, &client_cred_handle, NULL); + ok(status == SEC_E_OK, "got %08lx\n", status); + if (status != SEC_E_OK) return; + status = AcquireCredentialsHandleA(NULL, (SEC_CHAR *)UNISP_NAME_A, SECPKG_CRED_INBOUND, NULL, &server_cred, NULL, NULL, &server_cred_handle, NULL); + ok(status == SEC_E_OK, "got %08lx\n", status); + if (status != SEC_E_OK) return; + + init_buffers(&buffers[0], 4, buf_size); + init_buffers(&buffers[1], 4, buf_size); + init_buffers(&buffers[2], 1, 128); + + alpn_buffer = buffers[2].pBuffers[0].pvBuffer; + extension_len = (unsigned int *)&alpn_buffer[offset]; + offset += sizeof(*extension_len); + *(unsigned int *)&alpn_buffer[offset] = SecApplicationProtocolNegotiationExt_ALPN; + offset += sizeof(unsigned int); + list_len = (unsigned short *)&alpn_buffer[offset]; + offset += sizeof(*list_len); + list_start_index = offset; + + alpn_buffer[offset++] = sizeof("http/1.1") - 1; + memcpy(&alpn_buffer[offset], "http/1.1", sizeof("http/1.1") - 1); + offset += sizeof("http/1.1") - 1; + alpn_buffer[offset++] = sizeof("h2") - 1; + memcpy(&alpn_buffer[offset], "h2", sizeof("h2") - 1); + offset += sizeof("h2") - 1; + + *list_len = offset - list_start_index; + *extension_len = *list_len + sizeof(*extension_len) + sizeof(*list_len); + + buffers[2].pBuffers[0].BufferType = SECBUFFER_APPLICATION_PROTOCOLS; + buffers[2].pBuffers[0].cbBuffer = offset; + buffers[0].pBuffers[0].BufferType = SECBUFFER_TOKEN; + status = InitializeSecurityContextA(&client_cred_handle, NULL, (SEC_CHAR *)"localhost", ISC_REQ_CONFIDENTIALITY|ISC_REQ_STREAM, 0, 0, &buffers[2], 0, &client_context, &buffers[0], &attrs, NULL); + ok(status == SEC_I_CONTINUE_NEEDED, "got %08lx\n", status); + + buffers[1].pBuffers[0].cbBuffer = buf_size; + buffers[1].pBuffers[0].BufferType = SECBUFFER_TOKEN; + buffers[0].pBuffers[1] = buffers[2].pBuffers[0]; + status = AcceptSecurityContext(&server_cred_handle, NULL, &buffers[0], ASC_REQ_CONFIDENTIALITY|ASC_REQ_STREAM, 0, &server_context, &buffers[1], &attrs, NULL); + ok(status == SEC_I_CONTINUE_NEEDED, "got %08lx\n", status); + memset(&buffers[0].pBuffers[1], 0, sizeof(buffers[0].pBuffers[1])); + + client_context2.dwLower = client_context2.dwUpper = 0xdeadbeef; + buffers[0].pBuffers[0].cbBuffer = buf_size; + status = InitializeSecurityContextA(&client_cred_handle, &client_context, (SEC_CHAR *)"localhost", ISC_REQ_CONFIDENTIALITY|ISC_REQ_STREAM|ISC_REQ_USE_SUPPLIED_CREDS, 0, 0, &buffers[1], 0, &client_context2, &buffers[0], &attrs, NULL); + ok(client_context.dwLower == client_context2.dwLower, "dwLower mismatch, expected %#Ix, got %#Ix\n", client_context.dwLower, client_context2.dwLower); + ok(client_context.dwUpper == client_context2.dwUpper, "dwUpper mismatch, expected %#Ix, got %#Ix\n", client_context.dwUpper, client_context2.dwUpper); + ok(status == SEC_I_CONTINUE_NEEDED, "got %08lx\n", status); + + server_context2.dwLower = server_context2.dwUpper = 0xdeadbeef; + buffers[1].pBuffers[0].cbBuffer = buf_size; + status = AcceptSecurityContext(&server_cred_handle, &server_context, &buffers[0], ASC_REQ_CONFIDENTIALITY|ASC_REQ_STREAM, 0, &server_context2, &buffers[1], &attrs, NULL); + ok(server_context.dwLower == server_context2.dwLower, "dwLower mismatch, expected %#Ix, got %#Ix\n", server_context.dwLower, server_context2.dwLower); + ok(server_context.dwUpper == server_context2.dwUpper, "dwUpper mismatch, expected %#Ix, got %#Ix\n", server_context.dwUpper, server_context2.dwUpper); + ok(status == SEC_E_OK, "got %08lx\n", status); + + buffers[0].pBuffers[0].cbBuffer = buf_size; + status = InitializeSecurityContextA(&client_cred_handle, &client_context, (SEC_CHAR *)"localhost", ISC_REQ_USE_SUPPLIED_CREDS, 0, 0, &buffers[1], 0, NULL, &buffers[0], &attrs, NULL); + ok(status == SEC_E_OK, "got %08lx\n", status); + + memset(&protocol, 0, sizeof(protocol)); + status = pQueryContextAttributesA(&client_context, SECPKG_ATTR_APPLICATION_PROTOCOL, &protocol); + ok(status == SEC_E_OK || broken(status == SEC_E_UNSUPPORTED_FUNCTION) /* win2k8 */, "got %08lx\n", status); + if (status == SEC_E_OK) + { + ok(protocol.ProtoNegoStatus == SecApplicationProtocolNegotiationStatus_Success, "got %u\n", protocol.ProtoNegoStatus); + ok(protocol.ProtoNegoExt == SecApplicationProtocolNegotiationExt_ALPN, "got %u\n", protocol.ProtoNegoExt); + ok(protocol.ProtocolIdSize == 8, "got %u\n", protocol.ProtocolIdSize); + ok(!memcmp(protocol.ProtocolId, "http/1.1", 8), "wrong protocol id\n"); + } + + DeleteSecurityContext(&client_context); + DeleteSecurityContext(&server_context); + FreeCredentialsHandle(&client_cred_handle); + FreeCredentialsHandle(&server_cred_handle); + + free_buffers(&buffers[0]); + free_buffers(&buffers[1]); + free_buffers(&buffers[2]); + + CryptDestroyKey(key); + CryptReleaseContext(csp, 0); + CryptAcquireContextW(&csp, cspNameW, MS_DEF_PROV_W, PROV_RSA_FULL, CRYPT_DELETEKEYSET); + CertFreeCertificateContext(cert); +} + static void init_dtls_output_buffer(SecBufferDesc *buffer) { buffer->pBuffers[0].BufferType = SECBUFFER_TOKEN; @@ -1949,6 +2089,7 @@ START_TEST(schannel) test_InitializeSecurityContext(); test_communication(); test_application_protocol_negotiation(); + test_server_protocol_negotiation(); test_dtls(); test_connection_shutdown(); }