add user ctx to stream IO callbacks

This commit is contained in:
JacobBarthelmeh
2024-03-04 06:00:07 -07:00
parent fbf1b783da
commit 66f419bd18
3 changed files with 40 additions and 27 deletions

View File

@@ -26872,32 +26872,37 @@ static int rsaSignRawDigestCb(PKCS7* pkcs7, byte* digest, word32 digestSz,
#endif #endif
#if defined(HAVE_PKCS7) && defined(ASN_BER_TO_DER) #if defined(HAVE_PKCS7) && defined(ASN_BER_TO_DER)
static byte encodeSignedDataStreamOut[FOURK_BUF*3] = {0}; typedef struct encodeSignedDataStream {
static int encodeSignedDataStreamIdx = 0; byte out[FOURK_BUF*3];
static word32 encodeSignedDataStreamOutIdx = 0; int idx;
word32 outIdx;
} encodeSignedDataStream;
/* content is 8k of partially created bundle */ /* content is 8k of partially created bundle */
static int GetContentCB(PKCS7* pkcs7, byte** content) static int GetContentCB(PKCS7* pkcs7, byte** content, void* ctx)
{ {
int ret = 0; int ret = 0;
encodeSignedDataStream* strm = (encodeSignedDataStream*)ctx;
if (encodeSignedDataStreamOutIdx < pkcs7->contentSz) { if (strm->outIdx < pkcs7->contentSz) {
ret = (pkcs7->contentSz > encodeSignedDataStreamOutIdx + FOURK_BUF)? ret = (pkcs7->contentSz > strm->outIdx + FOURK_BUF)?
FOURK_BUF : pkcs7->contentSz - encodeSignedDataStreamOutIdx; FOURK_BUF : pkcs7->contentSz - strm->outIdx;
*content = encodeSignedDataStreamOut + encodeSignedDataStreamOutIdx; *content = strm->out + strm->outIdx;
encodeSignedDataStreamOutIdx += ret; strm->outIdx += ret;
} }
(void)pkcs7; (void)pkcs7;
return ret; return ret;
} }
static int StreamOutputCB(PKCS7* pkcs7, const byte* output, word32 outputSz) static int StreamOutputCB(PKCS7* pkcs7, const byte* output, word32 outputSz,
void* ctx)
{ {
XMEMCPY(encodeSignedDataStreamOut + encodeSignedDataStreamIdx, output, encodeSignedDataStream* strm = (encodeSignedDataStream*)ctx;
outputSz);
encodeSignedDataStreamIdx += outputSz; XMEMCPY(strm->out + strm->idx, output, outputSz);
strm->idx += outputSz;
(void)pkcs7; (void)pkcs7;
return 0; return 0;
} }
@@ -27031,6 +27036,7 @@ static int test_wc_PKCS7_EncodeSignedData(void)
/* reinitialize and test setting stream mode */ /* reinitialize and test setting stream mode */
{ {
int signedSz; int signedSz;
encodeSignedDataStream strm;
ExpectNotNull(pkcs7 = wc_PKCS7_New(HEAP_HINT, testDevId)); ExpectNotNull(pkcs7 = wc_PKCS7_New(HEAP_HINT, testDevId));
ExpectIntEQ(wc_PKCS7_Init(pkcs7, HEAP_HINT, INVALID_DEVID), 0); ExpectIntEQ(wc_PKCS7_Init(pkcs7, HEAP_HINT, INVALID_DEVID), 0);
@@ -27051,8 +27057,9 @@ static int test_wc_PKCS7_EncodeSignedData(void)
pkcs7->rng = &rng; pkcs7->rng = &rng;
} }
ExpectIntEQ(wc_PKCS7_GetStreamMode(pkcs7), 0); ExpectIntEQ(wc_PKCS7_GetStreamMode(pkcs7), 0);
ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, NULL, NULL), 0); ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, NULL, NULL, NULL), 0);
ExpectIntEQ(wc_PKCS7_SetStreamMode(NULL, 1, NULL, NULL), BAD_FUNC_ARG); ExpectIntEQ(wc_PKCS7_SetStreamMode(NULL, 1, NULL, NULL, NULL),
BAD_FUNC_ARG);
ExpectIntEQ(wc_PKCS7_GetStreamMode(pkcs7), 1); ExpectIntEQ(wc_PKCS7_GetStreamMode(pkcs7), 1);
ExpectIntGT(signedSz = wc_PKCS7_EncodeSignedData(pkcs7, output, ExpectIntGT(signedSz = wc_PKCS7_EncodeSignedData(pkcs7, output,
@@ -27085,8 +27092,9 @@ static int test_wc_PKCS7_EncodeSignedData(void)
#endif #endif
pkcs7->rng = &rng; pkcs7->rng = &rng;
} }
XMEMSET(&strm, 0, sizeof(strm));
ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, GetContentCB, ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, GetContentCB,
StreamOutputCB), 0); StreamOutputCB, (void*)&strm), 0);
ExpectIntGT(signedSz = wc_PKCS7_EncodeSignedData(pkcs7, NULL, 0), 0); ExpectIntGT(signedSz = wc_PKCS7_EncodeSignedData(pkcs7, NULL, 0), 0);
wc_PKCS7_Free(pkcs7); wc_PKCS7_Free(pkcs7);
@@ -27096,8 +27104,7 @@ static int test_wc_PKCS7_EncodeSignedData(void)
ExpectIntEQ(wc_PKCS7_InitWithCert(pkcs7, NULL, 0), 0); ExpectIntEQ(wc_PKCS7_InitWithCert(pkcs7, NULL, 0), 0);
/* use exact signed buffer size since BER encoded */ /* use exact signed buffer size since BER encoded */
ExpectIntEQ(wc_PKCS7_VerifySignedData(pkcs7, encodeSignedDataStreamOut, ExpectIntEQ(wc_PKCS7_VerifySignedData(pkcs7, strm.out, signedSz), 0);
signedSz), 0);
} }
#endif #endif
@@ -28335,6 +28342,8 @@ static int test_wc_PKCS7_EncodeDecodeEnvelopedData(void)
testSz = (int)sizeof(testVectors)/(int)sizeof(pkcs7EnvelopedVector); testSz = (int)sizeof(testVectors)/(int)sizeof(pkcs7EnvelopedVector);
for (i = 0; i < testSz; i++) { for (i = 0; i < testSz; i++) {
#ifdef ASN_BER_TO_DER #ifdef ASN_BER_TO_DER
encodeSignedDataStream strm;
/* test setting stream mode, the first one using IO callbacks */ /* test setting stream mode, the first one using IO callbacks */
ExpectIntEQ(wc_PKCS7_InitWithCert(pkcs7, (testVectors + i)->cert, ExpectIntEQ(wc_PKCS7_InitWithCert(pkcs7, (testVectors + i)->cert,
(word32)(testVectors + i)->certSz), 0); (word32)(testVectors + i)->certSz), 0);
@@ -28355,12 +28364,13 @@ static int test_wc_PKCS7_EncodeDecodeEnvelopedData(void)
} }
if (i == 0) { if (i == 0) {
XMEMSET(&strm, 0, sizeof(strm));
ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, GetContentCB, ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, GetContentCB,
StreamOutputCB), 0); StreamOutputCB, (void*)&strm), 0);
encodedSz = wc_PKCS7_EncodeEnvelopedData(pkcs7, NULL, 0); encodedSz = wc_PKCS7_EncodeEnvelopedData(pkcs7, NULL, 0);
} }
else { else {
ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, NULL, NULL), 0); ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, NULL, NULL, NULL), 0);
encodedSz = wc_PKCS7_EncodeEnvelopedData(pkcs7, output, encodedSz = wc_PKCS7_EncodeEnvelopedData(pkcs7, output,
(word32)sizeof(output)); (word32)sizeof(output));
} }
@@ -28396,7 +28406,7 @@ static int test_wc_PKCS7_EncodeDecodeEnvelopedData(void)
if (encodedSz > 0) { if (encodedSz > 0) {
if (i == 0) { if (i == 0) {
decodedSz = wc_PKCS7_DecodeEnvelopedData(pkcs7, decodedSz = wc_PKCS7_DecodeEnvelopedData(pkcs7,
encodeSignedDataStreamOut, (word32)encodedSz, decoded, strm.out, (word32)encodedSz, decoded,
(word32)sizeof(decoded)); (word32)sizeof(decoded));
} }
else { else {

View File

@@ -2497,7 +2497,7 @@ static int wc_PKCS7_EncodeContentStream(PKCS7* pkcs7, ESD* esd, void* aes,
#ifdef ASN_BER_TO_DER #ifdef ASN_BER_TO_DER
if (pkcs7->getContentCb) { if (pkcs7->getContentCb) {
contentDataRead = pkcs7->getContentCb(pkcs7, contentDataRead = pkcs7->getContentCb(pkcs7,
&buf); &buf, pkcs7->streamCtx);
} }
else else
#endif #endif
@@ -7549,7 +7549,7 @@ int wc_PKCS7_WriteOut(PKCS7* pkcs7, byte* output, const byte* input,
#ifdef ASN_BER_TO_DER #ifdef ASN_BER_TO_DER
if (pkcs7->streamOutCb) { if (pkcs7->streamOutCb) {
ret = pkcs7->streamOutCb(pkcs7, input, inputSz); ret = pkcs7->streamOutCb(pkcs7, input, inputSz, pkcs7->streamCtx);
/* sanity check on user provided ret value */ /* sanity check on user provided ret value */
if (ret < 0) { if (ret < 0) {
WOLFSSL_MSG("Return value error from stream out callback"); WOLFSSL_MSG("Return value error from stream out callback");
@@ -13854,7 +13854,7 @@ int wc_PKCS7_SetDecodeEncryptedCtx(PKCS7* pkcs7, void* ctx)
* returns 0 on success */ * returns 0 on success */
int wc_PKCS7_SetStreamMode(PKCS7* pkcs7, byte flag, int wc_PKCS7_SetStreamMode(PKCS7* pkcs7, byte flag,
CallbackGetContent getContentCb, CallbackGetContent getContentCb,
CallbackStreamOut streamOutCb) CallbackStreamOut streamOutCb, void* ctx)
{ {
if (pkcs7 == NULL) { if (pkcs7 == NULL) {
return BAD_FUNC_ARG; return BAD_FUNC_ARG;
@@ -13863,11 +13863,13 @@ int wc_PKCS7_SetStreamMode(PKCS7* pkcs7, byte flag,
pkcs7->encodeStream = flag; pkcs7->encodeStream = flag;
pkcs7->getContentCb = getContentCb; pkcs7->getContentCb = getContentCb;
pkcs7->streamOutCb = streamOutCb; pkcs7->streamOutCb = streamOutCb;
pkcs7->streamCtx = ctx;
return 0; return 0;
#else #else
(void)flag; (void)flag;
(void)getContentCb; (void)getContentCb;
(void)streamOutCb; (void)streamOutCb;
(void)ctx;
return NOT_COMPILED_IN; return NOT_COMPILED_IN;
#endif #endif
} }

View File

@@ -225,9 +225,9 @@ typedef int (*CallbackWrapCEK)(PKCS7* pkcs7, byte* cek, word32 cekSz,
int keyWrapAlgo, int type, int dir); int keyWrapAlgo, int type, int dir);
/* Callbacks for supporting different stream cases */ /* Callbacks for supporting different stream cases */
typedef int (*CallbackGetContent)(PKCS7* pkcs7, byte** content); typedef int (*CallbackGetContent)(PKCS7* pkcs7, byte** content, void* ctx);
typedef int (*CallbackStreamOut)(PKCS7* pkcs7, const byte* output, typedef int (*CallbackStreamOut)(PKCS7* pkcs7, const byte* output,
word32 outputSz); word32 outputSz, void* ctx);
#if defined(HAVE_PKCS7_RSA_RAW_SIGN_CALLBACK) && !defined(NO_RSA) #if defined(HAVE_PKCS7_RSA_RAW_SIGN_CALLBACK) && !defined(NO_RSA)
/* RSA sign raw digest callback, user builds DigestInfo */ /* RSA sign raw digest callback, user builds DigestInfo */
@@ -254,6 +254,7 @@ struct PKCS7 {
word32 derSz; word32 derSz;
CallbackGetContent getContentCb; CallbackGetContent getContentCb;
CallbackStreamOut streamOutCb; CallbackStreamOut streamOutCb;
void* streamCtx; /* passed to getcontentCb and streamOutCb */
#endif #endif
byte encodeStream:1; /* use BER when encoding */ byte encodeStream:1; /* use BER when encoding */
byte noCerts:1; /* if certificates should be added into bundle byte noCerts:1; /* if certificates should be added into bundle
@@ -509,7 +510,7 @@ WOLFSSL_API int wc_PKCS7_SetDecodeEncryptedCtx(PKCS7* pkcs7, void* ctx);
WOLFSSL_LOCAL int wc_PKCS7_WriteOut(PKCS7* pkcs7, byte* output, WOLFSSL_LOCAL int wc_PKCS7_WriteOut(PKCS7* pkcs7, byte* output,
const byte* input, word32 inputSz); const byte* input, word32 inputSz);
WOLFSSL_API int wc_PKCS7_SetStreamMode(PKCS7* pkcs7, byte flag, WOLFSSL_API int wc_PKCS7_SetStreamMode(PKCS7* pkcs7, byte flag,
CallbackGetContent getContentCb, CallbackStreamOut streamOutCb); CallbackGetContent getContentCb, CallbackStreamOut streamOutCb, void* ctx);
WOLFSSL_API int wc_PKCS7_GetStreamMode(PKCS7* pkcs7); WOLFSSL_API int wc_PKCS7_GetStreamMode(PKCS7* pkcs7);
WOLFSSL_API int wc_PKCS7_SetNoCerts(PKCS7* pkcs7, byte flag); WOLFSSL_API int wc_PKCS7_SetNoCerts(PKCS7* pkcs7, byte flag);
WOLFSSL_API int wc_PKCS7_GetNoCerts(PKCS7* pkcs7); WOLFSSL_API int wc_PKCS7_GetNoCerts(PKCS7* pkcs7);