Merge pull request #9002 from holtrop/aes-key-wrap-callbacks

Add callback functions for custom AES key wrap/unwrap operations
This commit is contained in:
David Garske
2025-07-23 16:01:49 -07:00
committed by GitHub
4 changed files with 212 additions and 22 deletions

View File

@@ -1,3 +1,23 @@
/*!
\ingroup PKCS7
\brief Callback used for a custom AES key wrap/unwrap operation.
\return The size of the wrapped/unwrapped key written to the output buffer
should be returned on success. A 0 return value or error code (< 0)
indicates a failure.
\param[in] key Specify the key to use.
\param[in] keySz Size of the key to use.
\param[in] in Specify the input data to wrap/unwrap.
\param[in] inSz Size of the input data.
\param[in] wrap 1 if the requested operation is a key wrap, 0 for unwrap.
\param[out] out Specify the output buffer.
\param[out] outSz Size of the output buffer.
*/
typedef int (*CallbackAESKeyWrapUnwrap)(const byte* key, word32 keySz,
const byte* in, word32 inSz, int wrap, byte* out, word32 outSz);
/*!
\ingroup PKCS7
@@ -477,6 +497,21 @@ int wc_PKCS7_VerifySignedData_ex(PKCS7* pkcs7, const byte* hashBuf,
word32 hashSz, byte* pkiMsgHead, word32 pkiMsgHeadSz, byte* pkiMsgFoot,
word32 pkiMsgFootSz);
/*!
\ingroup PKCS7
\brief Set the callback function to be used to perform a custom AES key
wrap/unwrap operation.
\retval 0 Callback function was set successfully
\retval BAD_FUNC_ARG Parameter pkcs7 is NULL
\param pkcs7 pointer to the PKCS7 structure
\param aesKeyWrapCb pointer to custom AES key wrap/unwrap function
*/
int wc_PKCS7_SetAESKeyWrapUnwrapCb(wc_PKCS7* pkcs7,
CallbackAESKeyWrapUnwrap aesKeyWrapCb);
/*!
\ingroup PKCS7

View File

@@ -18179,6 +18179,131 @@ static int test_wc_PKCS7_EncodeDecodeEnvelopedData(void)
} /* END test_wc_PKCS7_EncodeDecodeEnvelopedData() */
#if defined(HAVE_PKCS7) && defined(HAVE_ECC) && !defined(NO_SHA256) && defined(WOLFSSL_AES_256)
static int wasAESKeyWrapCbCalled = 0;
static int wasAESKeyUnwrapCbCalled = 0;
static int testAESKeyWrapUnwrapCb(const byte* key, word32 keySz,
const byte* in, word32 inSz, int wrap, byte* out, word32 outSz)
{
(void)key;
(void)keySz;
(void)wrap;
if (wrap)
wasAESKeyWrapCbCalled = 1;
else
wasAESKeyUnwrapCbCalled = 1;
XMEMSET(out, 0xEE, outSz);
if (inSz <= outSz) {
XMEMCPY(out, in, inSz);
}
return inSz;
}
#endif
/*
* Test custom AES key wrap/unwrap callback
*/
static int test_wc_PKCS7_SetAESKeyWrapUnwrapCb(void)
{
EXPECT_DECLS;
#if defined(HAVE_PKCS7) && defined(HAVE_ECC) && !defined(NO_SHA256) && defined(WOLFSSL_AES_256)
static const char input[] = "Test input for AES key wrapping";
PKCS7 * pkcs7 = NULL;
byte * eccCert = NULL;
byte * eccPrivKey = NULL;
word32 eccCertSz = 0;
word32 eccPrivKeySz = 0;
byte output[ONEK_BUF];
byte decoded[sizeof(input)/sizeof(char)];
int decodedSz = 0;
#ifdef ECC_TIMING_RESISTANT
WC_RNG rng;
#endif
/* Load test certs */
#ifdef USE_CERT_BUFFERS_256
ExpectNotNull(eccCert = (byte*)XMALLOC(TWOK_BUF, HEAP_HINT,
DYNAMIC_TYPE_TMP_BUFFER));
/* Init buffer. */
eccCertSz = (word32)sizeof_cliecc_cert_der_256;
if (eccCert != NULL) {
XMEMCPY(eccCert, cliecc_cert_der_256, eccCertSz);
}
ExpectNotNull(eccPrivKey = (byte*)XMALLOC(TWOK_BUF, HEAP_HINT,
DYNAMIC_TYPE_TMP_BUFFER));
eccPrivKeySz = (word32)sizeof_ecc_clikey_der_256;
if (eccPrivKey != NULL) {
XMEMCPY(eccPrivKey, ecc_clikey_der_256, eccPrivKeySz);
}
#else /* File system. */
ExpectTrue((certFile = XFOPEN(eccClientCert, "rb")) != XBADFILE);
eccCertSz = (word32)FOURK_BUF;
ExpectNotNull(eccCert = (byte*)XMALLOC(FOURK_BUF, HEAP_HINT,
DYNAMIC_TYPE_TMP_BUFFER));
ExpectTrue((eccCertSz = (word32)XFREAD(eccCert, 1, eccCertSz,
certFile)) > 0);
if (certFile != XBADFILE) {
XFCLOSE(certFile);
}
ExpectTrue((keyFile = XFOPEN(eccClientKey, "rb")) != XBADFILE);
eccPrivKeySz = (word32)FOURK_BUF;
ExpectNotNull(eccPrivKey = (byte*)XMALLOC(FOURK_BUF, HEAP_HINT,
DYNAMIC_TYPE_TMP_BUFFER));
ExpectTrue((eccPrivKeySz = (word32)XFREAD(eccPrivKey, 1, eccPrivKeySz,
keyFile)) > 0);
if (keyFile != XBADFILE) {
XFCLOSE(keyFile);
}
#endif /* USE_CERT_BUFFERS_256 */
ExpectNotNull(pkcs7 = wc_PKCS7_New(HEAP_HINT, testDevId));
ExpectIntEQ(wc_PKCS7_InitWithCert(pkcs7, eccCert, eccCertSz), 0);
if (pkcs7 != NULL) {
pkcs7->content = (byte*)input;
pkcs7->contentSz = sizeof(input);
pkcs7->contentOID = DATA;
pkcs7->encryptOID = AES256CBCb;
pkcs7->keyWrapOID = AES256_WRAP;
pkcs7->keyAgreeOID = dhSinglePass_stdDH_sha256kdf_scheme;
pkcs7->privateKey = eccPrivKey;
pkcs7->privateKeySz = eccPrivKeySz;
pkcs7->singleCert = eccCert;
pkcs7->singleCertSz = (word32)eccCertSz;
#ifdef ECC_TIMING_RESISTANT
XMEMSET(&rng, 0, sizeof(WC_RNG));
ExpectIntEQ(wc_InitRng(&rng), 0);
pkcs7->rng = &rng;
#endif
}
/* Test custom AES key wrap/unwrap callback */
ExpectIntEQ(wc_PKCS7_SetAESKeyWrapUnwrapCb(pkcs7, testAESKeyWrapUnwrapCb), 0);
ExpectIntGE(wc_PKCS7_EncodeEnvelopedData(pkcs7, output,
(word32)sizeof(output)), 0);
decodedSz = wc_PKCS7_DecodeEnvelopedData(pkcs7, output,
(word32)sizeof(output), decoded, (word32)sizeof(decoded));
ExpectIntGE(decodedSz, 0);
/* Verify the size of each buffer. */
ExpectIntEQ((word32)sizeof(input)/sizeof(char), decodedSz);
ExpectIntEQ(wasAESKeyWrapCbCalled, 1);
ExpectIntEQ(wasAESKeyUnwrapCbCalled, 1);
wc_PKCS7_Free(pkcs7);
pkcs7 = NULL;
XFREE(eccCert, HEAP_HINT, DYNAMIC_TYPE_TMP_BUFFER);
XFREE(eccPrivKey, HEAP_HINT, DYNAMIC_TYPE_TMP_BUFFER);
#ifdef ECC_TIMING_RESISTANT
DoExpectIntEQ(wc_FreeRng(&rng), 0);
#endif
#endif
return EXPECT_RESULT();
}
/*
* Testing wc_PKCS7_EncodeEncryptedData()
*/
@@ -68016,6 +68141,7 @@ TEST_CASE testCases[] = {
TEST_DECL(test_wc_PKCS7_VerifySignedData_ECC),
TEST_DECL(test_wc_PKCS7_DecodeEnvelopedData_stream),
TEST_DECL(test_wc_PKCS7_EncodeDecodeEnvelopedData),
TEST_DECL(test_wc_PKCS7_SetAESKeyWrapUnwrapCb),
TEST_DECL(test_wc_PKCS7_EncodeEncryptedData),
TEST_DECL(test_wc_PKCS7_DecodeEncryptedKeyPackage),
TEST_DECL(test_wc_PKCS7_Degenerate),

View File

@@ -6814,14 +6814,15 @@ static int PKCS7_GenerateContentEncryptionKey(wc_PKCS7* pkcs7, word32 len)
}
/* wrap CEK (content encryption key) with KEK, 0 on success, < 0 on error */
static int wc_PKCS7_KeyWrap(byte* cek, word32 cekSz, byte* kek,
word32 kekSz, byte* out, word32 outSz,
int keyWrapAlgo, int direction)
/* wrap CEK (content encryption key) with KEK, returns output size (> 0) on
* success, < 0 on error */
static int wc_PKCS7_KeyWrap(const wc_PKCS7 * pkcs7, const byte * cek,
word32 cekSz, const byte * kek, word32 kekSz, byte * out, word32 outSz,
int keyWrapAlgo, int direction)
{
int ret = 0;
if (cek == NULL || kek == NULL || out == NULL)
if (pkcs7 == NULL || cek == NULL || kek == NULL || out == NULL)
return BAD_FUNC_ARG;
switch (keyWrapAlgo) {
@@ -6837,14 +6838,24 @@ static int wc_PKCS7_KeyWrap(byte* cek, word32 cekSz, byte* kek,
#endif
if (direction == AES_ENCRYPTION) {
ret = wc_AesKeyWrap(kek, kekSz, cek, cekSz,
out, outSz, NULL);
if (pkcs7->aesKeyWrapUnwrapCb != NULL) {
ret = pkcs7->aesKeyWrapUnwrapCb(kek, kekSz, cek, cekSz, 1,
out, outSz);
}
else {
ret = wc_AesKeyWrap(kek, kekSz, cek, cekSz,
out, outSz, NULL);
}
} else if (direction == AES_DECRYPTION) {
ret = wc_AesKeyUnWrap(kek, kekSz, cek, cekSz,
out, outSz, NULL);
if (pkcs7->aesKeyWrapUnwrapCb != NULL) {
ret = pkcs7->aesKeyWrapUnwrapCb(kek, kekSz, cek, cekSz, 0,
out, outSz);
}
else {
ret = wc_AesKeyUnWrap(kek, kekSz, cek, cekSz,
out, outSz, NULL);
}
} else {
WOLFSSL_MSG("Bad key un/wrap direction");
return BAD_FUNC_ARG;
@@ -7548,7 +7559,7 @@ int wc_PKCS7_AddRecipient_KARI(wc_PKCS7* pkcs7, const byte* cert, word32 certSz,
}
/* encrypt CEK with KEK */
keySz = wc_PKCS7_KeyWrap(pkcs7->cek, pkcs7->cekSz, kari->kek,
keySz = wc_PKCS7_KeyWrap(pkcs7, pkcs7->cek, pkcs7->cekSz, kari->kek,
kari->kekSz, encryptedKey, encryptedKeySz,
keyWrapOID, direction);
if (keySz <= 0) {
@@ -9630,9 +9641,8 @@ int wc_PKCS7_AddRecipient_KEKRI(wc_PKCS7* pkcs7, int keyWrapOID, byte* kek,
direction = DES_ENCRYPTION;
#endif
encryptedKeySz = wc_PKCS7_KeyWrap(pkcs7->cek, pkcs7->cekSz, kek, kekSz,
encryptedKey, (word32)encryptedKeySz, keyWrapOID,
direction);
encryptedKeySz = wc_PKCS7_KeyWrap(pkcs7, pkcs7->cek, pkcs7->cekSz, kek,
kekSz, encryptedKey, (word32)encryptedKeySz, keyWrapOID, direction);
if (encryptedKeySz < 0) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(encryptedKey, pkcs7->heap, DYNAMIC_TYPE_PKCS7);
@@ -11082,6 +11092,19 @@ int wc_PKCS7_SetWrapCEKCb(wc_PKCS7* pkcs7, CallbackWrapCEK cb)
return 0;
}
/* return 0 on success */
int wc_PKCS7_SetAESKeyWrapUnwrapCb(wc_PKCS7* pkcs7, CallbackAESKeyWrapUnwrap aesKeyWrapUnwrapCb)
{
if (pkcs7 == NULL)
return BAD_FUNC_ARG;
pkcs7->aesKeyWrapUnwrapCb = aesKeyWrapUnwrapCb;
return 0;
}
/* Decrypt ASN.1 OtherRecipientInfo (ori), as defined by:
*
* OtherRecipientInfo ::= SEQUENCE {
@@ -11529,10 +11552,9 @@ static int wc_PKCS7_DecryptKekri(wc_PKCS7* pkcs7, byte* in, word32 inSz,
(int)PKCS7_KEKRI, direction);
}
else {
keySz = wc_PKCS7_KeyWrap(pkiMsg + *idx, (word32)length,
pkcs7->privateKey, pkcs7->privateKeySz,
decryptedKey, *decryptedKeySz,
(int)keyWrapOID, direction);
keySz = wc_PKCS7_KeyWrap(pkcs7, pkiMsg + *idx, (word32)length,
pkcs7->privateKey, pkcs7->privateKeySz, decryptedKey,
*decryptedKeySz, (int)keyWrapOID, direction);
}
if (keySz <= 0)
return keySz;
@@ -11795,9 +11817,10 @@ static int wc_PKCS7_DecryptKari(wc_PKCS7* pkcs7, byte* in, word32 inSz,
}
/* decrypt CEK with KEK */
keySz = wc_PKCS7_KeyWrap(encryptedKey, (word32)encryptedKeySz,
kari->kek, kari->kekSz, decryptedKey, *decryptedKeySz,
(int)keyWrapOID, direction);
keySz = wc_PKCS7_KeyWrap(pkcs7, encryptedKey,
(word32)encryptedKeySz, kari->kek, kari->kekSz,
decryptedKey, *decryptedKeySz, (int)keyWrapOID,
direction);
}
if (keySz <= 0) {
wc_PKCS7_KariFree(kari);

View File

@@ -213,6 +213,8 @@ typedef int (*CallbackWrapCEK)(wc_PKCS7* pkcs7, byte* cek, word32 cekSz,
byte* originKey, word32 originKeySz,
byte* out, word32 outSz,
int keyWrapAlgo, int type, int dir);
typedef int (*CallbackAESKeyWrapUnwrap)(const byte* key, word32 keySz,
const byte* in, word32 inSz, int wrap, byte* out, word32 outSz);
/* Callbacks for supporting different stream cases */
typedef int (*CallbackGetContent)(wc_PKCS7* pkcs7, byte** content, void* ctx);
@@ -371,6 +373,8 @@ struct wc_PKCS7 {
} decryptKey;
#endif
CallbackAESKeyWrapUnwrap aesKeyWrapUnwrapCb;
/* !! NEW DATA MEMBERS MUST BE ADDED AT END !! */
};
@@ -498,6 +502,8 @@ WOLFSSL_API int wc_PKCS7_AddRecipient_ORI(wc_PKCS7* pkcs7, CallbackOriEncrypt c
int options);
WOLFSSL_API int wc_PKCS7_SetWrapCEKCb(wc_PKCS7* pkcs7,
CallbackWrapCEK wrapCEKCb);
WOLFSSL_API int wc_PKCS7_SetAESKeyWrapUnwrapCb(wc_PKCS7* pkcs7,
CallbackAESKeyWrapUnwrap aesKeyWrapUnwrapCb);
#if defined(HAVE_PKCS7_RSA_RAW_SIGN_CALLBACK) && !defined(NO_RSA)
WOLFSSL_API int wc_PKCS7_SetRsaSignRawDigestCb(wc_PKCS7* pkcs7,