diff --git a/doc/dox_comments/header_files/pkcs7.h b/doc/dox_comments/header_files/pkcs7.h index 5884b97f7..b2c344d9b 100644 --- a/doc/dox_comments/header_files/pkcs7.h +++ b/doc/dox_comments/header_files/pkcs7.h @@ -11,11 +11,12 @@ \param[in] keySz Size of the key to use. \param[in] in Specify the input data to wrap/unwrap. \param[in] inSz Size of the input data. + \param[in] wrap 1 if the requested operation is a key wrap, 0 for unwrap. \param[out] out Specify the output buffer. \param[out] outSz Size of the output buffer. */ -typedef int (*CallbackAESKeyWrap)(const byte* key, word32 keySz, - const byte* in, word32 inSz, byte* out, word32 outSz); +typedef int (*CallbackAESKeyWrapUnwrap)(const byte* key, word32 keySz, + const byte* in, word32 inSz, int wrap, byte* out, word32 outSz); /*! \ingroup PKCS7 @@ -500,33 +501,16 @@ int wc_PKCS7_VerifySignedData_ex(PKCS7* pkcs7, const byte* hashBuf, \ingroup PKCS7 \brief Set the callback function to be used to perform a custom AES key - wrap operation. + wrap/unwrap operation. \retval 0 Callback function was set successfully \retval BAD_FUNC_ARG Parameter pkcs7 is NULL \param pkcs7 pointer to the PKCS7 structure - \param aesKeyWrapCb pointer to custom AES key wrap function - - \sa wc_PKCS7_SetAESKeyUnwrapCb + \param aesKeyWrapCb pointer to custom AES key wrap/unwrap function */ -int wc_PKCS7_SetAESKeyWrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrap aesKeyWrapCb); - -/*! - \ingroup PKCS7 - - \brief Set the callback function to be used to perform a custom AES key - unwrap operation. - - \retval 0 Callback function was set successfully - \retval BAD_FUNC_ARG Parameter pkcs7 is NULL - - \param pkcs7 pointer to the PKCS7 structure - \param aesKeyUnwrapCb pointer to custom AES key unwrap function - - \sa wc_PKCS7_SetAESKeyWrapCb -*/ -int wc_PKCS7_SetAESKeyUnwrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrap aesKeyUnwrapCb); +int wc_PKCS7_SetAESKeyWrapUnwrapCb(wc_PKCS7* pkcs7, + CallbackAESKeyWrapUnwrap aesKeyWrapCb); /*! \ingroup PKCS7 diff --git a/tests/api.c b/tests/api.c index 2ce79d779..8b407acbd 100644 --- a/tests/api.c +++ b/tests/api.c @@ -17999,25 +17999,16 @@ static int test_wc_PKCS7_EncodeDecodeEnvelopedData(void) static int wasAESKeyWrapCbCalled = 0; static int wasAESKeyUnwrapCbCalled = 0; -static int testAESKeyWrapCb(const byte* key, word32 keySz, - const byte* in, word32 inSz, byte* out, word32 outSz) +static int testAESKeyWrapUnwrapCb(const byte* key, word32 keySz, + const byte* in, word32 inSz, int wrap, byte* out, word32 outSz) { (void)key; (void)keySz; - wasAESKeyWrapCbCalled = 1; - XMEMSET(out, 0xEE, outSz); - if (inSz <= outSz) { - XMEMCPY(out, in, inSz); - } - return inSz; -} - -static int testAESKeyUnwrapCb(const byte* key, word32 keySz, - const byte* in, word32 inSz, byte* out, word32 outSz) -{ - (void)key; - (void)keySz; - wasAESKeyUnwrapCbCalled = 1; + (void)wrap; + if (wrap) + wasAESKeyWrapCbCalled = 1; + else + wasAESKeyUnwrapCbCalled = 1; XMEMSET(out, 0xEE, outSz); if (inSz <= outSz) { XMEMCPY(out, in, inSz); @@ -18104,8 +18095,7 @@ static int test_wc_PKCS7_SetAESKeyWrapUnwrapCb(void) } /* Test custom AES key wrap/unwrap callback */ - ExpectIntEQ(wc_PKCS7_SetAESKeyWrapCb(pkcs7, testAESKeyWrapCb), 0); - ExpectIntEQ(wc_PKCS7_SetAESKeyUnwrapCb(pkcs7, testAESKeyUnwrapCb), 0); + ExpectIntEQ(wc_PKCS7_SetAESKeyWrapUnwrapCb(pkcs7, testAESKeyWrapUnwrapCb), 0); ExpectIntGE(wc_PKCS7_EncodeEnvelopedData(pkcs7, output, (word32)sizeof(output)), 0); diff --git a/wolfcrypt/src/pkcs7.c b/wolfcrypt/src/pkcs7.c index 7a910c833..57eaec484 100644 --- a/wolfcrypt/src/pkcs7.c +++ b/wolfcrypt/src/pkcs7.c @@ -6838,9 +6838,9 @@ static int wc_PKCS7_KeyWrap(const wc_PKCS7 * pkcs7, byte const * cek, #endif if (direction == AES_ENCRYPTION) { - if (pkcs7->aesKeyWrapCb != NULL) { - ret = pkcs7->aesKeyWrapCb(kek, kekSz, cek, cekSz, - out, outSz); + if (pkcs7->aesKeyWrapUnwrapCb != NULL) { + ret = pkcs7->aesKeyWrapUnwrapCb(kek, kekSz, cek, cekSz, 1, + out, outSz); } else { ret = wc_AesKeyWrap(kek, kekSz, cek, cekSz, @@ -6848,9 +6848,9 @@ static int wc_PKCS7_KeyWrap(const wc_PKCS7 * pkcs7, byte const * cek, } } else if (direction == AES_DECRYPTION) { - if (pkcs7->aesKeyUnwrapCb != NULL) { - ret = pkcs7->aesKeyUnwrapCb(kek, kekSz, cek, cekSz, - out, outSz); + if (pkcs7->aesKeyWrapUnwrapCb != NULL) { + ret = pkcs7->aesKeyWrapUnwrapCb(kek, kekSz, cek, cekSz, 0, + out, outSz); } else { ret = wc_AesKeyUnWrap(kek, kekSz, cek, cekSz, @@ -11094,28 +11094,17 @@ int wc_PKCS7_SetWrapCEKCb(wc_PKCS7* pkcs7, CallbackWrapCEK cb) /* return 0 on success */ -int wc_PKCS7_SetAESKeyWrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrap aesKeyWrapCb) +int wc_PKCS7_SetAESKeyWrapUnwrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrapUnwrap aesKeyWrapUnwrapCb) { if (pkcs7 == NULL) return BAD_FUNC_ARG; - pkcs7->aesKeyWrapCb = aesKeyWrapCb; + pkcs7->aesKeyWrapUnwrapCb = aesKeyWrapUnwrapCb; return 0; } -/* return 0 on success */ -int wc_PKCS7_SetAESKeyUnwrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrap aesKeyUnwrapCb) -{ - if (pkcs7 == NULL) - return BAD_FUNC_ARG; - - pkcs7->aesKeyUnwrapCb = aesKeyUnwrapCb; - - return 0; -} - /* Decrypt ASN.1 OtherRecipientInfo (ori), as defined by: * * OtherRecipientInfo ::= SEQUENCE { diff --git a/wolfssl/wolfcrypt/pkcs7.h b/wolfssl/wolfcrypt/pkcs7.h index f7f22a691..54e428651 100644 --- a/wolfssl/wolfcrypt/pkcs7.h +++ b/wolfssl/wolfcrypt/pkcs7.h @@ -213,8 +213,8 @@ typedef int (*CallbackWrapCEK)(wc_PKCS7* pkcs7, byte* cek, word32 cekSz, byte* originKey, word32 originKeySz, byte* out, word32 outSz, int keyWrapAlgo, int type, int dir); -typedef int (*CallbackAESKeyWrap)(const byte* key, word32 keySz, - const byte* in, word32 inSz, byte* out, word32 outSz); +typedef int (*CallbackAESKeyWrapUnwrap)(const byte* key, word32 keySz, + const byte* in, word32 inSz, int wrap, byte* out, word32 outSz); /* Callbacks for supporting different stream cases */ typedef int (*CallbackGetContent)(wc_PKCS7* pkcs7, byte** content, void* ctx); @@ -373,8 +373,7 @@ struct wc_PKCS7 { } decryptKey; #endif - CallbackAESKeyWrap aesKeyWrapCb; - CallbackAESKeyWrap aesKeyUnwrapCb; + CallbackAESKeyWrapUnwrap aesKeyWrapUnwrapCb; /* !! NEW DATA MEMBERS MUST BE ADDED AT END !! */ }; @@ -503,10 +502,8 @@ WOLFSSL_API int wc_PKCS7_AddRecipient_ORI(wc_PKCS7* pkcs7, CallbackOriEncrypt c int options); WOLFSSL_API int wc_PKCS7_SetWrapCEKCb(wc_PKCS7* pkcs7, CallbackWrapCEK wrapCEKCb); -WOLFSSL_API int wc_PKCS7_SetAESKeyWrapCb(wc_PKCS7* pkcs7, - CallbackAESKeyWrap aesKeyWrapCb); -WOLFSSL_API int wc_PKCS7_SetAESKeyUnwrapCb(wc_PKCS7* pkcs7, - CallbackAESKeyWrap aesKeyUnwrapCb); +WOLFSSL_API int wc_PKCS7_SetAESKeyWrapUnwrapCb(wc_PKCS7* pkcs7, + CallbackAESKeyWrapUnwrap aesKeyWrapUnwrapCb); #if defined(HAVE_PKCS7_RSA_RAW_SIGN_CALLBACK) && !defined(NO_RSA) WOLFSSL_API int wc_PKCS7_SetRsaSignRawDigestCb(wc_PKCS7* pkcs7,