secur32: Prepare schan_send() buffers on PE side.

Signed-off-by: Nikolay Sivov <nsivov@codeweavers.com>
This commit is contained in:
Nikolay Sivov 2022-05-31 09:38:26 +03:00 committed by Alexandre Julliard
parent 0a62c7bd40
commit 088c288214
2 changed files with 46 additions and 83 deletions

View file

@ -1299,8 +1299,12 @@ static SECURITY_STATUS SEC_ENTRY schan_EncryptMessage(PCtxtHandle context_handle
SIZE_T data_size;
SIZE_T length;
char *data;
int idx, output_buffer_idx = -1;
int output_buffer_idx = -1;
ULONG output_offset = 0;
SecBufferDesc output_desc = { 0 };
SecBuffer output_buffers[3];
int header_idx, data_idx, trailer_idx = -1;
int buffer_index[3];
TRACE("context_handle %p, quality %ld, message %p, message_seq_no %ld\n",
context_handle, quality, message, message_seq_no);
@ -1310,29 +1314,56 @@ static SECURITY_STATUS SEC_ENTRY schan_EncryptMessage(PCtxtHandle context_handle
dump_buffer_desc(message);
idx = schan_find_sec_buffer_idx(message, 0, SECBUFFER_DATA);
if (idx == -1)
data_idx = schan_find_sec_buffer_idx(message, 0, SECBUFFER_DATA);
if (data_idx == -1)
{
WARN("No data buffer passed\n");
return SEC_E_INTERNAL_ERROR;
}
buffer = &message->pBuffers[idx];
buffer = &message->pBuffers[data_idx];
data_size = buffer->cbBuffer;
data = malloc(data_size);
memcpy(data, buffer->pvBuffer, data_size);
/* Use { STREAM_HEADER, DATA, STREAM_TRAILER } or { TOKEN, DATA, TOKEN } buffers. */
output_desc.pBuffers = output_buffers;
if ((header_idx = schan_find_sec_buffer_idx(message, 0, SECBUFFER_STREAM_HEADER)) == -1)
{
if ((header_idx = schan_find_sec_buffer_idx(message, 0, SECBUFFER_TOKEN)) != -1)
{
output_buffers[output_desc.cBuffers++] = message->pBuffers[header_idx];
output_buffers[output_desc.cBuffers++] = message->pBuffers[data_idx];
trailer_idx = schan_find_sec_buffer_idx(message, header_idx + 1, SECBUFFER_TOKEN);
if (trailer_idx != -1)
output_buffers[output_desc.cBuffers++] = message->pBuffers[trailer_idx];
}
}
else
{
output_buffers[output_desc.cBuffers++] = message->pBuffers[header_idx];
output_buffers[output_desc.cBuffers++] = message->pBuffers[data_idx];
trailer_idx = schan_find_sec_buffer_idx(message, 0, SECBUFFER_STREAM_TRAILER);
if (trailer_idx != -1)
output_buffers[output_desc.cBuffers++] = message->pBuffers[trailer_idx];
}
buffer_index[0] = header_idx;
buffer_index[1] = data_idx;
buffer_index[2] = trailer_idx;
length = data_size;
params.session = ctx->session;
params.output = message;
params.output = &output_desc;
params.buffer = data;
params.length = &length;
params.output_buffer_idx = &output_buffer_idx;
params.output_offset = &output_offset;
status = GNUTLS_CALL( send, &params );
if (!status && output_buffer_idx != -1)
message->pBuffers[output_buffer_idx].cbBuffer = output_offset;
if (!status)
message->pBuffers[buffer_index[output_buffer_idx]].cbBuffer = output_offset;
TRACE("Sent %Id bytes.\n", length);

View file

@ -236,78 +236,13 @@ static void init_schan_buffers(struct schan_buffers *s, const PSecBufferDesc des
s->get_next_buffer = get_next_buffer;
}
static int schan_find_sec_buffer_idx(const SecBufferDesc *desc, unsigned int start_idx, ULONG buffer_type)
static int common_get_next_buffer(struct schan_buffers *s)
{
unsigned int i;
PSecBuffer buffer;
for (i = start_idx; i < desc->cBuffers; ++i)
{
buffer = &desc->pBuffers[i];
if ((buffer->BufferType | SECBUFFER_ATTRMASK) == (buffer_type | SECBUFFER_ATTRMASK))
return i;
}
return -1;
}
static int handshake_get_next_buffer(struct schan_buffers *s)
{
if (s->current_buffer_idx != -1)
return -1;
return s->desc->cBuffers ? 0 : -1;
}
static int send_message_get_next_buffer(struct schan_buffers *s)
{
SecBuffer *b;
if (s->current_buffer_idx == -1)
return schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_STREAM_HEADER);
b = &s->desc->pBuffers[s->current_buffer_idx];
if (b->BufferType == SECBUFFER_STREAM_HEADER)
return schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_DATA);
if (b->BufferType == SECBUFFER_DATA)
return schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_STREAM_TRAILER);
return -1;
}
static int send_message_get_next_buffer_token(struct schan_buffers *s)
{
SecBuffer *b;
if (s->current_buffer_idx == -1)
return schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_TOKEN);
b = &s->desc->pBuffers[s->current_buffer_idx];
if (b->BufferType == SECBUFFER_TOKEN)
{
int idx = schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_TOKEN);
if (idx != s->current_buffer_idx) return -1;
return schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_DATA);
}
if (b->BufferType == SECBUFFER_DATA)
{
int idx = schan_find_sec_buffer_idx(s->desc, 0, SECBUFFER_TOKEN);
if (idx != -1)
idx = schan_find_sec_buffer_idx(s->desc, idx + 1, SECBUFFER_TOKEN);
return idx;
}
return -1;
}
static int recv_message_get_next_buffer(struct schan_buffers *s)
{
if (s->current_buffer_idx != -1)
return s->desc->cBuffers ? 0 : -1;
if (s->current_buffer_idx == s->desc->cBuffers - 1)
return -1;
return s->desc->cBuffers ? 0 : -1;
return s->current_buffer_idx + 1;
}
static char *get_buffer(struct schan_buffers *s, SIZE_T *count)
@ -583,9 +518,9 @@ static NTSTATUS schan_handshake( void *args )
NTSTATUS status;
int err;
init_schan_buffers(&t->in, params->input, handshake_get_next_buffer);
init_schan_buffers(&t->in, params->input, common_get_next_buffer);
t->in.limit = params->input_size;
init_schan_buffers(&t->out, params->output, handshake_get_next_buffer);
init_schan_buffers(&t->out, params->output, common_get_next_buffer);
while (1)
{
@ -850,10 +785,7 @@ static NTSTATUS schan_send( void *args )
struct schan_transport *t = (struct schan_transport *)pgnutls_transport_get_ptr(s);
SSIZE_T ret, total = 0;
if (schan_find_sec_buffer_idx(params->output, 0, SECBUFFER_STREAM_HEADER) != -1)
init_schan_buffers(&t->out, params->output, send_message_get_next_buffer);
else
init_schan_buffers(&t->out, params->output, send_message_get_next_buffer_token);
init_schan_buffers(&t->out, params->output, common_get_next_buffer);
for (;;)
{
@ -893,7 +825,7 @@ static NTSTATUS schan_recv( void *args )
ssize_t ret;
SECURITY_STATUS status = SEC_E_OK;
init_schan_buffers(&t->in, params->input, recv_message_get_next_buffer);
init_schan_buffers(&t->in, params->input, common_get_next_buffer);
t->in.limit = params->input_size;
while (received < data_size)