From 429ccd54565bd5b3269f06685b41736112f1cf93 Mon Sep 17 00:00:00 2001 From: Josh Holtrop Date: Fri, 11 Jul 2025 16:08:53 -0400 Subject: [PATCH] Add callback functions for custom AES key wrap/unwrap operations --- doc/dox_comments/header_files/pkcs7.h | 48 +++++++++ tests/api.c | 136 ++++++++++++++++++++++++++ wolfcrypt/src/pkcs7.c | 72 ++++++++++---- wolfssl/wolfcrypt/pkcs7.h | 9 ++ 4 files changed, 246 insertions(+), 19 deletions(-) diff --git a/doc/dox_comments/header_files/pkcs7.h b/doc/dox_comments/header_files/pkcs7.h index be5a75c43..577ae7c11 100644 --- a/doc/dox_comments/header_files/pkcs7.h +++ b/doc/dox_comments/header_files/pkcs7.h @@ -1,3 +1,19 @@ +/*! + \ingroup PKCS7 + + \brief Callback used for a custom AES key wrap/unwrap operation. + + key/keySz specify the key to use. + in/inSz specify the input data to wrap/unwrap. + out/outSz specify the output buffer. + + The size of the wrapped/unwrapped key written to the output buffer should + be returned on success. A 0 return value or error code (< 0) indicates a + failure. +*/ +typedef int (*CallbackAESKeyWrap)(const byte* key, word32 keySz, + const byte* in, word32 inSz, byte* out, word32 outSz); + /*! \ingroup PKCS7 @@ -477,6 +493,38 @@ int wc_PKCS7_VerifySignedData_ex(PKCS7* pkcs7, const byte* hashBuf, word32 hashSz, byte* pkiMsgHead, word32 pkiMsgHeadSz, byte* pkiMsgFoot, word32 pkiMsgFootSz); +/*! + \ingroup PKCS7 + + \brief Set the callback function to be used to perform a custom AES key + wrap operation. + + \retval 0 Callback function was set successfully + \retval BAD_FUNC_ARG Parameter pkcs7 is NULL + + \param pkcs7 pointer to the PKCS7 structure + \param aesKeyWrapCb pointer to custom AES key wrap function + + \sa wc_PKCS7_SetAESKeyUnwrapCb +*/ +int wc_PKCS7_SetAESKeyWrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrap aesKeyWrapCb); + +/*! + \ingroup PKCS7 + + \brief Set the callback function to be used to perform a custom AES key + unwrap operation. + + \retval 0 Callback function was set successfully + \retval BAD_FUNC_ARG Parameter pkcs7 is NULL + + \param pkcs7 pointer to the PKCS7 structure + \param aesKeyUnwrapCb pointer to custom AES key unwrap function + + \sa wc_PKCS7_SetAESKeyWrapCb +*/ +int wc_PKCS7_SetAESKeyUnwrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrap aesKeyUnwrapCb); + /*! \ingroup PKCS7 diff --git a/tests/api.c b/tests/api.c index 49b6f4634..2ce79d779 100644 --- a/tests/api.c +++ b/tests/api.c @@ -17995,6 +17995,141 @@ static int test_wc_PKCS7_EncodeDecodeEnvelopedData(void) } /* END test_wc_PKCS7_EncodeDecodeEnvelopedData() */ +#if defined(HAVE_PKCS7) && defined(HAVE_ECC) && !defined(NO_SHA256) && defined(WOLFSSL_AES_256) +static int wasAESKeyWrapCbCalled = 0; +static int wasAESKeyUnwrapCbCalled = 0; + +static int testAESKeyWrapCb(const byte* key, word32 keySz, + const byte* in, word32 inSz, byte* out, word32 outSz) +{ + (void)key; + (void)keySz; + wasAESKeyWrapCbCalled = 1; + XMEMSET(out, 0xEE, outSz); + if (inSz <= outSz) { + XMEMCPY(out, in, inSz); + } + return inSz; +} + +static int testAESKeyUnwrapCb(const byte* key, word32 keySz, + const byte* in, word32 inSz, byte* out, word32 outSz) +{ + (void)key; + (void)keySz; + wasAESKeyUnwrapCbCalled = 1; + XMEMSET(out, 0xEE, outSz); + if (inSz <= outSz) { + XMEMCPY(out, in, inSz); + } + return inSz; +} +#endif + + +/* + * Test custom AES key wrap/unwrap callback + */ +static int test_wc_PKCS7_SetAESKeyWrapUnwrapCb(void) +{ + EXPECT_DECLS; +#if defined(HAVE_PKCS7) && defined(HAVE_ECC) && !defined(NO_SHA256) && defined(WOLFSSL_AES_256) + static const char input[] = "Test input for AES key wrapping"; + PKCS7 * pkcs7 = NULL; + byte * eccCert = NULL; + byte * eccPrivKey = NULL; + word32 eccCertSz = 0; + word32 eccPrivKeySz = 0; + byte output[ONEK_BUF]; + byte decoded[sizeof(input)/sizeof(char)]; + int decodedSz = 0; +#ifdef ECC_TIMING_RESISTANT + WC_RNG rng; +#endif + + /* Load test certs */ + #ifdef USE_CERT_BUFFERS_256 + ExpectNotNull(eccCert = (byte*)XMALLOC(TWOK_BUF, HEAP_HINT, + DYNAMIC_TYPE_TMP_BUFFER)); + /* Init buffer. */ + eccCertSz = (word32)sizeof_cliecc_cert_der_256; + if (eccCert != NULL) { + XMEMCPY(eccCert, cliecc_cert_der_256, eccCertSz); + } + ExpectNotNull(eccPrivKey = (byte*)XMALLOC(TWOK_BUF, HEAP_HINT, + DYNAMIC_TYPE_TMP_BUFFER)); + eccPrivKeySz = (word32)sizeof_ecc_clikey_der_256; + if (eccPrivKey != NULL) { + XMEMCPY(eccPrivKey, ecc_clikey_der_256, eccPrivKeySz); + } + #else /* File system. */ + ExpectTrue((certFile = XFOPEN(eccClientCert, "rb")) != XBADFILE); + eccCertSz = (word32)FOURK_BUF; + ExpectNotNull(eccCert = (byte*)XMALLOC(FOURK_BUF, HEAP_HINT, + DYNAMIC_TYPE_TMP_BUFFER)); + ExpectTrue((eccCertSz = (word32)XFREAD(eccCert, 1, eccCertSz, + certFile)) > 0); + if (certFile != XBADFILE) { + XFCLOSE(certFile); + } + ExpectTrue((keyFile = XFOPEN(eccClientKey, "rb")) != XBADFILE); + eccPrivKeySz = (word32)FOURK_BUF; + ExpectNotNull(eccPrivKey = (byte*)XMALLOC(FOURK_BUF, HEAP_HINT, + DYNAMIC_TYPE_TMP_BUFFER)); + ExpectTrue((eccPrivKeySz = (word32)XFREAD(eccPrivKey, 1, eccPrivKeySz, + keyFile)) > 0); + if (keyFile != XBADFILE) { + XFCLOSE(keyFile); + } + #endif /* USE_CERT_BUFFERS_256 */ + + ExpectNotNull(pkcs7 = wc_PKCS7_New(HEAP_HINT, testDevId)); + ExpectIntEQ(wc_PKCS7_InitWithCert(pkcs7, eccCert, eccCertSz), 0); + if (pkcs7 != NULL) { + pkcs7->content = (byte*)input; + pkcs7->contentSz = sizeof(input); + pkcs7->contentOID = DATA; + pkcs7->encryptOID = AES256CBCb; + pkcs7->keyWrapOID = AES256_WRAP; + pkcs7->keyAgreeOID = dhSinglePass_stdDH_sha256kdf_scheme; + pkcs7->privateKey = eccPrivKey; + pkcs7->privateKeySz = eccPrivKeySz; + pkcs7->singleCert = eccCert; + pkcs7->singleCertSz = (word32)eccCertSz; +#ifdef ECC_TIMING_RESISTANT + XMEMSET(&rng, 0, sizeof(WC_RNG)); + ExpectIntEQ(wc_InitRng(&rng), 0); + pkcs7->rng = &rng; +#endif + } + + /* Test custom AES key wrap/unwrap callback */ + ExpectIntEQ(wc_PKCS7_SetAESKeyWrapCb(pkcs7, testAESKeyWrapCb), 0); + ExpectIntEQ(wc_PKCS7_SetAESKeyUnwrapCb(pkcs7, testAESKeyUnwrapCb), 0); + + ExpectIntGE(wc_PKCS7_EncodeEnvelopedData(pkcs7, output, + (word32)sizeof(output)), 0); + + decodedSz = wc_PKCS7_DecodeEnvelopedData(pkcs7, output, + (word32)sizeof(output), decoded, (word32)sizeof(decoded)); + ExpectIntGE(decodedSz, 0); + /* Verify the size of each buffer. */ + ExpectIntEQ((word32)sizeof(input)/sizeof(char), decodedSz); + + ExpectIntEQ(wasAESKeyWrapCbCalled, 1); + ExpectIntEQ(wasAESKeyUnwrapCbCalled, 1); + + wc_PKCS7_Free(pkcs7); + pkcs7 = NULL; + XFREE(eccCert, HEAP_HINT, DYNAMIC_TYPE_TMP_BUFFER); + XFREE(eccPrivKey, HEAP_HINT, DYNAMIC_TYPE_TMP_BUFFER); +#ifdef ECC_TIMING_RESISTANT + DoExpectIntEQ(wc_FreeRng(&rng), 0); +#endif +#endif + return EXPECT_RESULT(); +} + /* * Testing wc_PKCS7_EncodeEncryptedData() */ @@ -67781,6 +67916,7 @@ TEST_CASE testCases[] = { TEST_DECL(test_wc_PKCS7_VerifySignedData_ECC), TEST_DECL(test_wc_PKCS7_DecodeEnvelopedData_stream), TEST_DECL(test_wc_PKCS7_EncodeDecodeEnvelopedData), + TEST_DECL(test_wc_PKCS7_SetAESKeyWrapUnwrapCb), TEST_DECL(test_wc_PKCS7_EncodeEncryptedData), TEST_DECL(test_wc_PKCS7_DecodeEncryptedKeyPackage), TEST_DECL(test_wc_PKCS7_Degenerate), diff --git a/wolfcrypt/src/pkcs7.c b/wolfcrypt/src/pkcs7.c index b2d2bd9f6..6b134e25f 100644 --- a/wolfcrypt/src/pkcs7.c +++ b/wolfcrypt/src/pkcs7.c @@ -6814,8 +6814,9 @@ static int PKCS7_GenerateContentEncryptionKey(wc_PKCS7* pkcs7, word32 len) } -/* wrap CEK (content encryption key) with KEK, 0 on success, < 0 on error */ -static int wc_PKCS7_KeyWrap(byte* cek, word32 cekSz, byte* kek, +/* wrap CEK (content encryption key) with KEK, returns output size (> 0) on + * success, < 0 on error */ +static int wc_PKCS7_KeyWrap(wc_PKCS7 * pkcs7, byte* cek, word32 cekSz, byte* kek, word32 kekSz, byte* out, word32 outSz, int keyWrapAlgo, int direction) { @@ -6837,14 +6838,24 @@ static int wc_PKCS7_KeyWrap(byte* cek, word32 cekSz, byte* kek, #endif if (direction == AES_ENCRYPTION) { - - ret = wc_AesKeyWrap(kek, kekSz, cek, cekSz, - out, outSz, NULL); + if (pkcs7->aesKeyWrapCb != NULL) { + ret = pkcs7->aesKeyWrapCb(kek, kekSz, cek, cekSz, + out, outSz); + } + else { + ret = wc_AesKeyWrap(kek, kekSz, cek, cekSz, + out, outSz, NULL); + } } else if (direction == AES_DECRYPTION) { - - ret = wc_AesKeyUnWrap(kek, kekSz, cek, cekSz, - out, outSz, NULL); + if (pkcs7->aesKeyUnwrapCb != NULL) { + ret = pkcs7->aesKeyUnwrapCb(kek, kekSz, cek, cekSz, + out, outSz); + } + else { + ret = wc_AesKeyUnWrap(kek, kekSz, cek, cekSz, + out, outSz, NULL); + } } else { WOLFSSL_MSG("Bad key un/wrap direction"); return BAD_FUNC_ARG; @@ -7548,7 +7559,7 @@ int wc_PKCS7_AddRecipient_KARI(wc_PKCS7* pkcs7, const byte* cert, word32 certSz, } /* encrypt CEK with KEK */ - keySz = wc_PKCS7_KeyWrap(pkcs7->cek, pkcs7->cekSz, kari->kek, + keySz = wc_PKCS7_KeyWrap(pkcs7, pkcs7->cek, pkcs7->cekSz, kari->kek, kari->kekSz, encryptedKey, encryptedKeySz, keyWrapOID, direction); if (keySz <= 0) { @@ -9630,9 +9641,8 @@ int wc_PKCS7_AddRecipient_KEKRI(wc_PKCS7* pkcs7, int keyWrapOID, byte* kek, direction = DES_ENCRYPTION; #endif - encryptedKeySz = wc_PKCS7_KeyWrap(pkcs7->cek, pkcs7->cekSz, kek, kekSz, - encryptedKey, (word32)encryptedKeySz, keyWrapOID, - direction); + encryptedKeySz = wc_PKCS7_KeyWrap(pkcs7, pkcs7->cek, pkcs7->cekSz, kek, + kekSz, encryptedKey, (word32)encryptedKeySz, keyWrapOID, direction); if (encryptedKeySz < 0) { #ifdef WOLFSSL_SMALL_STACK XFREE(encryptedKey, pkcs7->heap, DYNAMIC_TYPE_PKCS7); @@ -11082,6 +11092,30 @@ int wc_PKCS7_SetWrapCEKCb(wc_PKCS7* pkcs7, CallbackWrapCEK cb) return 0; } + +/* return 0 on success */ +int wc_PKCS7_SetAESKeyWrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrap aesKeyWrapCb) +{ + if (pkcs7 == NULL) + return BAD_FUNC_ARG; + + pkcs7->aesKeyWrapCb = aesKeyWrapCb; + + return 0; +} + + +/* return 0 on success */ +int wc_PKCS7_SetAESKeyUnwrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrap aesKeyUnwrapCb) +{ + if (pkcs7 == NULL) + return BAD_FUNC_ARG; + + pkcs7->aesKeyUnwrapCb = aesKeyUnwrapCb; + + return 0; +} + /* Decrypt ASN.1 OtherRecipientInfo (ori), as defined by: * * OtherRecipientInfo ::= SEQUENCE { @@ -11529,10 +11563,9 @@ static int wc_PKCS7_DecryptKekri(wc_PKCS7* pkcs7, byte* in, word32 inSz, (int)PKCS7_KEKRI, direction); } else { - keySz = wc_PKCS7_KeyWrap(pkiMsg + *idx, (word32)length, - pkcs7->privateKey, pkcs7->privateKeySz, - decryptedKey, *decryptedKeySz, - (int)keyWrapOID, direction); + keySz = wc_PKCS7_KeyWrap(pkcs7, pkiMsg + *idx, (word32)length, + pkcs7->privateKey, pkcs7->privateKeySz, decryptedKey, + *decryptedKeySz, (int)keyWrapOID, direction); } if (keySz <= 0) return keySz; @@ -11795,9 +11828,10 @@ static int wc_PKCS7_DecryptKari(wc_PKCS7* pkcs7, byte* in, word32 inSz, } /* decrypt CEK with KEK */ - keySz = wc_PKCS7_KeyWrap(encryptedKey, (word32)encryptedKeySz, - kari->kek, kari->kekSz, decryptedKey, *decryptedKeySz, - (int)keyWrapOID, direction); + keySz = wc_PKCS7_KeyWrap(pkcs7, encryptedKey, + (word32)encryptedKeySz, kari->kek, kari->kekSz, + decryptedKey, *decryptedKeySz, (int)keyWrapOID, + direction); } if (keySz <= 0) { wc_PKCS7_KariFree(kari); diff --git a/wolfssl/wolfcrypt/pkcs7.h b/wolfssl/wolfcrypt/pkcs7.h index 9248eddac..f7f22a691 100644 --- a/wolfssl/wolfcrypt/pkcs7.h +++ b/wolfssl/wolfcrypt/pkcs7.h @@ -213,6 +213,8 @@ typedef int (*CallbackWrapCEK)(wc_PKCS7* pkcs7, byte* cek, word32 cekSz, byte* originKey, word32 originKeySz, byte* out, word32 outSz, int keyWrapAlgo, int type, int dir); +typedef int (*CallbackAESKeyWrap)(const byte* key, word32 keySz, + const byte* in, word32 inSz, byte* out, word32 outSz); /* Callbacks for supporting different stream cases */ typedef int (*CallbackGetContent)(wc_PKCS7* pkcs7, byte** content, void* ctx); @@ -371,6 +373,9 @@ struct wc_PKCS7 { } decryptKey; #endif + CallbackAESKeyWrap aesKeyWrapCb; + CallbackAESKeyWrap aesKeyUnwrapCb; + /* !! NEW DATA MEMBERS MUST BE ADDED AT END !! */ }; @@ -498,6 +503,10 @@ WOLFSSL_API int wc_PKCS7_AddRecipient_ORI(wc_PKCS7* pkcs7, CallbackOriEncrypt c int options); WOLFSSL_API int wc_PKCS7_SetWrapCEKCb(wc_PKCS7* pkcs7, CallbackWrapCEK wrapCEKCb); +WOLFSSL_API int wc_PKCS7_SetAESKeyWrapCb(wc_PKCS7* pkcs7, + CallbackAESKeyWrap aesKeyWrapCb); +WOLFSSL_API int wc_PKCS7_SetAESKeyUnwrapCb(wc_PKCS7* pkcs7, + CallbackAESKeyWrap aesKeyUnwrapCb); #if defined(HAVE_PKCS7_RSA_RAW_SIGN_CALLBACK) && !defined(NO_RSA) WOLFSSL_API int wc_PKCS7_SetRsaSignRawDigestCb(wc_PKCS7* pkcs7,