crypt32: Keep track of state in each message type's update function rather than in CryptMsgUpdate.

This commit is contained in:
Juan Lang 2008-03-25 11:15:08 -07:00 committed by Alexandre Julliard
parent 3617819bf6
commit 856270972f

View file

@ -188,7 +188,9 @@ static BOOL CDataEncodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData,
CDataEncodeMsg *msg = (CDataEncodeMsg *)hCryptMsg; CDataEncodeMsg *msg = (CDataEncodeMsg *)hCryptMsg;
BOOL ret = FALSE; BOOL ret = FALSE;
if (msg->base.streamed) if (msg->base.state == MsgStateFinalized)
SetLastError(CRYPT_E_MSG_ERROR);
else if (msg->base.streamed)
{ {
__TRY __TRY
{ {
@ -225,11 +227,15 @@ static BOOL CDataEncodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData,
} }
} }
if (!fFinal) if (!fFinal)
{
ret = msg->base.stream_info.pfnStreamOutput( ret = msg->base.stream_info.pfnStreamOutput(
msg->base.stream_info.pvArg, (BYTE *)pbData, cbData, msg->base.stream_info.pvArg, (BYTE *)pbData, cbData,
FALSE); FALSE);
msg->base.state = MsgStateUpdated;
}
else else
{ {
msg->base.state = MsgStateFinalized;
if (msg->base.stream_info.cbContent == 0xffffffff) if (msg->base.stream_info.cbContent == 0xffffffff)
{ {
BYTE indefinite_trailer[6] = { 0 }; BYTE indefinite_trailer[6] = { 0 };
@ -265,6 +271,7 @@ static BOOL CDataEncodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData,
} }
else else
{ {
msg->base.state = MsgStateFinalized;
if (!cbData) if (!cbData)
SetLastError(E_INVALIDARG); SetLastError(E_INVALIDARG);
else else
@ -504,12 +511,15 @@ static BOOL CHashEncodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData,
TRACE("(%p, %p, %d, %d)\n", hCryptMsg, pbData, cbData, fFinal); TRACE("(%p, %p, %d, %d)\n", hCryptMsg, pbData, cbData, fFinal);
if (msg->base.streamed || (msg->base.open_flags & CMSG_DETACHED_FLAG)) if (msg->base.state == MsgStateFinalized)
SetLastError(CRYPT_E_MSG_ERROR);
else if (msg->base.streamed || (msg->base.open_flags & CMSG_DETACHED_FLAG))
{ {
/* Doesn't do much, as stream output is never called, and you /* Doesn't do much, as stream output is never called, and you
* can't get the content. * can't get the content.
*/ */
ret = CryptHashData(msg->hash, pbData, cbData, 0); ret = CryptHashData(msg->hash, pbData, cbData, 0);
msg->base.state = fFinal ? MsgStateFinalized : MsgStateUpdated;
} }
else else
{ {
@ -529,6 +539,7 @@ static BOOL CHashEncodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData,
else else
ret = FALSE; ret = FALSE;
} }
msg->base.state = MsgStateFinalized;
} }
} }
return ret; return ret;
@ -1183,12 +1194,15 @@ static BOOL CSignedEncodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData,
CSignedEncodeMsg *msg = (CSignedEncodeMsg *)hCryptMsg; CSignedEncodeMsg *msg = (CSignedEncodeMsg *)hCryptMsg;
BOOL ret = FALSE; BOOL ret = FALSE;
if (msg->base.streamed || (msg->base.open_flags & CMSG_DETACHED_FLAG)) if (msg->base.state == MsgStateFinalized)
SetLastError(CRYPT_E_MSG_ERROR);
else if (msg->base.streamed || (msg->base.open_flags & CMSG_DETACHED_FLAG))
{ {
ret = CSignedMsgData_Update(&msg->msg_data, pbData, cbData, fFinal, ret = CSignedMsgData_Update(&msg->msg_data, pbData, cbData, fFinal,
Sign); Sign);
if (msg->base.streamed) if (msg->base.streamed)
FIXME("streamed partial stub\n"); FIXME("streamed partial stub\n");
msg->base.state = fFinal ? MsgStateFinalized : MsgStateUpdated;
} }
else else
{ {
@ -1211,6 +1225,7 @@ static BOOL CSignedEncodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData,
if (ret) if (ret)
ret = CSignedMsgData_Update(&msg->msg_data, pbData, cbData, ret = CSignedMsgData_Update(&msg->msg_data, pbData, cbData,
fFinal, Sign); fFinal, Sign);
msg->base.state = MsgStateFinalized;
} }
} }
return ret; return ret;
@ -1644,11 +1659,14 @@ static BOOL CDecodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData,
TRACE("(%p, %p, %d, %d)\n", hCryptMsg, pbData, cbData, fFinal); TRACE("(%p, %p, %d, %d)\n", hCryptMsg, pbData, cbData, fFinal);
if (msg->base.streamed) if (msg->base.state == MsgStateFinalized)
SetLastError(CRYPT_E_MSG_ERROR);
else if (msg->base.streamed)
{ {
ret = CDecodeMsg_CopyData(msg, pbData, cbData); ret = CDecodeMsg_CopyData(msg, pbData, cbData);
FIXME("(%p, %p, %d, %d): streamed update stub\n", hCryptMsg, pbData, FIXME("(%p, %p, %d, %d): streamed update stub\n", hCryptMsg, pbData,
cbData, fFinal); cbData, fFinal);
msg->base.state = fFinal ? MsgStateFinalized : MsgStateUpdated;
} }
else else
{ {
@ -1659,7 +1677,7 @@ static BOOL CDecodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData,
ret = CDecodeMsg_CopyData(msg, pbData, cbData); ret = CDecodeMsg_CopyData(msg, pbData, cbData);
if (ret) if (ret)
ret = CDecodeMsg_DecodeContent(msg, &msg->msg_data, msg->type); ret = CDecodeMsg_DecodeContent(msg, &msg->msg_data, msg->type);
msg->base.state = MsgStateFinalized;
} }
} }
return ret; return ret;
@ -2360,20 +2378,10 @@ BOOL WINAPI CryptMsgUpdate(HCRYPTMSG hCryptMsg, const BYTE *pbData,
DWORD cbData, BOOL fFinal) DWORD cbData, BOOL fFinal)
{ {
CryptMsgBase *msg = (CryptMsgBase *)hCryptMsg; CryptMsgBase *msg = (CryptMsgBase *)hCryptMsg;
BOOL ret = FALSE;
TRACE("(%p, %p, %d, %d)\n", hCryptMsg, pbData, cbData, fFinal); TRACE("(%p, %p, %d, %d)\n", hCryptMsg, pbData, cbData, fFinal);
if (msg->state == MsgStateFinalized) return msg->update(hCryptMsg, pbData, cbData, fFinal);
SetLastError(CRYPT_E_MSG_ERROR);
else
{
ret = msg->update(hCryptMsg, pbData, cbData, fFinal);
msg->state = MsgStateUpdated;
if (fFinal)
msg->state = MsgStateFinalized;
}
return ret;
} }
BOOL WINAPI CryptMsgGetParam(HCRYPTMSG hCryptMsg, DWORD dwParamType, BOOL WINAPI CryptMsgGetParam(HCRYPTMSG hCryptMsg, DWORD dwParamType,