From 66f419bd1891119035939a5a34377f81b3b05c6d Mon Sep 17 00:00:00 2001 From: JacobBarthelmeh Date: Mon, 4 Mar 2024 06:00:07 -0700 Subject: [PATCH] add user ctx to stream IO callbacks --- tests/api.c | 52 +++++++++++++++++++++++---------------- wolfcrypt/src/pkcs7.c | 8 +++--- wolfssl/wolfcrypt/pkcs7.h | 7 +++--- 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/tests/api.c b/tests/api.c index 5fb4a6a45..fbaab39c9 100644 --- a/tests/api.c +++ b/tests/api.c @@ -26872,32 +26872,37 @@ static int rsaSignRawDigestCb(PKCS7* pkcs7, byte* digest, word32 digestSz, #endif #if defined(HAVE_PKCS7) && defined(ASN_BER_TO_DER) -static byte encodeSignedDataStreamOut[FOURK_BUF*3] = {0}; -static int encodeSignedDataStreamIdx = 0; -static word32 encodeSignedDataStreamOutIdx = 0; +typedef struct encodeSignedDataStream { + byte out[FOURK_BUF*3]; + int idx; + word32 outIdx; +} encodeSignedDataStream; /* content is 8k of partially created bundle */ -static int GetContentCB(PKCS7* pkcs7, byte** content) +static int GetContentCB(PKCS7* pkcs7, byte** content, void* ctx) { int ret = 0; + encodeSignedDataStream* strm = (encodeSignedDataStream*)ctx; - if (encodeSignedDataStreamOutIdx < pkcs7->contentSz) { - ret = (pkcs7->contentSz > encodeSignedDataStreamOutIdx + FOURK_BUF)? - FOURK_BUF : pkcs7->contentSz - encodeSignedDataStreamOutIdx; - *content = encodeSignedDataStreamOut + encodeSignedDataStreamOutIdx; - encodeSignedDataStreamOutIdx += ret; + if (strm->outIdx < pkcs7->contentSz) { + ret = (pkcs7->contentSz > strm->outIdx + FOURK_BUF)? + FOURK_BUF : pkcs7->contentSz - strm->outIdx; + *content = strm->out + strm->outIdx; + strm->outIdx += ret; } (void)pkcs7; return ret; } -static int StreamOutputCB(PKCS7* pkcs7, const byte* output, word32 outputSz) +static int StreamOutputCB(PKCS7* pkcs7, const byte* output, word32 outputSz, + void* ctx) { - XMEMCPY(encodeSignedDataStreamOut + encodeSignedDataStreamIdx, output, - outputSz); - encodeSignedDataStreamIdx += outputSz; + encodeSignedDataStream* strm = (encodeSignedDataStream*)ctx; + + XMEMCPY(strm->out + strm->idx, output, outputSz); + strm->idx += outputSz; (void)pkcs7; return 0; } @@ -27031,6 +27036,7 @@ static int test_wc_PKCS7_EncodeSignedData(void) /* reinitialize and test setting stream mode */ { int signedSz; + encodeSignedDataStream strm; ExpectNotNull(pkcs7 = wc_PKCS7_New(HEAP_HINT, testDevId)); ExpectIntEQ(wc_PKCS7_Init(pkcs7, HEAP_HINT, INVALID_DEVID), 0); @@ -27051,8 +27057,9 @@ static int test_wc_PKCS7_EncodeSignedData(void) pkcs7->rng = &rng; } ExpectIntEQ(wc_PKCS7_GetStreamMode(pkcs7), 0); - ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, NULL, NULL), 0); - ExpectIntEQ(wc_PKCS7_SetStreamMode(NULL, 1, NULL, NULL), BAD_FUNC_ARG); + ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, NULL, NULL, NULL), 0); + ExpectIntEQ(wc_PKCS7_SetStreamMode(NULL, 1, NULL, NULL, NULL), + BAD_FUNC_ARG); ExpectIntEQ(wc_PKCS7_GetStreamMode(pkcs7), 1); ExpectIntGT(signedSz = wc_PKCS7_EncodeSignedData(pkcs7, output, @@ -27085,8 +27092,9 @@ static int test_wc_PKCS7_EncodeSignedData(void) #endif pkcs7->rng = &rng; } + XMEMSET(&strm, 0, sizeof(strm)); ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, GetContentCB, - StreamOutputCB), 0); + StreamOutputCB, (void*)&strm), 0); ExpectIntGT(signedSz = wc_PKCS7_EncodeSignedData(pkcs7, NULL, 0), 0); wc_PKCS7_Free(pkcs7); @@ -27096,8 +27104,7 @@ static int test_wc_PKCS7_EncodeSignedData(void) ExpectIntEQ(wc_PKCS7_InitWithCert(pkcs7, NULL, 0), 0); /* use exact signed buffer size since BER encoded */ - ExpectIntEQ(wc_PKCS7_VerifySignedData(pkcs7, encodeSignedDataStreamOut, - signedSz), 0); + ExpectIntEQ(wc_PKCS7_VerifySignedData(pkcs7, strm.out, signedSz), 0); } #endif @@ -28335,6 +28342,8 @@ static int test_wc_PKCS7_EncodeDecodeEnvelopedData(void) testSz = (int)sizeof(testVectors)/(int)sizeof(pkcs7EnvelopedVector); for (i = 0; i < testSz; i++) { #ifdef ASN_BER_TO_DER + encodeSignedDataStream strm; + /* test setting stream mode, the first one using IO callbacks */ ExpectIntEQ(wc_PKCS7_InitWithCert(pkcs7, (testVectors + i)->cert, (word32)(testVectors + i)->certSz), 0); @@ -28355,12 +28364,13 @@ static int test_wc_PKCS7_EncodeDecodeEnvelopedData(void) } if (i == 0) { + XMEMSET(&strm, 0, sizeof(strm)); ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, GetContentCB, - StreamOutputCB), 0); + StreamOutputCB, (void*)&strm), 0); encodedSz = wc_PKCS7_EncodeEnvelopedData(pkcs7, NULL, 0); } else { - ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, NULL, NULL), 0); + ExpectIntEQ(wc_PKCS7_SetStreamMode(pkcs7, 1, NULL, NULL, NULL), 0); encodedSz = wc_PKCS7_EncodeEnvelopedData(pkcs7, output, (word32)sizeof(output)); } @@ -28396,7 +28406,7 @@ static int test_wc_PKCS7_EncodeDecodeEnvelopedData(void) if (encodedSz > 0) { if (i == 0) { decodedSz = wc_PKCS7_DecodeEnvelopedData(pkcs7, - encodeSignedDataStreamOut, (word32)encodedSz, decoded, + strm.out, (word32)encodedSz, decoded, (word32)sizeof(decoded)); } else { diff --git a/wolfcrypt/src/pkcs7.c b/wolfcrypt/src/pkcs7.c index 11f0fa7d8..8b1141976 100644 --- a/wolfcrypt/src/pkcs7.c +++ b/wolfcrypt/src/pkcs7.c @@ -2497,7 +2497,7 @@ static int wc_PKCS7_EncodeContentStream(PKCS7* pkcs7, ESD* esd, void* aes, #ifdef ASN_BER_TO_DER if (pkcs7->getContentCb) { contentDataRead = pkcs7->getContentCb(pkcs7, - &buf); + &buf, pkcs7->streamCtx); } else #endif @@ -7549,7 +7549,7 @@ int wc_PKCS7_WriteOut(PKCS7* pkcs7, byte* output, const byte* input, #ifdef ASN_BER_TO_DER if (pkcs7->streamOutCb) { - ret = pkcs7->streamOutCb(pkcs7, input, inputSz); + ret = pkcs7->streamOutCb(pkcs7, input, inputSz, pkcs7->streamCtx); /* sanity check on user provided ret value */ if (ret < 0) { WOLFSSL_MSG("Return value error from stream out callback"); @@ -13854,7 +13854,7 @@ int wc_PKCS7_SetDecodeEncryptedCtx(PKCS7* pkcs7, void* ctx) * returns 0 on success */ int wc_PKCS7_SetStreamMode(PKCS7* pkcs7, byte flag, CallbackGetContent getContentCb, - CallbackStreamOut streamOutCb) + CallbackStreamOut streamOutCb, void* ctx) { if (pkcs7 == NULL) { return BAD_FUNC_ARG; @@ -13863,11 +13863,13 @@ int wc_PKCS7_SetStreamMode(PKCS7* pkcs7, byte flag, pkcs7->encodeStream = flag; pkcs7->getContentCb = getContentCb; pkcs7->streamOutCb = streamOutCb; + pkcs7->streamCtx = ctx; return 0; #else (void)flag; (void)getContentCb; (void)streamOutCb; + (void)ctx; return NOT_COMPILED_IN; #endif } diff --git a/wolfssl/wolfcrypt/pkcs7.h b/wolfssl/wolfcrypt/pkcs7.h index d87ebaa79..405771239 100644 --- a/wolfssl/wolfcrypt/pkcs7.h +++ b/wolfssl/wolfcrypt/pkcs7.h @@ -225,9 +225,9 @@ typedef int (*CallbackWrapCEK)(PKCS7* pkcs7, byte* cek, word32 cekSz, int keyWrapAlgo, int type, int dir); /* Callbacks for supporting different stream cases */ -typedef int (*CallbackGetContent)(PKCS7* pkcs7, byte** content); +typedef int (*CallbackGetContent)(PKCS7* pkcs7, byte** content, void* ctx); typedef int (*CallbackStreamOut)(PKCS7* pkcs7, const byte* output, - word32 outputSz); + word32 outputSz, void* ctx); #if defined(HAVE_PKCS7_RSA_RAW_SIGN_CALLBACK) && !defined(NO_RSA) /* RSA sign raw digest callback, user builds DigestInfo */ @@ -254,6 +254,7 @@ struct PKCS7 { word32 derSz; CallbackGetContent getContentCb; CallbackStreamOut streamOutCb; + void* streamCtx; /* passed to getcontentCb and streamOutCb */ #endif byte encodeStream:1; /* use BER when encoding */ byte noCerts:1; /* if certificates should be added into bundle @@ -509,7 +510,7 @@ WOLFSSL_API int wc_PKCS7_SetDecodeEncryptedCtx(PKCS7* pkcs7, void* ctx); WOLFSSL_LOCAL int wc_PKCS7_WriteOut(PKCS7* pkcs7, byte* output, const byte* input, word32 inputSz); WOLFSSL_API int wc_PKCS7_SetStreamMode(PKCS7* pkcs7, byte flag, - CallbackGetContent getContentCb, CallbackStreamOut streamOutCb); + CallbackGetContent getContentCb, CallbackStreamOut streamOutCb, void* ctx); WOLFSSL_API int wc_PKCS7_GetStreamMode(PKCS7* pkcs7); WOLFSSL_API int wc_PKCS7_SetNoCerts(PKCS7* pkcs7, byte flag); WOLFSSL_API int wc_PKCS7_GetNoCerts(PKCS7* pkcs7);